In [10]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("mps")
torch.set_default_device(device)

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True,
                                transform=transforms.Compose([transforms.ToTensor()]))
features = mnist_trainset.data.float().to(device)
labels = mnist_trainset.targets.to(device)

train_size = int(features.shape[0] * 0.85)
train_feat = features[:train_size,:]
train_labels = labels[:train_size:]
test_feat = features[train_size:,:]
test_labels = labels[train_size:]



In [11]:
class Mnist(nn.Module):
    def __init__(self):
        super().__init__()
        # ~97.4% accuracy
        d = 28
        self.dense1 = nn.Linear(d * d, 1024)
        self.dense2 = nn.Linear(1024, 512)
        self.logits = nn.Linear(512, 10)

    def forward(self, x):
        d = torch.flatten(x, 1)
        dd1 = self.dense1(d)
        dd1 = F.sigmoid(dd1)
        dd2 = self.dense2(dd1)
        dd2 = F.sigmoid(dd2)
        logits = self.logits(dd2)
        return F.log_softmax(logits, 1)
    
class MnistConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 3, padding="same")
        self.conv2 = torch.nn.Conv2d(16, 32, 3, padding="same")
        # 28 / 2 / 2 * 64 = 448
        # 99.1%
        n = 28 * 28 // 4 // 4 * 32
        self.dense1 = torch.nn.Linear(n, n)
        self.dense2 = torch.nn.Linear(n, n)
        self.logits = torch.nn.Linear(n, 10)

    def forward(self, x):
        x = x.unsqueeze(1)
        v = self.conv1(x)
        v = F.relu(v)
        v = F.max_pool2d(v, 2)
        v = self.conv2(v)
        v = F.relu(v)
        v = F.max_pool2d(v, 2)
        v = torch.flatten(v, 1)
        v = self.dense1(v)
        v = F.tanh(v)
        v = self.dense2(v)
        v = F.tanh(v)
        v = self.logits(v)
        return F.log_softmax(v)

In [12]:
learning_rate = 1e-4
classifier = MnistConv()
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate, weight_decay=0.1)
batch_size = 2000

for epoch in range(500):
    for i in range((train_feat.shape[0]//batch_size)):
        optimizer.zero_grad()
        batch_samples = train_feat[i * batch_size:(i + 1) * batch_size, :]
        batch_labels = train_labels[i * batch_size:(i + 1) * batch_size]
        # print(batch_samples)
        outputs = classifier(batch_samples)
        loss = F.cross_entropy(outputs, batch_labels)
        loss.backward()
        optimizer.step()
#        print(loss.item())

    predicted_labels = torch.argmax(classifier(test_feat), 1)
    corrects = torch.eq(predicted_labels, test_labels).float().sum()
    print(f'accuracy (epoch {epoch}): {corrects/predicted_labels.shape[0]}')




  return func(*args, **kwargs)


accuracy (epoch 0): 0.9181111454963684
accuracy (epoch 1): 0.9448888897895813
accuracy (epoch 2): 0.9542222619056702
accuracy (epoch 3): 0.9610000252723694
accuracy (epoch 4): 0.9678888916969299
accuracy (epoch 5): 0.9725555777549744
accuracy (epoch 6): 0.975777804851532
accuracy (epoch 7): 0.9786666631698608
accuracy (epoch 8): 0.9811111092567444
accuracy (epoch 9): 0.9821110963821411
accuracy (epoch 10): 0.9824444651603699
accuracy (epoch 11): 0.9818888902664185
accuracy (epoch 12): 0.9816666841506958
accuracy (epoch 13): 0.9822222590446472
accuracy (epoch 14): 0.9825555682182312
accuracy (epoch 15): 0.9819999933242798
accuracy (epoch 16): 0.9819999933242798
accuracy (epoch 17): 0.9819999933242798


KeyboardInterrupt: 