**This notebook demonstrates how to extract and visualize computational graphs.**

In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
import onnx
import matplotlib.pyplot as plt
import networkx as nx

# Converting pytorch model to onnx

In [18]:
def convert_pytorch_to_onnx(model_name, output_path, input_shape=None, text_example=None):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    
    model.eval()
    
    if text_example:
        inputs = tokenizer(text_example, return_tensors="pt")
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
    elif input_shape:
        batch_size, seq_length = input_shape
        input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
        attention_mask = torch.ones(batch_size, seq_length)
    else:
        batch_size, seq_length = 1, 16
        input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
        attention_mask = torch.ones(batch_size, seq_length)
    
    torch.onnx.export(
        model,                                         # PyTorch model
        (input_ids, attention_mask),                   # Model input
        output_path,                                   # Output file
        export_params=True,                            # Store the trained weights
        opset_version=14,                              # ONNX version
        do_constant_folding=True,                      # Optimization
        input_names=["input_ids", "attention_mask"],   # Input names
        output_names=["last_hidden_state"],            # Output names
        dynamic_axes={                                  # Dynamic axes
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "attention_mask": {0: "batch_size", 1: "sequence_length"},
            "last_hidden_state": {0: "batch_size", 1: "sequence_length"}
        }
    )
    
    print(f"Model exported to {output_path}")
    return output_path

In [19]:
model_name = "bert-base-uncased"
onnx_path = "bert_model.onnx"
text_input = "This is a sample text to analyze the model's computational graph."

convert_pytorch_to_onnx(model_name, onnx_path, text_example=text_input)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Model exported to bert_model.onnx


'bert_model.onnx'

# Analyzing the computational graph

In [40]:
# Using onnx_model.graph you can acces the computational graph of the model.

def analyze_onnx_model(onnx_path):
    onnx_model = onnx.load(onnx_path)
    
    print(f"Model IR version: {onnx_model.ir_version}")
    print(f"Producer name: {onnx_model.producer_name}")
    print(f"Producer version: {onnx_model.producer_version}")
    print(f"Model version: {onnx_model.model_version}")
    
    op_types = {}
    for node in onnx_model.graph.node:
        op_type = node.op_type
        op_types[op_type] = op_types.get(op_type, 0) + 1
    
    print("\nOperation types:")
    for op_type, count in sorted(op_types.items(), key=lambda x: x[1], reverse=True):
        print(f"  {op_type}: {count}")
    
    print("\nInputs:")
    for inp in onnx_model.graph.input:
        print(f"  {inp.name}")
    
    print("\nOutputs:")
    for outp in onnx_model.graph.output:
        print(f"  {outp.name}")
    
    print("\nModel structure:")
    for i, node in enumerate(onnx_model.graph.node[:20]):  # Print first 20 nodes
        print(f"  Node {i}: {node.op_type} - {node.name}")
        print(f"    Inputs: {', '.join(node.input)}")
        print(f"    Outputs: {', '.join(node.output)}")
    
    if len(onnx_model.graph.node) > 20:
        print(f"  ... and {len(onnx_model.graph.node) - 20} more nodes")

In [None]:
onnx_path = "bert_model.onnx"
analyze_onnx_model(onnx_path)

Model IR version: 7
Producer name: pytorch
Producer version: 2.6.0
Model version: 0

Operation types:
  Constant: 427
  Add: 172
  Shape: 114
  Unsqueeze: 105
  Gather: 104
  MatMul: 96
  Mul: 75
  Sqrt: 61
  Concat: 50
  Reshape: 50
  ReduceMean: 50
  Div: 49
  Transpose: 48
  Cast: 27
  Sub: 26
  Pow: 25
  Slice: 14
  Softmax: 12
  Erf: 12
  Where: 3
  ConstantOfShape: 2
  Equal: 2
  Expand: 2
  Gemm: 1
  Tanh: 1

Inputs:
  input_ids
  attention_mask

Outputs:
  last_hidden_state
  1895

Model structure:
  Node 0: Constant - /Constant
    Inputs: 
    Outputs: /Constant_output_0
  Node 1: Shape - /Shape
    Inputs: input_ids
    Outputs: /Shape_output_0
  Node 2: Constant - /Constant_1
    Inputs: 
    Outputs: /Constant_1_output_0
  Node 3: Gather - /Gather
    Inputs: /Shape_output_0, /Constant_1_output_0
    Outputs: /Gather_output_0
  Node 4: Shape - /Shape_1
    Inputs: input_ids
    Outputs: /Shape_1_output_0
  Node 5: Constant - /Constant_2
    Inputs: 
    Outputs: /Constant_

# Extracting subgraphs with specific operators involved

In [42]:
def extract_subgraph(onnx_path, target_ops=None, output_path="onnx_subgraph.png", viz=False):
    if target_ops is None:
        target_ops = ["Attention", "MatMul", "Add"]
    
    onnx_model = onnx.load(onnx_path)
    
    G = nx.DiGraph()
    
    for i, node in enumerate(onnx_model.graph.node):
        op_type = node.op_type
        if op_type in target_ops:
            node_name = node.name or f"{op_type}_{i}"
            
            G.add_node(node_name, op_type=op_type, shape="box", style="filled", color="lightblue")
            
            for input_name in node.input:
                input_base = input_name.split(".")[0] if "." in input_name else input_name
                G.add_edge(input_base, node_name)
            
            for output_name in node.output:
                output_base = output_name.split(".")[0] if "." in output_name else output_name
                G.add_edge(node_name, output_base)
    
    if viz:
        plt.figure(figsize=(15, 8))
        pos = nx.spring_layout(G, k=0.5)
        
        nx.draw(
            G, pos, with_labels=True, arrows=True, 
            node_color="lightgreen", node_size=2000, 
            font_size=8, font_weight="bold", 
            edge_color="gray", width=0.5
        )
        
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"Subgraph visualization saved to {output_path}")
    return G

# Visualization

In [47]:
!python ./net_drawer.py --input ./bert_model.onnx --output bert_model.dot --embed_docstring

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [48]:
# Open svg file in your browser
!dot -Tsvg ./bert_model.dot -o bert_model.svg

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
