In [63]:
import torchvision
from torch import nn, optim
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Lambda
from torch.utils.data import DataLoader
from tqdm import tqdm

In [64]:
transform = Compose([
    ToTensor(),
    Lambda(lambda image: image.view(784))
])
data_train = MNIST(root="./",download=True,train=True,transform=transform)
data_test = MNIST(root="./",download=True,train=False,transform=transform)

In [65]:
data_train[0][0].size()

torch.Size([784])

In [66]:
class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784,512),
            nn.ReLU(),
            # nn.Linear(512,512),
            # nn.ReLU(),
            nn.Linear(512,10)
        )
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.parameters())

    def forward(self,X):
        return self.layers(X)

    def predict(self,X):
        with torch.no_grad():
            return torch.argmax(self.forward(X),axis=-1)

    def fit(self,X,Y):
        self.optimizer.zero_grad()
        y_pred = self.forward(X)
        loss = self.loss(y_pred,Y)
        loss.backward()
        self.optimizer.step()
        return loss.item()

In [67]:
mnist_model = MNISTModel()

In [68]:
BATCH_SIZE = 16
dataloader_train = DataLoader(data_train,batch_size=BATCH_SIZE,shuffle=True)
dataloader_test = DataLoader(data_test,batch_size=BATCH_SIZE,shuffle=True)

In [69]:
EPOCHS = 5
for i in range(EPOCHS):
    total_loss = 0
    for xs, ys in tqdm(dataloader_train,desc=f"FITTING EPOCH {i}"):
        total_loss += mnist_model.fit(xs,ys)
    total_loss /= len(dataloader_train)
    print(f"EPOCH {i}: {total_loss:.4f}")

FITTING EPOCH 0: 100%|██████████| 3750/3750 [00:15<00:00, 241.54it/s]


EPOCH 0: 0.2003


FITTING EPOCH 1: 100%|██████████| 3750/3750 [00:17<00:00, 219.57it/s]


EPOCH 1: 0.0807


FITTING EPOCH 2: 100%|██████████| 3750/3750 [00:17<00:00, 210.88it/s]


EPOCH 2: 0.0531


FITTING EPOCH 3: 100%|██████████| 3750/3750 [00:17<00:00, 211.95it/s]


EPOCH 3: 0.0393


FITTING EPOCH 4: 100%|██████████| 3750/3750 [00:17<00:00, 210.64it/s]

EPOCH 4: 0.0280





In [73]:
correct = 0
for xs, ys in dataloader_test:
    y_pred = mnist_model.predict(xs)
    correct += (ys == y_pred).sum()
acc = correct / (len(dataloader_test) * BATCH_SIZE)
print(f"ACCURACY: {acc}")

ACCURACY: 0.9790999889373779
