In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


## download dataset

In [None]:
def download_mnist_datasets():
    train_data = datasets.MNIST(
        root = "data" 
        download=True
        train=True
        transform=ToTensor() # each value is normalized btw 0 and 1
    )

    validation_data = datasets.MNIST(
        root = "data" 
        download=True
        train=False
        transform=ToTensor() # each value is normalized btw 0 and 1
    )
    return train_data, validation_data

In [None]:
train_data, validation_data = downlaod_mnist_datasets()

## data loader

In [None]:
# we want to load data in batches since it allow us to save memory
BATCH_SIZE = 128
train_data_loader = DataLoader(train_data, batch_size = BATCH_SIZE)


## build model

In [None]:
class FeedForwardNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        # images of MNIST are 28 x 28 pixels
        # since we have flatten the images, we have to pass a 1D
        # dimension which will be 28*28 in the first dense layer
        self.dense_layers = nn.Sequential(
            nn.Linear(28*28, 256), # equivalent dense layer in keras
            nn.ReLU(),
            nn.Linear(256, 10) # 10 num of classes
        ) # sequential let us to pack together more layers
        self.softmax = nn.Softmax(dim=1)

    # forward indicates pytorch how manipulate data
    def forward(self, input_data): 
        flattened_data = self.flatten(input_data)
        logits = self.dense_layers(flattened_data)
        predictions = self.softmax(logits)
        return predictions


In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    
feed_forward_net = FeedForwardNet().to(device)


## train

In [None]:
def train_one_epoch(model, data_loader, loss_fn, optimiser, device):
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # calculate loss
        predictions = model(inputs) # pass inputs to the mdoels
        loss = loss_fn(predictions, targets)

        # backpropagate loss and update weights
        optimiser.zero_grad() # at every iteration the optimizer will compute the grad but at each batch we want to reset grad to zero
        loss.backward()
        optimiser.step() #updates the weights

    print(f"Loss: {loss.item()}")


def train(model, data_loader, loss_fn, optimiser, device, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_one_epoch(model, data_loader, loss_fn, optimiser, device)
        print("--------------")
    print("Train is done")

In [None]:

loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(feed_forward_net.parameters(),
                             lr = 0.01)
train(feed_forward_net, train_data_loader, loss_fn, optimiser, device, 10)

In [None]:
torch.save(feed_forward_net.state_dict(), "feedfwnet.pth") # state dict is a dictionary python that has all the important information of layers and parameters


## predictions


In [None]:
class_mapping = [
    "0", # class 0
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9"
]

In [None]:
def predict(model, input, target, class_mapping):
    model.eval() # pytorch model, switches the model to an evaluation mode switching off all the normalizations
    # model.train() # back to train mode
    with torch.no_grad(): # context manager, helpful since it does not cmpute any gradient
        # we dont want to compute the gradient if we are evaluating, but only during training
        predictions = model(input)
        # Tensor obj (1, 10) # 1 sample and 10 classes
        predicted_index = predictions[0].argmax(0)
        predicted = class_mapping[predicted_index]
        expected = class_mapping[target]
    return predicted, expected

In [None]:
validation_data_loader = DataLoader(validation_data, batch_size = BATCH_SIZE)

# get a sample from the validation dataset for inference
input, target = validation_data[0][0], validation_data[0][1]

# make an inference
predicted, expected = predict(feed_forward_net, input, target, class_mapping)
