## Introduction
At this point, we wished to add some interactivity to the portable figures and HTML exports using controls provided by Plotly. In these outputs, all interaction happen on the front-end. We experiment with the following:
* A dropdown menu to color nodes according type / parameters / FLOPS
* A dropdown menu to size nodes according parameters / FLOPS
* An overview panel with a rectangle that responds to the slider



## Dropdown menus for marker styling
* Reference: [here](https://plotly.com/python/dropdowns/)
* More properties: [here](https://plotly.com/python/reference/layout/updatemenus/)
* Everything happens on the front-end (i.e. no need for `go.FigureWidget` or `ipywidgets`)

### With a single trace
The example below uses IDLMAV for tracing and layout, but plotly code from scratch for plotting

In [1]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [2]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

INFO:2025-02-24 06:47:07 590671:590671 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


In [3]:
mav.tracer.g.nodes

[MavNode: name=x; (x,y)=(0.0,0); in=[], out=['conv1'],
 MavNode: name=conv1; (x,y)=(0.0,1); in=['x'], out=['relu'],
 MavNode: name=relu; (x,y)=(0.0,1.15); in=['conv1'], out=['conv2'],
 MavNode: name=conv2; (x,y)=(0.0,2); in=['relu'], out=['relu_1'],
 MavNode: name=relu_1; (x,y)=(0.0,2.15); in=['conv2'], out=['max_pool2d'],
 MavNode: name=max_pool2d; (x,y)=(0.0,2.3); in=['relu_1'], out=['dropout1'],
 MavNode: name=dropout1; (x,y)=(0.0,2.45); in=['max_pool2d'], out=['flatten'],
 MavNode: name=flatten; (x,y)=(0.0,2.6); in=['dropout1'], out=['fc1'],
 MavNode: name=fc1; (x,y)=(0.0,3); in=['flatten'], out=['relu_2'],
 MavNode: name=relu_2; (x,y)=(0.0,3.15); in=['fc1'], out=['dropout2'],
 MavNode: name=dropout2; (x,y)=(0.0,3.3); in=['relu_2'], out=['fc2'],
 MavNode: name=fc2; (x,y)=(0.0,4); in=['dropout2'], out=['log_softmax'],
 MavNode: name=log_softmax; (x,y)=(0.0,4.15); in=['fc2'], out=['output'],
 MavNode: name=output; (x,y)=(0.0,5); in=['log_softmax'], out=[]]

In [None]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
fig = go.Figure(layout=dict(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0)))

node_trace = go.Scatter(
    x=[n.x for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(node_trace)
fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])

fig.show(renderer='notebook_connected')

In [5]:
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=[dict(
                        marker=dict(
                            size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                            color=[renderer.get_node_color(n) for n in g.nodes],
                            colorscale='Bluered'
                        )
                    )],
                    label="Size by params, color by operation",
                    method="update"
                ),
                dict(
                    args=[dict(
                        marker=dict(
                            size=[renderer.flops_to_dot_size(n.params) for n in g.nodes],
                            color=[renderer.get_node_color(n) for n in g.nodes],
                            colorscale='Bluered'
                        )
                    )],
                    label="Size by FLOPS, color by operation",
                    method="update"
                ),
                dict(
                    args=[dict(
                        marker=dict(
                            size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                            color=[renderer.get_node_color(n, 'flops') for n in g.nodes],
                            colorscale='Bluered'
                        )
                    )],
                    label="Size by params, color by FLOPS",
                    method="update"
                ),
                dict(
                    args=[dict(
                        marker=dict(
                            size=[renderer.flops_to_dot_size(n.params) for n in g.nodes],
                            color=[renderer.get_node_color(n, 'params') for n in g.nodes],
                            colorscale='Bluered'
                        )
                    )],
                    label="Size by FLOPS, color by params",
                    method="update"
                )
            ]),
            direction="up", showactive=True,
            pad={"l": 0, "t": 0},
            x=0, xanchor="left",
            y=0, yanchor="top"
        ),
    ],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

### With multiple traces

In [6]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [7]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

In [None]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
fig = go.Figure(layout=dict(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0)))

