In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from tqdm import tqdm

from alexnet import AlexNet

In [3]:
mnist_tr = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = MNIST('./data', train=True, transform=mnist_tr, download=True)
testset = MNIST('./data', train=False, transform=mnist_tr, download=False)

train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
test_loader = DataLoader(trainset, batch_size=64, shuffle=False)

100%|██████████| 9.91M/9.91M [00:03<00:00, 2.82MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 115kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 986kB/s] 
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [14]:
model = AlexNet(num_classes=10, in_channels=1).to(device)

In [17]:
optimizer = optim.AdamW(model.parameters())
criterion = nn.CrossEntropyLoss()

In [19]:
epochs = 1

for epoch in range(1, epochs+1):
    model.train()

    total_loss = 0
    for X, Y in tqdm(train_loader):
        X, Y = X.to(device), Y.to(device)
        logits = model(X)
        loss = criterion(logits, Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Loss for {epoch=} is {total_loss/len(train_loader)}.")


100%|██████████| 938/938 [1:14:34<00:00,  4.77s/it]

Loss for epoch=1 is 0.2963668534967568.





In [20]:
torch.save(model.state_dict(), 'alexnet_mnist.pth')

In [21]:
from evaluate import load
accuracy = load('accuracy')

In [22]:
model.eval()
for X, Y in tqdm(test_loader):
    X, Y = X.to(device), Y.to(device)
    with torch.no_grad():
        logits = model(X)
    preds = logits.argmax(dim=-1)

    accuracy.add_batch(predictions=preds, references=Y)

acc = accuracy.compute()

100%|██████████| 938/938 [25:33<00:00,  1.63s/it]  


In [23]:
acc

{'accuracy': 0.9779333333333333}

In [27]:
print("Actual:".rjust(10), Y.tolist())
print("Predicted:", preds.tolist())

   Actual: [5, 9, 2, 2, 0, 9, 2, 4, 6, 7, 3, 1, 3, 6, 6, 2, 1, 2, 6, 0, 7, 8, 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]
Predicted: [5, 9, 2, 2, 0, 9, 2, 4, 6, 7, 3, 1, 3, 6, 6, 2, 1, 2, 6, 0, 7, 8, 9, 2, 9, 5, 1, 8, 3, 5, 6, 8]
