In [1]:
from torch import nn, save, load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
%%capture
train = datasets.MNIST(root='/content', download=True, 
                        train=True, transform=ToTensor())

dataset = DataLoader(train, 32)

In [3]:
class ImageClassifier(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1, 32, (3,3)), # Input Chanel, Shape, Kernel
        nn.ReLU(),
        nn.Conv2d(32, 64, (3,3)),
        nn.ReLU(),
        nn.Conv2d(64, 64, (3,3)),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64 * (28-6) * (28-6), 10)

    )

  def forward(self, x):
    return self.model(x)

In [4]:
clf = ImageClassifier().to('cuda')
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [5]:
if __name__ == '__main__':
  for epoch in range(10):
    for batch in dataset:
      X, y = batch
      X, y = X.to('cuda'), y.to('cuda')
      yhat = clf(X)
      loss = loss_fn(yhat, y)

      # Backpropagation
      opt.zero_grad()
      loss.backward()
      opt.step()

    # Visualization
    print(f"Epoch: {epoch} loss is {loss.item()}")

  with open('model_state.pt', 'wb') as f:
    save(clf.state_dict(), f)

Epoch: 0 loss is 0.0213587898761034
Epoch: 1 loss is 0.0042461068369448185
Epoch: 2 loss is 0.00234182458370924
Epoch: 3 loss is 0.000462794560007751
Epoch: 4 loss is 0.00015365681611001492
Epoch: 5 loss is 0.0014855499612167478
Epoch: 6 loss is 6.202309123182204e-06
Epoch: 7 loss is 3.41706327162683e-05
Epoch: 8 loss is 1.4118555782260955e-06
Epoch: 9 loss is 3.356361503392691e-06
