## Introduction
This notebook builds up the `PlotGrid` and `Subplot` classes that allow extensible, dynamically updated and interactive subplots. 
* `Subplot` is the base class for all built-in and user-defined subplots supported by this library
* `PlotGrid` is the main functional entry point into this library and contains a list of `Subplot` instances

As an example, a simple training curve plot (loss vs epoch) is developed, first in pieces and then wrapped into a class. The class is designed to be easy to customize for other plot types

## Imports

In [1]:
import sys, importlib
from typing import List, Tuple, Mapping, Any
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset
from torcheval.metrics import MulticlassAccuracy
from torcheval.metrics.metric import Metric
import plotly.graph_objects as go
import plotly.callbacks as cb
from plotly.basedatatypes import BaseTraceType
from plotly.subplots import make_subplots
from IPython.display import display

sys.path.append('..')
from trainplotkit.layout import place_subplots

In [2]:
def reload_imports():
    importlib.reload(sys.modules['trainplotkit.layout'])
    global place_subplots
    from trainplotkit.layout import place_subplots

## Incrementally updated subplots
* First, we'll add hard-coded data incrementally to our graph. 
* Once we have that working, we'll wrap it into a class and call it from a custom validing loop. 
* After that, we'll attach it to callback mechanisms of high-level validing frameworks

In [3]:
# Hard-coded data
epochs = list(range(5))
train_losses = [0.100, 0.070, 0.060, 0.055, 0.052]
valid_losses = [0.090, 0.075, 0.056, 0.054, 0.050]
accuracies   = [0.300, 0.600, 0.800, 0.900, 0.950]
data_gen = zip(epochs, train_losses, valid_losses, accuracies)  # Yields tuples of the 4 upon calling `next`

# Create the figure widget
specs = [[{}],[{}]]  # Two vertically arranged subplots, each spanning a single grid cell
widget = go.FigureWidget(make_subplots(rows=2, cols=1, specs=specs))
widget.update_layout(height=500, width=800, margin=dict(l=8, r=8, t=16, b=16))

# Loss subplot with 2 traces
widget.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Training loss'), row=1, col=1)
metric_trace = widget.data[-1]
widget.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Validation loss'), row=1, col=1)
valid_loss_trace = widget.data[-1]
widget.update_layout(dict(xaxis_title_text='Epoch', yaxis_title_text='Loss'))

# Accuracy subplot with 1 trace
widget.add_trace(go.Scatter(x=[], y=[], mode='lines+markers', name='Accuracy'), row=2, col=1)
accuracy_trace = widget.data[-1]
widget.update_layout(dict(xaxis2_title_text='Epoch', yaxis2_title_text='Accuracy'))

