In [10]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from tqdm import tqdm

In [11]:
training_data = datasets.MNIST(
    root = "data",
    train = True,
    download = True,
    transform = ToTensor(),
)

test_data = datasets.MNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor(),
)

In [12]:
batch_size = 64

training_dataloader = DataLoader(training_data, batch_size= batch_size)
test_dataloader = DataLoader(test_data, batch_size= batch_size)

for X, y in test_dataloader:
    print(f"Shape of X: [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.type}")
    break

Shape of X: [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) <built-in method type of Tensor object at 0x7f3aab3014e0>


In [13]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [14]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, X):
        x = self.flatten(X)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)


In [15]:
epochs = 15

learning_rate = 1e-3

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [16]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()

    for X, y in tqdm(dataloader):
        #X, y = X.to(device), y.to(device)

        # predicition 
        pred = model(X)
        loss = loss_fn(pred, y)

        # backprod
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [17]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    model.eval()

    correct = 0

    with torch.no_grad():
        
        for X, y in tqdm(dataloader):
            pred = model(X)
            x = (pred.argmax(1) == y)

            x= x.type(torch.float)

            x= x.sum()

            x = x.item()
            
            correct += x
            
    print(f"Accuracy(%): {correct/size*100}")

In [18]:
for i in range(epochs):
    train(training_dataloader, model, loss_fn, optim)
    test(test_dataloader, model, loss_fn)
    print(i)


100%|██████████| 938/938 [00:07<00:00, 127.76it/s]
100%|██████████| 157/157 [00:00<00:00, 223.06it/s]


Accuracy(%): 95.73
0


100%|██████████| 938/938 [00:07<00:00, 126.41it/s]
100%|██████████| 157/157 [00:00<00:00, 242.36it/s]


Accuracy(%): 97.05
1


100%|██████████| 938/938 [00:07<00:00, 129.79it/s]
100%|██████████| 157/157 [00:00<00:00, 244.58it/s]


Accuracy(%): 97.24000000000001
2


100%|██████████| 938/938 [00:07<00:00, 129.44it/s]
100%|██████████| 157/157 [00:00<00:00, 240.15it/s]


Accuracy(%): 97.37
3


100%|██████████| 938/938 [00:07<00:00, 128.99it/s]
100%|██████████| 157/157 [00:00<00:00, 239.44it/s]


Accuracy(%): 96.44
4


100%|██████████| 938/938 [00:07<00:00, 129.95it/s]
100%|██████████| 157/157 [00:00<00:00, 239.43it/s]


Accuracy(%): 97.77
5


100%|██████████| 938/938 [00:07<00:00, 127.15it/s]
100%|██████████| 157/157 [00:00<00:00, 230.47it/s]


Accuracy(%): 97.92
6


100%|██████████| 938/938 [00:08<00:00, 111.93it/s]
100%|██████████| 157/157 [00:00<00:00, 185.51it/s]


Accuracy(%): 97.78999999999999
7


100%|██████████| 938/938 [00:07<00:00, 117.93it/s]
100%|██████████| 157/157 [00:00<00:00, 214.18it/s]


Accuracy(%): 97.78
8


100%|██████████| 938/938 [00:08<00:00, 116.84it/s]
100%|██████████| 157/157 [00:00<00:00, 211.52it/s]


Accuracy(%): 98.22999999999999
9


100%|██████████| 938/938 [00:08<00:00, 116.58it/s]
100%|██████████| 157/157 [00:00<00:00, 207.45it/s]


Accuracy(%): 97.78
10


100%|██████████| 938/938 [00:07<00:00, 119.96it/s]
100%|██████████| 157/157 [00:00<00:00, 201.39it/s]


Accuracy(%): 98.19
11


100%|██████████| 938/938 [00:08<00:00, 116.29it/s]
100%|██████████| 157/157 [00:00<00:00, 212.44it/s]


Accuracy(%): 98.03
12


100%|██████████| 938/938 [00:09<00:00, 102.36it/s]
100%|██████████| 157/157 [00:00<00:00, 205.07it/s]


Accuracy(%): 98.11
13


100%|██████████| 938/938 [00:08<00:00, 111.75it/s]
100%|██████████| 157/157 [00:00<00:00, 171.74it/s]

Accuracy(%): 98.11999999999999
14





In [19]:
torch.save(model, f"./models/model_adam_epochs:{epochs}.pth")
print("Saving full model")

Saving full model
