In [None]:
import numpy as np
import torch
import torchvision
from torch import nn
import datasets
from torch.nn import functional as F
from torch.optim import AdamW
from sklearn.metrics import accuracy_score

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)

mnist_train = torchvision.datasets.MNIST(root="./mnist", train=True,  download=True, transform=transform)
mnist_test  = torchvision.datasets.MNIST(root="./mnist", train=False, download=True, transform=transform)

device = "cuda" if torch.cuda.is_available() else "cpu"

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,    32, (3, 3))
        self.conv2 = nn.Conv2d(32,   64, (3, 3))
        self.max_pool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1   = nn.Linear(9216, 2048)
        self.fc2   = nn.Linear(2048, 10)
    
    def forward(self, x):
        x = self.conv1(x) # B, 1, 28, 28 ==> B, 32, 26, 26
        x = F.relu(x)
        x = self.conv2(x) # B, 32, 26, 26 ==> B, 64, 24, 24
        x = F.relu(x)
        x = self.max_pool(x) # B, 64, 12, 12
        x = self.flatten(x) # B, 9216
        x = self.fc1(x) # B, 2048
        x = self.fc2(x) # B, 10
        return x

batch_size = 256
lr = 3e-3

model = Encoder()
model = model.to(device)
optim = AdamW(model.parameters(), lr=lr)
ce_loss = nn.CrossEntropyLoss()

train_dl = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_dl  = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)

for epoch in range(10):
    for idx, (data, labels) in enumerate(train_dl):
        data = data.to(device)
        labels = labels.to(device)
        optim.zero_grad()
        logits = model(data)
        loss = ce_loss(logits, labels)
        loss.backward()
        optim.step()
        if idx % 100 == 0:
            loss_arr, preds_arr, labels_arr = [], [], []
            for test_data, test_labels in test_dl:
                test_data   = test_data.to(device)
                #test_labels = test_labels.to(device)
                logits = model(test_data)
                preds = torch.argmax(logits, dim=1)
                #preds = preds.to("cpu")
                #acc = accuracy_score(preds.cpu(), test_labels)
                loss_arr.append(loss.item())
                preds_arr.extend(preds.cpu())
                labels_arr.extend(test_labels)
            print(f"test avg loss: {np.mean(loss_arr):.02f} test acc: {accuracy_score(preds_arr, labels_arr)}")

test avg loss: 2.29 test acc: 0.2347
test avg loss: 0.11 test acc: 0.9712
test avg loss: 0.19 test acc: 0.9742
test avg loss: 0.03 test acc: 0.9801
test avg loss: 0.08 test acc: 0.9829
test avg loss: 0.14 test acc: 0.9789
test avg loss: 0.06 test acc: 0.976
test avg loss: 0.06 test acc: 0.9749
test avg loss: 0.07 test acc: 0.9814
test avg loss: 0.02 test acc: 0.984
test avg loss: 0.04 test acc: 0.9843
test avg loss: 0.04 test acc: 0.9819
test avg loss: 0.04 test acc: 0.9822
test avg loss: 0.08 test acc: 0.9773
test avg loss: 0.04 test acc: 0.9831
test avg loss: 0.03 test acc: 0.9801
test avg loss: 0.02 test acc: 0.983
test avg loss: 0.02 test acc: 0.9858
test avg loss: 0.02 test acc: 0.9836
test avg loss: 0.04 test acc: 0.9832
test avg loss: 0.06 test acc: 0.9798
test avg loss: 0.02 test acc: 0.9812
test avg loss: 0.02 test acc: 0.9807
test avg loss: 0.08 test acc: 0.9804
test avg loss: 0.02 test acc: 0.984
test avg loss: 0.05 test acc: 0.9846
test avg loss: 0.02 test acc: 0.9829
test 

In [None]:
pass