Let's build a simple neural network to classify images from the FashionMNIST dataset.

**1. Import Libraries**

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [3]:
%%capture
!pip install lightning

In [4]:
import lightning as L

*Checking for GPU Availability*

This code checks if a CUDA-enabled GPU is available and sets the `device` accordingly. If no GPU is available, it defaults to the CPU.

In [5]:
# Check if GPU is available
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# if using windows use next code instead
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


**2. Data Preparation**

In [6]:
# Define a transform to convert images to tensors
transform = transforms.ToTensor()

# Download and load the training data
train_set = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)

# Download and load the test data
test_set = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:05<00:00, 4.68MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 199kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 3.78MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 5.03MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



**3. Neural Network Model**

In [13]:
class NN(L.LightningModule):
    def __init__(self, model, learning_rate):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def _common_eval(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_eval(batch, batch_idx)
        self.log('train_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._common_eval(batch, batch_idx)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


In [18]:
def get_model():
    return nn.Sequential(
        nn.Conv2d(
            in_channels=1,
            out_channels=32,
            kernel_size=(3, 3),
        ),
        nn.ReLU(),
        nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3),
        ),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(in_features=24*24*64, out_features=128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(in_features=128, out_features=10)
    )

In [19]:
def evaluate(model, test_loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():  # Disable gradient calculation for evaluation
        for images, labels in test_loader:
            # Move images and labels to the device
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [20]:
import logging
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)


In [21]:
import time

start = time.time()
model = NN(
    model=get_model(),
    learning_rate=1e-4,
).to(device)
trainer = L.Trainer(max_epochs=10, enable_model_summary=False)
trainer.fit(model=model, train_dataloaders=train_loader)
acc = evaluate(model.model.to(device), test_loader)
end = time.time()
print(f'Accuracy: {acc} in {end - start} s', flush=True)


133597082275744


Training: |          | 0/? [00:00<?, ?it/s]

Accuracy: 0.8905 in 85.71896243095398 s
