In [2]:
import torchvision
import torch

In [3]:
torchvision.__version__

'0.18.1+cpu'

In [4]:
torch.__version__

'2.3.1+cpu'

In [5]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Lambda
transform = Compose([
    ToTensor(),
    Lambda(lambda image: image / 255),
    Lambda(lambda image: image.view(784))

])
data_train = MNIST(root="./",download=True,train=True,transform=transform)
data_test = MNIST(root="./",download=True,train=False,transform=transform)

In [6]:
data_train[0][0].shape

torch.Size([784])

In [7]:
data_train[0][1]

5

In [8]:
from torch import nn
from torch import optim
import torch
class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(784,512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU(),
            nn.Linear(512,10),
        )
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.parameters())
    def forward(self,X):
        return self.layers(X)
    def predict(self,X):
        with torch.no_grad():
            return torch.argmax(self.forward(X),axis=-1)
    def fit(self,X,Y):
        self.optimizer.zero_grad()
        y_pred = self.forward(X)
        loss = self.loss(y_pred,Y)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    

In [9]:
mnist_model = MNISTModel()

In [10]:
from torch.utils.data import DataLoader
BATCH_SIZE = 16
dataloader_train = DataLoader(data_train,batch_size=BATCH_SIZE,shuffle=True)
dataloader_test = DataLoader(data_test,batch_size=BATCH_SIZE,shuffle=True)

In [11]:
from tqdm import tqdm
EPOCHS = 5
for i in range(EPOCHS):
    total_loss = 0
    for xs, ys in tqdm(dataloader_train,desc=f"FITTING EPOCH {i}"):
        total_loss += mnist_model.fit(xs,ys)
    total_loss / len(dataloader_train)
    print(f"EPOCH {i}: {total_loss:.4f}")

FITTING EPOCH 0: 100%|██████████| 3750/3750 [00:20<00:00, 186.11it/s]


EPOCH 0: 2173.4513


FITTING EPOCH 1: 100%|██████████| 3750/3750 [00:19<00:00, 187.85it/s]


EPOCH 1: 967.3809


FITTING EPOCH 2: 100%|██████████| 3750/3750 [00:19<00:00, 187.85it/s]


EPOCH 2: 663.3571


FITTING EPOCH 3: 100%|██████████| 3750/3750 [00:20<00:00, 183.63it/s]


EPOCH 3: 502.0361


FITTING EPOCH 4: 100%|██████████| 3750/3750 [00:20<00:00, 181.80it/s]

EPOCH 4: 400.8455





In [13]:
correct = 0
for xs, ys in dataloader_test:
   #print(mnist_model.predict(xs))
   y_pred = mnist_model.predict(xs)
   correct += (ys == y_pred).sum()
acc = (correct / len(dataloader_test) * BATCH_SIZE)/255
print(f"ACCURACY: {acc}") 

ACCURACY: 0.972900390625
