# PyTorch Graph Optimizer Essentials

**Goal**: Learn ONLY what you need to build a graph optimizer using torch.fx

**Time**: ~30 minutes

**Environment**: M2 Mac with MPS support

In [1]:
import torch
import torch.nn as nn
import torch.fx as fx
from typing import Dict

# Set up MPS device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## 1. COMPUTATIONAL GRAPHS (5 min)

### What is a Computation Graph?

A computation graph is a directed acyclic graph (DAG) where:
- **Nodes** = operations (add, multiply, conv2d, relu, etc.)
- **Edges** = data flowing between operations (tensors)
- **Graph** = complete sequence of operations from inputs to outputs

PyTorch builds this graph **dynamically** during the forward pass. When you call `z = x + y`, PyTorch creates a node representing addition and connects it to nodes for x and y.

### How PyTorch Builds Graphs During Forward Pass

When `requires_grad=True`, PyTorch builds an **autograd graph** to track operations for backpropagation. Each tensor remembers the operation that created it via `.grad_fn`. This is NOT the same as torch.fx's symbolic graph (we'll get to that), but it's the foundation.

In [2]:
# Simple example: autograd graph
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)
z = x * y  # multiplication node
w = z + x  # addition node

print(f"z.grad_fn: {z.grad_fn}")  # Shows the operation that created z
print(f"w.grad_fn: {w.grad_fn}")  # Shows the operation that created w

z.grad_fn: <MulBackward0 object at 0x10adc8ca0>
w.grad_fn: <AddBackward0 object at 0x1127d6cb0>


### Simple Example: 2-Layer MLP Graph

Let's build a tiny 2-layer MLP and understand what graph PyTorch conceptually builds. The graph has: Input → Linear1 → ReLU → Linear2 → Output.

In [3]:
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 5)
    
    def forward(self, x):
        # Graph: x → fc1 → relu → fc2 → out
        x = self.fc1(x)      # Node 1: Linear transformation
        x = torch.relu(x)    # Node 2: ReLU activation
        x = self.fc2(x)      # Node 3: Linear transformation
        return x

model = SimpleMLP().to(device)
input_tensor = torch.randn(1, 10, device=device)
output = model(input_tensor)

print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}")
print(f"\nConceptual graph: Input(1,10) → Linear(10→20) → ReLU → Linear(20→5) → Output(1,5)")

Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 5])

Conceptual graph: Input(1,10) → Linear(10→20) → ReLU → Linear(20→5) → Output(1,5)


### Exercise 1: Create a Tiny Model and Explain Its Graph

**Task**: Create a model with Conv2d → BatchNorm2d → ReLU → MaxPool2d. Describe the computation graph in words.

In [6]:
# Your code here
class TinyConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # YOUR CODE
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        # YOUR CODE
        print(f"Input:          {x.shape}") 
        
        x = self.conv1(x)
        print(f"After Conv1:    {x.shape}")  
        
        x = self.relu1(x)
        print(f"After ReLU:     {x.shape}")  
        
        x = self.pool1(x)
        print(f"After Pool1:    {x.shape}")  
        
        return x

conv_model = TinyConvNet().to(device)
test_input = torch.randn(1, 3, 32, 32, device=device)
test_output = conv_model(test_input)

Input:          torch.Size([1, 3, 32, 32])
After Conv1:    torch.Size([1, 16, 32, 32])
After ReLU:     torch.Size([1, 16, 32, 32])
After Pool1:    torch.Size([1, 16, 16, 16])


In [5]:
# SOLUTION
class TinyConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(16)
        self.pool = nn.MaxPool2d(2, 2)
    
    def forward(self, x):
        # Graph: x → conv2d → batch_norm → relu → max_pool → out
        x = self.conv(x)      # Node 1: Convolution
        x = self.bn(x)        # Node 2: Batch normalization
        x = torch.relu(x)     # Node 3: ReLU activation
        x = self.pool(x)      # Node 4: Max pooling
        return x

