<a href="https://colab.research.google.com/github/d112358/idlmav/blob/main/environments/explore_colab_widget_renderer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and setup

In [None]:
%pip install git+https://github.com/d112358/idlmav.git

import torch, torchvision
from idlmav import MAV, plotly_renderer
from IPython.display import display

from google.colab import output
output.enable_custom_widget_manager()

In [2]:
device = 'cpu'
model = torchvision.models.resnet18().to(device)
x = torch.randn(16,3,160,160).to(device)
mav = MAV(model, x, device=device)

## Issue: blank plotly graph subplots
* Initially, the table and slider displayed correctly, but the plotly graphs were just blank
* After some experimentation with the rendering fuction, it was found that if the number of connection lines were reduced to around 20 (the exact number varied run by run), everything would display correctly
* Eventually, this was resolved by using a single `go.Scatter` trace object for all the connection lines, using `None` values where we need discontinuities. This reduced the number of graphics objects, which were probably just too many to handle.

### High-level call via MAV object

In [None]:
from idlmav import WidgetRenderer
WidgetRenderer(mav.tracer.g)

In [None]:
container1 = mav.render_widget(add_slider=False, add_overview=True, add_table=True)
display(container1)

### Via intermediate object

In [None]:
from idlmav import WidgetRenderer
viewer = WidgetRenderer(mav.tracer.g)
container2 = viewer.render(add_slider=False, add_overview=True, add_table=True)
display(container2)

### Using copied contents of rendering method
* Used built-in refactoring to replace `self` with `viewer`
* Other methods of `WidgetRenderer` referenced from here are resolved in the imported library

In [None]:
# Additional imports
import time
import ipywidgets as widgets
import plotly.graph_objects as go
from IPython.display import display, HTML, Javascript

# Arguments passed to method
add_table:bool=True
add_slider:bool=False
add_overview:bool=True
num_levels_displayed:float=10
height_px=400

# Create viewer
viewer = WidgetRenderer(mav.tracer.g)

# Setup parameters
g = viewer.g
initial_y_range = viewer.fit_range([viewer.in_level+num_levels_displayed-0.5, viewer.in_level-0.5], viewer.full_y_range)
initial_x_range = viewer.full_x_range

# Create a new unique ID every time this is called
viewer.unique_id = f'{id(viewer)}_{int(time.time() * 1000)}'

# Create the main panel
main_panel_layout = widgets.Layout(flex = '0 1 auto', margin='0px', padding='0px', overflow='hidden')
main_fig_layout = go.Layout(
    width=max((viewer.graph_num_cols*100, 180)), height=height_px,
    plot_bgcolor='#e5ecf6',
    autosize=True,
    xaxis=dict(range=initial_x_range, showgrid=False, zeroline=False, visible=False),
    yaxis=dict(range=initial_y_range, showgrid=False, zeroline=False, visible=False),
    margin=dict(l=0, r=2, t=1, b=1),
    showlegend=False,
    title=dict(text=None)
)
viewer.main_fig = go.FigureWidget(layout=main_fig_layout)
viewer.main_panel = widgets.Box(children=[viewer.main_fig], layout=main_panel_layout)
panels = [viewer.main_panel]

# Add a selection marker (behind notes for hover purposes)
node = g.nodes[0]
sel_marker = go.Scatter(
    x=[node.x], y=[node.y],
    mode='markers',
    marker=dict(
        size=[viewer.params_to_dot_size(node.params)],
        color='rgba(0,0,0,0.1)',
        line=dict(color='black', width=3)
    ),
    hovertemplate='<extra></extra>', showlegend=False
)
viewer.main_fig.add_trace(sel_marker)
viewer.sel_marker_idx = len(viewer.main_fig.data)-1

# Add connections lines between the nodes
# * Use a single trace with `None` values separating different lines
# * Using a separate trace for every line cause a blank display
#   on Colab
# * Separate traces may also negatively impact responsiveness, e.g.
#   to pan & zoom actions
x_coords, y_coords = [],[]
for c in g.connections:
    xs, ys = viewer.get_connection_coords(c)
    if x_coords: x_coords.append(None)
    if y_coords: y_coords.append(None)
    x_coords += xs
    y_coords += ys

line_trace = go.Scatter(
    x=x_coords, y=y_coords, mode="lines",
    line=dict(color="gray", width=1),
    showlegend=False
)
viewer.main_fig.add_trace(line_trace)

# Add the node markers
node_trace = go.Scatter(
    x=[n.x for n in g.nodes],
    y=[n.y for n in g.nodes],
    mode='markers',
    marker=dict(
        size=[viewer.params_to_dot_size(n.params) for n in g.nodes],
        color=[viewer.get_node_color(n) for n in g.nodes],
        colorscale='Bluered'
    ),
    hovertemplate=(
        'Name: %{customdata[0]}<br>' +
        'Operation: %{customdata[1]}<br>' +
        'Activations: %{customdata[2]}<br>' +
        'Parameters: %{customdata[3]}<br>' +
        'FLOPS: %{customdata[4]}<br>' +
        '<extra></extra>'
    ),
    customdata=[viewer.node_data(n) for n in g.nodes],
    showlegend=False
)
viewer.main_fig.add_trace(node_trace)
node_trace_idx = len(viewer.main_fig.data)-1