node_trace = go.Scatter(
    x=[n.x for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(node_trace)

scatter_trace_2 = go.Scatter(
    x=[n.x+1 for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(scatter_trace_2)

x_coords, y_coords = [],[]
for c in g.connections:
    xs, ys = renderer.get_connection_coords(c)
    if x_coords: x_coords.append(None)
    if y_coords: y_coords.append(None)
    x_coords += xs
    y_coords += ys

line_trace = go.Scatter(
    x=x_coords, y=y_coords, mode="lines",
    line=dict(color="gray", width=1),            
    hoverinfo='skip',
    showlegend=False
)
fig.add_trace(line_trace)

fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])

fig.show(renderer='notebook_connected')

In [9]:
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=[dict(
                        marker=[
                            dict(
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            ),
                            dict(
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            ),
                            {}
                        ]
                    )],
                    label="Size by params, color by operation",
                    method="restyle"
                ),
                dict(
                    args=[dict(
                        marker=[
                            dict(
                                size=[renderer.flops_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            ),
                            dict(
                                # Setting this to `params_to_dot_size` on purpose to examine the possibility
                                # of changing one trace while keeping another trace unchanged
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            )
                        ]
                    )],
                    label="Size by FLOPS, color by operation",
                    method="restyle"
                ),
                dict(
                    args=[dict(
                        marker=[
                            dict(
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n, 'flops') for n in g.nodes],
                                colorscale='Bluered'
                            ),
                            dict(
                                # Setting this to defaults on purpose to examine the possibility
                                # of changing one trace while keeping another trace unchanged
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            )
                        ]
                    )],
                    label="Size by params, color by FLOPS",
                    method="restyle"
                ),
                dict(
                    args=[dict(
                        marker=[
                            dict(
                                size=[renderer.flops_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n, 'params') for n in g.nodes],
                                colorscale='Bluered'
                            ),
                            dict(
                                # Setting this to defaults on purpose to examine the possibility
                                # of changing one trace while keeping another trace unchanged
                                size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                                color=[renderer.get_node_color(n) for n in g.nodes],
                                colorscale='Bluered'
                            )
                        ]
                    )],
                    label="Size by FLOPS, color by params",
                    method="restyle"
                )
            ]),
            direction="up", showactive=True,
            pad=dict(l=0, r=0, t=0, b=0),
            x=0, xanchor="left",
            y=0, yanchor="top"
        ),
    ],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

### Observations
* Interactions work after saving and reloading, so they may be expected to work in front-end only deployments
* When restyling markers, one can either provide a single `marker` dictionary or an array of dictionaries
  * If a single dictionary is provided, it is applied to all traces
  * If an array of dictionaries is provided, the first element is applied to the first trace, etc., until either the dictionaries or the traces run out
* There is no persistence of marker styling attributes
  * If no paramteers are provided for a certain trace containing markers, its markers are reset to default parameters (sizing, coloring, etc)
  * If only colors are provided for a certain trace containing markers, the sizes are reset to default parameters
  * To change marker styling for one trace without changing another trace, the marker style for the other trace must be specified again
  * To change one attribute (e.g. color, size) without changing another attribute, the values for the other attribute must be specified again, unless default values are desired
* The magic underscore notation available in other parts of Plotly is not available here
  * e.g. Replacing `marker=[dict(size=[...])]` by `marker_size=[...]` does not work
* For `direction="down"` and a dropdown menu below the figure, extra padding must be provided
* The `x` and `y` values are in normalized coordinates and the `xanchor` and `yanchor` specifies which edges of the menu are tied to these coordinates
  * The following configurations place the menu just inside the figure:
    * Top left corner:     `dict(x=0, y=1, xanchor='left',  yanchor='top',    direction='down')`
    * Bottom left corner:  `dict(x=0, y=0, xanchor='left',  yanchor='bottom', direction='up')`
    * Top right corner:    `dict(x=1, y=1, xanchor='right', yanchor='top',    direction='down')`
    * Bottom right corner: `dict(x=1, y=0, xanchor='right', yanchor='bottom', direction='up')`
  * The following configurations place the menu just outside the figure:
    * Top left corner:     `dict(x=0, y=1, xanchor='left',  yanchor='bottom', direction='down')`
    * Bottom left corner:  `dict(x=0, y=0, xanchor='left',  yanchor='top',    direction='up')`
    * Top right corner:    `dict(x=1, y=1, xanchor='right', yanchor='bottom', direction='down')`
    * Bottom right corner: `dict(x=1, y=0, xanchor='right', yanchor='top',    direction='up')`
* The `active` parameter may be used for the first item in the list is not the default

