Src:
* [Building a Neural Network with PyTorch in 15 Minutes](https://www.youtube.com/watch?v=mozBidd58VQ)

In [1]:
import torch
from torch import nn, save, load
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from PIL import Image


In [2]:
train = MNIST(root="data", download=True, train=True, transform=ToTensor())
dataset = DataLoader(train, 32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.6MB/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 486kB/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.43MB/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.20MB/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [3]:
print(dataset)

<torch.utils.data.dataloader.DataLoader object at 0x7e87aa512650>


In [4]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (3, 3)),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3)),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*(28-6)*(28-6), 10))
    def forward(self, x):
        return self.model(x)

In [5]:
clf = ImageClassifier().to("cuda")
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [6]:
for epoch in range(10):
    for batch in dataset:
        X, y = batch
        X, y = X.to("cuda"), y.to("cuda")
        yhat = clf(X)
        loss = loss_fn(yhat, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"Epoch: {epoch+1} loss is {loss.item()}")

Epoch: 1 loss is 0.04879789054393768
Epoch: 2 loss is 0.0023995456285774708
Epoch: 3 loss is 0.0014650262892246246
Epoch: 4 loss is 0.00042135323747061193
Epoch: 5 loss is 0.0012685342226177454
Epoch: 6 loss is 5.772522126790136e-05
Epoch: 7 loss is 0.0001810850517358631
Epoch: 8 loss is 1.9706237708305707e-06
Epoch: 9 loss is 1.9073006569669815e-06
Epoch: 10 loss is 3.9487997582909884e-07


In [7]:
with open("model_state.pt", "wb") as f:
    save(clf.state_dict(), f)

In [8]:
with open("model_state.pt", "rb") as f:
      clf.load_state_dict(load(f, weights_only=True))

In [10]:
img1 = Image.open("img_1.jpg")
img1_tensor = ToTensor()(img1).unsqueeze(0).to("cuda")
print(torch.argmax(clf(img1_tensor)))

tensor(2, device='cuda:0')


In [11]:
img2 = Image.open("img_2.jpg")
img2_tensor = ToTensor()(img2).unsqueeze(0).to("cuda")
print(torch.argmax(clf(img2_tensor)))

tensor(0, device='cuda:0')


In [12]:
img3 = Image.open("img_3.jpg")
img3_tensor = ToTensor()(img3).unsqueeze(0).to("cuda")
print(torch.argmax(clf(img3_tensor)))

tensor(9, device='cuda:0')
