In [31]:
import torch
from torch import nn
import numpy as np

import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

In [32]:
# Downloading the train MNIST dataset
mnist_train = datasets.MNIST(root='./data', # Defining the path
                             download=True, # Downloading the dataset
                             train=True,  # Defining train
                             transform=ToTensor() # Transforming data into tensors
                             )

# Downloading the test MNIST dataset
mnist_test = datasets.MNIST(root='./data', 
                            download=True, 
                            train=False,
                            transform=ToTensor()
                             )

In [33]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(mnist_test, batch_size=32, shuffle=True)

In [34]:
model = nn.Sequential(
    nn.Linear(784, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
for i in range(10):
    loss_sum = 0
    for X, y in train_dataloader:
        X = X.reshape(-1, 784)
        y = nn.functional.one_hot(y, num_classes=10).type(torch.float32)

        optimizer.zero_grad()
        y_pred = model(X)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()

    print(loss_sum)

571.0778968911618
254.76616612635553
177.13995666895062
136.30453493562527
106.59862545889337
85.42660710081691
71.17847316083498
57.68613145544077
48.7996904116153
40.510199388707406


In [37]:
model.eval()
with torch.no_grad():
    accurate = 0
    total = 0
    for X, y in test_dataloader:
        X = X.reshape(-1, 784)

        y_pred = nn.functional.softmax(model(X), dim=1)
        correct_pred = (y == y_pred.max(dim=1).indices)

        total += correct_pred.size(0)
        accurate += correct_pred.type(torch.int).sum().item()

    print(f"Accuracy: {(accurate/total)*100:.2f} %")

Accuracy: 97.74 %
