In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Data handling
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from pathlib import Path

# Bokeh libraries
from bokeh.io import output_file, output_notebook 
from bokeh.plotting import figure, show, from_networkx
from bokeh.models import ColumnDataSource, Circle, MultiLine
from bokeh.layouts import row, column, gridplot
from bokeh.models import ColumnDataSource, CustomJS, Slider
from bokeh.colors import Color

import networkx as nx

import nx_utils

output_notebook()  # Render inline in a Jupyter Notebook

In [3]:
MODELDIR = Path("models")
TIME = "2023-09-12_162539"
NAME = 'Number-1'
dir = MODELDIR/NAME/TIME

list_of_files = sorted(dir.glob("*.pt"))
G = nx.DiGraph()

attribute_functions = [
    nx_utils.positive_negative,
    nx_utils.alpha_value,
]


G = nx_utils.checkpoints_to_networkx(G, list_of_files, attribute_functions)

# nx.draw(G, pos=nx.multipartite_layout(G, subset_key="layer"))

TypeError: checkpoints_to_networkx() missing 1 required positional argument: 'attribute_functions'

In [139]:
graph = from_networkx(G, layout_function=nx.multipartite_layout(G, subset_key="layer"))

graph.node_renderer.glyph = Circle(size=15, fill_color="lightblue")
graph.edge_renderer.glyph = MultiLine(
    line_color= "weight-display", # the field of the edges
    line_alpha=1, 
    line_width=2
)

plot = figure()

callback = CustomJS(
    args=dict(source=graph.edge_renderer.data_source), 
    code="""
        // make a shallow copy of the current data dict
        const new_data = Object.assign({}, source.data)
        
        // update the y column in the new data dict from the appropriate other column
        new_data.weight_display = source.data['w' + cb_obj.value]
        
        // set the new data on source, BokehJS will pick this up automatically
        source.data = new_data
    """)

slider = Slider(start=1, end=2, value=1, step=1, title="iteration")
slider.js_on_change('value', callback)

plot.renderers.append(graph)
layout = column(slider, plot)
show(layout)