conv_model = TinyConvNet().to(device)
test_input = torch.randn(1, 3, 32, 32, device=device)
test_output = conv_model(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"\nGraph: Input(1,3,32,32) → Conv2d → BatchNorm2d → ReLU → MaxPool2d → Output(1,16,16,16)")

# Verify output shape
assert test_output.shape == torch.Size([1, 16, 16, 16]), "Output shape mismatch!"

Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 16, 16, 16])

Graph: Input(1,3,32,32) → Conv2d → BatchNorm2d → ReLU → MaxPool2d → Output(1,16,16,16)


---

## 2. TORCH.FX BASICS (10 min)

### What is Symbolic Tracing?

**Symbolic tracing** captures the operations of your model into an explicit, inspectable graph WITHOUT actually executing them with real data. torch.fx runs your model with "symbolic" inputs (Proxy objects) and records every operation into a `GraphModule`. This gives you a **static representation** of the computation graph you can modify before execution.

### How torch.fx Differs from Regular PyTorch

- **Regular PyTorch**: Dynamic execution, operations run immediately, graph exists only for autograd
- **torch.fx**: Symbolic execution, captures operations into a graph IR, lets you inspect/modify before running
- **Key benefit**: You can analyze and transform the graph structure programmatically

In [7]:
# Example: Trace the SimpleMLP
mlp = SimpleMLP().to(device)

# Symbolic tracing - this doesn't run the model, just records operations
traced_mlp = fx.symbolic_trace(mlp)

print("=== Traced Graph ===")
print(traced_mlp.graph)

print("\n=== Generated Python Code ===")
print(traced_mlp.code)

# Verify it still works
test_input = torch.randn(1, 10, device=device)
original_output = mlp(test_input)
traced_output = traced_mlp(test_input)

assert torch.allclose(original_output, traced_output), "Traced model output differs!"
print("\n✓ Traced model produces identical output")

=== Traced Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%fc1,), kwargs = {})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%relu,), kwargs = {})
    return fc2

=== Generated Python Code ===



def forward(self, x):
    fc1 = self.fc1(x);  x = None
    relu = torch.relu(fc1);  fc1 = None
    fc2 = self.fc2(relu);  relu = None
    return fc2
    

✓ Traced model produces identical output


### Exercise 2: Trace a CNN and Inspect Its Nodes

**Task**: Create a simple CNN with Conv2d → ReLU → Conv2d → ReLU. Trace it and print:
1. The complete graph
2. The number of nodes
3. The generated Python code

In [8]:
# Your code here
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # YOUR CODE
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
    def forward(self, x):
        # YOUR CODE
        x = self.conv1(x) 
        x = torch.relu(x)
        x = self.conv2(x) 
        x = torch.relu(x)
        return x

my_cnn = SimpleCNN().to(device)  
trace_cnn = fx.symbolic_trace(my_cnn) 

print("=== Traced Graph ===")
print(trace_cnn.graph)

print("\n=== Generated Python Code ===")
print(trace_cnn.code)

print(f"\n=== Number of Nodes: {len(trace_cnn.graph.nodes)} ===")

=== Traced Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.relu](args = (%conv2,), kwargs = {})
    return relu_1

=== Generated Python Code ===



def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.relu(conv1);  conv1 = None
    conv2 = self.conv2(relu);  relu = None
    relu_1 = torch.relu(conv2);  conv2 = None
    return relu_1
    

=== Number of Nodes: 6 ===


In [9]:
# SOLUTION
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        return x

cnn = SimpleCNN().to(device)
traced_cnn = fx.symbolic_trace(cnn)

print("=== Graph ===")
print(traced_cnn.graph)

print(f"\n=== Number of Nodes: {len(traced_cnn.graph.nodes)} ===")

print("\n=== Generated Code ===")
print(traced_cnn.code)

