In [1]:
from torch import nn, save, load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
train = datasets.MNIST(root="data", download=True, train=True, transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:02<00:00, 4.57MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:06<00:00, 244kB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 10.5MB/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [3]:
dataset = DataLoader(train, 32)

In [4]:
# image.shape(1, 28, 28) - classes (0 to 9)
class ImageClassifier(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Conv2d(1, 32, (3,3)),
        nn.ReLU(),
        nn.Conv2d(32, 64, (3,3)),
        nn.ReLU(),
        nn.Conv2d(64, 64, (3,3)),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*(28-6)*(28-6), 10) # every layer shifts by 2 pixels
    )

  def forward(self, x):
    return self.model(x)

In [5]:
# Instance of the neural network, loss, and optimizers

classifier = ImageClassifier().to('cuda')
optimizer = Adam(classifier.parameters(), lr=1e-3)
loss_function = nn.CrossEntropyLoss()

In [7]:
# Trianing Flow

if __name__ == "__main__":
  for epoch in range(10):
    for batch in dataset:
      X, y = batch
      X, y = X.to('cuda'), y.to('cuda')
      yhat = classifier(X)
      loss = loss_function(yhat, y)

      # Apply backprop
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    print(f'Epoch: {epoch} loss is {loss.item()}')

Epoch: 0 loss is 0.013951292261481285
Epoch: 1 loss is 0.0020412798039615154
Epoch: 2 loss is 0.0003558704338502139
Epoch: 3 loss is 0.00014147607726044953
Epoch: 4 loss is 6.818849942646921e-05
Epoch: 5 loss is 0.0002973915543407202
Epoch: 6 loss is 1.7695011820251239e-06
Epoch: 7 loss is 1.583219955136883e-06
Epoch: 8 loss is 2.9073775294818915e-05
Epoch: 9 loss is 7.204234862001613e-06


In [8]:
# save model
with open('model_state.pt', 'wb') as f:
  save(classifier.state_dict(), f)

In [9]:
# load model
if __name__ == "main":
  with open('model_state.pt', 'rb') as f:
    classifier.load_dict(load(f))

In [11]:
# predictions
import torch
from PIL import Image

In [12]:
img = Image.open('img_1.jpg')
img_tensor = ToTensor()(img).unsqueeze(0).to('cuda')

In [13]:
print(torch.argmax(classifier(img_tensor)))

tensor(2, device='cuda:0')
