In [1]:
from nnsight import LanguageModel
from nnsight.tracing.Graph import Graph
from typing import Any, Callable, Dict, List, Type, Union
from nnsight.tracing.Node import Node
from nnsight.module import Module
from nnsight.util import apply

import torch 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LanguageModel("gpt2")

In [3]:
def vis(model_graph, filename: str = "graph", format: str = "png"):
    import graphviz

    def style(value: Any) -> Dict[str, Any]:
        style = {}

        if isinstance(value, Node):
            if value.target == "null":
                style["color"] = "red"

            elif value.target == "argument":
                style["color"] = "green"

            elif value.target == "module":
                style["color"] = "green4"

            else:
                style["color"] = "black"
        else:
            style["color"] = "grey"
            style["shape"] = "box"

        return style

    arg_name_idx = 0

    def add_node(value: Any, graph: graphviz.Digraph, kname: str = None) -> str:
        nonlocal arg_name_idx

        if isinstance(value, Node):
            name = value.name
            label = (
                value.target
                if isinstance(value.target, str)
                else value.target.__name__
            )
        else:
            if isinstance(value, torch.Tensor):
                name = str(arg_name_idx)
                label = "Tensor"
            elif isinstance(value, str):
                name = str(arg_name_idx)
                label = f'"{value}"'
            else:
                name = str(arg_name_idx)
                label = str(value)

            arg_name_idx += 1

        if kname is not None:
            label = f"{kname}={label}"

        if f"\t{name}" not in graph.body:
            graph.node(name, label=label, **style(value))

        return name

    graph = graphviz.Digraph("round-table", comment="The Round Table")

    storage = {}

    for node in model_graph.nodes.values():
        

        if "fetch_attr" in node.name:
            start = node.args[0].name
            end = node.args[1]
            graph.edge(start, end)
            storage[start] = end
            
            continue
        else:
            add_node(node, graph)

        for i, arg in enumerate(node.args):

            kname = None

            if node.target == "argument":
                if i == 0:
                    kname = "key"
                elif i == 1:
                    kname = "batch_size"
                elif i == 2:
                    kname = "batch_start"

            name = add_node(arg, graph, kname=kname)

            if name in storage:
                graph.edge(storage[name], node.name)
                continue

            graph.edge(name, node.name)

        # for kname, arg in node.kwargs.items():
        #     name = add_node(arg, graph, kname=kname)
            
        #     graph.edge(name, node.name)

    graph.render(filename=filename, format=format)

In [4]:
vis(model.transformer.h[0].attn.graph)

In [33]:
model.transformer.h[0].attn.graph.nodes.values()