# Verify
test_input = torch.randn(1, 3, 32, 32, device=device)
original = cnn(test_input)
traced = traced_cnn(test_input)
assert torch.allclose(original, traced), "Outputs differ!"
print("\n✓ Traced CNN works correctly")

=== Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.relu](args = (%conv2,), kwargs = {})
    return relu_1

=== Number of Nodes: 6 ===

=== Generated Code ===



def forward(self, x):
    conv1 = self.conv1(x);  x = None
    relu = torch.relu(conv1);  conv1 = None
    conv2 = self.conv2(relu);  relu = None
    relu_1 = torch.relu(conv2);  conv2 = None
    return relu_1
    

✓ Traced CNN works correctly


---

## 3. GRAPH INSPECTION (10 min)

### How to Iterate Over Graph Nodes

Every `GraphModule` has a `.graph` attribute containing nodes. You iterate with `graph.nodes`. Each node has properties: `op` (operation type), `target` (what's being called), `args` (inputs), and `name` (unique identifier).

In [10]:
# Iterate over nodes in traced MLP
print("=== Node Iteration ===")
for node in traced_mlp.graph.nodes:
    print(f"Node: {node.name:15} | Op: {node.op:15} | Target: {node.target}")

=== Node Iteration ===
Node: x               | Op: placeholder     | Target: x
Node: fc1             | Op: call_module     | Target: fc1
Node: relu            | Op: call_function   | Target: <built-in method relu of type object at 0x113f135f8>
Node: fc2             | Op: call_module     | Target: fc2
Node: output          | Op: output          | Target: output


### Understanding Node Types, Inputs, and Outputs

**Node types (op field)**:
- `placeholder`: Input to the graph
- `call_module`: Calls a nn.Module (e.g., nn.Linear)
- `call_function`: Calls a function (e.g., torch.relu)
- `call_method`: Calls a tensor method (e.g., .view())
- `output`: Output of the graph

**Node connections**:
- `node.args`: Tuple of input nodes
- `node.users`: Dict of nodes that use this node's output

In [11]:
# Detailed node inspection
print("=== Detailed Node Information ===")
for node in traced_mlp.graph.nodes:
    print(f"\nNode: {node.name}")
    print(f"  Op type: {node.op}")
    print(f"  Target: {node.target}")
    print(f"  Args: {node.args}")
    print(f"  Users: {list(node.users.keys())}")

=== Detailed Node Information ===

Node: x
  Op type: placeholder
  Target: x
  Args: ()
  Users: [fc1]

Node: fc1
  Op type: call_module
  Target: fc1
  Args: (x,)
  Users: [relu]

Node: relu
  Op type: call_function
  Target: <built-in method relu of type object at 0x113f135f8>
  Args: (fc1,)
  Users: [fc2]

Node: fc2
  Op type: call_module
  Target: fc2
  Args: (relu,)
  Users: [output]

Node: output
  Op type: output
  Target: output
  Args: (fc2,)
  Users: []


### Understanding Graph Structure

The graph is a linked list of nodes in execution order. You can traverse forward (via `.users`) or backward (via `.args`). This structure lets you analyze data flow and dependencies.

In [13]:
# Find all ReLU operations
print("=== Finding ReLU Nodes ===")
relu_nodes = []
for node in traced_cnn.graph.nodes:
    if node.op == 'call_function' and node.target == torch.relu:
        relu_nodes.append(node)
        print(f"Found ReLU: {node.name}")
        print(f"  Input from: {node.args[0].name if node.args else 'None'}")
        print(f"  Used by: {[user.name for user in node.users.keys()]}")

print(f"\nTotal ReLU nodes: {len(relu_nodes)}")
assert len(relu_nodes) == 2, "Should have 2 ReLU nodes"

=== Finding ReLU Nodes ===
Found ReLU: relu
  Input from: conv1
  Used by: ['conv2']
Found ReLU: relu_1
  Input from: conv2
  Used by: ['output']

Total ReLU nodes: 2


### Exercise 3: Find All 'conv2d' Operations in a Traced ResNet Block

**Task**: Create a basic ResNet block (two conv layers with a skip connection). Trace it and find all Conv2d operations. Print their names and input/output connections.

In [17]:
# Your code here
class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # YOUR CODE
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # YOUR CODE
        output = self.conv1(x)
        output = torch.relu(output)
        output = self.conv2(output)
        output = output + x
        output = torch.relu(output)
        return output
        
resnet = ResNetBlock(64).to(device)
trace_resnet = fx.symbolic_trace(resnet)

print("=== Traced Graph ===")
print(trace_resnet.graph)

print("\n=== Generated Python Code ===")
print(trace_resnet.code)

print(f"\n=== Find Conv2d Operations ===")

# 2) Define precisely what "Conv2d operation" means (no string guessing)
def is_conv2d_call(node: fx.Node, gm: fx.GraphModule) -> bool:
    if node.op != "call_module":
        return False
    submod = gm.get_submodule(node.target)  # <- API, not string parsing
    return isinstance(submod, nn.Conv2d)

# 3) Walk the graph and print connections using FX graph APIs
for node in trace_resnet.graph.nodes:
    if not is_conv2d_call(node, trace_resnet):
        continue

    conv_name = node.target  # e.g., "conv1", "conv2"

    # upstream connections (nodes that feed into this conv)
    in_nodes = list(node.all_input_nodes)   # Nodes only (cleaner than .args)
    in_names = [n.name for n in in_nodes]

    # downstream connections (nodes that consume this conv's output)
    out_nodes = list(node.users.keys())
    out_names = [n.name for n in out_nodes]

    print(f"Conv2d module: {conv_name}")
    print(f"  node.name: {node.name}")
    print(f"  inputs from: {in_names}")
    print(f"  outputs to : {out_names}")

test_input = torch.randn(1, 64, 32, 32, device=device)
original = resnet(test_input)
traced = trace_resnet(test_input)
assert torch.allclose(original, traced), "Outputs differ!"
print("\n✓ Traced ResNet block works correctly")


=== Traced Graph ===
graph():
    %x : [num_users=2] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%conv2, %x), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
    return relu_1

