# Example of building and training a classic Neural Network
In this example, a classic neural network is created using the `neural` framework. This network is then trained on the MNIST data set of hand-written digits.
You can skip training an use pretrained network to see performance results by setting the `usePretrained` variable to `True`.

In [None]:
usePretrained = False

In [None]:
import matplotlib.pyplot as plt

In [None]:
import sys
sys.path.append("..")

In [None]:
import numpy as np
import time

from neural import MNIST, Tensor, nn, optim
from utils import *

## Importing MNIST training data

In [None]:
# Loading training set
allTrainImages, allTrainLabels = MNIST.get("train")
# Images are normalized, all values are in the range [-1, 1]
allTrainImages = normalize(allTrainImages, 0.5, 0.5)

## Defining the Neural Network architecture

In [None]:
class Network(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.logSoftmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        x = nn.ReLU()(self.fc1(x))
        x = nn.ReLU()(self.fc2(x))
        x = self.fc3(x)
        x = self.logSoftmax(x)
        return x
    
model = Network()

## Choosing training criterion (loss function) and optimizer

In [None]:
# Loss function
reduction = "mean"
criterion = nn.NLLLoss(reduction=reduction)

# Optimizer setup
optimizerSetup = dict(
    lr = 0.001,
    betas = (0.9, 0.999),
    eps = 1e-08)

optimizer = optim.Adam(
    model.parameters(),
    **optimizerSetup)

## Training

### Choosing training parameters

In [None]:
epochs = 5
batchSize = 300

In [None]:
numBatches = allTrainImages.shape[0] // batchSize
numTraining = int(numBatches * batchSize)

print(f"Number of epochs: {epochs}")
print(f"Batch size: {batchSize}")
print(f"Total number of train images: {numTraining}")
print(f"Total number of batches: {numBatches}")

# Reshaping training data
trainImages = allTrainImages[:numTraining].reshape(numBatches, -1, allTrainImages.shape[-2], allTrainImages.shape[-1])
trainLabels = allTrainLabels[:numTraining].reshape(numBatches, -1)

### Running epochs

In [None]:
if not usePretrained:
    lossTrack = np.zeros((epochs, numBatches))
    for e in range(epochs):
        startTime = time.time()
        for i, (images, labels) in enumerate(zip(trainImages, trainLabels)):
            images = images.reshape(images.shape[0], -1)
            optimizer.zeroGrad()
            out = model(images)
            loss = criterion(out, labels)
            loss.backward()
            optimizer.step()
            lossTrack[e, i] = loss.item()
        else:
            endTime = time.time()
            print(f"Finished epoch {e} in {endTime - startTime:.2f}s")
else:
    model = nn.Module.load("classic.pkl")

In [None]:
if not usePretrained:
    legend = "\n".join([f"{k} = {v}" for k,v in optimizerSetup.items()])
    plotLossTrack([(lossTrack, batchSize, legend)])

### Saving the trained model

In [None]:
if not usePretrained:
    nn.Module.save(model, "classic.pkl")

Saved module can be loaded with
```python
model = nn.Module.load("classic.pkl")
```

## Performance evaluation

In [None]:
# Image iterator
imgIter = iter(allTrainImages)

Run the cell bellow multiple times to check model performance for different images.

In [None]:
img = next(imgIter)
img_ = img.reshape(1, -1)

logps = model(img_)
# Output of the network are log-probabilities
ps = np.exp(logps)

showMNIST(img.squeeze(), ps.squeeze())