<a href="https://colab.research.google.com/github/ksharat45/Pytorch/blob/main/torch_compile.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [2]:
import torch.nn as nn

In [4]:
class MyModel(nn.Module):
    def forward(self, x):
        print("➡️ Python forward executing")
        y = x * 2
        z = torch.relu(y)
        return z.sum()

model = MyModel()
x = torch.randn(4, 4)

print("Eager output:", model(x))

➡️ Python forward executing
Eager output: tensor(13.7301)


In [5]:
import torch._dynamo as dynamo

dynamo.config.verbose = True
dynamo.config.suppress_errors = False

In [6]:
compiled_model = torch.compile(
    model,
    backend="inductor",   # TorchInductor
    fullgraph=False       # allow graph breaks
)

In [7]:
print("\n--- First compiled run ---")
out = compiled_model(x)
print("Compiled output:", out)


--- First compiled run ---
➡️ Python forward executing
Compiled output: tensor(13.7301)


In [8]:
from torch._dynamo import explain

explanation = explain(model, x)
print(explanation)


➡️ Python forward executing
Graph Count: 1
Graph Break Count: 0
Op Count: 2
Break Reasons:
Ops per Graph:
  Ops 1:
    <built-in function mul>
    <built-in method relu of type object at 0x7f02ef3828c0>
Out Guards:
  Guard 1:
    Name: ''
    Source: global
    Create Function: AUTOGRAD_SAVED_TENSORS_HOOKS
    Guard Types: ['AUTOGRAD_SAVED_TENSORS_HOOKS']
    Code List: ['torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f02cd8cf0b0; to 'Tensor' at 0x7f02d2e6d270>
    Guarded Class Weakref: <weakref at 0x7f03041c54e0; 

  return func(*newargs, **newkeywargs)


In [9]:
import torch.fx as fx

graph = fx.symbolic_trace(model)
print("\nFX Graph:")
print(graph.graph)

➡️ Python forward executing

FX Graph:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.relu](args = (%mul,), kwargs = {})
    %sum_1 : [num_users=1] = call_method[target=sum](args = (%relu,), kwargs = {})
    return sum_1


In [10]:
#Modify model:

class BreakModel(nn.Module):
    def forward(self, x):
        print("Python print → graph break")
        if x.sum().item() > 0:
            x = x * 2
        return x.relu().sum()

bm = BreakModel()
compiled_bm = torch.compile(bm)

compiled_bm(torch.randn(4,4))

W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0] Graph break from `Tensor.item()`, consider setting:
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0]     torch._dynamo.config.capture_scalar_outputs = True
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0] or:
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0] to include these operations in the captured graph.
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0] 
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0] Graph break: from user code at:
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0]   File "/tmp/ipython-input-2944648084.py", line 6, in torch_dynamo_resume_in_forward_at_5
W1222 09:32:29.530000 169 torch/_dynamo/variables/tensor.py:1048] [1/0]     if x.sum().item() 

Python print → graph break


tensor(4.3360)

In [11]:
compiled_strict = torch.compile(
    bm,
    fullgraph=True
)

compiled_strict(torch.randn(4,4))

Python print → graph break


tensor(20.4656)

In [12]:
import os
print(os.environ.get("TORCHINDUCTOR_CACHE_DIR"))

/tmp/torchinductor_root


In [13]:
import torch
torch._inductor.config.debug = True

In [14]:
compiled_bm(torch.randn(4,4))

Python print → graph break


tensor(13.5818)