In [36]:
import json
import networkx as nx
from pyvis.network import Network

In [37]:
graph_data = json.load(open("attribution_graph.json"))
all_nodes = graph_data["nodes"]
all_edges = graph_data["edges"]



In [38]:
# count how many unique nodes there are that have edges
unique_nodes = set()
for edge in all_edges:
    unique_nodes.add(edge["source"]["id"])
    unique_nodes.add(edge["target"]["id"])

print(len(unique_nodes))

299


In [39]:
print(len(all_edges))

389


In [40]:
# find a input node in all_nodes
input_node = next((node for node in all_nodes if node["node_type"] == "InputNode"), None)
print(input_node)


{'id': 'input_0', 'layer_index': 0, 'token_position': 0, 'token_str': 'When', 'node_type': 'InputNode'}


In [41]:
G = nx.DiGraph()
for edge in all_edges:
    source = edge["source"]
    target = edge["target"]
    weight = edge["weight"]
    #if weight > 0.01:
    G.add_edge(source["id"], target["id"], weight=weight)
    if target["id"] in G.nodes:
        #G.add_node(target["id"], **target)
        G.nodes[target["id"]].update(**target)
    if source["id"] in G.nodes:
        #G.add_node(source["id"], **source)
        G.nodes[source["id"]].update(**source)




In [42]:
# Create a PyVis network
nt = Network(notebook=True, height='750px', width='100%', heading='Attribution Graph')
nt.options.physics.enabled = False




In [43]:
X_SCALE = 500  # Pixels per token position
Y_SCALE = 100  # Pixels per layer index
Y_OFFSET = 50 # Optional offset to push graph down slightly
position_counters = {}
for node_id, node_attrs in G.nodes(data=True):
    # Determine position based on node attributes
    x_pos = None
    y_pos = None
    print(node_attrs)
    if 'token_position' in node_attrs and node_attrs['token_position'] is not None:
         # Use token_position for x, scaled
        x_pos = node_attrs['token_position'] * X_SCALE 
        if node_attrs['token_position'] not in position_counters:
            position_counters[node_attrs['token_position']] = {}
        if node_attrs['layer_index'] not in position_counters[node_attrs['token_position']]:
            position_counters[node_attrs['token_position']][node_attrs['layer_index']] = 0
        position_counters[node_attrs['token_position']][node_attrs['layer_index']] += 1
    if 'feature_index' in node_attrs:
        x_pos += position_counters[node_attrs['token_position']][node_attrs['layer_index']]/10 * X_SCALE 

    if 'layer_index' in node_attrs and node_attrs['layer_index'] is not None:
         # Use layer_index for y, scaled and potentially inverted
         # Assuming layer 0 at top, higher layers go down
        y_pos = node_attrs['layer_index'] * Y_SCALE + Y_OFFSET
        # If you want layer 0 at the bottom, invert:
        # y_pos = (MAX_LAYER_INDEX - node_attrs['layer_index']) * Y_SCALE + Y_OFFSET
        # (You'd need to determine MAX_LAYER_INDEX first)
    print(x_pos, y_pos)
    # Add the node to Pyvis network with calculated position
    # Also add other attributes like title for hover info
    #title_str = f"ID: {node_id}<br>" + "<br>".join(f"{k}: {v}" for k, v in node_attrs.items())
    title_str = f"ID: {node_id}"
    # Check if positions are valid before adding
    if "input" in node_id or "output" in node_id:
        label = node_attrs.get('token_str', node_id)
        color = "red"
    elif "skip" or "error":
        label = node_id
        color = "green"
    else:
        color = "blue"
        label = node_id
    if x_pos is not None and y_pos is not None:
        nt.add_node(node_id, label=label, title=title_str, x=x_pos, y=y_pos, physics=False,color=color)
    else:
        # Add node without fixed position if attributes are missing (it might float)
        nt.add_node(node_id, label=label, title=title_str, physics=True,color=color) # Allow physics for nodes without position


# --- Add Edges ---
# Add edges from the NetworkX graph G
# Pyvis automatically connects nodes added above
for source, target, edge_attrs in G.edges(data=True):
    print(edge_attrs)
    weight = edge_attrs.get('weight')
    # You can customize edge appearance based on weight, etc.
    title_str = f"Weight: {weight:.4f}" if weight is not None else "No weight"
    nt.add_edge(source, target, value=weight, title=title_str)



{'id': 'intermediate_3_1_27883', 'layer_index': 1, 'token_position': 3, 'feature_index': 27883.0, 'activation': 0.3965, 'node_type': 'IntermediateNode'}
1550.0 150
{'id': 'intermediate_3_23_23237', 'layer_index': 23, 'token_position': 3, 'feature_index': 23237.0, 'activation': 0.0918, 'node_type': 'IntermediateNode'}
1550.0 2350
{'id': 'intermediate_3_2_27759', 'layer_index': 2, 'token_position': 3, 'feature_index': 27759.0, 'activation': 0.1777, 'node_type': 'IntermediateNode'}
1550.0 250
{'id': 'intermediate_3_23_69355', 'layer_index': 23, 'token_position': 3, 'feature_index': 69355.0, 'activation': 0.1289, 'node_type': 'IntermediateNode'}
1600.0 2350
{'id': 'intermediate_3_2_46883', 'layer_index': 2, 'token_position': 3, 'feature_index': 46883.0, 'activation': 0.2217, 'node_type': 'IntermediateNode'}
1600.0 250
{'id': 'intermediate_3_23_48774', 'layer_index': 23, 'token_position': 3, 'feature_index': 48774.0, 'activation': 0.1069, 'node_type': 'IntermediateNode'}
1650.0 2350
{'id': 

In [44]:
nt.show('attribution_graph_visualization.html')

attribution_graph_visualization.html


In [45]:
nt.from_nx(G)