# Display the widget for subsequent cells to add data points
display(widget)

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'c96f4f1f-760a-448b-a040-8ccfccafbc4d',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '46b96e1d-8947-48c0-9322-3b5624cf7029',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Accuracy',
              'type': 'scatter',
              'uid': 'ea1af14d-e97f-484e-bf50-a9dcd5a3c970',
              'x': [],
              'xaxis': 'x2',
              'y': [],
              'yaxis': 'y2'}],
    'layout': {'height': 500,
               'margin': {'b': 16, 'l': 8, 'r': 8, 't': 16},
               'template': '...',
               'width': 800,
      

Now add data points. Run this cell multiple times to add 1 point at a time

In [19]:
try:
    epoch, train_loss, valid_loss, accuracy = next(data_gen)
    new_x       = tuple(metric_trace.x) + (epoch,)
    new_y_train = tuple(metric_trace.y) + (train_loss,)
    new_y_valid = tuple(valid_loss_trace.y) + (valid_loss,)
    new_y_acc   = tuple(accuracy_trace.y)   + (accuracy,)
    metric_trace.update(x=new_x, y=new_y_train)
    valid_loss_trace.update(x=new_x, y=new_y_valid)
    accuracy_trace.update(x=new_x, y=new_y_acc)
except StopIteration:
    print("No more data points to add")
except Exception:
    # Re-raise any other exceptions
    raise 

No more data points to add


## Click events
For now, we'll just add additional markers to the widgets to show where the user clicked
* A click on one subplot updates markers on both subplots
* This can easily be extended to update other subplots in more useful ways, such as selecting an epoch or validation sample to visualize more detailed information at that epoch / sample

The widget is displayed again below to reduce scrolling fatigue. Two further interesting observations were made:
* Click events may be performed on either of the two copies of the widget
* Updates in response to click events are applied to both copies

This is probably thanks to the fact that both copies of the widget are referred to by the same ipywidgets model ID in the notebook DOM 

In [5]:
# Add marker to loss subplot (initially empty)
scatter = go.Scatter(x=[], y=[], mode='markers', showlegend=False, hoverinfo='skip',
                    marker=dict(color='rgba(0,0,0,0.2)', line=dict(color='black', width=2)))
widget.add_trace(scatter, row=1, col=1)
loss_marker_trace = widget.data[-1]

# Add marker to accuracy subplot (initially empty)
scatter = go.Scatter(x=[], y=[], mode='markers', showlegend=False, hoverinfo='skip',
                    marker=dict(color='rgba(0,0,0,0.2)', line=dict(color='black', width=2)))
widget.add_trace(scatter, row=2, col=1)
accuracy_marker_trace = widget.data[-1]

# Define click events
def on_trace_click(clicked_trace:BaseTraceType, points:cb.Points, selector):
    if not points.point_inds: return
    epoch = points.point_inds[0]

    # Update loss curve marker
    if clicked_trace == valid_loss_trace:
        y_loss = valid_loss_trace.y[epoch]
    else:
        y_loss = metric_trace.y[epoch]
    loss_marker_trace.update(x=[epoch], y=[y_loss])

    # Update accuracy curve marker
    y_acc = accuracy_trace.y[epoch]
    accuracy_marker_trace.update(x=[epoch], y=[y_acc])

# Register click event
metric_trace.on_click(on_trace_click)
valid_loss_trace.on_click(on_trace_click)
accuracy_trace.on_click(on_trace_click)

# Display the widget again to reduce scrolling fatigue
# * Note that subsequent updates occur on both copies of the widget
# * Click events can also be performed on either copy
display(widget)


FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'c96f4f1f-760a-448b-a040-8ccfccafbc4d',
              'x': [0],
              'xaxis': 'x',
              'y': [0.1],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '46b96e1d-8947-48c0-9322-3b5624cf7029',
              'x': [0],
              'xaxis': 'x',
              'y': [0.09],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Accuracy',
              'type': 'scatter',
              'uid': 'ea1af14d-e97f-484e-bf50-a9dcd5a3c970',
              'x': [0],
              'xaxis': 'x2',
              'y': [0.3],
              'yaxis': 'y2'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode

## PlotGrid and Subplot classes
We need a generic interface that takes care of all the machinery that all plots have in common, but allows the user to specify anything plot-specific. The `Subplot` class will provide this interface.

We also need a class to manage the figure widget that contains the different subplots. The `PlotGrid` class will take on this role.

At the end of a batch or epoch, the user should be able to call events on `PlotGrid`, which `PlotGrid` then forwards to all subplots. 

The architecture must support calling these events from either a custom training loop or from a callback registered with a high-level training framework.

The implementation in this notebook is a minimal subset of the full implementation: just enough to illustrate the core concepts. In subsequent notebooks, the implementation will be developed further in [subplots.py](../trainplotkit/subplots.py) using the workflow discussed in [development_workflow.ipynb](./development_workflow.ipynb)

### Class definitions
During the initial setup phase, the typical order of operations is as follows:
* The user first creates several instances of `SubPlot` subclasses to specify the plots to generate. At this point, the object's context parameters `parent`, `spi` and `position` are still `None`, because it does not know where it will be placed in the grid
* The user creates the `PlotGrid` object, passing in a list of `SubPlot` objects
* `PlotGrid.__init__` calls `PlotGrid.create_empty`, from where the `go.FigureWidget` object is created and the subplots are allocated their positions in the grid
* `PlotGrid.create_empty` calls `SubPlot.create_empty`, which creates objects for all traces with empty data vectors. The `Subplot` object's context parameters `parent`, `spi` and `position` are also populated here
* `SubPlot.create_empty` calls `PlotGrid.add_trace`, which adds the trace to the `go.FigureWidget` object and returns a `BaseTraceType` object which can be used to update the trace data later.
  * Note that although the trace object passed into `add_trace()` is of the same type as the trace returned by it, calling the update methods on the former does not actually update the widget.

When data becomes available, these are added from the `after_batch` and `after_epoch` methods, which are typically invoked from a custom training loop or from callbacks of higher-level training frameworks.

After training has finished, `Subplot.after_fit` is called by the user, from where some subplots will call `PlotGrid.register_user_epoch_event`. This will set the `on_click` events of the appropriate traces, making them interactive.

The `append_spi` helper function ensures that the correct object is targeted. Plotly provides the `go.Figure` and `BaseTraceType` objects for the figure and the traces, but there is no object for a subplot. Instead, a specific subplot is targeted by appending a numeric suffix to a dictionary key, e.g. `xaxis2`. `append_spi` appends the correct numeric suffix where appropriate.

In [6]:
class SubPlot():
    def __init__(self, colspan:int=1, rowspan:int=1):
        self.span = (rowspan, colspan)
        self.parent:"PlotGrid" = None
        self.spi: int = None  # Index within parent's subplot list
        self.position: Tuple[int,int] = None  # (row,col) position in grid

    # Subplot labels and contents
    def title(self) -> str: return ''
    def xlabel(self) -> str: return ''
    def ylabel(self) -> str: return ''
    def create_empty(self, parent, spi, position):
        self.parent = parent
        self.spi = spi
        self.position = position

    # Events
    def after_batch(self, training:bool, inputs:Tensor, targets:Tensor, predictions:Tensor, loss:Tensor): pass
    def after_epoch(self, training:bool): pass
    def after_fit(self): pass
    def on_user_epoch(self, epoch:int): pass

    # Helpers
    def append_spi(self, name):
        """Ensure that the correct subplot is targeted, e.g. 'xaxis' -> 'xaxis2'"""
        return name if self.spi < 1 else f"{name}{self.spi+1}"

class PlotGrid():
    def __init__(self, num_grid_cols, subplots:List[SubPlot], fig_height=500):
        self.num_grid_cols = num_grid_cols
        self.subplots = subplots
        self.fig_height = fig_height
        self.widget: go.FigureWidget = None
        self.create_empty()

    def show(self): display(self.widget)
    def show_static(self, renderer='notebook_connected'): self.widget.show(renderer)

    def create_empty(self):
        spans = [sp.span for sp in self.subplots]
        num_rows, positions, specs, matrix = place_subplots(self.num_grid_cols, spans)
        sp_titles = [sp.title() for sp in self.subplots]
        self.widget = go.FigureWidget(make_subplots(rows=num_rows, cols=self.num_grid_cols, specs=specs, subplot_titles=sp_titles))
        self.widget.update_layout(height=self.fig_height)
        for spi, sp in enumerate(self.subplots):
            sp.create_empty(self, spi, positions[spi])
            self.update_ax_titles(sp)

    def add_trace(self, sp:SubPlot, trace:BaseTraceType): 
        self.widget.add_trace(trace, row=sp.spi+1, col=1)
        return self.widget.data[-1]  # Object reference to the trace just added
    
    def update_ax_titles(self, sp:SubPlot):
        xaxis_name = sp.append_spi('xaxis')
        yaxis_name = sp.append_spi('yaxis')
        kwargs = {xaxis_name: dict(title_text=sp.xlabel()),
                  yaxis_name: dict(title_text=sp.ylabel())}
        self.widget.update_layout(**kwargs)

    def register_user_epoch_event(self, trace:BaseTraceType): 
        trace.on_click(self.on_user_epoch)

    def after_batch(self, training:bool, inputs, targets, predictions, loss):
        for sp in self.subplots: sp.after_batch(training, inputs, targets, predictions, loss)
    def after_epoch(self, training:bool):
        for sp in self.subplots: sp.after_epoch(training)
    def after_fit(self):
        for sp in self.subplots: sp.after_fit()
    def on_user_epoch(self, trace, points:cb.Points, selector):
        if not points.point_inds: return
        epoch = points.point_inds[0]
        for sp in self.subplots: sp.on_user_epoch(epoch)


### Specifying plots in subclasses
The following methods of the two classes below are of note:
* `create_empty()` creates and stores the empty trace objects
* `after_batch()` updates state as part of the computation of data to be plotted
* `after_epoch()` finishes the computation and updates the traces
* `after_fit()` registers click events on the traces to make them interactive
* `on_user_epoch()` contains the implementation of updates that occur in response to user click events

In [7]:
class TrainingCurveSP(SubPlot):
    def __init__(self, colspan=1, rowspan=1):
        super().__init__(colspan, rowspan)
        self.train_loss_trace: BaseTraceType = None
        self.valid_loss_trace: BaseTraceType = None
        self.marker_trace: BaseTraceType = None
        self.epoch = 0
        self.train_num = 0
        self.train_denom = 0
        self.valid_num = 0
        self.valid_denom = 0
        
    def title(self) -> str: return 'Training curve'
    def xlabel(self) -> str: return 'Epoch'
    def ylabel(self) -> str: return 'Loss'
    def create_empty(self, parent:PlotGrid, spi, position):
        super().create_empty(parent, spi, position)
        train_loss_trace = go.Scatter(x=[], y=[], mode='lines+markers', name='Training loss')
        self.train_loss_trace = parent.add_trace(self, train_loss_trace)
        valid_loss_trace = go.Scatter(x=[], y=[], mode='lines+markers', name='Validation loss')
        self.valid_loss_trace = parent.add_trace(self, valid_loss_trace)
        marker_trace = go.Scatter(x=[], y=[], mode='markers', showlegend=False, hoverinfo='skip',
                                  marker=dict(color='rgba(0,0,0,0.2)', line=dict(color='black', width=2)))
        self.marker_trace = parent.add_trace(self, marker_trace)
    
    def after_batch(self, training, inputs, targets, predictions, loss):
        if training:
            self.train_num += float(loss.detach().cpu())
            self.train_denom += 1
        else:
            self.valid_num += float(loss.detach().cpu())
            self.valid_denom += 1

    def after_epoch(self, training):
        if training:
            loss = self.train_num / self.train_denom
            new_x = tuple(self.train_loss_trace.x) + (self.epoch,)
            new_y = tuple(self.train_loss_trace.y) + (loss,)
            self.train_loss_trace.update(x=new_x, y=new_y)
            self.train_num = 0
            self.train_denom = 0
        else:
            loss = self.valid_num / self.valid_denom
            new_x = tuple(self.valid_loss_trace.x) + (self.epoch,)
            new_y = tuple(self.valid_loss_trace.y) + (loss,)
            self.valid_loss_trace.update(x=new_x, y=new_y)
            self.valid_num = 0
            self.valid_denom = 0
            self.epoch += 1

    def after_fit(self):
        self.parent.register_user_epoch_event(self.train_loss_trace)
        self.parent.register_user_epoch_event(self.valid_loss_trace)
        
    def on_user_epoch(self, epoch:int):
        y = self.train_loss_trace.y[epoch]
        self.marker_trace.update(x=[epoch], y=[y])

In [8]:
class MetricSP(SubPlot):
    def __init__(self, metric_name:str, metric:Metric[Tensor], colspan=1, rowspan=1):
        super().__init__(colspan, rowspan)
        self.metric_name = metric_name
        self.metric = metric
        self.metric_trace: BaseTraceType = None
        self.marker_trace: BaseTraceType = None
        self.epoch = 0
        self.train_num = 0
        self.train_denom = 0
        self.valid_num = 0
        self.valid_denom = 0
        
    def title(self) -> str: return self.metric_name
    def xlabel(self) -> str: return 'Epoch'
    def ylabel(self) -> str: return self.metric_name
    def create_empty(self, parent:PlotGrid, spi, position):
        super().create_empty(parent, spi, position)
        metric_trace = go.Scatter(x=[], y=[], mode='lines+markers', name=self.metric_name)
        self.metric_trace = parent.add_trace(self, metric_trace)
        marker_trace = go.Scatter(x=[], y=[], mode='markers', showlegend=False, hoverinfo='skip',
                                  marker=dict(color='rgba(0,0,0,0.2)', line=dict(color='black', width=2)))
        self.marker_trace = parent.add_trace(self, marker_trace)
    
    def after_batch(self, training, inputs, targets, predictions, loss):
        if training: return  # Only interested in validation samples
        self.metric.update(predictions.detach().cpu(), targets.detach().cpu())

    def after_epoch(self, training):
        if training: return  # Only interested in validation samples
        value = self.metric.compute()
        new_x = tuple(self.metric_trace.x) + (epoch,)
        new_y = tuple(self.metric_trace.y) + (value,)
        self.metric_trace.update(x=new_x, y=new_y)
        self.metric.reset()

    def after_fit(self):
        self.parent.register_user_epoch_event(self.metric_trace)
        self.parent.register_user_epoch_event(self.metric_trace)
        
    def on_user_epoch(self, epoch:int):
        y = self.metric_trace.y[epoch]
        self.marker_trace.update(x=[epoch], y=[y])

### Usage example
This contains an example of how the classes developed above might be called. Tensors passed to `after_batch` are simulated to produce the same accuracy and loss values as in the earlier examples. The first two cells below contain this simulation and will therefore not normally be needed.

In [9]:
targets=Tensor([1,1,1,1,1,1,1,1,1,1])
predictions=Tensor([0,0,0,0,0,0,1,1,1,1])
metric = MulticlassAccuracy()
metric.update(predictions, targets)
metric.compute()

tensor(0.4000)

In [10]:
train_losses = [0.100, 0.070, 0.060, 0.055, 0.052]
valid_losses = [0.090, 0.075, 0.056, 0.054, 0.050]
accuracies   = [0.300, 0.600, 0.800, 0.900, 0.950]

num_epochs = len(accuracies)
batch_size, num_train_batches, num_valid_batches = 10, 20, 10
inputs=torch.full((batch_size,1,4,4), fill_value=1)
targets=torch.ones((batch_size,))
def predictions_gen(accuracies, batch_size, num_batches):
    for epoch_accuracy in accuracies:
        num_ones = int(epoch_accuracy * batch_size * num_batches)
        num_zeros = batch_size * num_batches - num_ones
        epoch_preds = torch.cat((torch.ones((num_ones,)), torch.zeros((num_zeros,))), dim=0)
        for batch_idx in range(num_batches):
            yield epoch_preds[batch_idx*batch_size:batch_idx*batch_size+batch_size]

Create the empty widget

In [None]:
sp0 = TrainingCurveSP()
sp1 = MetricSP('Accuracy', MulticlassAccuracy())
pg = PlotGrid(num_grid_cols=1, subplots=[sp0,sp1])
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': '576168d0-f355-4fa6-af10-31e274ffbc3b',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '4618f7ee-242f-40a7-a0ed-54035a32cf84',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': '7b9537ea-f9ab-48a6-aceb-1b267d3aa824',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',


Send the simulated batches through

In [12]:
train_preds_gen = predictions_gen(accuracies, batch_size, num_train_batches)
valid_preds_gen = predictions_gen(accuracies, batch_size, num_valid_batches)

for epoch in range(num_epochs):
    for bi in range(num_train_batches):
        pg.after_batch(training=True, inputs=inputs, targets=targets, predictions=next(train_preds_gen), loss=Tensor([train_losses[epoch]]))
    pg.after_epoch(training=True)
    for bi in range(num_valid_batches):
        pg.after_batch(training=False, inputs=inputs, targets=targets, predictions=next(valid_preds_gen), loss=Tensor([valid_losses[epoch]]))
    pg.after_epoch(training=False)

Create click events for interactivity

In [13]:
pg.after_fit()

### Saving a static copy
When done interacting with the figure, add a static copy to the notebook DOM for sharing purposes 

In [14]:
pg.show_static()

## Discussion

### Architectural considerations
A design aim was to support easy integration from both custom training loops and high-level frameworks like `pytorch_lightning`, `fastai` and HuggingFace `Trainer` without requiring the user to write much code in either case.

The following decisions were made:
* Training loop variables such as `inputs`, `targets`, `predictions` and `loss` are passed as arguments to `PlotGrid.after_batch`
* If subplots require access to activations and/or parameters of specific modules, references to these modules are passed to the constructor of the `SubPlot` subclass

### Using a custom training loop
User code is responsible for the following steps:
* Constructing instances of `SubPlot` subclasses to specify the desired plots
* Constructing a `PlotGrid` instance with a list of these `SubPlot` instances as argument
* Calling `PlotGrid.show()` from the cell(s) below which the widget should be displayed
* Calling `PlotGrid.after_batch()`, `PlotGrid.after_epoch()`, etc. from the appropriate places in the training loop

Optional user steps for additional functionality
* Interacting with graphs after calling `PlotGrid.after_fit()` where specific subplots define events
* Adding a static copy to the notebook DOM using `PlotGrid.show_static()`
* Adding an customizing subplots by subclassing `SubPlot` or one of its other subclasses

An example of how `trainplotkit` may be used from a custom training loop is provided in [04_using_custom_training_loop.ipynb](./04_using_custom_training_loop.ipynb)

### Using a high-level framework
* Most high-level frameworks like `pytorch_lightning`, `fastai` and HuggingFace `Trainer` allow users to provide callbacks that will be executed by the framework after batches, epochs, etc.
* The steps for using a framework with `trainplotkit` is similar to those for a custom training loop above, except that the `PlotGrid` object must be passed to the constructor of such a callback and events like `after_batch`, `after_epoch`, etc. must be called from the appropriate callback methods
* For `pytorch_lightning` and `fastai`, `trainplotkit` already provides callback adapters, which can be constructed as-is and passed as callbacks to the high-level framework. These adapters take care of calling the right methods of `PlotGrid` at the right time

Examples of how `trainplotkit` may be used with these high-level frameworks are provided in [05_using_pyorch_lightning.ipynb](./05_using_pyorch_lightning.ipynb) and [06_using_fastai](./06_using_fastai.ipynb)

The diagram below shows the information path from the custom training loop or high-level framework to the plot implementation.

![trainplotkit architecture diagram](../resources/trainplotkit_arch.png) 