## Introduction

Before we delve into trying to `jit` compile a pytorch module, it is important
to understand what `Pytorch 2.0` brings with `torch.compile` and why `torch.jit.script`
or `FX tracing` weren't good enough and what were the limitations.

Refer to: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing

Essentially, scripting or tracing either error out or only capture the activated path in control
flow instructions thus being erroneous or non-functional.

------------------------------------------------------------------------------------------------------------------------

We use `torch.dynamo` here to capture the graphs generated for the corresponding `nn.Module` and understand
the number of `graph-breaks` in the code.


In [1]:
import os
from huggingface_hub import login

access_token = os.environ["HF_ACCESS_TOKEN"]
login(token=access_token)

os.environ["TORCH_COMPILE_DEBUG"] = "1"  # Dumps files in `torch_compile_debug/`

# Choose which logs to enable
# os.environ["TORCH_LOGS"] = "+dynamo,+aot_graphs,+inductor,+guards,+graph"
os.environ["TORCH_LOGS"] = "+dynamo"

import torch.nn as nn
from torch._dynamo import optimize

from transformers import AutoModelForCausalLM, AutoTokenizer

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/gaurav/.cache/huggingface/token
Login successful


In [2]:
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch

class LLaMAFirstLayerModel(nn.Module):
    def __init__(self, llama_model_name: str, output_dim: int):
        super(LLaMAFirstLayerModel, self).__init__()

        # Load the LLaMA model
        full_llama = AutoModel.from_pretrained(llama_model_name)

        # Extract and store the embedding layer
        self.embed_tokens = full_llama.embed_tokens

        # Extract and store the first decoder layer
        self.first_layer = full_llama.layers[0]

        # Linear layer to map to output dimensions
        llama_hidden_dim = full_llama.config.hidden_size
        self.linear = nn.Linear(llama_hidden_dim, output_dim)

        # Softmax for output probabilities
        self.softmax = nn.Softmax(dim=-1)

    # Explicit typing of input_ids and attention_mask for TorchScript
    def forward(self, input_ids):
        # Generate embeddings
        embeddings = self.embed_tokens(input_ids)

        # Check if position_ids need to be explicitly handled
        position_ids = torch.arange(0, input_ids.size(1), device=input_ids.device).unsqueeze(0)

        # Pass through the first layer with position_ids
        layer_output = self.first_layer(embeddings, position_ids=position_ids)[0]

        # Pool the output (mean along sequence dimension)
        pooled_output = torch.mean(layer_output, dim=1)

        # Map to output dimension
        logits = self.linear(pooled_output)

        # Apply softmax
        probs = self.softmax(logits)

        return probs


In [3]:
llama_model_name = "meta-llama/Llama-3.2-1B-Instruct"  # Replace with actual model name
output_dim = 10  # Number of classes for classification

# Initialize the model
model = LLaMAFirstLayerModel(llama_model_name, output_dim).to("cuda")

# Example tokenizer and input
tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
tokenizer.pad_token = tokenizer.eos_token
example_text = ["This is an example input."]
inputs = tokenizer(example_text, return_tensors="pt", padding=True, truncation=True).to("cuda")

In [4]:
# Use TorchDynamo's explain to capture the graph
# Extract the input_ids tensor from BatchEncoding
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
explanation = torch._dynamo.explain(model, input_ids)

# Print the explanation
print(explanation)

