## Introduction
At this point, `idlmav` was tested with a more diverse set of models than the convolutional image classification models it was initially developed with (see [15_explore_misc_models.ipynb](./15_explore_misc_models.ipynb)). These tests revealed that the documented limitations of `torch.fx.symbolic_trace` are more restrictive for our use case than initially thought. 

This notebook experiments with the following remedies, which seemed promising to address the issues noted in [15_explore_misc_models.ipynb](./15_explore_misc_models.ipynb):
* Using the `concrete_args` input to `fx.symbolic_trace` to fix some positional arguments to the `forward` function
* Using `torch.compile` with a custom back-end based on [This tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torchdynamo-and-fx-graphs)

### Workflow note
Sections that start with `import sys` are suitable starting points to resume from after restarting the kernel. 

This approach simplified the experimental process, where some exploratory code may have bloated or corrupted the notebook DOM or used memory in ways that slowed down notebook navigation and execution.

## Observations
* Specifying `concrete_args` to `torch.fx.symbolic_trace` removed some instances of `TraceError: symbolically traced variables cannot be used as inputs to control flow`, but not all of them. In fact, all the models that raised this error, just raised it again elsewhere after introducing `concrete_args`
* Using `torch.compile` with a custom back-end seems much more promising, running without errors on all models tested
* In some cases, `torch.compile` produces multiple graphs, for example:
  * When data-dependent control flow is detected
  * When unsupported functions are called, the graphs are broken into parts before and after the unsupported functions
  * For most models tested, there is one dominant graph, detectable based on number of `call_function` operations
    * If the need arises later, an interactive control may be added to `idlmav` to allow the user to select different graphs
* `torch.compile` optimizes out `call_module` operations and replaces modules by direct function calls
  * Trainable parameters are replaced by `placeholder` operations in the graph
  * It is possible to detect these and infer which trainable parameters belong to which functions based on the graph
  * Activation shapes, number of parameters and FLOPS can still be calculated from the graph without relying on any model-specific knowledge

## CLIP model

### Without any adjustments

In [None]:
import sys
sys.path.append('..')
from transformers import CLIPProcessor, CLIPModel
import torch
from idlmav import MAV, plotly_renderer

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

mav = MAV(model.vision_model, image_inputs, device=device)
with plotly_renderer('notebook_connected'): mav.show_figure()

TraceError: symbolically traced variables cannot be used as inputs to control flow

### With `concrete_args`
Data-based control path detected in `CLIPVisionEmbeddings.forward`, at the following line:
```python
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
```

In [None]:
import sys
sys.path.append('..')
from transformers import CLIPProcessor, CLIPModel
import torch
from idlmav import MAV, plotly_renderer

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

none_args = ['input_ids','attention_mask','position_ids','return_loss','output_attentions','output_hidden_states','interpolate_pos_encoding','return_dict']
concrete_args={arg_name:None for arg_name in none_args}
mav = MAV(model.vision_model, image_inputs, device=device, concrete_args=concrete_args)
with plotly_renderer('notebook_connected'): mav.show_figure()

TraceError: symbolically traced variables cannot be used as inputs to control flow

### Using `torch.compile` with custom `fx` backend

In [1]:
import sys
sys.path.append('..')
import torch.compiler
from transformers import CLIPProcessor, CLIPModel
import torch

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

def custom_backend(gm, example_inputs):
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    
    for node in gm.graph.nodes:                                                                                                                                   
        if node.op == "call_module":                                                                                                                              
            print("call_module: ", node.target)     

    gm.graph.print_tabular()
    return gm.forward

vision_model = model.vision_model
torch._dynamo.reset()
compiled_model = torch.compile(vision_model, backend=custom_backend)
outputs = compiled_model(image_inputs)


custom backend called with FX graph:
opcode         name                                                                                                    target                                                                                                  args                                                                                                                                                                                                                            kwargs
-------------  ------------------------------------------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  -----------------------------------------------------
placeho

In [4]:
compiled_model

OptimizedModule(
  (_orig_mod): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
      (position_embedding): Embedding(197, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
      

### Building the MavGraph
* For this model, `torch.compile` returns a single FX graph with 204 trainable model parameters as additional `placeholder` nodes. 
  * All of these have names starting with `l_self_*`, most ending in `*_weight_` or `*_bias_`
* One possible strategy is to just filter out any nodes starting with `l_self_*` during `MavTracer.build_graph`. 
  * This will cause these nodes to be omitted from connections as well
  * **_Added later:_** These placeholders should actually be used to calculate parameter counts for modules that have been optimized to direct function calls. See the next section for details.
* The methods of `MavTracer2` below were copied as-is from `MavTracer`
  * The key change is the line `if n.name.startswith('l_self_'): continue`

In [None]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Tuple, Optional, Any, Union
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn, fx, Tensor
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection

class MavTracer2(MavTracer):
    def custom_backend(self, gm:fx.GraphModule, example_inputs):
        self.gms.append(gm)
        return gm.forward
    
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # Have to initialize this here, because run() is called from super().__init__()
        self.gms:List[fx.GraphModule] = []

        # 1st pass: symbolic tracing using torch.compile
        torch._dynamo.reset()
        compiled_model = torch.compile(model.vision_model.to(device), backend=self.custom_backend)
        outputs = compiled_model(self.inputs)
        self.gm = self.gms[0]
        self.interp = ShapeMacInterpreter(self.gm)

        # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
        try:
            for n in self.gm.graph.nodes:
                if n.op == 'call_module':
                    m:nn.Module = self.interp.fetch_attr(n.target)
                    self.target_names[n] = m.__class__.__name__
                    self.param_counts[n] = get_num_trainable_params(m)
                elif n.op == 'call_function':
                    self.target_names[n] = n.target.__name__
        except Exception as e:
            self.err_node = n
            warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

        # 3rd pass: forward pass using torch.fx.Interpreter
        try:
            self.interp.run(self.inputs)
        except Exception as e:
            msg = 'Forward pass failed.'
            n1 = self.interp.last_successful_node
            if n1:
                target_name = self.target_names.get(n1, None)
                node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                msg += f' Last successful node: "{node_name}".'
            n2 = self.interp.running_node
            if n2:
                target_name = self.target_names.get(n2, None)
                node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                msg += f' Possible error node: "{node_name}".'
            self.err_node = self.interp.running_node
            warnings.warn(f'{msg}: {e}')

    def build_graph(self):
        nodes: List[MavNode] = []
        nodes_by_name: Dict[str, MavNode] = {}
        connections: List[MavConnection] = []
        existing_connections: Set[Tuple[MavNode, MavNode]] = set([])
        for n in self.gm.graph.nodes:
            # Create a new node and append it to the list
            if n.name.startswith('l_self_'): continue
            target_name = self.target_names.get(n, '')
            node = MavNode(n.name, 0, 0)
            node.operation = self.get_operation(n.op, target_name)
            node.activations = self.interp.shapes.get(n, (0,))
            node.params = self.param_counts.get(n, 0)
            node.flops = self.interp.flops.get(n, 0)
            node.metadata['kwargs'] = n.kwargs 
            if n == self.err_node: node.error = True
            nodes.append(node)
            nodes_by_name[n.name] = node

            # Find connections to and from this node
            in_nodes = n.all_input_nodes
            in_node_names = [n2.name for n2 in in_nodes]
            for in_node_name in in_node_names:
                if in_node_name not in nodes_by_name: continue
                from_node = nodes_by_name[in_node_name]
                to_node = node
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node))
                    
            out_nodes = list(n.users.keys())
            out_node_names = [n2.name for n2 in out_nodes]
            for out_node_name in out_node_names:
                if out_node_name not in nodes_by_name: continue
                from_node = node
                to_node = nodes_by_name[out_node_name]
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node)) 
        
        # Assemble into graph
        self.g = MavGraph(nodes, connections)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

