## Introduction
In this notebook, a simple model is trained while visualizing several graphs
* Line graphs: training loss, validation loss and accuracy as a function of epoch
* Scatter plot: Validation loss for individual validation samples
* A sample image, selectable by clicking a sample on the scatter plot 
* Bar graph: class probabilities for selected sample 

The MNIST dataset is used in this notebook to facilitate fast iteration on plot functionality updates.

While writing this notebook, the functionality developed in [03_architecture.ipynb](./03_architecture.ipynb) was expanded directly in [](../trainplotkit/subplots.py) using the workflow described in [development_workflow.ipynb](./development_workflow.ipynb)

## Imports

In [1]:
import sys, importlib
import torch
from torch import nn, optim, Tensor
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torcheval.metrics import MulticlassAccuracy

sys.path.append('..')
from trainplotkit.layout import place_subplots
from trainplotkit.subplots import PlotGrid, SubPlot, TrainingCurveSP, MetricSP, ValidLossSP

In [2]:
def reload_imports():
    importlib.reload(sys.modules['trainplotkit.layout'])
    importlib.reload(sys.modules['trainplotkit.subplots'])
    global place_subplots, PlotGrid, SubPlot, TrainingCurveSP, MetricSP, ValidLossSP
    from trainplotkit.layout import place_subplots
    from trainplotkit.subplots import PlotGrid, SubPlot, TrainingCurveSP, MetricSP, ValidLossSP

## Data preparation
Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)

In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=64, num_workers=1, pin_memory=True, shuffle=True)
test_loader  = torch.utils.data.DataLoader(dataset2, batch_size=500, num_workers=1, pin_memory=True, shuffle=True)

In [4]:
dataset1[0][0].shape

torch.Size([1, 28, 28])

## Model
Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
class SmallNet(nn.Module):
    def __init__(self):
        super(SmallNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 2)
        self.conv2 = nn.Conv2d(8, 16, 3, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(144, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Training loop
* Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)
* Small edits to add calls to `PlotGrid` methods

In [7]:
def train_epoch(model, device, train_loader, optimizer, pg:PlotGrid):
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pg.after_batch(training=True, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=True)

def test_epoch(model, device, test_loader, pg:PlotGrid):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            pg.after_batch(training=False, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=False)

def fit(n_epochs, model, device, train_loader, test_loader, optimizer, scheduler, pg:PlotGrid):
    for _ in range(n_epochs):
        train_epoch(model, device, train_loader, optimizer, pg)
        test_epoch(model, device, test_loader, pg)
        scheduler.step()
    pg.after_fit()

## Run it all

In [12]:
reload_imports()

In [None]:
device = 'cuda'
model = SmallNet().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# Sub-plots
batch_loss_fn = lambda preds,targs: F.nll_loss(preds,targs,reduction='none')
sps = [TrainingCurveSP(), ValidLossSP(batch_loss_fn), MetricSP("Accuracy", MulticlassAccuracy()), ImageSP(dataset2)]
pg = PlotGrid(num_grid_cols=2, subplots=sps)
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': '9cb66070-ee92-4df4-bc31-afbd154f4b05',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': 'bebefacb-ea25-4a1d-ac1c-b571f82cfaac',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': '9cb5de49-8713-45f6-af2c-3186ed1276ef',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'markers',
      

In [16]:
n_epochs = 5
fit(n_epochs, model, device, train_loader, test_loader, optimizer, scheduler, pg)

In [15]:
dataset1[14][0].shape

torch.Size([1, 28, 28])