In [1]:
import os

os.environ["TORCH_COMPILE_DEBUG"] = "1"  # Dumps files in `torch_compile_debug/`

# Choose which logs to enable
# os.environ["TORCH_LOGS"] = "+dynamo,+aot_graphs,+inductor,+guards,+graph"
os.environ["TORCH_LOGS"] = "+inductor,+aot_graphs"

import torch
import torch.nn as nn
from torch._dynamo import optimize

In [2]:
# Create a simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, input_size=32, hidden_size=64, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [3]:
model = SimpleMLP()

# Create Dummy Input
input_data = torch.randn(1, 32)

# 2. Enable TorchDynamo with TorchInductor Backend
# We use torch.compile to enable the JIT compilation process.
# Specifying backend='inductor' tells TorchDynamo to use TorchInductor
# for the actual code generation.
compiled_model = torch.compile(model, backend='inductor')

In [4]:
# 3. Dynamic Graph Capture (TorchDynamo)
# When you first call the compiled model with a specific input signature,
# TorchDynamo intercepts the execution and dynamically builds a graph
# representing the operations.
# We can trigger this by running the model once. This first execution
# will involve the graph capture.
output = compiled_model(input_data)
print("Output from the compiled model:", output)

[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_0 <eval_with_key>.0 opcode       name            target          args                 kwargs
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] -----------  --------------  --------------  -------------------  --------
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_x_            L_x_            ()                   {}
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module  l__self___fc1   L__self___fc1   (l_x_,)              {}
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module  l__self___relu  L__self___relu  (l__self___fc1,)     {}
[2024-12-20 23:50:49,981] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module  l__self___fc2   L__self___fc2   (l__self___relu,)   

Output from the compiled model: tensor([[ 0.2117,  0.1899,  0.0066, -0.1359, -0.1234, -0.0568, -0.2092, -0.0407,
         -0.0576,  0.1012]], grad_fn=<CompiledFunctionBackward>)


### Generating Python code from the aot_graph output

Capture the logs generated by enabling `aot_graph` and feeding it into Gemini 2.0 Thinking model and ChatGPT.
Then compare the code for the low-level ForwardGraph and BackwardGraph to make sure the capture code is correct.

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ForwardGraph(nn.Module):
    def __init__(self):
        super().__init__()
        # Assuming the primals correspond to weights and biases
        self.fc1_weight = nn.Parameter(torch.empty(64, 32))  # primals_1
        self.fc1_bias = nn.Parameter(torch.empty(64))      # primals_2
        self.fc2_weight = nn.Parameter(torch.empty(10, 64)) # primals_3
        self.fc2_bias = nn.Parameter(torch.empty(10))      # primals_4

    def forward(self, primals_5):  # Input to the forward pass
        # File: /tmp/ipykernel_18358/2190141578.py:9, code: out = self.fc1(x)
        permute = torch.permute(self.fc1_weight, [1, 0])
        addmm = torch.addmm(self.fc1_bias, primals_5, permute)

        # File: /tmp/ipykernel_18358/2190141578.py:10, code: out = self.relu(out)
        relu = torch.relu(addmm)

        # File: /tmp/ipykernel_18358/2190141578.py:11, code: out = self.fc2(out)
        permute_1 = torch.permute(self.fc2_weight, [1, 0])
        addmm_1 = torch.addmm(self.fc2_bias, relu, permute_1)
        permute_2 = torch.permute(permute_1, [1, 0])

        return addmm_1, primals_5, relu, permute_2

class BackwardGraph(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, primals_5, relu, permute_2, tangents_1):
        # File: /tmp/ipykernel_18358/2190141578.py:11, code: out = self.fc2(out)
        mm = torch.mm(tangents_1, permute_2)
        permute_3 = torch.permute(tangents_1, [1, 0])
        mm_1 = torch.mm(permute_3, relu)
        permute_4 = torch.permute(mm_1, [1, 0])
        sum_1 = torch.sum(tangents_1, dim=[0], keepdim=True)
        view = sum_1.view(10)  # Corrected line
        permute_5 = torch.permute(permute_4, [1, 0])

        # File: /tmp/ipykernel_18358/2190141578.py:10, code: out = self.relu(out)
        le = torch.le(relu, 0)
        full_default = torch.full([], 0.0, dtype=torch.float32, layout=torch.strided, device=tangents_1.device, pin_memory=False)
        where = torch.where(le, full_default, mm)

        # File: /tmp/ipykernel_18358/2190141578.py:9, code: out = self.fc1(x)
        permute_6 = torch.permute(where, [1, 0])
        mm_2 = torch.mm(permute_6, primals_5)
        permute_7 = torch.permute(mm_2, [1, 0])
        sum_2 = torch.sum(where, dim=[0], keepdim=True)
        view_1 = sum_2.view(64)  # Corrected line
        permute_8 = torch.permute(permute_7, [1, 0])

        return permute_8, view_1, permute_5, view, None