# Add table if selected
if add_table:
    table_panel_layout = widgets.Layout(flex='0 0 auto', margin='0px', padding='0px', overflow='visible')
    table_style = viewer.write_table_style()
    table_html = viewer.write_table_html(g)
    scrolling_table_html = f'<div id="{viewer.html_scrolling_table_id()}" style="height: {height_px}px; overflow: auto; width: fit-content">{table_html}</div>'
    viewer.table_widget = widgets.Output()
    with viewer.table_widget:
        display(HTML(table_style))
        display(HTML(scrolling_table_html))
    viewer.table_panel = widgets.Box(children=[viewer.table_widget], layout=table_panel_layout)
    panels.append(viewer.table_panel)

# Add overview window if selected
overview_trace_idx = None
if add_overview:
    # Overview panel
    overview_panel_layout = widgets.Layout(flex = '0 0 auto', margin='0px', padding='0px', overflow='hidden')
    overview_fig_layout = go.Layout(
        width=max((viewer.graph_num_cols*15, 45)),
        height=height_px,
        plot_bgcolor='#dfdfdf',
        xaxis=dict(showgrid=False, zeroline=False, visible=False),
        yaxis=dict(range=viewer.full_y_range,
                    showgrid=False, zeroline=False, visible=False),
        margin=dict(l=0, r=4, t=1, b=1),
        showlegend=False,
        title=dict(text=None),
        hoverdistance=-1,  # Always hover over something
    )
    viewer.overview_fig = go.FigureWidget(layout=overview_fig_layout)
    viewer.overview_panel = widgets.Box(children=[viewer.overview_fig], layout=overview_panel_layout)
    panels.insert(0, viewer.overview_panel)

    # Connection lines
    viewer.overview_fig.add_trace(line_trace)

    # Nodes
    overview_nodes_trace = go.Scatter(
        x=[n.x for n in g.nodes],
        y=[n.y for n in g.nodes],
        mode='markers',
        marker=dict(
            size=[viewer.params_to_dot_size_overview(n.params) for n in g.nodes],
            color=[viewer.get_node_color(n) for n in g.nodes],
            colorscale='Bluered'
        ),
        hovertemplate='%{customdata[0]}<extra></extra>',
        customdata=[viewer.node_data(n) for n in g.nodes],
        showlegend=False
    )
    viewer.overview_fig.add_trace(overview_nodes_trace)
    overview_trace_idx = len(viewer.overview_fig.data)-1

    # Rectangle
    x0, y0, x1, y1 = viewer.min_x-0.5, initial_y_range[0], viewer.max_x+0.5, initial_y_range[1]
    rect_trace = go.Scatter(
        x=[x0, x0, x1, x1, x0],
        y=[y0, y1, y1, y0, y0],
        fill='toself',
        mode='lines',
        line=dict(color="#3d6399", width=1),
        fillcolor='rgba(112,133,161,0.25)',
        hoveron='points',
        hovertemplate='<extra></extra>',
        showlegend=False
    )
    viewer.overview_fig.add_trace(rect_trace)
    viewer.overview_rect_idx = len(viewer.overview_fig.data)-1

# Add slider if selected
# * Use negative values everywhere, because ipywidgets does not support
#   inverting the direction of vertical sliders
if add_slider:
    slider_panel_layout = widgets.Layout(flex = '0 0 auto', margin='0px', padding='0px', overflow='visible')
    viewer.slider_widget = widgets.FloatRangeSlider(
        value=[-initial_y_range[0], -initial_y_range[1]], min=-viewer.full_y_range[0], max=-viewer.full_y_range[1],
        step=0.01, description='', orientation='vertical', continuous_update=True,
        layout=widgets.Layout(height=f'{height_px}px')
    )
    viewer.slider_widget.readout = False  # For some reason it does not seem to work if set during construction
    viewer.slider_panel = widgets.Box(children=[viewer.slider_widget], layout=slider_panel_layout)
    panels.insert(0, viewer.slider_panel)

# Create container for all panels
# * To be displayed in Notebook using `display`
container_layout = widgets.Layout(
    width='100%',
    margin='0px', padding='0px')
container2 = widgets.HBox(panels, layout=container_layout)

# Set up event handlers
viewer.main_fig.data[node_trace_idx].on_click(viewer.on_main_panel_click)
viewer.main_fig.layout.on_change(viewer.on_main_panel_pan_zoom, 'xaxis.range', 'yaxis.range')
if viewer.overview_fig:
    viewer.overview_fig.data[overview_trace_idx].on_click(viewer.on_overview_panel_click)
if viewer.slider_widget:
    viewer.slider_widget.observe(viewer.on_slider_value_change, names="value")