=== Generated Python Code ===



def forward(self, x):
    conv1 = self.conv1(x)
    relu = torch.relu(conv1);  conv1 = None
    conv2 = self.conv2(relu);  relu = None
    add = conv2 + x;  conv2 = x = None
    relu_1 = torch.relu(add);  add = None
    return relu_1
    

=== Find Conv2d Operations ===
Conv2d module: conv1
  node.name: conv1
  inputs from: ['x']
  outputs to : ['relu']
Conv2d module: conv2
  node.name: conv2
  inputs from: ['relu']
  ou

In [16]:
# SOLUTION
class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = torch.relu(out)
        out = self.conv2(out)
        out = out + identity  # Skip connection
        out = torch.relu(out)
        return out

resnet_block = ResNetBlock(64).to(device)
traced_resnet = fx.symbolic_trace(resnet_block)

print("=== Finding Conv2d Operations ===")
conv_nodes = []
for node in traced_resnet.graph.nodes:
    # Conv2d appears as call_module with nn.Conv2d target
    if node.op == 'call_module' and isinstance(traced_resnet.get_submodule(node.target), nn.Conv2d):
        conv_nodes.append(node)
        print(f"\nConv2d node: {node.name}")
        print(f"  Module: {node.target}")
        print(f"  Input from: {node.args[0].name if node.args else 'None'}")
        print(f"  Used by: {[user.name for user in node.users.keys()]}")

print(f"\nTotal Conv2d nodes: {len(conv_nodes)}")
assert len(conv_nodes) == 2, "Should have 2 Conv2d nodes"

