## Introduction

Before we delve into trying to `jit` compile a pytorch module, it is important
to understand what `Pytorch 2.0` brings with `torch.compile` and why `torch.jit.script`
or `FX tracing` weren't good enough and what were the limitations.

Refer to: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#comparison-to-torchscript-and-fx-tracing

Essentially, scripting or tracing either error out or only capture the activated path in control
flow instructions thus being erroneous or non-functional.

------------------------------------------------------------------------------------------------------------------------

We use `torch.dynamo` here to capture the graphs generated for the corresponding `nn.Module` and understand
the number of `graph-breaks` in the code.


In [3]:
import os
from huggingface_hub import login

access_token = os.environ["HF_ACCESS_TOKEN"]
login(token=access_token)

os.environ["TORCH_COMPILE_DEBUG"] = "1"  # Dumps files in `torch_compile_debug/`

# Choose which logs to enable
# os.environ["TORCH_LOGS"] = "+dynamo,+aot_graphs,+inductor,+guards,+graph"
# os.environ["TORCH_LOGS"] = "+dynamo,guards,bytecode,graph_code"
os.environ["TORCH_LOGS"] = "+dynamo,graph_code"

import torch.nn as nn
from torch._dynamo import optimize

from transformers import AutoModelForCausalLM, AutoTokenizer

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/gaurav/.cache/huggingface/token
Login successful


------------------------------------------------------------------------------------------------------------------------------------------------------

#### First things first - 

The TorchDynamo deep-dive resource https://pytorch.org/docs/stable/torch.compiler_dynamo_deepdive.html#torch-compiler-dynamo-deepdive