dict_values([<nnsight.tracing.Node.Node object at 0x7f54cd676350>, <nnsight.tracing.Node.Node object at 0x7f559f081ea0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b5b0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b730>, <nnsight.tracing.Node.Node object at 0x7f54c5e0ba30>, <nnsight.tracing.Node.Node object at 0x7f54c5e09930>, <nnsight.tracing.Node.Node object at 0x7f54c5e0a0b0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0a0e0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0a1d0>, <nnsight.tracing.Node.Node object at 0x7f559f082fb0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b790>, <nnsight.tracing.Node.Node object at 0x7f54c5e0a140>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b760>, <nnsight.tracing.Node.Node object at 0x7f54c5e0afe0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b010>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b0a0>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b100>, <nnsight.tracing.Node.Node object at 0x7f54c5e0b1c0>, <nnsight.tracin

In [29]:
print(model.transformer.h[0].attn.graph)

  %module_0:[  args:() l:23 d:0]
  %argument_0:[  args:('hidden_states') l:1 d:0]
  %fetch_attr_0:[  args:(module_0,'c_attn') l:1 d:1]
  %proxy_call_0:[  args:(fetch_attr_0,argument_0) l:1 d:2]
  %fetch_attr_1:[  args:(proxy_call_0,'split') l:1 d:1]
  %fetch_attr_2:[  args:(module_0,'split_size') l:1 d:1]
  %proxy_call_1:[  args:(fetch_attr_1,fetch_attr_2) l:3 d:2]
  %getitem_0:[  args:(proxy_call_1,0) l:2 d:1]
  %getitem_1:[  args:(proxy_call_1,1) l:2 d:1]
  %getitem_2:[  args:(proxy_call_1,2) l:2 d:1]
  %fetch_attr_3:[  args:(module_0,'_split_heads') l:0 d:1]
  %fetch_attr_4:[  args:(module_0,'num_heads') l:1 d:1]
  %fetch_attr_5:[  args:(module_0,'head_dim') l:1 d:1]
  %fetch_attr_6:[  args:(getitem_0,'size') l:1 d:1]
  %proxy_call_2:[  args:(fetch_attr_6) l:1 d:1]
  %getitem_3:[  args:(proxy_call_2,slice(None, -1, None)) l:1 d:1]
  %add_0:[  args:(getitem_3,('fetch_attr_4', 'fetch_attr_5')) l:1 d:3]
  %fetch_attr_7:[  args:(getitem_0,'view') l:1 d:1]
  %proxy_call_3:[  args:(fetch_

In [73]:
list(model.transformer.h[0].attn.graph.nodes.values())[3].args[0].args

[<nnsight.tracing.Node.Node at 0x7f54cd676350>, 'c_attn']

In [66]:
for i in list(model.transformer.h[0].attn.graph.nodes.values()):
    print(i.args)

[]
['hidden_states']
[<nnsight.tracing.Node.Node object at 0x7f54cd676350>, 'c_attn']
[<nnsight.tracing.Node.Node object at 0x7f54c5e0b5b0>, <nnsight.tracing.Node.Node object at 0x7f559f081ea0>]
[<nnsight.tracing.Node.Node object at 0x7f54c5e0b730>, 'split']
[<nnsight.tracing.Node.Node object at 0x7f54cd676350>, 'split_size']
[<nnsight.tracing.Node.Node object at 0x7f54c5e0ba30>, <nnsight.tracing.Node.Node object at 0x7f54c5e09930>]
[<nnsight.tracing.Node.Node object at 0x7f54c5e0a0b0>, 0]
[<nnsight.tracing.Node.Node object at 0x7f54c5e0a0b0>, 1]
[<nnsight.tracing.Node.Node object at 0x7f54c5e0a0b0>, 2]
[<nnsight.tracing.Node.Node object at 0x7f54cd676350>, '_split_heads']
[<nnsight.tracing.Node.Node object at 0x7f54cd676350>, 'num_heads']
[<nnsight.tracing.Node.Node object at 0x7f54cd676350>, 'head_dim']
[<nnsight.tracing.Node.Node object at 0x7f54c5e0a0e0>, 'size']
[<nnsight.tracing.Node.Node object at 0x7f54c5e0afe0>]
[<nnsight.tracing.Node.Node object at 0x7f54c5e0b010>, slice(None

In [80]:
import networkx as nx

def add_node_and_args_to_graph(graph, node):
    """
    Recursively adds a node and its arguments (if any) to the graph.
    Each argument that is a node is also explored recursively.
    """
    # Add the current node to the graph if not already present
    if node.name not in graph:
        graph.add_node(node.name)

    # Iterate through the arguments of the node
    for arg in node.args:
        # Check if the argument is a node itself
        if isinstance(arg, Node):  # Assuming Node is the class of the nodes
            # Add an edge between the current node and the argument node
            graph.add_edge(node.name, arg.name)

            # Recursively add this argument node and its arguments to the graph
            add_node_and_args_to_graph(graph, arg)

# Assuming 'model' is your model object and you have access to it
nodes = list(model.transformer.h[0].attn.graph.nodes.values())

# Create a directed graph
G = nx.DiGraph()

# Add each node and its arguments to the graph
for node in nodes:
    add_node_and_args_to_graph(G, node)

# Now, G contains the graph structure of your model's nodes and their arguments


In [81]:
list(G.nodes)

['module_0',
 'argument_0',
 'fetch_attr_0',
 'proxy_call_0',
 'fetch_attr_1',
 'fetch_attr_2',
 'proxy_call_1',
 'getitem_0',
 'getitem_1',
 'getitem_2',
 'fetch_attr_3',
 'fetch_attr_4',
 'fetch_attr_5',
 'fetch_attr_6',
 'proxy_call_2',
 'getitem_3',
 'add_0',
 'fetch_attr_7',
 'proxy_call_3',
 'fetch_attr_8',
 'proxy_call_4',
 'fetch_attr_9',
 'fetch_attr_10',
 'fetch_attr_11',
 'fetch_attr_12',
 'proxy_call_5',
 'getitem_4',
 'add_1',
 'fetch_attr_13',
 'proxy_call_6',
 'fetch_attr_14',
 'proxy_call_7',
 'fetch_attr_15',
 'fetch_attr_16',
 'fetch_attr_17',
 'fetch_attr_18',
 'proxy_call_8',
 'getitem_5',
 'add_2',
 'fetch_attr_19',
 'proxy_call_9',
 'fetch_attr_20',
 'proxy_call_10',
 'fetch_attr_21',
 'fetch_attr_22',
 'fetch_attr_23',
 'proxy_call_11',
 'matmul_0',
 'fetch_attr_24',
 'fetch_attr_25',
 'proxy_call_12',
 'pow_0',
 'fetch_attr_26',
 'fetch_attr_27',
 'full_0',
 'truediv_0',
 'fetch_attr_28',
 'fetch_attr_29',
 'fetch_attr_30',
 'proxy_call_13',
 'fetch_attr_31',
 '

In [79]:
list(G.edges)

[('fetch_attr_0', 'module_0'),
 ('proxy_call_0', 'fetch_attr_0'),
 ('proxy_call_0', 'argument_0'),
 ('fetch_attr_1', 'proxy_call_0'),
 ('fetch_attr_2', 'module_0'),
 ('proxy_call_1', 'fetch_attr_1'),
 ('proxy_call_1', 'fetch_attr_2'),
 ('getitem_0', 'proxy_call_1'),
 ('getitem_1', 'proxy_call_1'),
 ('getitem_2', 'proxy_call_1'),
 ('fetch_attr_3', 'module_0'),
 ('fetch_attr_4', 'module_0'),
 ('fetch_attr_5', 'module_0'),
 ('fetch_attr_6', 'getitem_0'),
 ('proxy_call_2', 'fetch_attr_6'),
 ('getitem_3', 'proxy_call_2'),
 ('add_0', 'getitem_3'),
 ('fetch_attr_7', 'getitem_0'),
 ('proxy_call_3', 'fetch_attr_7'),
 ('proxy_call_3', 'add_0'),
 ('fetch_attr_8', 'proxy_call_3'),
 ('proxy_call_4', 'fetch_attr_8'),
 ('fetch_attr_9', 'module_0'),
 ('fetch_attr_10', 'module_0'),
 ('fetch_attr_11', 'module_0'),
 ('fetch_attr_12', 'getitem_1'),
 ('proxy_call_5', 'fetch_attr_12'),
 ('getitem_4', 'proxy_call_5'),
 ('add_1', 'getitem_4'),
 ('fetch_attr_13', 'getitem_1'),
 ('proxy_call_6', 'fetch_attr_13'