# Verify traced model works
test_input = torch.randn(1, 64, 32, 32, device=device)
original = resnet_block(test_input)
traced = traced_resnet(test_input)
assert torch.allclose(original, traced), "Outputs differ!"
print("\n✓ Traced ResNet block works correctly")

=== Finding Conv2d Operations ===

Conv2d node: conv1
  Module: conv1
  Input from: x
  Used by: ['relu']

Conv2d node: conv2
  Module: conv2
  Input from: relu
  Used by: ['add']

Total Conv2d nodes: 2

✓ Traced ResNet block works correctly


---

## 4. BASIC GRAPH MANIPULATION (5 min)

### Core Operations: Add, Remove, Replace Nodes

Graph manipulation requires:
1. **Adding nodes**: Use `graph.call_function()`, `graph.call_module()`, etc.
2. **Removing nodes**: Use `graph.erase_node()` after removing all users
3. **Replacing nodes**: Use `node.replace_all_uses_with()` then erase old node
4. **Finalize**: Call `graph.lint()` to validate and `traced.recompile()` to update code

### Example: Adding a Node

Let's insert a print statement in the middle of our graph.

In [18]:
# Create a fresh traced model
mlp_for_add = SimpleMLP().to(device)
traced_add = fx.symbolic_trace(mlp_for_add)

print("=== Original Graph ===")
print(traced_add.graph)

# Find the relu node and insert a clone operation after it
for node in traced_add.graph.nodes:
    if node.op == 'call_function' and node.target == torch.relu:
        with traced_add.graph.inserting_after(node):
            # Add a clone operation (identity operation for demonstration)
            new_node = traced_add.graph.call_method('clone', args=(node,))
            # Replace all uses of relu with the clone
            node.replace_all_uses_with(new_node)
            # But clone should use relu as input
            new_node.args = (node,)
        break

traced_add.graph.lint()  # Validate graph
traced_add.recompile()   # Regenerate Python code

print("\n=== Modified Graph ===")
print(traced_add.graph)

# Verify it still works
test_input = torch.randn(1, 10, device=device)
output = traced_add(test_input)
print(f"\n✓ Modified model output shape: {output.shape}")

=== Original Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%fc1,), kwargs = {})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%relu,), kwargs = {})
    return fc2

=== Modified Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%fc1,), kwargs = {})
    %clone : [num_users=1] = call_method[target=clone](args = (%relu,), kwargs = {})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%clone,), kwargs = {})
    return fc2

✓ Modified model output shape: torch.Size([1, 5])


### Example: Removing a Node

Let's remove the ReLU activation from a model. This requires replacing it with an identity operation.

In [19]:
# Create a fresh traced model
mlp_for_remove = SimpleMLP().to(device)
traced_remove = fx.symbolic_trace(mlp_for_remove)

print("=== Original Graph ===")
print(traced_remove.graph)

# Find and remove relu
for node in traced_remove.graph.nodes:
    if node.op == 'call_function' and node.target == torch.relu:
        # Replace relu with its input (bypass it)
        node.replace_all_uses_with(node.args[0])
        traced_remove.graph.erase_node(node)
        break

traced_remove.graph.lint()
traced_remove.recompile()

print("\n=== Graph After ReLU Removal ===")
print(traced_remove.graph)

# Verify - output will differ since we removed activation
test_input = torch.randn(1, 10, device=device)
output = traced_remove(test_input)
print(f"\n✓ Modified model output shape: {output.shape}")

=== Original Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%fc1,), kwargs = {})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%relu,), kwargs = {})
    return fc2

=== Graph After ReLU Removal ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%fc1,), kwargs = {})
    return fc2

✓ Modified model output shape: torch.Size([1, 5])


### Exercise 4: Replace All ReLU with GELU

**Task**: Take the SimpleCNN model, trace it, and replace all ReLU activations with GELU. Verify the transformation works correctly.