![TorchDynamo Guards](https://i.imgur.com/Pvq65gt.png)

In [4]:
import torch
# Set the environment variable
# os.environ["TORCHDYNAMO_REPORT_GUARD_FAILURES"] = "1"

# The other interesting thing to note here is that Dynamo removed the second argument to the function. 
# Instead, it treated it as a constant and recorded the result of the operation n + 1 in the graph. 
# This is another feature of Dynamo: Dynamo will treat as constant any non-tensor value… other than ints.

# The last defining property of Dynamo is that it knows how to handle dynamic shapes. Symbolic shapes refer to 
# Dynamo’s ability of tracing shapes, and more generally, integers, rather than leaving them as constants. 
# This allows for avoiding recompilations and deploying generic models that work for any size in production.

@torch.compile
def fn(x, n):
    y = x ** 2
    if n >= 0:
        return (n + 1) * y
    else:
        return y / n

x = torch.randn(200)
fn(x, 2)
fn(x, 3)
fn(x, -2)

[2024-12-28 16:21:29,600] [1/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing fn /tmp/ipykernel_414314/2999559549.py:13
[2024-12-28 16:21:29,601] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line fn /tmp/ipykernel_414314/2999559549.py:13
[2024-12-28 16:21:29,601] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]     @torch.compile
[2024-12-28 16:21:29,601] [1/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (200,) [<DimDynamic.STATIC: 2>] [None]
[2024-12-28 16:21:29,602] [1/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:21:29,602] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line fn /tmp/ipykernel_414314/2999559549.py:15
[2024-12-28 16:21:29,602] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         y = x ** 2
[2024-12-28 16:21:29,602] [1/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x []
[2024-12-28 16:21:29,603] [1/0] 

tensor([-6.6696e-01, -1.0884e-01, -3.6186e-02, -1.1740e+00, -9.0413e-01,
        -5.6791e-01, -1.9477e-01, -3.4753e-01, -1.9683e-02, -1.7803e-01,
        -3.9002e-01, -3.0301e-01, -1.3585e+00, -5.2352e-02, -2.9549e+00,
        -1.2119e-02, -6.8487e-02, -1.2304e-01, -4.4547e-04, -2.8063e-02,
        -1.9392e-01, -2.5780e-01, -4.5149e-01, -5.4279e-01, -2.1302e-01,
        -1.2001e-01, -1.3565e-01, -3.5765e-02, -3.9820e-02, -1.5678e+00,
        -1.0912e-01, -6.2852e-01, -3.2199e-01, -3.4568e-01, -4.8182e-01,
        -1.0612e+00, -5.6914e-01, -6.5895e-03, -4.0533e-01, -5.1158e-01,
        -1.7885e+00, -1.0898e-02, -4.7453e-01, -4.7165e-02, -1.4123e+00,
        -3.1337e-02, -1.8508e+00, -2.6575e-02, -6.9351e-01, -6.3217e-03,
        -3.2985e-02, -1.2599e-02, -7.2915e-02, -2.5334e-01, -8.3970e-02,
        -2.8962e-01, -8.3909e-01, -5.3914e-02, -7.5858e-01, -2.8807e-02,
        -9.1632e-05, -9.2425e-02, -9.5559e-02, -6.1646e-01, -5.6891e-01,
        -2.3647e+00, -4.0835e-02, -2.0778e-01, -1.3

In [5]:
import torch

@torch.compile(dynamic=True)
def fn(a):
    if a.shape[0] * 2 < 16:
        return a
    else:
        return a + 1

fn(torch.randn(8))

[2024-12-28 16:21:32,072] torch._dynamo.eval_frame: [DEBUG] skipping patch /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/_dynamo/config_utils.py
[2024-12-28 16:21:32,072] torch._dynamo.eval_frame: [DEBUG] skipping ConfigPatch /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/_dynamo/config_utils.py
[2024-12-28 16:21:32,073] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/_dynamo/config_utils.py
[2024-12-28 16:21:32,074] [2/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing fn /tmp/ipykernel_414314/1791670339.py:3
[2024-12-28 16:21:32,074] [2/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line fn /tmp/ipykernel_414314/1791670339.py:3
[2024-12-28 16:21:32,074] [2/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]     @torch.compile(dynamic=True)
[2024-12-28 16:21:32,075] [2/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['a'] (8,) [<DimDyna

tensor([ 2.0000,  0.4921, -1.7825,  0.2682, -0.6509,  0.4737,  0.8816, -0.9086])

In [6]:
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch

class LLaMAFirstLayerModel(nn.Module):
    def __init__(self, llama_model_name: str, output_dim: int):
        super(LLaMAFirstLayerModel, self).__init__()

        # Load the LLaMA model
        full_llama = AutoModel.from_pretrained(llama_model_name)

        # Extract and store the embedding layer
        self.embed_tokens = full_llama.embed_tokens

        # Extract and store the first decoder layer
        self.first_layer = full_llama.layers[0]

        # Linear layer to map to output dimensions
        llama_hidden_dim = full_llama.config.hidden_size
        self.linear = nn.Linear(llama_hidden_dim, output_dim)

        # Softmax for output probabilities
        self.softmax = nn.Softmax(dim=-1)

    # Explicit typing of input_ids and attention_mask for TorchScript
    def forward(self, input_ids):
        # Generate embeddings
        embeddings = self.embed_tokens(input_ids)

        # Check if position_ids need to be explicitly handled
        position_ids = torch.arange(0, input_ids.size(1), device=input_ids.device).unsqueeze(0)

        # Pass through the first layer with position_ids
        layer_output = self.first_layer(embeddings, position_ids=position_ids)[0]

        # Pool the output (mean along sequence dimension)
        pooled_output = torch.mean(layer_output, dim=1)

        # Map to output dimension
        logits = self.linear(pooled_output)

        # Apply softmax
        probs = self.softmax(logits)

        return probs


In [7]:
llama_model_name = "meta-llama/Llama-3.2-1B-Instruct"  # Replace with actual model name
output_dim = 10  # Number of classes for classification

# Initialize the model
model = LLaMAFirstLayerModel(llama_model_name, output_dim).to("cuda")

# Example tokenizer and input
tokenizer = AutoTokenizer.from_pretrained(llama_model_name)
tokenizer.pad_token = tokenizer.eos_token
example_text = ["This is an example input."]
inputs = tokenizer(example_text, return_tensors="pt", padding=True, truncation=True).to("cuda")

In [8]:
# Use TorchDynamo's explain to capture the graph
# Extract the input_ids tensor from BatchEncoding
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
explanation = torch._dynamo.explain(model, input_ids)

# Print the explanation
print(explanation)

[2024-12-28 16:21:37,925] torch._dynamo.eval_frame: [DEBUG] skipping _wrapped_call_impl /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py
[2024-12-28 16:21:37,925] torch._dynamo.eval_frame: [DEBUG] skipping _call_impl /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py
[2024-12-28 16:21:37,927] [3/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/343866.py:26
[2024-12-28 16:21:37,927] [3/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/343866.py:26
[2024-12-28 16:21:37,927] [3/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids):
[2024-12-28 16:21:37,928] [3/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:21:37,929] [3/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUM

Graph Count: 6
Graph Break Count: 5
Op Count: 44
Break Reasons:
  Break Reason 1:
    User Stack:
      <FrameSummary file /tmp/ipykernel_414314/343866.py, line 34 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 405 in forward>
  Break Reason 2:
    User Stack:
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.

### Graph break to Python `forward`

Using `dynamo` explain to evaluate the graph and breaks generated by `torch._dynamo`, use `torch._dynamo.optimize`
generate the Python `forward` function for each of these graph breaks

In [9]:
# Generate Python code from the torch._dynamo graph
def debug_callback(graph_module, example_inputs):
    # Generate Python code for the traced graph
    print(graph_module.code)
    return graph_module

# Wrap your model with the debug callback
model_optimized = torch._dynamo.optimize(debug_callback)(model)

# Run your model to trigger the tracing
model_optimized(input_ids)

[2024-12-28 16:21:43,339] [9/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/343866.py:26
[2024-12-28 16:21:43,339] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/343866.py:26
[2024-12-28 16:21:43,339] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids):
[2024-12-28 16:21:43,340] [9/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:21:43,341] [9/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:21:43,341] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/343866.py:28
[2024-12-28 16:21:43,341] [9/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_tokens(input_ids)
[2024-12-28 16:21:43,341] [9/0




def forward(self, L_input_ids_ : torch.Tensor):
    l_input_ids_ = L_input_ids_
    l__self___embed_tokens = self.L__self___embed_tokens(l_input_ids_);  l_input_ids_ = None
    arange = torch.arange(0, 7, device = device(type='cuda', index=0))
    unsqueeze = arange.unsqueeze(0);  arange = None
    return (l__self___embed_tokens, unsqueeze)
    



def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    to = l_hidden_states_.to(torch.float32);  l_hidden_states_ = None
    pow_1 = to.pow(2)
    mean = pow_1.mean(-1, keepdim = True);  pow_1 = None
    add = mean + 1e-05;  mean = None
    rsqrt = torch.rsqrt(add);  add = None
    mul = to * rsqrt;  to = rsqrt = None
    l__self___input_layernorm_weight = self.L__self___input_layernorm_weight
    to_1 = mul.to(torch.float32);  mul = None
    mul_1 = l__self___input_layernorm_weight * to_1;  l__self___input_layernorm_weight = to_1 = None
    return (mul_1,)
    


[2024-12-28 16:21:43,721] [11/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE PRECALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-12-28 16:21:43,722] [11/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL 1 [NullVariable, NNModuleVariable(), TensorVariable()]
[2024-12-28 16:21:43,722] [11/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call l__self___v_proj from forward /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:398
[2024-12-28 16:21:43,722] [11/0] torch._dynamo.output_graph.__trace_call: [DEBUG]             value_states = self.v_proj(hidden_states)
[2024-12-28 16:21:43,722] [11/0] torch._dynamo.output_graph.__trace_call: [DEBUG]                            ~~~~~~~~~~~^^^^^^^^^^^^^^^
[2024-12-28 16:21:43,724] [11/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST value_states [TensorVariable()]
[2024-12-28 16:21:43,724] [11/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE start




def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    l__self___q_proj = self.L__self___q_proj(l_hidden_states_)
    l__self___k_proj = self.L__self___k_proj(l_hidden_states_)
    l__self___v_proj = self.L__self___v_proj(l_hidden_states_);  l_hidden_states_ = None
    view = l__self___q_proj.view(1, 7, 32, 64);  l__self___q_proj = None
    transpose = view.transpose(1, 2);  view = None
    view_1 = l__self___k_proj.view(1, 7, 8, 64);  l__self___k_proj = None
    transpose_1 = view_1.transpose(1, 2);  view_1 = None
    view_2 = l__self___v_proj.view(1, 7, 8, 64);  l__self___v_proj = None
    transpose_2 = view_2.transpose(1, 2);  view_2 = None
    return (transpose, transpose_1, transpose_2)
    


[2024-12-28 16:21:43,952] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST Ellipsis [TensorVariable()]
[2024-12-28 16:21:43,953] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST None [TensorVariable(), ConstantVariable(ellipsis)]
[2024-12-28 16:21:43,953] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [TensorVariable(), ConstantVariable(ellipsis), ConstantVariable(NoneType)]
[2024-12-28 16:21:43,953] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR shape [TensorVariable(), ConstantVariable(ellipsis), ConstantVariable(NoneType), TensorVariable()]
[2024-12-28 16:21:43,954] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST -1 [TensorVariable(), ConstantVariable(ellipsis), ConstantVariable(NoneType), ShapeVariable()]
[2024-12-28 16:21:43,955] [12/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BINARY_SUBSCR None [TensorVariable(), ConstantVariable(ellipsis), ConstantVariable(NoneType), ShapeVariable(), ConstantVar




def forward(self, L_position_ids_ : torch.Tensor, L_query_states_ : torch.Tensor, L_key_states_ : torch.Tensor, L_value_states_ : torch.Tensor):
    l_position_ids_ = L_position_ids_
    l_query_states_ = L_query_states_
    l_key_states_ = L_key_states_
    l_value_states_ = L_value_states_
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    l__self___rotary_emb_inv_freq = self.L__self___rotary_emb_inv_freq
    getitem = l__self___rotary_emb_inv_freq[(None, slice(None, None, None), None)];  l__self___rotary_emb_inv_freq = None
    float_1 = getitem.float();  getitem = None
    expand = float_1.expand(1, -1, 1);  float_1 = None
    getitem_1 = l_position_ids_[(slice(None, None, None), None, slice(None, None, None))];  l_position_ids_ = None
    float_2 = getitem_1.float();  getitem_1 = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, False, None)
    float_3 = expand.float();  expand = None
    float_4 = float_2.float();  float_2 = None
    

tensor([[0.1006, 0.0982, 0.1000, 0.0982, 0.1100, 0.0975, 0.0986, 0.0959, 0.1019,
         0.0992]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

## Next Steps - Capturing `nn.Module` with Custom ops

There are instances where PyTorch implementation can have custom ops - for instance where the programmer wants to force `kernel` fusion, they can define a custom op as such and would need `torch.compile` to respect that. And similarly in cases where there ops on `numpy` or `scipy` defined.

In [10]:
my_lib = torch.library.Library("scale_custom_op", "DEF")

# Step 2: Define the custom op schema
my_lib.define("scale_by_max(Tensor input) -> Tensor")

def scale_by_max(input: torch.Tensor) -> torch.Tensor:
    max_value = torch.max(input)
    return input * max_value

# Use IMPL to register the implementation
impl_lib = torch.library.Library("scale_custom_op", "IMPL")
impl_lib.impl("scale_by_max", scale_by_max, "CPU")

In [11]:
# Verify if the op was successfully registered - 
# List all registered operators for the CPU backend
torch._C._dispatch_print_registrations_for_dispatch_key("CPU")

aten::hardsigmoid.out
aten::reflection_pad3d_backward
aten::logical_and.out
aten::quantized_lstm.input_legacy
aten::linalg_eig.out
aten::log_sigmoid_backward.grad_input
prepacked::conv2d_transpose_clamp_run
aten::leaky_relu.out
aten::_log_softmax_backward_data
aten::_foreach_pow_.ScalarList
aten::atanh.out
aten::sgn
aten::_to_sparse_bsr
aten::ge.Scalar
aten::mean.out
aten::geqrf.a
aten::_foreach_pow_.Scalar
aten::xlogy.Tensor
aten::isin.Scalar_Tensor
aten::_foreach_cos_
aten::std.correction_out
aten::silu_backward.grad_input
aten::atan2
aten::_foreach_sinh
aten::_test_optional_filled_intlist
aten::bucketize.Tensor
aten::_foreach_log_
aten::bitwise_and.Tensor_out
aten::_foreach_acos
aten::lt.Tensor
aten::pow.Tensor_Scalar
aten::_foreach_maximum_.Scalar
aten::round
aten::__irshift__.Scalar
aten::special_laguerre_polynomial_l.out
aten::pow.Scalar
aten::sin.out
aten::masked_fill_.Tensor
aten::softplus_backward.grad_input
aten::div.Tensor_mode
aten::linalg_matrix_exp
aten::atanh
aten::_fore

In [12]:
# Test your custom operator
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.ops.scale_custom_op.scale_by_max(x)
print(result)  # Output: tensor([3.0, 6.0, 9.0])

tensor([3., 6., 9.])


In [13]:
class LlamaWithCustomOp(nn.Module):
    def __init__(self, llama_model_name: str, output_dim: int):
        super(LlamaWithCustomOp, self).__init__()

        # Load the LLaMA model
        full_llama = AutoModel.from_pretrained(llama_model_name)

        # Extract and store the embedding layer
        self.embed_tokens = full_llama.embed_tokens

        # Extract and store the first decoder layer
        self.first_layer = full_llama.layers[0]

        # Linear layer to map to output dimensions
        llama_hidden_dim = full_llama.config.hidden_size
        self.linear = nn.Linear(llama_hidden_dim, output_dim)

        # Softmax for output probabilities
        self.softmax = nn.Softmax(dim=-1)

    # Explicit typing of input_ids and attention_mask for TorchScript
    def forward(self, input_ids, custom_forward_fn=None):
        # Generate embeddings
        embeddings = self.embed_tokens(input_ids)

        # Check if position_ids need to be explicitly handled
        position_ids = torch.arange(0, input_ids.size(1), device=input_ids.device).unsqueeze(0)

        # Pass through the first layer with position_ids
        layer_output = self.first_layer(embeddings, position_ids=position_ids)[0]

        # Pool the output (mean along sequence dimension)
        pooled_output = torch.mean(layer_output, dim=1)

        # Map to output dimension
        logits = self.linear(pooled_output)

        if custom_forward_fn is not None:
            # Apply the custom operation
            custom_logits = custom_forward_fn(logits)
        else:
            custom_logits = logits

        # Apply softmax
        probs = self.softmax(custom_logits)

        return probs


In [14]:
# Initialize the model
torch._dynamo.reset()
model_w_custom_op = LlamaWithCustomOp(llama_model_name, output_dim).to("cuda")

# Step 1: Analyze with custom_forward_fn=None
print("=== Explanation for custom_forward_fn=None ===")
explanation_none = torch._dynamo.explain(model_w_custom_op)(input_ids, custom_forward_fn=None)
print(explanation_none)


[2024-12-28 16:22:37,730] [15/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:22:37,731] [15/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:22:37,731] [15/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids, custom_forward_fn=None):
[2024-12-28 16:22:37,731] [15/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:22:37,732] [15/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:22:37,732] [15/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:24
[2024-12-28 16:22:37,732] [15/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_token

=== Explanation for custom_forward_fn=None ===


[2024-12-28 16:22:37,929] [16/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR float32 [NullVariable, GetAttrVariable(TensorVariable(), to), TorchVariable(<module 'torch' from '/home/gaurav/anaconda3/lib/python3.11/site-packages/torch/__init__.py'>)]
[2024-12-28 16:22:37,930] [16/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE PRECALL 1 [NullVariable, GetAttrVariable(TensorVariable(), to), ConstantVariable(dtype)]
[2024-12-28 16:22:37,930] [16/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL 1 [NullVariable, GetAttrVariable(TensorVariable(), to), ConstantVariable(dtype)]
[2024-12-28 16:22:37,930] [16/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call to from forward /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:122 (inline depth: 2)
[2024-12-28 16:22:37,930] [16/0] torch._dynamo.output_graph.__trace_call: [DEBUG]         hidden_states = hidden_states.to(torch.float32)
[2024-12-28 16:22:37,930] [16/0] torch

Graph Count: 6
Graph Break Count: 5
Op Count: 44
Break Reasons:
  Break Reason 1:
    User Stack:
      <FrameSummary file /tmp/ipykernel_414314/1905900009.py, line 30 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 405 in forward>
  Break Reason 2:
    User Stack:
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/mod

In [15]:
# Step 2: Analyze with custom_forward_fn=custom_scale_fn (no reset here)
print("\n=== Explanation for custom_forward_fn=custom_scale_fn ===")
explanation_custom = torch._dynamo.explain(model_w_custom_op)(input_ids, custom_forward_fn=scale_by_max)
print(explanation_custom)


[2024-12-28 16:23:43,457] [21/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:23:43,458] [21/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:23:43,458] [21/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids, custom_forward_fn=None):
[2024-12-28 16:23:43,458] [21/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:23:43,459] [21/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:23:43,459] [21/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:24
[2024-12-28 16:23:43,459] [21/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_token


=== Explanation for custom_forward_fn=custom_scale_fn ===


[2024-12-28 16:23:43,658] [22/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:695
[2024-12-28 16:23:43,659] [22/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:695
[2024-12-28 16:23:43,659] [22/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(
[2024-12-28 16:23:43,659] [22/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['hidden_states'] (1, 7, 2048) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None, None]
[2024-12-28 16:23:43,660] [22/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['position_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:23:43,661] [22/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE

Graph Count: 6
Graph Break Count: 5
Op Count: 46
Break Reasons:
  Break Reason 1:
    User Stack:
      <FrameSummary file /tmp/ipykernel_414314/1905900009.py, line 30 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 405 in forward>
  Break Reason 2:
    User Stack:
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/mod

In [16]:
# Wrap your model with the debug callback
model_custom_optimized = torch._dynamo.optimize(debug_callback)(model_w_custom_op)

# Run your model to trigger the tracing
model_custom_optimized(input_ids)

[2024-12-28 16:23:48,606] [27/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:23:48,607] [27/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:23:48,607] [27/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids, custom_forward_fn=None):
[2024-12-28 16:23:48,607] [27/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:23:48,608] [27/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:23:48,609] [27/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:24
[2024-12-28 16:23:48,609] [27/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_token




def forward(self, L_input_ids_ : torch.Tensor):
    l_input_ids_ = L_input_ids_
    l__self___embed_tokens = self.L__self___embed_tokens(l_input_ids_);  l_input_ids_ = None
    arange = torch.arange(0, 7, device = device(type='cuda', index=0))
    unsqueeze = arange.unsqueeze(0);  arange = None
    return (l__self___embed_tokens, unsqueeze)
    



def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    to = l_hidden_states_.to(torch.float32);  l_hidden_states_ = None
    pow_1 = to.pow(2)
    mean = pow_1.mean(-1, keepdim = True);  pow_1 = None
    add = mean + 1e-05;  mean = None
    rsqrt = torch.rsqrt(add);  add = None
    mul = to * rsqrt;  to = rsqrt = None
    l__self___input_layernorm_weight = self.L__self___input_layernorm_weight
    to_1 = mul.to(torch.float32);  mul = None
    mul_1 = l__self___input_layernorm_weight * to_1;  l__self___input_layernorm_weight = to_1 = None
    return (mul_1,)
    


[2024-12-28 16:23:48,987] [28/0] torch._dynamo.guards.__guards: [DEBUG] str(G['torch'].float32) == 'torch.float32'                    # hidden_states = hidden_states.to(torch.float32)  # transformers/models/llama/modeling_llama.py:122 in forward
[2024-12-28 16:23:48,988] [28/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_hooks, 8829024)  # if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks  # nn/modules/module.py:1524 in _call_impl
[2024-12-28 16:23:48,988] [28/0] torch._dynamo.guards.__guards: [DEBUG] set(G['__import_torch_dot_nn_dot_modules_dot_module']._global_forward_hooks.keys()) == set()  # if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks  # nn/modules/module.py:1524 in _call_impl
[2024-12-28 16:23:48,989] [28/0] torch._dynamo.guards.__guards: [DEBUG] ___check_type_id(G['__import_torch_dot_n




def forward(self, L_hidden_states_ : torch.Tensor):
    l_hidden_states_ = L_hidden_states_
    l__self___q_proj = self.L__self___q_proj(l_hidden_states_)
    l__self___k_proj = self.L__self___k_proj(l_hidden_states_)
    l__self___v_proj = self.L__self___v_proj(l_hidden_states_);  l_hidden_states_ = None
    view = l__self___q_proj.view(1, 7, 32, 64);  l__self___q_proj = None
    transpose = view.transpose(1, 2);  view = None
    view_1 = l__self___k_proj.view(1, 7, 8, 64);  l__self___k_proj = None
    transpose_1 = view_1.transpose(1, 2);  view_1 = None
    view_2 = l__self___v_proj.view(1, 7, 8, 64);  l__self___v_proj = None
    transpose_2 = view_2.transpose(1, 2);  view_2 = None
    return (transpose, transpose_1, transpose_2)
    


[2024-12-28 16:23:49,263] [30/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE UNARY_NEGATIVE None [NullVariable, TorchVariable(<built-in method cat of type object at 0x7bf5a531cde0>), TensorVariable()]
[2024-12-28 16:23:49,263] [30/0] torch._dynamo.output_graph.__trace_call: [DEBUG] TRACE FX call neg_1 from rotate_half /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:250 (inline depth: 2)
[2024-12-28 16:23:49,263] [30/0] torch._dynamo.output_graph.__trace_call: [DEBUG]     return torch.cat((-x2, x1), dim=-1)
[2024-12-28 16:23:49,263] [30/0] torch._dynamo.output_graph.__trace_call: [DEBUG]                       ^^^
[2024-12-28 16:23:49,264] [30/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x1 [NullVariable, TorchVariable(<built-in method cat of type object at 0x7bf5a531cde0>), TensorVariable()]
[2024-12-28 16:23:49,264] [30/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 2 [NullVariable, TorchVariable(<built-in me




def forward(self, L_position_ids_ : torch.Tensor, L_query_states_ : torch.Tensor, L_key_states_ : torch.Tensor, L_value_states_ : torch.Tensor):
    l_position_ids_ = L_position_ids_
    l_query_states_ = L_query_states_
    l_key_states_ = L_key_states_
    l_value_states_ = L_value_states_
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    l__self___rotary_emb_inv_freq = self.L__self___rotary_emb_inv_freq
    getitem = l__self___rotary_emb_inv_freq[(None, slice(None, None, None), None)];  l__self___rotary_emb_inv_freq = None
    float_1 = getitem.float();  getitem = None
    expand = float_1.expand(1, -1, 1);  float_1 = None
    getitem_1 = l_position_ids_[(slice(None, None, None), None, slice(None, None, None))];  l_position_ids_ = None
    float_2 = getitem_1.float();  getitem_1 = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, False, None)
    float_3 = expand.float();  expand = None
    float_4 = float_2.float();  float_2 = None
    

tensor([[0.0991, 0.0959, 0.1000, 0.0948, 0.1064, 0.0984, 0.0997, 0.1023, 0.1015,
         0.1020]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

### Understanding the TorchDynamo behavior further

As seen in the implementation for the custom op - `scale_by_max`, it is defined for the CPU backend.

So obviously, the next thing I wanted to understand was what happens if I move the model to a CUDA device, i.e. Nvidia GPU, and `TorchDynamo` encounters a custom-op with only a CPU backend.

The next few cells talk about this in detail

In [12]:
# 1. TorchDynamo captures graphs instead of executing operations directly:
# When we wrap the model with torch._dynamo.optimize(debug_callback), TorchDynamo intercepts the Python bytecode and traces the computation graph.
# During this tracing phase, TorchDynamo does not execute operations immediately. Instead, it records the operations (including any custom operator) into an FX graph.
# Since the actual execution of scale_by_max was deferred, no error was triggered at this stage.
#
# Using decorators like `@torch._dynamo.disable`, we can force TorchDynamo to insert a GraphBreak for the `custom_forward_fn` operation.

In [17]:
# Enable only graph breaks in TORCH LOGS
os.environ["TORCH_LOGS"] = "+graph_breaks"

custom_lib = torch.library.Library("scale_op_eager", "DEF")

# Step 2: Define the custom op schema
custom_lib.define("scale_by_max_eager(Tensor input) -> Tensor")

@torch._dynamo.disable
def scale_by_max_eager(input: torch.Tensor) -> torch.Tensor:
    max_value = torch.max(input)
    return input * max_value

# Use IMPL to register the implementation
impl_lib = torch.library.Library("scale_op_eager", "IMPL")
impl_lib.impl("scale_by_max_eager", scale_by_max_eager, "CPU")

In [18]:
# Custom op eager
explanation_custom = torch._dynamo.explain(model_w_custom_op)(input_ids, custom_forward_fn=scale_by_max_eager)
print(explanation_custom)


[2024-12-28 16:24:08,900] [33/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:24:08,901] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:22
[2024-12-28 16:24:08,901] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         def forward(self, input_ids, custom_forward_fn=None):
[2024-12-28 16:24:08,901] [33/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['input_ids'] (1, 7) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-12-28 16:24:08,902] [33/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RESUME 0 []
[2024-12-28 16:24:08,902] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line forward /tmp/ipykernel_414314/1905900009.py:24
[2024-12-28 16:24:08,902] [33/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             embeddings = self.embed_token

Graph Count: 7
Graph Break Count: 6
Op Count: 44
Break Reasons:
  Break Reason 1:
    User Stack:
      <FrameSummary file /tmp/ipykernel_414314/1905900009.py, line 30 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1527 in _call_impl>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 405 in forward>
  Break Reason 2:
    User Stack:
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py, line 734 in forward>
      <FrameSummary file /home/gaurav/anaconda3/lib/python3.11/site-packages/torch/nn/modules/mod

##### Capturing expected output for Graph break at `custom_forward_fn`

[2024-12-28 16:24:09,795] [38/0] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='call torch._dynamo.disable() wrapped function <function scale_by_max_eager at 0x7bf4c59de0c0>', user_stack=[<FrameSummary file /tmp/ipykernel_414314/1905900009.py, line 40 in <resume in forward>>], graph_break=True)