In [6]:
# Import dependencies
import torch
from PIL import Image
from torch import nn, save, load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Get data
train = datasets.MNIST(root="data", download=True, train=True, transform=ToTensor())
dataset = DataLoader(train, 32)
#1,28,28 - classes 0-9

# Image Classifier Neural Network
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, (3,3)),
            nn.ReLU(),
            nn.Conv2d(32, 64, (3,3)),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3,3)),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*(28-6)*(28-6), 10)
        )

    def forward(self, x):
        return self.model(x)

# Instance of the neural network, loss, optimizer
# Changed device to 'cpu' if CUDA is not available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clf = ImageClassifier().to(device)
opt = Adam(clf.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Training flow
if __name__ == "__main__":
    for epoch in range(10): # train for 10 epochs
        for batch in dataset:
            X,y = batch
            # Move X and y to the selected device (CPU or GPU)
            X, y = X.to(device), y.to(device)
            yhat = clf(X)
            loss = loss_fn(yhat, y)


            # Apply backprop
            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"Epoch:{epoch} loss is {loss.item()}")

    with open('model_state.pt', 'wb') as f:
        save(clf.state_dict(), f)

    with open('model_state.pt', 'rb') as f:
        clf.load_state_dict(load(f))

Epoch:0 loss is 0.023627182468771935
Epoch:1 loss is 0.0011317746248096228
Epoch:2 loss is 0.0003202045918442309
Epoch:3 loss is 0.0003820345737040043
Epoch:4 loss is 0.00013877030869480222
Epoch:5 loss is 4.081029328517616e-05
Epoch:6 loss is 4.710096982307732e-05
Epoch:7 loss is 6.233017484191805e-05
Epoch:8 loss is 0.0018645740346983075
Epoch:9 loss is 0.001552947680465877


  clf.load_state_dict(load(f))


In [9]:
#img = Image.open('img_3.jpg')
#img_tensor = ToTensor()(img).unsqueeze(0).to('cuda')

#print(torch.argmax(clf(img_tensor)))

# Change the device to 'cpu'
img_tensor = ToTensor()(img).unsqueeze(0).to('cpu')

print(torch.argmax(clf(img_tensor)))

tensor(9)
