## Introduction
In this notebook, a simple model is trained while visualizing several graphs
* Line graphs: training loss, validation loss and accuracy as a function of epoch
* Scatter plot: Validation loss for individual validation samples
* A sample image, selectable by clicking a sample on the scatter plot 
* Bar graph: class probabilities for selected sample 

The MNIST dataset is used in this notebook to facilitate fast iteration on plot functionality updates.

While writing this notebook, the functionality developed in [03_architecture.ipynb](./03_architecture.ipynb) was expanded directly in [](../trainplotkit/subplots.py) using the workflow described in [development_workflow.ipynb](./development_workflow.ipynb)

## Imports

In [1]:
import sys, importlib
import torch
from torch import nn, optim, Tensor
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torcheval.metrics import MulticlassAccuracy

sys.path.append('..')
from trainplotkit.layout import place_subplots
from trainplotkit.plotgrid import PlotGrid
from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP

In [2]:
def reload_imports():
    importlib.reload(sys.modules['trainplotkit.layout'])
    importlib.reload(sys.modules['trainplotkit.plotgrid'])
    importlib.reload(sys.modules['trainplotkit.subplots.basic'])
    global place_subplots, PlotGrid, TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP
    from trainplotkit.layout import place_subplots
    from trainplotkit.plotgrid import PlotGrid
    from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP

## Data preparation
Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)

In [3]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=64, num_workers=1, pin_memory=True, shuffle=True)
test_loader  = torch.utils.data.DataLoader(dataset2, batch_size=500, num_workers=1, pin_memory=True, shuffle=False)

In [4]:
dataset1[0][0].shape

torch.Size([1, 28, 28])

## Model
Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [6]:
class SmallNet(nn.Module):
    """Smaller model to iterate faster while developing plots"""
    def __init__(self):
        super(SmallNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 2)
        self.conv2 = nn.Conv2d(8, 8, 3, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(72, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Training loop
* Based on [PyTorch Basic MNIST Example](https://github.com/pytorch/examples/blob/main/mnist/main.py)
* Small edits to add calls to `PlotGrid` methods

In [7]:
def train_epoch(model, device, train_loader, optimizer, pg:PlotGrid):
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pg.after_batch(training=True, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=True)

def test_epoch(model, device, test_loader, pg:PlotGrid):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            pg.after_batch(training=False, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=False)

def fit(n_epochs, model, device, train_loader, test_loader, optimizer, scheduler, pg:PlotGrid):
    pg.before_fit()
    for _ in range(n_epochs):
        train_epoch(model, device, train_loader, optimizer, pg)
        test_epoch(model, device, test_loader, pg)
        scheduler.step()
    pg.after_fit()

In [8]:
logits = Tensor([[-1,-2,-3,-4],[5,6,7,8],[3,4,5,6]])
torch.softmax(logits, dim=1)

tensor([[0.6439, 0.2369, 0.0871, 0.0321],
        [0.0321, 0.0871, 0.2369, 0.6439],
        [0.0321, 0.0871, 0.2369, 0.6439]])

## Define plots and run training loop

In [9]:
reload_imports()

In [10]:
device = 'cuda'
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# Subplots
batch_loss_fn = lambda preds,targs: F.nll_loss(preds,targs,reduction='none')
probs_fn = lambda preds: torch.softmax(preds, dim=1)
sps = [
    TrainingCurveSP(colspan=2), 
    ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2), 
    ImageSP(dataset2, rowspan=2),
    MetricSP("Accuracy", MulticlassAccuracy(), colspan=2), 
    ClassProbsSP(probs_fn, remember_past_epochs=True, class_names=[f"{i}" for i in range(10)], colspan=2),
]
pg = PlotGrid(num_grid_cols=5, subplots=sps)
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': '3b7d7455-d85d-4480-976a-2a08b4e72d81',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': 'e412d022-c12a-4cd0-97e2-a203134f0c02',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'dc7209f8-80bb-4b0f-9baf-3982f8044535',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'markers',
      

In [11]:
n_epochs = 5
fit(n_epochs, model, device, train_loader, test_loader, optimizer, scheduler, pg)

### Saving a static copy
When done interacting with the figure, add a static copy to the notebook DOM for sharing purposes 

In [19]:
pg.show_static()

## Debugging
* Two issues were discovered by recalculating predictions and loss for individual samples
  * Initially, some of the subplots forgot to call `.detach()`
  * Initially, the validation set was shuffled after each epoch, causing strange behavior on the `ValidLossSP` and `ClassProbsSP` subplots 

In [13]:
if False:
    it = iter(test_loader)
    data, target = next(it)
    data.shape, target.shape, [int(i) for i in target]

In [14]:
data = torch.cat([dataset2[i][0][None] for i in range(20)], dim=0)
target = Tensor([dataset2[i][1] for i in range(20)]).to(torch.int64)
data.shape, target.shape, [int(i) for i in target]


(torch.Size([20, 1, 28, 28]),
 torch.Size([20]),
 [7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4])

In [15]:
with torch.no_grad():
    data, target = data.to(device), target.to(device)
    output = model(data)
    loss = F.nll_loss(output, target, reduction='none')

In [16]:
idx=0
target[idx], output[idx:idx+1], torch.softmax(output[idx:idx+1], dim=1), loss[idx]

(tensor(7, device='cuda:0'),
 tensor([[-24.9671, -22.8873, -19.9265, -19.0947, -25.5827, -27.2579, -33.4926,
            0.0000, -27.5218, -20.7243]], device='cuda:0'),
 tensor([[1.4353e-11, 1.1486e-10, 2.2184e-09, 5.0965e-09, 7.7546e-12, 1.4523e-12,
          2.8468e-15, 1.0000e+00, 1.1154e-12, 9.9894e-10]], device='cuda:0'),
 tensor(-0., device='cuda:0'))

In [17]:
sps[4].probs[-1][idx]

[1.4352536520378933e-11,
 1.1485572215530482e-10,
 2.2183626047223015e-09,
 5.0965218711951366e-09,
 7.754569555928903e-12,
 1.4523319612907981e-12,
 2.8468228903383687e-15,
 1.0,
 1.1154269782126525e-12,
 9.989392646403417e-10]

### Updating subplot titles
Subplot title annotations are created and positioned upon subplot creation using the `subplot_titles` parameter passed to `make_subplots`. This was a quick experiment to see how these titles may be updated later in response to user interactions and how multi-line titles are displayed

In [18]:
pg.widget.layout.annotations[2].text = 'Input: sample 0<br>Target=7'

### Showing images the right way up in static figure
* Sometimes after using the Plotly controls on the widget (especially zoom and auto-scale) before calling `PlotGrid.show_static()`, the images were displayed upside down in the static figure
* After some experimentation, it was found that simply setting the `yaxis.range` of the appropriate sub-plot to the value already assigned to it resulted in sufficient internal state updates to ensure that images are shown the right way up in the static figure
* This has the side effect of setting `autorange` to `False` for the particular axis, but this is something we can live with in terms of UX, because it only happens after training has completed and all other interactions at that point are still user-friendly

In [21]:
yaxis3 = pg.widget.layout['yaxis3']
yaxis3

layout.YAxis({
    'anchor': 'x3', 'range': [27.5, -0.5], 'title': {'text': ''}
})

In [25]:
yaxis3['scaleanchor'], yaxis3['constrain'], yaxis3['autorange']

('x3', 'domain', False)

In [27]:
yaxis3['range'] = yaxis3['range']

In [28]:
pg.show_static()