# Example Usage (assuming you have an input 'x'):
model = ForwardGraph()
# Initialize parameters (important for a real model)
with torch.no_grad():
    model.fc1_weight[:] = torch.randn_like(model.fc1_weight)
    model.fc1_bias[:] = torch.randn_like(model.fc1_bias)
    model.fc2_weight[:] = torch.randn_like(model.fc2_weight)
    model.fc2_bias[:] = torch.randn_like(model.fc2_bias)

x = torch.randn(1, 32)
output, primals_5_saved, relu_saved, permute_2_saved = model(x)

# Assume you have the gradients of the output with respect to the output (e.g., from a loss function)
output_gradients = torch.randn_like(output)

backward_graph = BackwardGraph()
gradients = backward_graph(primals_5_saved, relu_saved, permute_2_saved, output_gradients)



In [6]:
# Initialize models and input
simple_model = SimpleMLP()
low_level_forward = ForwardGraph()
low_level_backward = BackwardGraph()

# Initialize parameters consistently
with torch.no_grad():
    low_level_forward.fc1_weight[:] = simple_model.fc1.weight
    low_level_forward.fc1_bias[:] = simple_model.fc1.bias
    low_level_forward.fc2_weight[:] = simple_model.fc2.weight
    low_level_forward.fc2_bias[:] = simple_model.fc2.bias

input_tensor = torch.randn(1, 32, requires_grad=True)

# 3. Compare Forward Pass Outputs
output_simple = simple_model(input_tensor)
output_simple.retain_grad()
output_low_level, primals_5_saved, relu_saved, permute_2_saved = low_level_forward(input_tensor)
assert torch.allclose(output_simple, output_low_level, atol=1e-5), "Forward pass outputs do not match!"
print("Forward pass outputs match!")

Forward pass outputs match!


In [7]:
# 4. Compare Backward Pass Gradients
# Define a simple loss function
loss_fn = nn.MSELoss()
target = torch.randn_like(output_simple)
loss_simple = loss_fn(output_simple, target)

# Run backpropagation on the simple model
simple_model.zero_grad()

# If you need to perform multiple backward passes through the same graph 
# (e.g., for computing higher-order derivatives), you can retain the graph 
# by specifying retain_graph=True in the backward() call:
loss_simple.backward(retain_graph=True)

print(output_simple.grad)

tensor([[-0.2062, -0.2366,  0.1161, -0.0224, -0.1292, -0.2070, -0.1095,  0.0493,
          0.2279,  0.3504]])


In [8]:
# Get gradients from the simple model
grad_fc1_weight_simple = simple_model.fc1.weight.grad
grad_fc1_bias_simple = simple_model.fc1.bias.grad
grad_fc2_weight_simple = simple_model.fc2.weight.grad
grad_fc2_bias_simple = simple_model.fc2.bias.grad

# Create output gradients for the low-level backward pass
output_gradients_low_level = output_simple.grad

# Run the low-level backward pass
gradients_low_level = low_level_backward(primals_5_saved, relu_saved, permute_2_saved, output_gradients_low_level)
grad_fc1_weight_low_level, grad_fc1_bias_low_level, grad_fc2_weight_low_level, grad_fc2_bias_low_level, _ = gradients_low_level

print(f"grad_fc1_weight_simple = {grad_fc1_weight_simple.shape}")
print(f"grad_fc1_weight_low_level = {grad_fc1_weight_low_level.shape}")
print(f"grad_fc2_weight_simple = {grad_fc2_weight_simple.shape}")
print(f"grad_fc2_weight_low_level = {grad_fc2_weight_low_level.shape}")

# Compare gradients
assert torch.allclose(grad_fc1_weight_simple, grad_fc1_weight_low_level, atol=1e-5), "fc1 weight gradients do not match!"
assert torch.allclose(grad_fc1_bias_simple, grad_fc1_bias_low_level, atol=1e-5), "fc1 bias gradients do not match!"
assert torch.allclose(grad_fc2_weight_simple.T, grad_fc2_weight_low_level.T, atol=1e-5), "fc2 weight gradients do not match!" # Note the transpose
assert torch.allclose(grad_fc2_bias_simple, grad_fc2_bias_low_level, atol=1e-5), "fc2 bias gradients do not match!"

print("Backward pass gradients match!")

grad_fc1_weight_simple = torch.Size([64, 32])
grad_fc1_weight_low_level = torch.Size([64, 32])
grad_fc2_weight_simple = torch.Size([10, 64])
grad_fc2_weight_low_level = torch.Size([10, 64])
Backward pass gradients match!