# Restrict actions on plots
# self.main_fig.update_layout(config=dict(displayModeBar=False))
# [ "autoScale2d", "autoscale", "editInChartStudio", "editinchartstudio", "hoverCompareCartesian", "hovercompare", "lasso", "lasso2d", "orbitRotation", "orbitrotation", "pan", "pan2d", "pan3d", "reset", "resetCameraDefault3d", "resetCameraLastSave3d", "resetGeo", "resetSankeyGroup", "resetScale2d", "resetViewMap", "resetViewMapbox", "resetViews", "resetcameradefault", "resetcameralastsave", "resetsankeygroup", "resetscale", "resetview", "resetviews", "select", "select2d", "sendDataToCloud", "senddatatocloud", "tableRotation", "tablerotation", "toImage", "toggleHover", "toggleSpikelines", "togglehover", "togglespikelines", "toimage", "zoom", "zoom2d", "zoom3d", "zoomIn2d", "zoomInGeo", "zoomInMap", "zoomInMapbox", "zoomOut2d", "zoomOutGeo", "zoomOutMap", "zoomOutMapbox", "zoomin", "zoomout"]
viewer.main_fig.update_layout(modebar_remove=["toimage", "resetscale", "select", "lasso", "reset"])
viewer.main_fig.layout.dragmode = 'zoom'
if viewer.overview_fig:
    viewer.overview_fig.update_layout(modebar_remove=["toimage", "autoscale", "select", "lasso", "pan", "reset", "resetscale", "zoom", "zoomin", "zoomout"])
    viewer.overview_fig.layout.dragmode = False

# Display the container
display(container2)

HBox(children=(Box(children=(FigureWidget({
    'data': [{'line': {'color': 'gray', 'width': 1},
             …

## Issue: blank plotly graph subplots when using high-level interface
* After fixing the previous issue, the graphs were rendered correctly when generating them from directly from `WidgetRenderer`, but not from the high-level `MAV` object
  - What's more: once rendered successfully via `WidgetRenderer`, the graphs were now rendered correctly from any class until the kernel is restarted
* To investigate this, we rewrite the MAV class here under a different name
* The issue appeared to be the fact that the `WidgetRenderer` object goes out of scope from `MAV.render_widget`. Assigning this object to a property like `self.viewer` seemed to fix this.
  - One possible explanation may be that `WidgetRenderer` stores references to the `go.FigureWidget` objects it uses and that something may attempt a clean-up action when these references go out of scope
  - It was later discovered that this fix only works if the creation of the `MAV` object (and specifically the tracer inside it) and the rendering takes place in different notebook cells. This remains to be fully understood.
* For now, it will be recommended to call the 5 steps directly when widgets are desired on Colab.

Since debugging this requires frequent kernel restarts, we copy the first cell here to improve the workflow

In [None]:
%pip install git+https://github.com/d112358/idlmav.git

import torch, torchvision
from IPython.display import display

from google.colab import output
output.enable_custom_widget_manager()

In [None]:
from typing import Tuple, List, Dict, Set, Union, overload
from torch import nn, Tensor
import plotly.graph_objects as go
import ipywidgets as widgets
from idlmav.tracing import MavTracer
from idlmav.merging import merge_graph_nodes
from idlmav.coloring import color_graph_nodes
from idlmav.layout import layout_graph_nodes
from idlmav.renderers.widget_renderer import WidgetRenderer


In [None]:
class MAV1:
    def __init__(self, model:nn.Module, inputs:Union[Tensor, Tuple[Tensor]], device=None,
                 merge_threshold=0.01,
                 palette:Union[str, List[str]]='large',
                 avoid_palette_idxs:Set[int]=set([]),
                 fixed_color_map:Dict[str,int]={}):
        self.tracer = MavTracer(model, inputs, device=device)
        merge_graph_nodes(self.tracer.g,
                          cumul_param_threshold=merge_threshold)
        color_graph_nodes(self.tracer.g,
                          palette=palette,
                          avoid_palette_idxs=avoid_palette_idxs,
                          fixed_color_map=fixed_color_map)
        layout_graph_nodes(self.tracer.g)

    def render_widget(self,
                      add_table:bool=True,
                      add_slider:bool=True,
                      add_overview:bool=False,
                      num_levels_displayed:float=10,
                      height_px=400
                      ):
        self.viewer = WidgetRenderer(self.tracer.g)
        return self.viewer.render(add_table=add_table, add_slider=add_slider, add_overview=add_overview, num_levels_displayed=num_levels_displayed, height_px=height_px)

In [21]:
device = 'cpu'
model = torchvision.models.resnet18().to(device)
x = torch.randn(16,3,160,160).to(device)
mav = MAV1(model, x, device=device)

In [None]:
container1 = mav.render_widget(add_slider=False, add_overview=True, add_table=True)
display(container1)

HBox(children=(Box(children=(FigureWidget({
    'data': [{'line': {'color': 'gray', 'width': 1},
             …