In [24]:
# Your code here
cnn_for_replace = SimpleCNN().to(device)
traced_replace = fx.symbolic_trace(cnn_for_replace)

print("=== Original Graph ===")
print(traced_replace.graph)

# Find all ReLU nodes and replace with GELU
nodes_to_replace = []
for node in traced_replace.graph.nodes:
    if node.op == 'call_function' and node.target == torch.relu:
        nodes_to_replace.append(node)

print(f"\nFound {len(nodes_to_replace)} ReLU nodes to replace")

for relu_node in nodes_to_replace:
    with traced_replace.graph.inserting_after(relu_node):
        # Create GELU node with same input as ReLU
        gelu_node = traced_replace.graph.call_function(
            torch.nn.functional.gelu,
            args=relu_node.args
        )
        # Replace all uses of ReLU with GELU
        relu_node.replace_all_uses_with(gelu_node)
    
    # Erase the old ReLU node
    traced_replace.graph.erase_node(relu_node)

traced_replace.graph.lint()
traced_replace.recompile()

print("\n=== Modified Graph ===")
print(traced_replace.graph)

print("\n=== Generated Code ===")
print(traced_replace.code)

# Verify ReLU is gone and GELU is present
has_relu = any(node.target == torch.relu for node in traced_replace.graph.nodes if node.op == 'call_function')
has_gelu = any(node.target == torch.nn.functional.gelu for node in traced_replace.graph.nodes if node.op == 'call_function')

assert not has_relu, "ReLU still present in graph!"
assert has_gelu, "GELU not found in graph!"

# Test execution
test_input = torch.randn(1, 3, 32, 32, device=device)
output = traced_replace(test_input)
print(f"\n✓ Modified model works! Output shape: {output.shape}")
print(f"✓ Successfully replaced {len(nodes_to_replace)} ReLU nodes with GELU")


=== Original Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.relu](args = (%conv2,), kwargs = {})
    return relu_1

Found 2 ReLU nodes to replace

=== Modified Graph ===
graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %gelu : [num_users=1] = call_function[target=torch._C._nn.gelu](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%gelu,), kwargs = {})
    %gelu_1 : [num_users=1] = call_function[target=torch._C._nn.gelu](args = (%conv2,), kwargs = {})
    return gelu_1

=== Generated Code ===



def forward(self, x):
    c

---

## Summary

You now know:

1. **Computation Graphs**: DAGs where nodes are operations and edges are tensors
2. **torch.fx**: Symbolic tracing captures model operations into inspectable/modifiable graphs
3. **Graph Inspection**: Iterate nodes, check types (placeholder, call_module, call_function, output), analyze connections
4. **Graph Manipulation**: Add nodes with `graph.call_*()`, remove with `erase_node()`, replace with `replace_all_uses_with()`

**Next Steps for Building a Graph Optimizer**:
- Learn pattern matching: `fx.passes.graph_matcher`
- Study fusion patterns: combine ops like conv+bn+relu into single optimized kernels
- Explore compiler passes: dead code elimination, constant folding
- Understand memory optimization: inplace operations, buffer reuse

**Key Workflow**:
1. Trace model with `fx.symbolic_trace()`
2. Iterate and analyze with `graph.nodes`
3. Modify graph (add/remove/replace nodes)
4. Validate with `graph.lint()`
5. Recompile with `module.recompile()`
6. Test modified model

## Quick Reference Card

```python
# Trace a model
traced = fx.symbolic_trace(model)

# Iterate nodes
for node in traced.graph.nodes:
    print(node.op, node.target, node.args)

# Node types
node.op in ['placeholder', 'call_module', 'call_function', 'call_method', 'output']

# Add node after current node
with graph.inserting_after(node):
    new = graph.call_function(torch.relu, args=(node,))

# Replace node
old_node.replace_all_uses_with(new_node)
graph.erase_node(old_node)

# Finalize changes
graph.lint()
traced.recompile()
```