[2024-12-24 18:09:11,453] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/gaurav/anaconda3/lib/python3.11/contextlib.py
[2024-12-24 18:09:11,453] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/gaurav/anaconda3/lib/python3.11/contextlib.py
[2024-12-24 18:09:11,454] torch._dynamo.eval_frame: [DEBUG] skipping helper /home/gaurav/anaconda3/lib/python3.11/contextlib.py
[2024-12-24 18:09:11,454] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/gaurav/anaconda3/lib/python3.11/contextlib.py
[2024-12-24 18:09:11,454] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py
[2024-12-24 18:09:11,454] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/gaurav/anaconda3/lib/python3.11/contextlib.py
[2024-12-24 18:09:11,454] torch._dynamo.eval_frame: [DEBUG] skipping _wrapped_call_impl /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py
[2024-12-24 18:09:1

Graph Count: 6
Graph Break Count: 5
Op Count: 44
Break Reasons:
  Break Reason 1:
    User Stack:
      <FrameSummary file /tmp/ipykernel_1020204/343866.py, line 34 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 405 in forward>
  Break Reason 2:
    User Stack:
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module

### Graph break to Python `forward`

Using `dynamo` explain to evaluate the graph and breaks generated by `torch._dynamo`, use `torch._dynamo.optimize`
generate the Python `forward` function for each of these graph breaks

In [5]:
# Generate Python code from the torch._dynamo graph
def debug_callback(graph_module, example_inputs):
    # Generate Python code for the traced graph
    print(graph_module.code)
    return graph_module

# Wrap your model with the debug callback
model_optimized = torch._dynamo.optimize(debug_callback)(model)

# Run your model to trigger the tracing
model_optimized(input_ids)

[2024-12-24 18:15:17,315] [6/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_1020204/343866.py:26
[2024-12-24 18:15:17,315] [6/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_1020204/343866.py:26
[2024-12-24 18:15:17,315] [6/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids):
[2024-12-24 18:15:17,316] [6/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-24 18:15:17,317] [6/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-24 18:15:17,317] [6/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_1020204/343866.py:28
[2024-12-24 18:15:17,317] [6/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_tokens(input_ids)
[2024-12-24 18:15:17,318] [




def forward(self, L_input_ids_ : torch.Tensor):
    l_input_ids_ = L_input_ids_
    l__self___embed_tokens = self.L__self___embed_tokens(l_input_ids_);  l_input_ids_ = None
    arange = torch.arange(0, 7, device = device(type='cuda', index=0))
    unsqueeze = arange.unsqueeze(0);  arange = None
    return (l__self___embed_tokens, unsqueeze)
    



def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    to = l_hidden_states_.to(torch.float32);  l_hidden_states_ = None
    pow_1 = to.pow(2)
    mean = pow_1.mean(-1, keepdim = True);  pow_1 = None
    add = mean + 1e-05;  mean = None
    rsqrt = torch.rsqrt(add);  add = None
    mul = to * rsqrt;  to = rsqrt = None
    l__self___input_layernorm_weight = self.L__self___input_layernorm_weight
    to_1 = mul.to(torch.float32);  mul = None
    mul_1 = l__self___input_layernorm_weight * to_1;  l__self___input_layernorm_weight = to_1 = None
    return (mul_1,)
    


[2024-12-24 18:15:17,714] [8/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_METHOD q_proj [NNModuleVariable()]
[2024-12-24 18:15:17,714] [8/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF hidden_states [NullVariable, NNModuleVariable()]
[2024-12-24 18:15:17,715] [8/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE PRECALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-12-24 18:15:17,715] [8/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-12-24 18:15:17,716] [8/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call l__self___q_proj from forward /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:396
[2024-12-24 18:15:17,716] [8/0] torch._dynamo.output_graph.__trace_call: [DEBUG]             query_states = self.q_proj(hidden_states)
[2024-12-24 18:15:17,716] [8/0] torch._dynamo.output_graph.__trace_call: [DEBUG]                         




def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    l__self___q_proj = self.L__self___q_proj(l_hidden_states_)
    l__self___k_proj = self.L__self___k_proj(l_hidden_states_)
    l__self___v_proj = self.L__self___v_proj(l_hidden_states_);  l_hidden_states_ = None
    view = l__self___q_proj.view(1, 7, 32, 64);  l__self___q_proj = None
    transpose = view.transpose(1, 2);  view = None
    view_1 = l__self___k_proj.view(1, 7, 8, 64);  l__self___k_proj = None
    transpose_1 = view_1.transpose(1, 2);  view_1 = None
    view_2 = l__self___v_proj.view(1, 7, 8, 64);  l__self___v_proj = None
    transpose_2 = view_2.transpose(1, 2);  view_2 = None
    return (transpose, transpose_1, transpose_2)
    


[2024-12-24 18:15:17,961] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line repeat_kv /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:319 (inline depth: 1)
[2024-12-24 18:15:17,961] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         batch, num_key_value_heads, slen, head_dim = hidden_states.shape
[2024-12-24 18:15:17,962] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST hidden_states []
[2024-12-24 18:15:17,962] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR shape [TensorVariable()]
[2024-12-24 18:15:17,962] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE UNPACK_SEQUENCE 4 [ShapeVariable()]
[2024-12-24 18:15:17,963] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST batch [ConstantVariable(int), ConstantVariable(int), ConstantVariable(int), ConstantVariable(int)]
[2024-12-24 18:15:17,963] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAS




def forward(self, L_position_ids_ : torch.Tensor, L_query_states_ : torch.Tensor, L_key_states_ : torch.Tensor, L_value_states_ : torch.Tensor):
    l_position_ids_ = L_position_ids_
    l_query_states_ = L_query_states_
    l_key_states_ = L_key_states_
    l_value_states_ = L_value_states_
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    l__self___rotary_emb_inv_freq = self.L__self___rotary_emb_inv_freq
    getitem = l__self___rotary_emb_inv_freq[(None, slice(None, None, None), None)];  l__self___rotary_emb_inv_freq = None
    float_1 = getitem.float();  getitem = None
    expand = float_1.expand(1, -1, 1);  float_1 = None
    getitem_1 = l_position_ids_[(slice(None, None, None), None, slice(None, None, None))];  l_position_ids_ = None
    float_2 = getitem_1.float();  getitem_1 = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, False, None)
    float_3 = expand.float();  expand = None
    float_4 = float_2.float();  float_2 = None
    

tensor([[0.0962, 0.1018, 0.1038, 0.0996, 0.1004, 0.1000, 0.1052, 0.0968, 0.0978,
         0.0983]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

## Next Steps - Capturing `nn.Module` with Custom ops

There are instances where PyTorch implementation can have custom ops - for instance where the programmer wants to force `kernel` fusion, they can define a custom op as such and would need `torch.compile` to respect that. And similarly in cases where there ops on `numpy` or `scipy` defined.