tracer = MavTracer2(model.vision_model, image_inputs, device=device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = WidgetRenderer(tracer.g)
    display(renderer.render())



Forward pass failed. Last successful node: "l_pixel_values_". Possible error node: "l_self_modules_embeddings_modules_patch_embedding_parameters_weight_".: Expected positional argument for parameter L_self_modules_embeddings_modules_patch_embedding_parameters_weight_, but one was not passed in!

While executing %l_self_modules_embeddings_modules_patch_embedding_parameters_weight_ : torch.nn.parameter.Parameter [num_users=1] = placeholder[target=L_self_modules_embeddings_modules_patch_embedding_parameters_weight_]
Original traceback:
  File "/home/dev/ai/idlmav/.venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 1093, in forward
    hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  File "/home/dev/ai/idlmav/.venv/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 247, in forward
    target_dtype = self.patch_embedding.weight.dtype




HBox(children=(Box(children=(FloatRangeSlider(value=(-9.5, 0.5), layout=Layout(height='400px'), max=0.5, min=-…

In [2]:
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()

### Tracing the activations and parameters
* `torch.compile` tends to optimize out `call_module`
  * See [this issue](https://github.com/pytorch/pytorch/issues/126566) referring to [this pull request](https://github.com/pytorch/pytorch/pull/116312)
  * It tends to replace the module with the equivalent function and move trainable parameters out to the top-level as `placeholder` operations
* After much browsing, asking LLMs for help, reading through the `torch.compile` code-base and running experiments, it seemed like trying to prevent this optimization may be futile
* Instead of walking through all `call_module` nodes, one can just calculate the number of parameters from the `placeholder` inputs that are clearly model parameters
  * These predictably start with `l_self_`
  * The `example_inputs` input to `custom_backend` already receives tensors of the appropriate shape for each trainable parameter
* The `ShapeMacInterpreter` forward pass also fails for the graph generated by `torch_compile`
  * This prevents the activations and FLOPS from being calculated
  * Using the `example_inputs` input to `custom_backend` for this forward pass also seems promising
  * More work may be needed to calculate the FLOPS, as `call_module` will no longer be called on `ShapeMacInterpreter`

In [None]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Mapping, Tuple, Optional, Any, Union
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn, fx, Tensor
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection

class MavTracer3(MavTracer):
    def custom_backend(self, gm:fx.GraphModule, example_inputs: List[torch.Tensor]):
        self.graphs.append((gm, example_inputs))
        # According to `fx.Interpreter.run()`, positional function args are consumed left-to-right by `placeholder` nodes.
        x_iter = iter(example_inputs)
        for n in gm.graph.nodes:
            if n.op != 'placeholder': continue
            x = next(x_iter)
            self.input_sizes[n.name] = x.nelement()
        return gm.forward
    
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # Have to initialize them here, because run() is called from super().__init__()
        self.graphs:List[Tuple[fx.GraphModule, List[torch.Tensor]]] = []
        self.input_sizes:Mapping[str, int] = {}

        # Compile to intercept all fx graphs
        torch._dynamo.reset()
        compiled_model = torch.compile(model.vision_model.to(device), backend=self.custom_backend)
        outputs = compiled_model(self.inputs)

        for self.gm, xs in self.graphs:
            # 1st pass: symbolic tracing using torch.compile
            self.interp = ShapeMacInterpreter(self.gm)

            # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
            try:
                for n in self.gm.graph.nodes:
                    if n.op == 'call_module':
                        m:nn.Module = self.interp.fetch_attr(n.target)
                        self.target_names[n] = m.__class__.__name__
                        self.param_counts[n] = get_num_trainable_params(m)
                    elif n.op == 'call_function':
                        self.target_names[n] = n.target.__name__
            except Exception as e:
                self.err_node = n
                warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

            # 3rd pass: forward pass using torch.fx.Interpreter
            try:
                self.interp.run(*xs)
            except Exception as e:
                msg = 'Forward pass failed.'
                n1 = self.interp.last_successful_node
                if n1:
                    target_name = self.target_names.get(n1, None)
                    node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                    msg += f' Last successful node: "{node_name}".'
                n2 = self.interp.running_node
                if n2:
                    target_name = self.target_names.get(n2, None)
                    node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                    msg += f' Possible error node: "{node_name}".'
                self.err_node = self.interp.running_node
                warnings.warn(f'{msg}: {e}')

    def build_graph(self):
        nodes: List[MavNode] = []
        nodes_by_name: Dict[str, MavNode] = {}
        connections: List[MavConnection] = []
        existing_connections: Set[Tuple[MavNode, MavNode]] = set([])
        for n in self.gm.graph.nodes:
            # Create a new node and append it to the list
            if n.name.startswith('l_self_'): continue
            target_name = self.target_names.get(n, '')
            node = MavNode(n.name, 0, 0)
            node.operation = self.get_operation(n.op, target_name)
            node.activations = self.interp.shapes.get(n, (0,))
            node.params = self.param_counts.get(n, 0)
            node.flops = self.interp.flops.get(n, 0)
            node.metadata['kwargs'] = n.kwargs 
            if n == self.err_node: node.error = True
            nodes.append(node)
            nodes_by_name[n.name] = node

            # Find connections to and from this node
            in_nodes = n.all_input_nodes
            in_node_names = [n2.name for n2 in in_nodes]
            for in_node_name in in_node_names:
                if in_node_name.startswith('l_self_') and in_node_name in self.input_sizes:
                    node.params += self.input_sizes[in_node_name]
                if in_node_name not in nodes_by_name: continue
                from_node = nodes_by_name[in_node_name]
                to_node = node
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node))
                    
            out_nodes = list(n.users.keys())
            out_node_names = [n2.name for n2 in out_nodes]
            for out_node_name in out_node_names:
                if out_node_name not in nodes_by_name: continue
                from_node = node
                to_node = nodes_by_name[out_node_name]
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node)) 
        
        # Assemble into graph
        self.g = MavGraph(nodes, connections)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

tracer = MavTracer3(model.vision_model, image_inputs, device=device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = WidgetRenderer(tracer.g)
    display(renderer.render(add_overview=True))


HBox(children=(Box(children=(FloatRangeSlider(value=(-9.5, 0.5), layout=Layout(height='400px'), max=0.5, min=-…

In [2]:
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()

#### Experimental cells

In [3]:
# Verify total number of parameters
sum(p.numel() for p in model.vision_model.parameters() if p.requires_grad)

85799424

### Tracing the FLOPS for inlined modules
* After having some trouble getting this going with `torchprofiler`, I discovered `torch.profile`
  * `torch.profile` provides a context manager and profiles code executed within the managed context, including FLOPS estimation
  * This may provide an opportunity to remove `torchprofiler` as a dependency from the project
  * Before this, I would like to at least compare estimates from `torchprofiler` and `torch.profile`
* Note, we're inheriting some methods directly from `fx.interpreter` below
  * This is the super-class of `ShapeMacInterpreter`
  * We do this because we're experimenting with possble updates to `ShapeMacInterpreter`. After updating `ShapeMacInterpreter`, we would still like to be able to run this code, if feasible.
  * When directly calling a method on a higher level descendent, the additional `self` argument is required as shown below  

In [None]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Mapping, Tuple, Optional, Any, Union
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn, fx, Tensor
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection

class ShapeMacInterpreter4(ShapeMacInterpreter):
    def run_node(self, n:fx.Node) -> Any:
        # Run the node
        self.cur_flops = None
        self.running_node = n
        result = fx.Interpreter.run_node(self,n)
        self.running_node = None

        # Retrieve the shape
        if isinstance(result, Tensor):
            shape = tuple(result.shape)
        else:
            shape = (0,0,0,0)
        self.shapes[n] = shape

        # Store the number of FLOPS if calculated
        if n.op == 'call_module' or n.op == 'call_function':
            if self.cur_flops is not None: self.flops[n] = self.cur_flops

        # Update the state and return the result
        self.last_successful_node = n
        return result
    
    def call_function(self, target, args, kwargs):
        result = fx.Interpreter.call_function(self, target, args, kwargs)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
            record_shapes=True,
            with_flops=True
        ) as prof:
            target(*args, **kwargs)
        flops = prof.key_averages().total_average().flops
        self.cur_flops = flops
        return result
    
class MavTracer4(MavTracer):
    def custom_backend(self, gm:fx.GraphModule, example_inputs: List[torch.Tensor]):
        self.graphs.append((gm, example_inputs))
        # According to `fx.Interpreter.run()`, positional function args are consumed left-to-right by `placeholder` nodes.
        x_iter = iter(example_inputs)
        for n in gm.graph.nodes:
            if n.op != 'placeholder': continue
            x = next(x_iter)
            self.input_sizes[n.name] = x.nelement()
        return gm.forward
    
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # Have to initialize them here, because run() is called from super().__init__()
        self.graphs:List[Tuple[fx.GraphModule, List[torch.Tensor]]] = []
        self.input_sizes:Mapping[str, int] = {}

        # Compile to intercept all fx graphs
        torch._dynamo.reset()
        compiled_model = torch.compile(model.vision_model.to(device), backend=self.custom_backend)
        outputs = compiled_model(self.inputs)

        for self.gm, xs in self.graphs:
            # 1st pass: symbolic tracing using torch.compile
            self.interp = ShapeMacInterpreter4(self.gm)

            # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
            try:
                for n in self.gm.graph.nodes:
                    if n.op == 'call_module':
                        m:nn.Module = self.interp.fetch_attr(n.target)
                        self.target_names[n] = m.__class__.__name__
                        self.param_counts[n] = get_num_trainable_params(m)
                    elif n.op == 'call_function':
                        self.target_names[n] = n.target.__name__
            except Exception as e:
                self.err_node = n
                warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

            # 3rd pass: forward pass using torch.fx.Interpreter
            try:
                self.interp.run(*xs)
            except Exception as e:
                msg = 'Forward pass failed.'
                n1 = self.interp.last_successful_node
                if n1:
                    target_name = self.target_names.get(n1, None)
                    node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                    msg += f' Last successful node: "{node_name}".'
                n2 = self.interp.running_node
                if n2:
                    target_name = self.target_names.get(n2, None)
                    node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                    msg += f' Possible error node: "{node_name}".'
                self.err_node = self.interp.running_node
                warnings.warn(f'{msg}: {e}')

    def build_graph(self):
        nodes: List[MavNode] = []
        nodes_by_name: Dict[str, MavNode] = {}
        connections: List[MavConnection] = []
        existing_connections: Set[Tuple[MavNode, MavNode]] = set([])
        for n in self.gm.graph.nodes:
            # Create a new node and append it to the list
            if n.name.startswith('l_self_'): continue
            target_name = self.target_names.get(n, '')
            node = MavNode(n.name, 0, 0)
            node.operation = self.get_operation(n.op, target_name)
            node.activations = self.interp.shapes.get(n, (0,))
            node.params = self.param_counts.get(n, 0)
            node.flops = self.interp.flops.get(n, 0)
            node.metadata['kwargs'] = n.kwargs 
            if n == self.err_node: node.error = True
            nodes.append(node)
            nodes_by_name[n.name] = node

            # Find connections to and from this node
            in_nodes = n.all_input_nodes
            in_node_names = [n2.name for n2 in in_nodes]
            for in_node_name in in_node_names:
                if in_node_name.startswith('l_self_') and in_node_name in self.input_sizes:
                    node.params += self.input_sizes[in_node_name]
                if in_node_name not in nodes_by_name: continue
                from_node = nodes_by_name[in_node_name]
                to_node = node
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node))
                    
            out_nodes = list(n.users.keys())
            out_node_names = [n2.name for n2 in out_nodes]
            for out_node_name in out_node_names:
                if out_node_name not in nodes_by_name: continue
                from_node = node
                to_node = nodes_by_name[out_node_name]
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node)) 
        
        # Assemble into graph
        self.g = MavGraph(nodes, connections)

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
model.eval()
image_inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'

tracer = MavTracer4(model.vision_model, image_inputs, device=device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = WidgetRenderer(tracer.g)
    display(renderer.render(add_overview=True))


INFO:2025-02-11 12:59:54 3817:3817 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


HBox(children=(Box(children=(FloatRangeSlider(value=(-9.5, 0.5), layout=Layout(height='400px'), max=0.5, min=-…

In [2]:
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()

#### Experimental cells

In [3]:
# Creating nn.Module to estimate FLOPS
gm, xs = tracer.graphs[0]
placeholders = [n for n in gm.graph.nodes if n.op=='placeholder']
len(placeholders), len(xs)

(201, 201)

In [4]:
[f"{i}: {n.name}: {x.nelement()}" for i,(n,x) in enumerate(zip(placeholders,xs))]

['0: l_pixel_values_: 150528',
 '1: l_self_modules_embeddings_modules_patch_embedding_parameters_weight_: 589824',
 '2: l_self_modules_embeddings_parameters_class_embedding_: 768',
 '3: l_self_modules_embeddings_buffers_position_ids_: 197',
 '4: l_self_modules_embeddings_modules_position_embedding_parameters_weight_: 151296',
 '5: l_self_modules_pre_layrnorm_parameters_weight_: 768',
 '6: l_self_modules_pre_layrnorm_parameters_bias_: 768',
 '7: l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm1_parameters_weight_: 768',
 '8: l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm1_parameters_bias_: 768',
 '9: l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_: 589824',
 '10: l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_bias_: 768',
 '11: l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_k_proj_parameters_weight_: 589824',
 '12: l_self_modul

In [5]:
# Lookup tensor by placeholder name
placeholder_tensors = {n.name:x for n,x in zip(placeholders,xs)}
placeholder_tensors['l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_'].shape

torch.Size([768, 768])

In [6]:
out_node = list(placeholders[9].users.keys())[0]
out_node.name

'query_states'

In [7]:
out_node.all_input_nodes

[hidden_states_1,
 l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_weight_,
 l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_q_proj_parameters_bias_]

In [8]:
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
    record_shapes=True,
    with_flops=True
) as prof:
    torch.matmul(torch.randn(128, 128), torch.randn(128, 128))

print(prof.key_averages().table(sort_by="flops", row_limit=10))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  Total MFLOPs  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::mm        34.56%     951.394us        34.63%     953.349us     953.349us             1         4.194  
           aten::randn         1.41%      38.916us        17.34%     477.408us     238.704us             2            --  
           aten::empty         1.49%      40.967us         1.49%      40.967us      20.484us             2            --  
         aten::normal_        14.44%     397.525us        14.44%     397.525us     198.763us             2            --  
          aten::matmul        48.03%       1.322ms        82.66%       2.276ms       2.276ms             1            --  
    aten::resolv

In [9]:
prof.key_averages().total_average().flops

4194304

In [None]:
from idlmav.tracing import ShapeMacInterpreter
class FlopCalcInterpreter(ShapeMacInterpreter):
    def run_node(self, n:fx.Node) -> Any:
        # Run the node
        self.cur_flops = None
        self.running_node = n
        result = fx.Interpreter.run_node(self,n)
        self.running_node = None

        # Retrieve the shape
        if isinstance(result, Tensor):
            shape = tuple(result.shape)
        else:
            shape = (0,0,0,0)
        self.shapes[n] = shape

        # Store the number of FLOPS if calculated
        if n.op == 'call_module' or n.op == 'call_function':
            if self.cur_flops is not None: self.flops[n] = self.cur_flops

        # Update the state and return the result
        self.last_successful_node = n
        return result
    
    def call_function(self, target, args, kwargs):
        # mod = FlopCalcModule(target, args, kwargs)
        # self.cur_flops = torchprofile.profile_macs(mod) * 2
        result = fx.Interpreter.call_function(self, target, args, kwargs)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
            record_shapes=True,
            with_flops=True
        ) as prof:
            target(*args, **kwargs)
        flops = prof.key_averages().total_average().flops
        self.cur_flops = flops
        return result

interp = FlopCalcInterpreter(gm)
interp.run(*xs);
interp.flops

{patch_embeds: 115605504,
 embeddings: 0,
 embedding: 0,
 embeddings_1: 75648,
 hidden_states: 0,
 hidden_states_1: 0,
 query_states: 116195328,
 key_states: 116195328,
 value_states: 116195328,
 attn_output: 0,
 attn_output_3: 116195328,
 hidden_states_2: 75648,
 hidden_states_3: 0,
 hidden_states_4: 464781312,
 mul: 302592,
 sigmoid: 0,
 hidden_states_5: 302592,
 hidden_states_6: 464781312,
 hidden_states_7: 75648,
 hidden_states_8: 0,
 query_states_2: 116195328,
 key_states_2: 116195328,
 value_states_2: 116195328,
 attn_output_4: 0,
 attn_output_7: 116195328,
 hidden_states_9: 75648,
 hidden_states_10: 0,
 hidden_states_11: 464781312,
 mul_2: 302592,
 sigmoid_1: 0,
 hidden_states_12: 302592,
 hidden_states_13: 464781312,
 hidden_states_14: 75648,
 hidden_states_15: 0,
 query_states_4: 116195328,
 key_states_4: 116195328,
 value_states_4: 116195328,
 attn_output_8: 0,
 attn_output_11: 116195328,
 hidden_states_16: 75648,
 hidden_states_17: 0,
 hidden_states_18: 464781312,
 mul_4: 30

## YOLOv11n model
* For YOLOv11n, `torch.compile` calls our custom back-end 8 times with different fx graphs. One is clearly the main graph in terms of node count
* `torch.compile also produces the following warnings where the graph is broken:
  * `UserWarning: Graph break due to unsupported builtin time.time`
  * `UserWarning: Graph break due to unsupported builtin sys.intern`

### Using `torch.compile` with custom `fx` backend

In [1]:
import sys
sys.path.append('..')
import torch.compiler
from ultralytics import YOLO
import torch

model = YOLO("yolo11n.pt")
inputs = torch.rand(1, 3, 224, 224)
device = 'cpu'
model.predict(inputs)  # Causes internal model to be loaded and configured dynamically

def custom_backend(gm, example_inputs):
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    
    for node in gm.graph.nodes:                                                                                                                                   
        if node.op == "call_module":                                                                                                                              
            print("call_module: ", node.target)     

    gm.graph.print_tabular()
    return gm.forward

torch._dynamo.reset()
compiled_model = torch.compile(model.to(device), backend=custom_backend)
outputs = compiled_model(inputs)


0: 224x224 (no detections), 157.3ms
Speed: 0.6ms preprocess, 157.3ms inference, 70.2ms postprocess per image at shape (1, 3, 224, 224)

custom backend called with FX graph:
opcode       name                                       target                                     args                                                                          kwargs
-----------  -----------------------------------------  -----------------------------------------  ----------------------------------------------------------------------------  --------
placeholder  l_stack0_modules_model_modules_23_stride   L_stack0_modules_model_modules_23_stride   ()                                                                            {}
placeholder  l_stack0_modules_model_modules_23_anchors  L_stack0_modules_model_modules_23_anchors  ()                                                                            {}
placeholder  l_stack0_modules_model_modules_23_strides  L_stack0_modules_model_modules_23_stride

  torch._dynamo.utils.warn_once(msg)



custom backend called with FX graph:
opcode         name                                                                                                                             target                                                                                                                           args                                                                                                                                                                                                                                                                                                  kwargs
-------------  -------------------------------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------

  torch._dynamo.utils.warn_once(msg)


In [2]:
import sys
sys.path.append('..')
import torch.compiler
from ultralytics import YOLO
import torch

model = YOLO("yolo11n.pt")
inputs = torch.rand(1, 3, 224, 224)
device = 'cpu'
model.predict(inputs)  # Causes internal model to be loaded and configured dynamically

gms = []
def graph_collecting_backend(gm:torch.fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward

torch._dynamo.reset()
compiled_model = torch.compile(model.to(device), backend=graph_collecting_backend)
outputs = compiled_model(inputs)


0: 224x224 (no detections), 20.9ms
Speed: 0.9ms preprocess, 20.9ms inference, 1.9ms postprocess per image at shape (1, 3, 224, 224)



  torch._dynamo.utils.warn_once(msg)


0: 224x224 (no detections), 3583.3ms
Speed: 15.1ms preprocess, 3583.3ms inference, 211.7ms postprocess per image at shape (1, 3, 224, 224)


  torch._dynamo.utils.warn_once(msg)


In [3]:
len(gms)

8

In [4]:
[len(gm.graph.nodes) for gm in gms]

[7, 7, 4, 1, 4, 447, 16, 41]

In [5]:
[len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]

[0, 0, 1, 0, 0, 241, 12, 28]

In [6]:
[len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]

[3, 3, 1, 0, 1, 178, 1, 2]

In [7]:
[len([n for n in gm.graph.nodes if n.op == 'call_method']) for gm in gms]

[3, 3, 1, 0, 2, 27, 2, 10]

In [8]:
[len([n for n in gm.graph.nodes if n.op == 'output']) for gm in gms]

[1, 1, 1, 1, 1, 1, 1, 1]

### Choosing a control path
* For this model, `torch.compile` returns 8 FX graphs of widely varying sizes. 
* Based on this limited sample, it seems the best graph to display is the longest one that also contains the input `l_im_` as a placeholder
  * `im` is the name of the first positional argument to `AutoBackend.forward` in [ultralytics/nn/autobackend.py](../.venv/lib/python3.10/site-packages/ultralytics/nn/autobackend.py)
  * `AutoBackend` is loaded dynamically by the Ultralytics machinery though, and difficult to find analytically in a model-generic way
  * We need to find a way to check whether a graph intercepted by our custom `torch.compile` backend contains the model input
  * Failing this, we could just default to choosing the longest graph, based on the number of `call_function` and `call_module` operations

In [9]:
modules = list(compiled_model.modules())
modules[0]

OptimizedModule(
  (_orig_mod): YOLO(
    (model): DetectionModel(
      (model): Sequential(
        (0): Conv(
          (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (act): SiLU(inplace=True)
        )
        (1): Conv(
          (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (act): SiLU(inplace=True)
        )
        (2): C3k2(
          (cv1): Conv(
            (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
            (act): SiLU(inplace=True)
          )
          (cv2): Conv(
            (conv): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1))
            (act): SiLU(inplace=True)
          )
          (m): ModuleList(
            (0): Bottleneck(
              (cv1): Conv(
                (conv): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (act): SiLU(inplace=True)
              )
              (cv2): Conv(
                (conv): Conv2d(8, 16, kernel_si

In [10]:
import inspect
from typing import get_type_hints
from torch.nn import Module

In [11]:
signature = inspect.signature(modules[0].forward)
signature

<Signature (source: Union[str, pathlib.Path, int, PIL.Image.Image, list, tuple, numpy.ndarray, torch.Tensor] = None, stream: bool = False, **kwargs: Any) -> list>

In [12]:
dict(signature.parameters.items())

{'source': <Parameter "source: Union[str, pathlib.Path, int, PIL.Image.Image, list, tuple, numpy.ndarray, torch.Tensor] = None">,
 'stream': <Parameter "stream: bool = False">,
 'kwargs': <Parameter "**kwargs: Any">}

In [13]:
compiled_model.forward.__annotations__

{'dump_patches': bool,
 '_version': int,
 'training': bool,
 '_parameters': typing.Dict[str, typing.Optional[torch.nn.parameter.Parameter]],
 '_buffers': typing.Dict[str, typing.Optional[torch.Tensor]],
 '_non_persistent_buffers_set': typing.Set[str],
 '_backward_pre_hooks': typing.Dict[int, typing.Callable],
 '_backward_hooks': typing.Dict[int, typing.Callable],
 '_is_full_backward_hook': typing.Optional[bool],
 '_forward_hooks': typing.Dict[int, typing.Callable],
 '_forward_hooks_with_kwargs': typing.Dict[int, bool],
 '_forward_hooks_always_called': typing.Dict[int, bool],
 '_forward_pre_hooks': typing.Dict[int, typing.Callable],
 '_forward_pre_hooks_with_kwargs': typing.Dict[int, bool],
 '_state_dict_hooks': typing.Dict[int, typing.Callable],
 '_load_state_dict_pre_hooks': typing.Dict[int, typing.Callable],
 '_state_dict_pre_hooks': typing.Dict[int, typing.Callable],
 '_load_state_dict_post_hooks': typing.Dict[int, typing.Callable],
 '_modules': typing.Dict[str, typing.Optional[Forw

### Building the MavGraph
* Copied from the final tracer for the CLIP model (which includes the new FLOPS calculation) and updated to select the largest graph

In [None]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Mapping, Tuple, Optional, Any, Union
from ultralytics import YOLO
import torch
from torch import nn, fx, Tensor
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection

class ShapeMacInterpreter5(ShapeMacInterpreter):
    def run_node(self, n):
        result = fx.Interpreter.run_node(self, n)
        if n.op == 'call_function':
            if self.cur_flops is not None: self.flops[n] = self.cur_flops
        return result
    
    def call_function(self, target, args, kwargs):
        result = fx.Interpreter.call_function(self, target, args, kwargs)  # TODO: Call it this way above as well to improve
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
            record_shapes=True,
            with_flops=True
        ) as prof:
            target(*args, **kwargs)
        flops = prof.key_averages().total_average().flops
        self.cur_flops = flops
        return result
    
class MavTracer5(MavTracer):
    def custom_backend(self, gm:fx.GraphModule, example_inputs: List[torch.Tensor]):
        self.graphs.append((gm, example_inputs))
        # According to `fx.Interpreter.run()`, positional function args are consumed left-to-right by `placeholder` nodes.
        x_iter = iter(example_inputs)
        for n in gm.graph.nodes:
            if n.op != 'placeholder': continue
            x = next(x_iter)
            self.input_sizes[n.name] = x.nelement()
        return gm.forward
    
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # Have to initialize them here, because run() is called from super().__init__()
        self.graphs:List[Tuple[fx.GraphModule, List[torch.Tensor]]] = []
        self.input_sizes:Mapping[str, int] = {}

        # Compile to intercept all fx graphs
        torch._dynamo.reset()
        compiled_model = torch.compile(model, backend=self.custom_backend)
        outputs = compiled_model(self.inputs)

        gm_lengths = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm,xs in self.graphs]
        gm_idx = gm_lengths.index(max(gm_lengths))
        self.gm, xs = self.graphs[gm_idx]
        
        # 1st pass: symbolic tracing using torch.compile
        self.interp = ShapeMacInterpreter5(self.gm)

        # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
        try:
            for n in self.gm.graph.nodes:
                if n.op == 'call_module':
                    m:nn.Module = self.interp.fetch_attr(n.target)
                    self.target_names[n] = m.__class__.__name__
                    self.param_counts[n] = get_num_trainable_params(m)
                elif n.op == 'call_function':
                    self.target_names[n] = n.target.__name__
        except Exception as e:
            self.err_node = n
            warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

        # 3rd pass: forward pass using torch.fx.Interpreter
        try:
            self.interp.run(*xs)
        except Exception as e:
            msg = 'Forward pass failed.'
            n1 = self.interp.last_successful_node
            if n1:
                target_name = self.target_names.get(n1, None)
                node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                msg += f' Last successful node: "{node_name}".'
            n2 = self.interp.running_node
            if n2:
                target_name = self.target_names.get(n2, None)
                node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                msg += f' Possible error node: "{node_name}".'
            self.err_node = self.interp.running_node
            warnings.warn(f'{msg}: {e}')

    def build_graph(self):
        nodes: List[MavNode] = []
        nodes_by_name: Dict[str, MavNode] = {}
        connections: List[MavConnection] = []
        existing_connections: Set[Tuple[MavNode, MavNode]] = set([])
        for n in self.gm.graph.nodes:
            # Create a new node and append it to the list
            if n.name.startswith('l_self_'): continue
            target_name = self.target_names.get(n, '')
            node = MavNode(n.name, 0, 0)
            node.operation = self.get_operation(n.op, target_name)
            node.activations = self.interp.shapes.get(n, (0,))
            node.params = self.param_counts.get(n, 0)
            node.flops = self.interp.flops.get(n, 0)
            node.metadata['kwargs'] = n.kwargs 
            if n == self.err_node: node.error = True
            nodes.append(node)
            nodes_by_name[n.name] = node

            # Find connections to and from this node
            in_nodes = n.all_input_nodes
            in_node_names = [n2.name for n2 in in_nodes]
            for in_node_name in in_node_names:
                if in_node_name.startswith('l_self_') and in_node_name in self.input_sizes:
                    node.params += self.input_sizes[in_node_name]
                if in_node_name not in nodes_by_name: continue
                from_node = nodes_by_name[in_node_name]
                to_node = node
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node))
                    
            out_nodes = list(n.users.keys())
            out_node_names = [n2.name for n2 in out_nodes]
            for out_node_name in out_node_names:
                if out_node_name not in nodes_by_name: continue
                from_node = node
                to_node = nodes_by_name[out_node_name]
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node)) 
        
        # Assemble into graph
        self.g = MavGraph(nodes, connections)

model = YOLO("yolo11n.pt")
inputs = torch.rand(1, 3, 224, 224)
device = 'cpu'
model.predict(inputs)  # Causes internal model to be loaded and configured dynamically

tracer = MavTracer5(model, inputs, device=device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = WidgetRenderer(tracer.g)
    display(renderer.render(add_overview=True))



0: 224x224 (no detections), 99.3ms
Speed: 1.6ms preprocess, 99.3ms inference, 47.8ms postprocess per image at shape (1, 3, 224, 224)







0: 224x224 (no detections), 3407.8ms
Speed: 13.2ms preprocess, 3407.8ms inference, 248.2ms postprocess per image at shape (1, 3, 224, 224)






HBox(children=(Box(children=(FloatRangeSlider(value=(-9.5, 0.5), layout=Layout(height='400px'), max=0.5, min=-…

<IPython.core.display.Javascript object>

In [2]:
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()

## ModernBERT
* For ModernBERT, `torch.compile` produces a single fx graph
* Building the `MavGraph` with the rules defined up to now, however, results in 92 input nodes and 67 output nodes
  * This breaks the current layout algorithm, because the number of permutations of 92 nodes on the same level is a 143-digit number
* Many of the input nodes in the `MavGraph` are not placeholder nodes in the fx graph, e.g.
  * Function calls that take constant arguments, e.g. `arange(9)`
  * Transformations and item access functions on learnable parameters
  * Control functions, e.g. `enter_autocast`, `set_grad_enabled`
* Many of the output nodes in the `MavGraph` are not output nodes in the fx graph, e.g.
  * Control functions, e.g. `exit_autocast`, `set_grad_enabled`

To address this challenge, we introduce the following definitions:
* Input placeholder node: An fx node with `operation=="placeholder"` and `name` not starting with `l_self_`
  * Thus far, these have always been used to represent model inputs
* Learnable placeholder node: An fx node with `operation=="placeholder"` and `name` starting with `l_self_`
  * Thus far, these have always been used to represent learnable model parameters
* Normal entry nodes: all nodes in the graph that can be reached via an input placeholder node
* Learnable entry nodes: nodes that can be reached via a learnable placeholder node, but not via an input placeholder node
* Special entry nodes: nodes that cannot be reached from either type of placeholder node
* Normal exit nodes: all nodes in the graph from which an output node can be reached  
* Special exit nodes: nodes from which no output nodes can be reached

For now, the following rules seem promising:
* Omit special entry and special exit nodes from the `MavGraph`
* For learnable entry paths that join a normal entry node without any splits, propagate the parameter count to the normal entry node and omit the nodes on the learnable entry path
* For learnable entry paths that split before joining the normal entry path, introduce a special "Learnable" node just before the split containing the parameter count
  * This learnable node should not move to the first level during the layout step, but remain just enough levels above where it is used to allow the necessary operations to take place

### Using `torch.compile` with custom fx backend

In [1]:
import sys
sys.path.append('..')
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch import fx
from typing import List

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")

gms:List[fx.GraphModule] = []
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(model, backend=custom_backend)
outputs = compiled_model(**inputs)

total_nodes = [len(gm.graph.nodes) for gm in gms]
total_placeholder_nodes = [len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]
total_function_nodes = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
print(f'total_nodes: {total_nodes}')
print(f'total_placeholder_nodes: {total_placeholder_nodes}')
print(f'total_function_nodes: {total_function_nodes}')

total_nodes: [1508]
total_placeholder_nodes: [161]
total_function_nodes: [917]


In [2]:
for gm in gms:
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    gm.graph.print_tabular()


custom backend called with FX graph:
opcode         name                                                                                              target                                                                                            args                                                                                                                                         kwargs
-------------  ------------------------------------------------------------------------------------------------  ------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------  ------------------------------------------------------
placeholder    l_input_ids_                                                                                      L_input_ids_                                                                       

### Building the `MavGraph`

In [None]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Mapping, Tuple, Optional, Any, Union
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import math
from torch import nn, fx, Tensor
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection

class ShapeMacInterpreter6(ShapeMacInterpreter):
    def run_node(self, n):
        result = fx.Interpreter.run_node(self, n)
        if n.op == 'call_function':
            if self.cur_flops is not None: self.flops[n] = self.cur_flops
        return result
    
    def call_function(self, target, args, kwargs):
        result = fx.Interpreter.call_function(self, target, args, kwargs)  # TODO: Call it this way above as well to improve
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
            record_shapes=True,
            with_flops=True
        ) as prof:
            target(*args, **kwargs)
        flops = prof.key_averages().total_average().flops
        self.cur_flops = flops
        return result
    
class MavTracer6(MavTracer):
    def custom_backend(self, gm:fx.GraphModule, example_inputs: List[torch.Tensor]):
        self.graphs.append((gm, example_inputs))
        # According to `fx.Interpreter.run()`, positional function args are consumed left-to-right by `placeholder` nodes.
        x_iter = iter(example_inputs)
        for n in gm.graph.nodes:
            if n.op != 'placeholder': continue
            x = next(x_iter)
            self.input_sizes[n.name] = x.nelement()
        return gm.forward
    
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # Have to initialize them here, because run() is called from super().__init__()
        self.graphs:List[Tuple[fx.GraphModule, List[torch.Tensor]]] = []
        self.input_sizes:Mapping[str, int] = {}

        # 1st pass: compile to intercept all fx graphs
        torch._dynamo.reset()
        compiled_model = torch.compile(model, backend=self.custom_backend)
        if isinstance(self.inputs, Mapping):
            outputs = compiled_model(**self.inputs)
        elif isinstance(self.inputs, Tuple):
            outputs = compiled_model(*self.inputs)
        else:
            outputs = compiled_model(self.inputs)

        # For now, select the largest graph based on number of `call_function` operations
        # * TODO: If valuable for many popular models, implement user controls to manually select graph
        gm_lengths = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm,xs in self.graphs]
        gm_idx = gm_lengths.index(max(gm_lengths))
        self.gm, xs = self.graphs[gm_idx]
        self.interp = ShapeMacInterpreter6(self.gm)

        # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
        try:
            for n in self.gm.graph.nodes:
                if n.op == 'call_module':
                    m:nn.Module = self.interp.fetch_attr(n.target)
                    self.target_names[n] = m.__class__.__name__
                    self.param_counts[n] = get_num_trainable_params(m)
                elif n.op == 'call_function':
                    self.target_names[n] = n.target.__name__
        except Exception as e:
            self.err_node = n
            warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

        # 3rd pass: forward pass using torch.fx.Interpreter
        try:
            self.interp.run(*xs)
        except Exception as e:
            msg = 'Forward pass failed.'
            n1 = self.interp.last_successful_node
            if n1:
                target_name = self.target_names.get(n1, None)
                node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                msg += f' Last successful node: "{node_name}".'
            n2 = self.interp.running_node
            if n2:
                target_name = self.target_names.get(n2, None)
                node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                msg += f' Possible error node: "{node_name}".'
            self.err_node = self.interp.running_node
            warnings.warn(f'{msg}: {e}')

    def build_graph(self):
        nodes: List[MavNode] = []
        nodes_by_name: Dict[str, MavNode] = {}
        connections: List[MavConnection] = []
        existing_connections: Set[Tuple[MavNode, MavNode]] = set([])
        for n in self.gm.graph.nodes:
            # Create a new node and append it to the list
            if n.name.startswith('l_self_'): continue
            target_name = self.target_names.get(n, '')
            node = MavNode(n.name, 0, 0)
            node.operation = self.get_operation(n.op, target_name)
            node.activations = self.interp.shapes.get(n, (0,))
            node.params = self.param_counts.get(n, 0)
            node.flops = self.interp.flops.get(n, 0)
            node.metadata['kwargs'] = n.kwargs 
            if n == self.err_node: node.error = True
            nodes.append(node)
            nodes_by_name[n.name] = node

            # Find connections to and from this node
            in_nodes = n.all_input_nodes
            in_node_names = [n2.name for n2 in in_nodes]
            for in_node_name in in_node_names:
                if in_node_name.startswith('l_self_') and in_node_name in self.input_sizes:
                    node.params += self.input_sizes[in_node_name]
                if in_node_name not in nodes_by_name: continue
                from_node = nodes_by_name[in_node_name]
                to_node = node
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node))
                    
            out_nodes = list(n.users.keys())
            out_node_names = [n2.name for n2 in out_nodes]
            for out_node_name in out_node_names:
                if out_node_name not in nodes_by_name: continue
                from_node = node
                to_node = nodes_by_name[out_node_name]
                if (from_node, to_node) in existing_connections: continue
                c = MavConnection(from_node, node)
                connections.append(c)
                existing_connections.add((from_node, to_node)) 
        
        # Assemble into graph
        self.g = MavGraph(nodes, connections)

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
device = 'cpu'

tracer = MavTracer6(model, inputs, device=device)

print(f'Total nodes: {len(tracer.g.nodes)}. Input nodes: {len(tracer.g.in_nodes)}. Output nodes: {len(tracer.g.out_nodes)}')

INFO:2025-02-12 11:13:19 25146:25146 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


Total nodes: 1349. Input nodes: 92. Output nodes: 67


In [7]:
math.factorial(len(tracer.g.in_nodes))

12438414054641307255475324325873553077577991715875414356840239582938137710983519518443046123837041347353107486982656753664000000000000000000000

In [8]:
[n.name for n in tracer.g.in_nodes]

['l_input_ids_',
 'l_attention_mask_',
 'arange',
 'arange_1',
 '_set_grad_enabled',
 'getitem_1',
 '_enter_autocast',
 '_set_grad_enabled_1',
 '_set_grad_enabled_2',
 'getitem_12',
 '_enter_autocast_1',
 '_set_grad_enabled_3',
 '_set_grad_enabled_4',
 'getitem_23',
 '_enter_autocast_2',
 '_set_grad_enabled_5',
 '_set_grad_enabled_6',
 'getitem_34',
 '_enter_autocast_3',
 '_set_grad_enabled_7',
 '_set_grad_enabled_8',
 'getitem_45',
 '_enter_autocast_4',
 '_set_grad_enabled_9',
 '_set_grad_enabled_10',
 'getitem_56',
 '_enter_autocast_5',
 '_set_grad_enabled_11',
 '_set_grad_enabled_12',
 'getitem_67',
 '_enter_autocast_6',
 '_set_grad_enabled_13',
 '_set_grad_enabled_14',
 'getitem_78',
 '_enter_autocast_7',
 '_set_grad_enabled_15',
 '_set_grad_enabled_16',
 'getitem_89',
 '_enter_autocast_8',
 '_set_grad_enabled_17',
 '_set_grad_enabled_18',
 'getitem_100',
 '_enter_autocast_9',
 '_set_grad_enabled_19',
 '_set_grad_enabled_20',
 'getitem_111',
 '_enter_autocast_10',
 '_set_grad_enabl

In [9]:
[n.name for n in tracer.g.out_nodes]

['_set_grad_enabled',
 '_exit_autocast',
 '_set_grad_enabled_1',
 '_set_grad_enabled_2',
 '_exit_autocast_1',
 '_set_grad_enabled_3',
 '_set_grad_enabled_4',
 '_exit_autocast_2',
 '_set_grad_enabled_5',
 '_set_grad_enabled_6',
 '_exit_autocast_3',
 '_set_grad_enabled_7',
 '_set_grad_enabled_8',
 '_exit_autocast_4',
 '_set_grad_enabled_9',
 '_set_grad_enabled_10',
 '_exit_autocast_5',
 '_set_grad_enabled_11',
 '_set_grad_enabled_12',
 '_exit_autocast_6',
 '_set_grad_enabled_13',
 '_set_grad_enabled_14',
 '_exit_autocast_7',
 '_set_grad_enabled_15',
 '_set_grad_enabled_16',
 '_exit_autocast_8',
 '_set_grad_enabled_17',
 '_set_grad_enabled_18',
 '_exit_autocast_9',
 '_set_grad_enabled_19',
 '_set_grad_enabled_20',
 '_exit_autocast_10',
 '_set_grad_enabled_21',
 '_set_grad_enabled_22',
 '_exit_autocast_11',
 '_set_grad_enabled_23',
 '_set_grad_enabled_24',
 '_exit_autocast_12',
 '_set_grad_enabled_25',
 '_set_grad_enabled_26',
 '_exit_autocast_13',
 '_set_grad_enabled_27',
 '_set_grad_enab

In [None]:
# merge_graph_nodes(tracer.g)
# layout_graph_nodes(tracer.g)
# color_graph_nodes(tracer.g)
# with plotly_renderer('notebook_connected'):
#     renderer = WidgetRenderer(tracer.g)
#     display(renderer.render(add_overview=True))

### Identifying entry and exit types of fx nodes 

In [1]:
import sys
sys.path.append('..')
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch import fx
from typing import List, Mapping

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")

gms:List[fx.GraphModule] = []
input_sizes:Mapping[str, int] = {}
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    # According to `fx.Interpreter.run()`, positional function args are consumed left-to-right by `placeholder` nodes.
    x_iter = iter(example_inputs)
    for n in gm.graph.nodes:
        if n.op != 'placeholder': continue
        x = next(x_iter)
        input_sizes[n.name] = x.nelement()
    return gm.forward

torch._dynamo.reset()
compiled_model = torch.compile(model, backend=custom_backend)
outputs = compiled_model(**inputs)

# Select the largest graph based on number of `call_function` operations
gm_lengths = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
gm_idx = gm_lengths.index(max(gm_lengths))
gm = gms[gm_idx]

total_nodes = len(gm.graph.nodes)
total_placeholder_nodes = len([n for n in gm.graph.nodes if n.op == 'placeholder'])
total_function_nodes = len([n for n in gm.graph.nodes if n.op == 'call_function'])
print(f'total_nodes: {total_nodes}, total_placeholder_nodes: {total_placeholder_nodes}, total_function_nodes: {total_function_nodes}')

total_nodes: 1508, total_placeholder_nodes: 161, total_function_nodes: 917


Breadth-first searches
* Perform BFS from all input placeholder nodes to identify all nodes on normal entry paths
* Perform BFS from all learnable placeholder nodes:
  * Identify all nodes on learnable entry paths
  * Stop when a node on normal entry path is reached
  * Propagate learnable parameter sizes
  * Identify where new learnable nodes must be created (i.e. where learnable entry path splits)
* Mark all nodes not on normal or learnable entry paths as "special entry"
* Perform reverse BFS from all output nodes to identify all nodes on normal exit paths
* Mark all nodes not on normal exit paths as "special exit"

In [2]:
nodes_by_name = {n.name:n for n in gm.graph.nodes}
learnable_placeholder_nodes = [n for n in gm.graph.nodes if n.op == 'placeholder' and n.name.startswith('l_self_')]
input_placeholder_nodes = [n for n in gm.graph.nodes if n.op == 'placeholder' and not n.name.startswith('l_self_')]
output_nodes = [n for n in gm.graph.nodes if n.op == 'output']

print(f'{len(input_placeholder_nodes)} input placeholder nodes')
print(f'{len(learnable_placeholder_nodes)} learnable placeholder nodes')
print(f'{len(output_nodes)} output nodes')
print(f'input placeholder nodes:{input_placeholder_nodes}')

2 input placeholder nodes
159 learnable placeholder nodes
1 output nodes
input placeholder nodes:[l_input_ids_, l_attention_mask_]


In [3]:
from collections import deque

verbose = False

# BFS on input placeholder nodes
entry_types:Mapping[fx.Node, str] = {}  # ['normal','learnable','learn-split','special']
propagated_sizes:Mapping[fx.Node, int] = {}  # Assign to first 'normal' or 'split' node
exit_types:Mapping[fx.Node, str] = {}  # ['normal','special']
for n in input_placeholder_nodes: entry_types[n] = 'normal'
queue = deque(input_placeholder_nodes)  # Initialize to contain all output nodes
while queue:
    if verbose: print([n.name for n in queue])
    n = queue.popleft()

    out_nodes = list(n.users.keys())
    for out_node in out_nodes:
        if out_node in entry_types: continue  # Already traversed
        entry_types[out_node] = 'normal'
        queue.append(out_node)

In [4]:
# BFS on learnable placeholder nodes
for n in learnable_placeholder_nodes: entry_types[n] = 'learnable'
queue = deque(learnable_placeholder_nodes)  # Initialize to contain all output nodes
temp_prop_sizes = {k:v for k,v in input_sizes.items()}
while queue:
    if verbose: print([n.name for n in queue])
    n = queue.popleft()
    input_size = temp_prop_sizes.get(n,0)
    out_nodes = list(n.users.keys())

    # Detect first split
    if len(out_nodes) > 1 and entry_types[n] == 'learnable':
        entry_types[n] = 'learn-split'
        propagated_sizes[n] = input_size

    # Propagate to output nodes and queue them
    for out_node in out_nodes:
        if out_node in entry_types:
            # Already traversed, so stop processing, but assign propagated input size if not assigned upstream yet
            if entry_types[n] == 'learnable': propagated_sizes[out_node] = input_size
            continue
        entry_types[out_node] = entry_types[n]
        if entry_types[n] == 'learnable': temp_prop_sizes[out_node] = input_size
        queue.append(out_node)

In [5]:
# BFS on output nodes
for n in output_nodes: exit_types[n] = 'normal'
queue = deque(output_nodes)  # Initialize to contain all output nodes
while queue:
    if verbose: print([n.name for n in queue])
    n = queue.popleft()

    in_nodes = n.all_input_nodes
    for in_node in in_nodes:
        if in_node in exit_types: continue  # Already traversed
        exit_types[in_node] = 'normal'
        queue.append(in_node)

In [6]:
for n in gm.graph.nodes:
    if n not in entry_types: entry_types[n] = 'special'
    if n not in exit_types: exit_types[n] = 'special'

In [7]:
[n.name for n,v in entry_types.items() if v=='special']

['arange',
 'position_ids',
 'arange_1',
 'rows',
 'getattr_1',
 'sub_1',
 'distance',
 'le',
 'unsqueeze_2',
 'unsqueeze_3',
 'window_mask',
 'logical_not',
 '_set_grad_enabled',
 'getitem_2',
 'position_ids_expanded',
 '_enter_autocast',
 'float_4',
 '_exit_autocast',
 '_set_grad_enabled_1',
 '_set_grad_enabled_2',
 'getitem_13',
 'position_ids_expanded_1',
 '_enter_autocast_1',
 'float_8',
 '_exit_autocast_1',
 '_set_grad_enabled_3',
 '_set_grad_enabled_4',
 'getitem_24',
 'position_ids_expanded_2',
 '_enter_autocast_2',
 'float_12',
 '_exit_autocast_2',
 '_set_grad_enabled_5',
 '_set_grad_enabled_6',
 'getitem_35',
 'position_ids_expanded_3',
 '_enter_autocast_3',
 'float_16',
 '_exit_autocast_3',
 '_set_grad_enabled_7',
 '_set_grad_enabled_8',
 'getitem_46',
 'position_ids_expanded_4',
 '_enter_autocast_4',
 'float_20',
 '_exit_autocast_4',
 '_set_grad_enabled_9',
 '_set_grad_enabled_10',
 'getitem_57',
 'position_ids_expanded_5',
 '_enter_autocast_5',
 'float_24',
 '_exit_autocas

In [8]:
[n.name for n,v in exit_types.items() if v=='special']

['_set_grad_enabled',
 '_enter_autocast',
 '_exit_autocast',
 '_set_grad_enabled_1',
 '_set_grad_enabled_2',
 '_enter_autocast_1',
 '_exit_autocast_1',
 '_set_grad_enabled_3',
 '_set_grad_enabled_4',
 '_enter_autocast_2',
 '_exit_autocast_2',
 '_set_grad_enabled_5',
 '_set_grad_enabled_6',
 '_enter_autocast_3',
 '_exit_autocast_3',
 '_set_grad_enabled_7',
 '_set_grad_enabled_8',
 '_enter_autocast_4',
 '_exit_autocast_4',
 '_set_grad_enabled_9',
 '_set_grad_enabled_10',
 '_enter_autocast_5',
 '_exit_autocast_5',
 '_set_grad_enabled_11',
 '_set_grad_enabled_12',
 '_enter_autocast_6',
 '_exit_autocast_6',
 '_set_grad_enabled_13',
 '_set_grad_enabled_14',
 '_enter_autocast_7',
 '_exit_autocast_7',
 '_set_grad_enabled_15',
 '_set_grad_enabled_16',
 '_enter_autocast_8',
 '_exit_autocast_8',
 '_set_grad_enabled_17',
 '_set_grad_enabled_18',
 '_enter_autocast_9',
 '_exit_autocast_9',
 '_set_grad_enabled_19',
 '_set_grad_enabled_20',
 '_enter_autocast_10',
 '_exit_autocast_10',
 '_set_grad_enab

Color nodes by entry type to verify whether they have been identified correctly (https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express)
* Normal: blue
* Learnable: yellow
* Learnable-split: green
* Special: pink

> Note: the plots below were based on `idlmav` components that have since been upgraded to filter out special entry, learnable entry and special exit nodes. To view these plots as originally intended, comment out the following lines in [tracing.py](../idlmav/tracing.py) before running this code:
> ```python
> if entry_type != 'normal' and entry_type != 'learn-split' and not self.show_param_nodes: continue
> if exit_type != 'normal': continue
> ```
> To draw all inputs at the top, remove the second part of this condition in [layout.py](../idlmav/layout.py)
> ```python
> if n in self.g.in_nodes and n.metadata['entry_type'] == 'normal':
> ```

In [9]:
import importlib
def reload_idlmav():
    idlmav_modules = {k:v for k,v in sys.modules.items() if k.startswith('idlmav')}
    for v in idlmav_modules.values(): importlib.reload(v)
    global MavTracer, merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
    from idlmav import MavTracer, merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer

In [10]:
# reload_idlmav()

In [11]:
from idlmav import MavTracer, merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from plotly import colors

# https://plotly.com/python/discrete-color/#color-sequences-in-plotly-express
entry_type_palette = {'normal':colors.qualitative.Bold[2],
                      'learnable':colors.qualitative.Bold[3],
                      'learn-split':colors.qualitative.Bold[1],
                      'special':colors.qualitative.Bold[4]}

class EntryTypeRenderer(FigureRenderer):
    def get_node_color(self, node):
        entry_type = entry_types[nodes_by_name[node.metadata['fx_name']]]
        if entry_type in entry_type_palette: return entry_type_palette[entry_type]
        return '#7F7F7F'
    
tracer = MavTracer(model, inputs)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = EntryTypeRenderer(tracer.g)
    renderer.render(add_table=False).show()

Tracing failed with torch.fx.symbolic_trace: You must specify exactly one of input_ids or inputs_embeds
Tracing with torch.compile


INFO:2025-02-22 13:07:33 50658:50658 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


Total nodes: 1508. Input nodes: 229. Output nodes: 67. Largest level nodes: 229



The largest level has 229 nodes. This is an indication that something may have gone wrong during the tracing step


Falling back to greedy layout algorithm



In [12]:
exit_type_palette = {'normal':colors.qualitative.Bold[2],
                     'special':colors.qualitative.Bold[4]}
class ExitTypeRenderer(FigureRenderer):
    def get_node_color(self, node):
        exit_type = exit_types[nodes_by_name[node.metadata['fx_name']]]
        if exit_type in exit_type_palette: return exit_type_palette[exit_type]
        return '#7F7F7F'
    
with plotly_renderer('notebook_connected'):
    renderer = ExitTypeRenderer(tracer.g)
    renderer.render(add_table=False).show()

### Shortening looooooong names
* These are typically produced by `torch.compile` for "placeholder" nodes representing learnable parameters
  * From ModernBERT: `l_self_modules_model_modules_embeddings_modules_tok_embeddings_parameters_weight_`
  * From ModernBERT: `l_self_modules_model_modules_layers_modules_0_modules_attn_modules_wqkv_parameters_weight_`
  * From CLIP: `l_self_modules_embeddings_modules_patch_embedding_parameters_weight_`
  * From CLIP: `l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm1_parameters_bias_`
* Remove common parts that don't add much to the meaning of a parameter, e.g. `_self`, `_modules`, `_parameters`
* Ensure that the name is still unique by appending a number or incrementing an existing numeric suffix

In [19]:
import re
def ensure_unique(name, existing_names):
    match = re.match(r"^(.*?)(\d+)?$", name)
    base, num = match.groups()
    num = int(num) if num else 0
    new_name = name

    while new_name in existing_names:
        num += 1
        new_name = f"{base}{num}"
    
    return new_name

print(ensure_unique('John', ['John','John1']))
print(ensure_unique('John1', ['John1']))
print(ensure_unique('John501', ['John501','John502']))
print(ensure_unique('John99', ['John99']))

John2
John2
John503
John100


In [18]:
long_strings = ['l_self_modules_model_modules_embeddings_modules_tok_embeddings_parameters_weight_',
                'l_self_modules_model_modules_layers_modules_0_modules_attn_modules_wqkv_parameters_weight_',
                'l_self_modules_embeddings_modules_patch_embedding_parameters_weight_',
                'l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm1_parameters_bias_']
for s in long_strings:
    shortened = re.sub('_self|_modules|_parameters','',s)
    print(shortened)


l_model_embeddings_tok_embeddings_weight_
l_model_layers_0_attn_wqkv_weight_
l_embeddings_patch_embedding_weight_
l_encoder_layers_0_layer_norm1_bias_


In [20]:
arg_strings = ['(attn_output_2, l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_out_proj_parameters_weight_, l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_out_proj_parameters_bias_)',
               '(layer_norm_45, l_self_modules_model_modules_embeddings_modules_tok_embeddings_parameters_weight_, l_self_modules_decoder_parameters_bias_)',
               '(hidden_states_2, (768,), l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm2_parameters_weight_, l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm2_parameters_bias_, 1e-05)'
               ]

def build_replacement_dict(long_names):
    replacements:Mapping[str,str] = {}
    for long_name in long_names:
        short_name = re.sub('_self|_modules|_parameters','',long_name)
        short_name = ensure_unique(short_name, replacements.values())
        replacements[long_name] = short_name
    return replacements

replacements = build_replacement_dict(['l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_out_proj_parameters_weight_',
                                       'l_self_modules_encoder_modules_layers_modules_0_modules_self_attn_modules_out_proj_parameters_bias_',
                                       'l_self_modules_model_modules_embeddings_modules_tok_embeddings_parameters_weight_',
                                       'l_self_modules_decoder_parameters_bias_',
                                       'l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm2_parameters_weight_',
                                       'l_self_modules_encoder_modules_layers_modules_0_modules_layer_norm2_parameters_bias_'])
def replace_match(match):
    return replacements[match.group(0)]
pattern = re.compile("|".join(map(re.escape, replacements.keys())))

def shorten(text):
    return pattern.sub(replace_match, text)

for arg_str in arg_strings:
    print(shorten(arg_str))

(attn_output_2, l_encoder_layers_0_attn_out_proj_weight_, l_encoder_layers_0_attn_out_proj_bias_)
(layer_norm_45, l_model_embeddings_tok_embeddings_weight_, l_decoder_bias_)
(hidden_states_2, (768,), l_encoder_layers_0_layer_norm2_weight_, l_encoder_layers_0_layer_norm2_bias_, 1e-05)


In [22]:
type(pattern)

re.Pattern

## Other models
This section just briefly runs through some other models to determine whether they can be traced by `torch.compile` with a custom back-end. Graphs will be drawn be re-running the notebooks indexed from [15_explore_misc_models.ipynb](./15_explore_misc_models.ipynb) after updating the library

### DistilBERT

In [1]:
import sys
sys.path.append('..')
from transformers import DistilBertModel, DistilBertTokenizer
import torch
from torch import fx
from typing import List

model = DistilBertModel.from_pretrained("distilbert-base-uncased")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model.eval()
inputs = tokenizer("Hello world", return_tensors="pt")

gms:List[fx.GraphModule] = []
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(model, backend=custom_backend)
outputs = compiled_model(**inputs)

total_nodes = [len(gm.graph.nodes) for gm in gms]
total_placeholder_nodes = [len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]
total_function_nodes = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
print(f'total_nodes: {total_nodes}')
print(f'total_placeholder_nodes: {total_placeholder_nodes}')
print(f'total_function_nodes: {total_function_nodes}')

total_nodes: [248]
total_placeholder_nodes: [103]
total_function_nodes: [86]


In [2]:
for gm in gms:
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    gm.graph.print_tabular()
    


custom backend called with FX graph:
opcode         name                                                                                                     target                                                                                                   args                                                                                                                                                                                                                              kwargs
-------------  -------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  -----------------------------------------------------

### T5-small encoder

In [1]:
import sys
sys.path.append('..')
from transformers import T5Model, T5Tokenizer
import torch
from torch import fx
from typing import List

model = T5Model.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model.eval()
inputs = tokenizer.encode("translate English to French: Hello, how are you?", return_tensors="pt")

gms:List[fx.GraphModule] = []
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(model.encoder, backend=custom_backend)
outputs = compiled_model(inputs)

total_nodes = [len(gm.graph.nodes) for gm in gms]
total_placeholder_nodes = [len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]
total_function_nodes = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
print(f'total_nodes: {total_nodes}')
print(f'total_placeholder_nodes: {total_placeholder_nodes}')
print(f'total_function_nodes: {total_function_nodes}')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


total_nodes: [351]
total_placeholder_nodes: [52]
total_function_nodes: [180]


In [2]:
for gm in gms:
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    gm.graph.print_tabular()


custom backend called with FX graph:
opcode         name                                                                                                                              target                                                                                                                           args                                                                                                                                                                                   kwargs
-------------  --------------------------------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  -----------------------------------------

### Wav2Vec

In [5]:
import sys
sys.path.append('..')
from transformers import Wav2Vec2Model, Wav2Vec2Processor
import torch
from torch import fx
from typing import List

model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model.eval()
inputs = torch.randn(1, 16000)

gms:List[fx.GraphModule] = []
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(model, backend=custom_backend)
outputs = compiled_model(inputs)

total_nodes = [len(gm.graph.nodes) for gm in gms]
total_placeholder_nodes = [len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]
total_function_nodes = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
print(f'total_nodes: {total_nodes}')
print(f'total_placeholder_nodes: {total_placeholder_nodes}')
print(f'total_function_nodes: {total_function_nodes}')



total_nodes: [565]
total_placeholder_nodes: [211]
total_function_nodes: [218]


In [6]:
for gm in gms:
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    gm.graph.print_tabular()


custom backend called with FX graph:
opcode         name                                                                                                                      target                                                                                                                    args                                                                                                                                                                                                                                                         kwargs
-------------  ------------------------------------------------------------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------  ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

### Whisper Tiny encoder

In [32]:
import sys
sys.path.append('..')
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torch
from torch import fx
from typing import List

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.eval()

inputs = torch.randn((1,80,3000))
outputs = model.model.encoder(inputs)

In [35]:
gms:List[fx.GraphModule] = []
def custom_backend(gm:fx.GraphModule, example_inputs):
    gms.append(gm)
    return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(model.model.encoder, backend=custom_backend)
outputs = compiled_model(inputs)

total_nodes = [len(gm.graph.nodes) for gm in gms]
total_placeholder_nodes = [len([n for n in gm.graph.nodes if n.op == 'placeholder']) for gm in gms]
total_function_nodes = [len([n for n in gm.graph.nodes if n.op == 'call_function']) for gm in gms]
print(f'total_nodes: {total_nodes}')
print(f'total_placeholder_nodes: {total_placeholder_nodes}')
print(f'total_function_nodes: {total_function_nodes}')

total_nodes: [181]
total_placeholder_nodes: [68]
total_function_nodes: [67]


In [36]:
for gm in gms:
    print("")
    print("====================================")
    print("custom backend called with FX graph:")
    print("====================================")
    gm.graph.print_tabular()


custom backend called with FX graph:
opcode         name                                                                                   target                                                                                 args                                                                                                                                                                                               kwargs
-------------  -------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  ---------------------------------------------------------
placeholder    l_input_features_                                                                      L_input_features_          