In [1]:
import torch
from torch import nn
import torchvision
from torch.utils.data.dataloader import DataLoader

In [2]:
device = torch.device("cuda:0")

In [3]:
# Used to transform the data to a transor
transform = torchvision.transforms.Compose(
    [
        # Transform to a tensor
        torchvision.transforms.ToTensor(),
    ]
)

In [4]:
# Load in MNIST
MNIST_dataset = torchvision.datasets.MNIST("./", train=True, transform=transform, download=True)

In [12]:
# Used to load in the dataset
data_loader = DataLoader(MNIST_dataset, batch_size=256,
        pin_memory=True, num_workers=1, 
        drop_last=False, shuffle=True
    )

In [23]:
# Model with 1x28x28 input and 10 output
class Model(nn.Module):
    def __init__(self, device):
        super(Model, self).__init__()
        
        # convolution layers
        self.convs = nn.Sequential( # 1x28x28
            nn.Conv2d(1, 32, 5), # 32x24x24
            nn.ReLU(),
            
            nn.Conv2d(32, 32, 5), # 32x20x20
            nn.ReLU(),
            nn.MaxPool2d(2), # 64x10x10
            
            nn.Conv2d(32, 64, 5), # 64x6x6
            nn.ReLU(),
            nn.MaxPool2d(2), # 64x3x3
            
            nn.Flatten(1, -1), # 3*3*64
            nn.Linear(3*3*64, 256), # 256
            nn.ReLU(),
            nn.Linear(256, 10), # 10
            nn.LogSoftmax(-1)
        ).to(device)
        
    def forward(self, X):
        return self.convs(X)

In [24]:
# Create the model
model = Model(device)

In [25]:
# Optimizer
optim = torch.optim.AdamW(model.parameters())

In [26]:
# Loss function
loss_funct = nn.CrossEntropyLoss()

In [27]:
# Training loop
epochs = 10
steps = 0
for epoch in range(0, epochs):
    # Iterate over all data
    for X,labels in data_loader:
        # Send the data through the model
        y_hat = model(X.to(device))
        
        # Get the loss
        loss = loss_funct(y_hat, labels.to(device))
        
        # Backprop the loss
        loss.backward()
        
        # Update model
        optim.step()
        optim.zero_grad()
        steps += 1
    print(f"Epoch {epoch}: {loss.detach().item()}")

Epoch 0: 0.08258404582738876
Epoch 1: 0.04608510434627533
Epoch 2: 0.05246753990650177
Epoch 3: 0.006525761913508177
Epoch 4: 0.011287428438663483
Epoch 5: 0.03101402521133423
Epoch 6: 0.05616789683699608
Epoch 7: 0.017455078661441803
Epoch 8: 0.0007231218623928726
Epoch 9: 0.0005497061065398157