### Decisions
* Use one dropdown menu to change both colors and sizes of markers
  * Because of the lack of persistence (which is probably the price of not wanting to rely on a back-end), changing them separately will require additional state and maintenance overhead 
  * Some combinations of sizing and coloring (e.g both based on FLOPS / both based on params) are undesirable anyway
  * This approach works well with more thorough descriptions in the dropdown menu items, saving us the trouble of adding labels, laying out the controls and taking different screen sizes into account
* Provide the following marker styling combinations:
  * Size by params, color by operation
  * Size by FLOPS, color by operation
  * Size by params, color by FLOPS
  * Size by FLOPS, color by params
* Use functions to build up marker and menu dictionaries
  * This will result in simpler high-level code for a reader interested in the different buttons and the different updates performed by each button
* Make `Size by FLOPS, color by operation` the default when the `show_param_nodes` user variable is set to `True`
  * Keep the same order of items, but set `active=1`

### Building up marker and menu dictionaries

In [None]:
def build_marker_dict(renderer:FigureRenderer, trace_idx:int, size_by:str, color_by:str):
    g = renderer.g
    if trace_idx > 0:
        # The trace that does not change
        return dict(size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
                    color=[renderer.get_node_color(n) for n in g.nodes],
                    colorscale='Bluered')
    
    if size_by=='flops':
        sizes = [renderer.flops_to_dot_size(n.params) for n in g.nodes]
    elif size_by=='params':
        sizes = [renderer.params_to_dot_size(n.params) for n in g.nodes]
    else:
        raise ValueError(f'Unknown size_by: {size_by}')
    
    colors = [renderer.get_node_color(n, color_by) for n in g.nodes]

    return dict(size=sizes, color=colors, colorscale='Bluered')

def build_menu_button(renderer:FigureRenderer, size_by:str, color_by:str):
    size_color_labels = dict(operation='operation', params='params', flops='FLOPS')
    return dict(
        args=[dict(
            marker=[
                build_marker_dict(renderer, 0, size_by, color_by),
                build_marker_dict(renderer, 1, size_by, color_by),
            ]
        )],
        label=f'Size by {size_color_labels[size_by]}, color by {size_color_labels[color_by]}',
        method="restyle"
    )

def build_styling_menu(renderer:FigureRenderer):
    size_color_options = [('params','operation'),
                          ('flops','operation'),
                          ('params','flops'),
                          ('flops','params')]
    
    menu_buttons = [build_menu_button(renderer, size_by, color_by) for (size_by, color_by) in size_color_options]
    
    return dict(buttons=menu_buttons, showactive=True, direction="up",
                pad=dict(l=0, r=0, t=0, b=0),
                x=0, xanchor="left",
                y=0, yanchor="top")

