In [4]:
from nnsight import LanguageModel
from nnsight.tracing.Graph import Graph
from typing import Any, Callable, Dict, List, Type, Union
from nnsight.tracing.Node import Node
from nnsight.module import Module
from nnsight.util import apply

import torch 

In [5]:
model = LanguageModel("gpt2")

In [6]:
def vis(model_graph, filename: str = "graph", format: str = "png"):
    import graphviz

    def style(value: Any) -> Dict[str, Any]:
        style = {}

        if isinstance(value, Node):
            if value.target == "null":
                style["color"] = "red"

            elif value.target == "argument":
                style["color"] = "green"

            elif value.target == "module":
                style["color"] = "green4"

            else:
                style["color"] = "black"
        else:
            style["color"] = "grey"
            style["shape"] = "box"

        return style

    arg_name_idx = 0

    def add_node(value: Any, graph: graphviz.Digraph, kname: str = None) -> str:
        nonlocal arg_name_idx

        if isinstance(value, Node):
            name = value.name
            label = (
                value.target
                if isinstance(value.target, str)
                else value.target.__name__
            )
        else:
            if isinstance(value, torch.Tensor):
                name = str(arg_name_idx)
                label = "Tensor"
            elif isinstance(value, str):
                name = str(arg_name_idx)
                label = f'"{value}"'
            else:
                name = str(arg_name_idx)
                label = str(value)

            arg_name_idx += 1

        if kname is not None:
            label = f"{kname}={label}"

        if f"\t{name}" not in graph.body:
            graph.node(name, label=label, **style(value))

        return name

    graph = graphviz.Digraph("round-table", comment="The Round Table")

    for node in model_graph.nodes.values():
        add_node(node, graph)

        for i, arg in enumerate(node.args):
            kname = None

            if node.target == "argument":
                if i == 0:
                    kname = "key"
                elif i == 1:
                    kname = "batch_size"
                elif i == 2:
                    kname = "batch_start"

            name = add_node(arg, graph, kname=kname)

            graph.edge(name, node.name)

        for kname, arg in node.kwargs.items():
            name = add_node(arg, graph, kname=kname)

            graph.edge(name, node.name)

    graph.render(filename=filename, format=format)

In [7]:
vis(model.transformer.graph)