## Introduction
In [Lesson 16 of the 2022 fastai course](https://course.fast.ai/Lessons/lesson16.html), at 1:14:28, Jeremy discusses a colorful dimension histogram visualization. It is similar to a histogram with a few changes:
* A time (batch/epoch) dimension is added and becomes the new x-axis
* The value axis (absolute value of activations of a particular module) moves from the x-axis to the y-axis
* The frequency of occurrences moves from the y-axis to the color axis

This notebook aims to recreate this visialization as a `trainplotkit` subplot using the same data and model. Reference plots may be found in [10_activations.ipynb](https://github.com/fastai/course22p2/blob/master/nbs/10_activations.ipynb) in the course material. 

Adding this visualization to `trainplotkit` provides two benefits:
* The visualization can be updated in real time after every batch / epoch
* If `trainplotkit` is already being used, this visualization can be added by adding a single element to an existing list of subplots

## Imports

In [1]:
import sys, importlib
import torch
from torch import nn, optim, Tensor
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset, DatasetDict, Dataset
from torch.utils.data import DataLoader, default_collate
from torcheval.metrics import MulticlassAccuracy
from operator import itemgetter

sys.path.append('..')
from trainplotkit.plotgrid import PlotGrid
from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP
from trainplotkit.subplots.activations import ColorfulDimensionSP

In [2]:
def reload_imports():
    importlib.reload(sys.modules['trainplotkit.layout'])
    importlib.reload(sys.modules['trainplotkit.plotgrid'])
    importlib.reload(sys.modules['trainplotkit.subplots.basic'])
    importlib.reload(sys.modules['trainplotkit.subplots.activations'])
    global PlotGrid, TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP, ColorfulDimensionSP, ActivationStatsSP
    from trainplotkit.plotgrid import PlotGrid
    from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, ClassProbsSP
    from trainplotkit.subplots.activations import ColorfulDimensionSP, ActivationStatsSP

## miniai snippets
During the course, a `miniai` high-level framework is developed on top of PyTorch. This section contains snippets from this framework required to run this experiment.

In [3]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

In [4]:
def collate_dict(ds:Dataset):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [5]:
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd:DatasetDict, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))

In [6]:
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

@inplace
def transformi(b): b['image'] = [TF.to_tensor(o) for o in b['image']]

## Data

In [7]:
bs = 256  # Modified from 1024 used in the course for a smaller GPU 

In [8]:
ds_name = "fashion_mnist"
dsd = load_dataset(ds_name)
dsd

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [9]:
dsd['train'][0]

{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
 'label': 9}

In [10]:
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=4)

## Models

### Baseline
This is the model used in [10_activations.ipynb](https://github.com/fastai/course22p2/blob/master/nbs/10_activations.ipynb) in the fastai course. This is before Jeremy discussed several initialization and normalization techniques to improve the training dynamics. This model produces a spiky colorful dimension histogram and high values on the dead chart

In [11]:
def conv(ni, nf, ks=3, act=True):
    res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

def cnn_layers():
    return [
        conv(1 ,8, ks=5),        #14x14
        conv(8 ,16),             #7x7
        conv(16,32),             #4x4
        conv(32,64),             #2x2
        conv(64,10, act=False),  #1x1
        nn.Flatten()]

def get_baseline_model(): 
    return nn.Sequential(*cnn_layers())

## Training loop

In [12]:
def train_epoch(model:nn.Module, device:str, train_loader:DataLoader, optimizer:optim.Optimizer, pg:PlotGrid):
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        pg.after_batch(training=True, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=True)

def test_epoch(model:nn.Module, device:str, test_loader:DataLoader, pg:PlotGrid):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output:Tensor = model(data)
            loss = F.cross_entropy(output, target)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            pg.after_batch(training=False, inputs=data, targets=target, predictions=output, loss=loss)
    pg.after_epoch(training=False)

def fit(n_epochs, model, device, train_loader, test_loader, optimizer, pg:PlotGrid):
    pg.before_fit()
    for _ in range(n_epochs):
        train_epoch(model, device, train_loader, optimizer, pg)
        test_epoch(model, device, test_loader, pg)
    pg.after_fit()

## Plots and training

In [13]:
reload_imports()

In [14]:
lr=0.4
device = 'cuda'
model = get_baseline_model().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr)

# Subplots
torch_ds = dsd['test'].with_format('torch')
batch_loss_fn = lambda preds,targs: F.cross_entropy(preds,targs,reduction='none')
probs_fn = lambda preds: torch.softmax(preds, dim=1)
conv_layers = [m for m in model.modules() if isinstance(m, nn.Conv2d)]
sps = [
    TrainingCurveSP(), 
    ColorfulDimensionSP(conv_layers[0], remember_past_epochs=True),
    ColorfulDimensionSP(conv_layers[1], remember_past_epochs=True),
    MetricSP("Accuracy", MulticlassAccuracy()),
    ColorfulDimensionSP(conv_layers[2], remember_past_epochs=True),
    ColorfulDimensionSP(conv_layers[3], remember_past_epochs=True),
    ActivationStatsSP(conv_layers, 'mean', remember_past_epochs=True),
    ActivationStatsSP(conv_layers, 'std', remember_past_epochs=True),
    ActivationStatsSP(conv_layers, 'dead', remember_past_epochs=True)
]
pg = PlotGrid(num_grid_cols=3, subplots=sps, fig_height=800)
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'e7061f9b-9595-4226-832b-33b801702d9b',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '091daf00-1909-40d9-86c4-71b5b93975a6',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'f86365ae-219d-4bee-9deb-b0a109a8b525',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'colorscale': [[0.0, '#44

In [15]:
n_epochs = 3
fit(n_epochs, model, device, dls.train, dls.valid, optimizer, pg)

In [16]:
pg.show_static()

## Debugging note
If during VSCode debugging session, the following message pops up
```
"Couldn't find a debug adapter descriptor for debug type 'Python Kernel Debug Adapter' (extension might have failed to activate)
```

Here are some references:
* ['debug adapter' error in jupyter notebook  #1166](https://github.com/microsoft/debugpy/issues/1166)
* [Subprocess support not being disabled #1168](https://github.com/microsoft/debugpy/issues/1168#issuecomment-1377998813)

Two suggested workarounds:
* Update [debugpy/server/api.py](../.venv/lib/python3.10/site-packages/debugpy/server/api.py) to set `"subProcess": False`
  * Finding it using `sys.modules['debugpy'].__file__`
* Set `num_workers=0` or omit it entirely for the DataLoader
* Try debugging on the CPU

In [2]:
import sys; sys.modules['debugpy'].__file__

'/home/dev/ai/trainplotkit/.venv/lib/python3.10/site-packages/debugpy/__init__.py'