fig.update_layout(
    updatemenus=[build_styling_menu(renderer)],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

## Slider moving a box around
Here, the idea is to use a [slider](https://plotly.com/python/reference/layout/sliders/) to pan one zoomed-in panel (by changing the axis limits) while showing a fully zoomed-out copy of the figure in another panel, along with a rectangle highlighting the area displayed on the zoomed-in panel

We start with a single panel to experiment with rectangle updates based on slider events. There are two possible approaches to draw the rectangle:
* A normal scatter trace with [mode='lines'](https://plotly.com/python/reference/scatter/#scatter-mode)
* The shapes layer: [tutorial](https://plotly.com/python/shapes/) | [reference](https://plotly.com/python/reference/layout/shapes/)

### Using a scatter trace with mode='lines'

In [10]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
import numpy as np
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [11]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

In [None]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
fig = go.Figure(layout=dict(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0)))

node_trace = go.Scatter(
    x=[n.x for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(node_trace)

box_trace = go.Scatter(
    x=[-0.1,-0.1,0.1,0.1,-0.1], y=[0,1,1,0,0],
    mode='lines',
    line=dict(color='#000000'),
    showlegend=False
)
fig.add_trace(box_trace)

fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[], range=[-1,1])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])

fig.show(renderer='notebook_connected')

In [13]:
steps = []
for i in np.linspace(0, max([n.y for n in g.nodes]), 20):
    step = dict(
        method="restyle",
        args=[dict(
            x=[[n.x for n in g.nodes],[-0.1,-0.1,0.1,0.1,-0.1]], 
            y=[[n.y for n in g.nodes],[i,i+1,i+1,i,i]], 
        )],
        label="",
    )
    steps.append(step)

fig.update_layout(
    sliders=[dict(
        steps=steps,
        active=0,
        currentvalue={"visible": False},
        pad=dict(t=10, l=0, b=0, r=0),
        x=0, xanchor="left",
        y=0, yanchor="top"
    )],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

### Using the shapes layer

In [14]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
import numpy as np
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [15]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

In [None]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
fig = go.Figure(layout=dict(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0)))

node_trace = go.Scatter(
    x=[n.x for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(node_trace)

fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[], range=[-1,1])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])

fig.add_shape(type="rect",
    xref="x", yref="y",
    x0=-0.1, y0=0, x1=0.1, y1=1,
    line=dict(color="#000000"),
)

fig.show(renderer='notebook_connected')

In [17]:
steps = []
for i in np.linspace(0, max([n.y for n in g.nodes]), 20):
    step = dict(
        method="relayout",
        args=[dict(
            shapes=[dict(type="rect",
                xref="x", yref="y",
                x0=-0.1, y0=i, x1=0.1, y1=i+1,
                line=dict(color="#000000"),
            )]
        )],
        label="",
    )
    steps.append(step)

fig.update_layout(
    sliders=[dict(
        steps=steps,
        active=0,
        currentvalue={"visible": False},
        pad=dict(t=10, l=0, b=0, r=0),
        x=0, xanchor="left",
        y=0, yanchor="top"
    )],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

### Cell output sizes

In [19]:
from idlmav.mavdbgutils import display_nb_cell_output_sizes
display_nb_cell_output_sizes('18_explore_plotly_frontend_controls.ipynb')

|Cell number|Output size|Cell type  |First line  |
|-----------|-----------|-----------|------------|
|0|0|markdown|## Introduction|
|1|0|markdown|## Dropdown menus for marker styling|
|2|0|markdown|### For a single trace|
|3|0|code|import sys, importlib|
|4|532|code|class MnistCnn(nn.Module):|
|5|1138|code|mav.tracer.g.nodes|
|6|11522|code|g = mav.tracer.g|
|7|12352|code|fig.update_layout(|
|8|0|markdown|### For multiple traces|
|9|0|code|import sys, importlib|
|10|0|code|class MnistCnn(nn.Module):|
|11|13931|code|g = mav.tracer.g|
|12|15320|code|fig.update_layout(|
|13|0|markdown|### Observations|
|14|0|markdown|### Decisions|
|15|0|markdown|## Slider moving a box around|
|16|0|markdown|### Scatter trace with mode='lines'|
|17|0|code|import sys, importlib|
|18|0|code|class MnistCnn(nn.Module):|
|19|11662|code|g = mav.tracer.g|
|20|17349|code|steps = []|
|21|0|markdown|### Shapes layer|
|22|0|code|import sys, importlib|
|23|0|code|class MnistCnn(nn.Module):|
|24|11646|code|g = mav.tracer.g|
|25|15459|code|steps = []|
|26|0|markdown|### Observations|

### Observations
* There is no way to change the orientation of the slider to vertical
* Data are written to the DOM for every slider step 
  * When the user drags a slider, values to update are retrieved via indexing, not calculated
  * As such, fine-grained slider steps are expensive in terms of DOM size
* Updating coordinates of `go.Scatter` objects is doable, but requires specifying the coordinates of all `go.Scatter` traces for every slider step
  * This may be very expensive if the figure contains many scatter data points
  * Updating the shapes layer may be more efficient

### With multiple panels
Now we can proceed to add two panels: one containing the zoomed-out overview with the rectangular box and one containing the zoomed-in data panned by the slider

In [12]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [13]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

In [None]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
max_y = max([n.y for n in g.nodes])

fig = make_subplots(rows=1, cols=2, vertical_spacing=0.03, 
                    specs=[[{"type": "scatter"},{"type": "scatter"}]], column_widths=[1,3])
fig.update_layout(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0))

node_trace = go.Scatter(
    x=[n.x for n in g.nodes], 
    y=[n.y for n in g.nodes], 
    mode='markers', 
    marker=dict(
        size=[renderer.params_to_dot_size(n.params) for n in g.nodes],
        color=[renderer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<br>' +
        'args: %{customdata[5]}<br>' +
        'kwargs: %{customdata[6]}<br>' +
        '<extra></extra>'
    ),
    customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
    showlegend=False
)
fig.add_trace(node_trace, row=1, col=1)
fig.add_trace(node_trace, row=1, col=2)

fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_layout(xaxis_range=[-0.11,0.11], xaxis2_range=[-0.1,0.1],
                  yaxis_range=[max_y+0.2,-0.2], yaxis2_range=[1.9,-0.1])

fig.add_shape(type="rect",
    xref="x", yref="y",
    x0=-0.1, y0=-0.1, x1=0.1, y1=1.9,
    line=dict(color="#000000"),
    row=1, col=1
)

fig.show(renderer='notebook_connected')


In [None]:
steps = []
num_levels_displayed = 2
for i in np.linspace(-0.1, max_y+0.1-num_levels_displayed, 20):
    step = dict(
        method="relayout",
        args=[dict(
            shapes=[dict(type="rect",
                xref="x", yref="y",
                x0=-0.1, y0=i, x1=0.1, y1=i+num_levels_displayed,
                line=dict(color="#000000"),
            )],
            yaxis2=dict(range=[i+num_levels_displayed, i], showgrid=False, zeroline=False, tickmode='array', tickvals=[]),
        )],
        label="",
    )
    steps.append(step)

fig.update_layout(
    sliders=[dict(
        steps=steps,
        active=0,
        currentvalue={"visible": False},
        pad=dict(t=10, l=0, b=0, r=0),
        x=0, xanchor="left",
        y=0, yanchor="top"
    )],
    margin=dict(l=0, r=0, t=0, b=0)
)
fig.show(renderer='notebook_connected')

## Both controls together
In this section, both controls developed above are implemented on the same figure to check for possible unwanted interactions

In [1]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [2]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

INFO:2025-02-24 20:53:56 611783:611783 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


In [13]:
def build_marker_dict(renderer:FigureRenderer, trace_idx:int, size_by:str, color_by:str):
    g = renderer.g
    if trace_idx==3:
        if size_by=='flops':
            sizes = [renderer.flops_to_dot_size(n.params) for n in g.nodes]
        elif size_by=='params':
            sizes = [renderer.params_to_dot_size(n.params) for n in g.nodes]
        else:
            raise ValueError(f'Unknown size_by: {size_by}')
    elif trace_idx==2:
        if size_by=='flops':
            sizes = [renderer.flops_to_dot_size_overview(n.params) for n in g.nodes]
        elif size_by=='params':
            sizes = [renderer.params_to_dot_size_overview(n.params) for n in g.nodes]
        else:
            raise ValueError(f'Unknown size_by: {size_by}')
    else:
        raise ValueError(f'Unknown trace_idx: {trace_idx}')
    
    colors = [renderer.get_node_color(n, color_by) for n in g.nodes]

    return dict(size=sizes, color=colors, colorscale='Bluered')

def build_menu_button(renderer:FigureRenderer, size_by:str, color_by:str):
    size_color_labels = dict(operation='operation', params='params', flops='FLOPS')
    return dict(
        args=[dict(
            marker=[
                {},{},
                build_marker_dict(renderer, 2, size_by, color_by),
                build_marker_dict(renderer, 3, size_by, color_by),
            ]
        )],
        label=f'Size by {size_color_labels[size_by]}, color by {size_color_labels[color_by]}',
        method="restyle"
    )

def build_styling_menu(renderer:FigureRenderer, pad_t=0):
    size_color_options = [('params','operation'),
                          ('flops','operation'),
                          ('params','flops'),
                          ('flops','params')]
    
    menu_buttons = [build_menu_button(renderer, size_by, color_by) for (size_by, color_by) in size_color_options]
    
    return dict(buttons=menu_buttons, showactive=True, direction="up",
                pad=dict(l=0, r=0, t=pad_t, b=0),
                x=0, xanchor="left",
                y=0, yanchor="top")

def build_overview_slider(renderer:FigureRenderer, pad_t=0, margin=0.1, num_levels_displayed=2, num_steps=20):
    steps = []
    for i in np.linspace(renderer.in_level-margin, renderer.out_level+margin-num_levels_displayed, num_steps):
        step = dict(
            method="relayout",
            args=[dict(
                shapes=[dict(type="rect", xref="x", yref="y",line=dict(color="#000000"),
                    x0=renderer.min_x-margin, y0=i, x1=renderer.max_x+margin, y1=i+num_levels_displayed,
                )],
                yaxis2=dict(range=[i+num_levels_displayed, i], showgrid=False, zeroline=False, tickmode='array', tickvals=[]),
            )],
            label="",
        )
        steps.append(step)

    return dict(
            steps=steps,
            active=0,
            currentvalue=dict(visible=False),
            pad=dict(t=pad_t, l=0, b=0, r=0),
            x=0, xanchor="left",
            y=0, yanchor="top"
        )

def build_node_trace(renderer:FigureRenderer, trace_idx:int, size_by:str, color_by:str):
    return go.Scatter(
        x=[n.x for n in g.nodes], 
        y=[n.y for n in g.nodes], 
        mode='markers', 
        marker=build_marker_dict(renderer, trace_idx, size_by, color_by),
        hovertemplate=(
            'Name: %{customdata[0]}<br>' +
            'Operation: %{customdata[1]}<br>' +
            'Activations: %{customdata[2]}<br>' +
            'Parameters: %{customdata[3]}<br>' +
            'FLOPS: %{customdata[4]}<br>' +
            '<br>' +
            'args: %{customdata[5]}<br>' +
            'kwargs: %{customdata[6]}<br>' +
            '<extra></extra>'
        ),
        customdata=[renderer.node_data(n) + renderer.node_arg_data(n) for n in g.nodes],
        showlegend=False
    )

In [14]:
g = mav.tracer.g
renderer = FigureRenderer(g)  # For helper functions e.g. "params_to_dot_size"
max_y = max([n.y for n in g.nodes])

fig = make_subplots(rows=1, cols=2, vertical_spacing=0.03, 
                    specs=[[{"type": "scatter"},{"type": "scatter"}]], column_widths=[1,3])
fig.update_layout(width=500, height=300, margin=dict(l=0, r=0, t=0, b=0))

# Add connections lines between the nodes
# * TODO: Do these first so that they are drawn behind nodes
x_coords, y_coords = [],[]
for c in g.connections:
    xs, ys = renderer.get_connection_coords(c)
    if x_coords: x_coords.append(None)
    if y_coords: y_coords.append(None)
    x_coords += xs
    y_coords += ys

line_trace = go.Scatter(
    x=x_coords, y=y_coords, mode="lines",
    line=dict(color="gray", width=1),            
    hoverinfo='skip',
    showlegend=False
)
fig.add_trace(line_trace, row=1, col=1)
fig.add_trace(line_trace, row=1, col=2)

# Add node markers
overview_node_trace = build_node_trace(renderer, trace_idx=2, size_by='params', color_by='operation')
node_trace = build_node_trace(renderer, trace_idx=3, size_by='params', color_by='operation')
fig.add_trace(overview_node_trace, row=1, col=1)
fig.add_trace(node_trace, row=1, col=2)

# Layout and additional controls
fig.update_xaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_yaxes(showgrid=False, zeroline=False, tickmode='array', tickvals=[])
fig.update_layout(xaxis_range=[-0.11,0.11], xaxis2_range=[-0.1,0.1],
                  yaxis_range=[max_y+0.2,-0.2], yaxis2_range=[1.9,-0.1])

fig.add_shape(type="rect",
    xref="x", yref="y",
    x0=-0.1, y0=-0.1, x1=0.1, y1=1.9,
    line=dict(color="#000000"),
    row=1, col=1
)

fig.update_layout(
    updatemenus=[build_styling_menu(renderer, pad_t=0)],
    sliders=[build_overview_slider(renderer, pad_t=40)]
)

fig.show(renderer='notebook_connected')

In [16]:
len(fig.data)

4

## Updates to FigureRenderer
At this point, the controls developed above are being applied to `FigureRenderer`. This section executes these in order to test them as they are applied

In [1]:
import sys, importlib
sys.path.append('..')
import torch
from torch import nn
import torch.nn.functional as F
from idlmav import MAV, color_graph_nodes, FigureRenderer

def reload_imports():
    importlib.reload(sys.modules['idlmav.renderers.figure_renderer'])
    importlib.reload(sys.modules['idlmav.renderers'])
    importlib.reload(sys.modules['idlmav.tracing'])
    importlib.reload(sys.modules['idlmav.layout'])
    importlib.reload(sys.modules['idlmav.idlmav'])
    importlib.reload(sys.modules['idlmav'])
    global MAV, FigureRenderer
    from idlmav import MAV, FigureRenderer

In [2]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
model = MnistCnn()
inputs = torch.randn((16,1,28,28))
mav = MAV(model, inputs)
color_graph_nodes(mav.tracer.g)

INFO:2025-02-25 00:22:37 655250:655250 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


In [19]:
reload_imports()

In [20]:
mav.show_figure(add_table=True, add_slider=True, add_overview=True, num_levels_displayed=2)