https://shashankprasanna.com/workshops/a-tour-of-pytorch2/4_inspecting_torch_compile/inspecting_torch_compile/

In [2]:
!pip install matplotlib -q

In [7]:
import torch
import math
import os
import matplotlib.pyplot as plt
from torch import optim
import torch._dynamo
from torchvision import models
from torch.profiler import profile, record_function, ProfilerActivity

pi = math.pi
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"


def fn(x):
    return torch.sin(x) ** 2 + torch.cos(x) ** 2


torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)

out = fn(x)
torch.linalg.norm(out - 1) <= 1e-4

tensor(True, device='cuda:0')

In [8]:
torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)


def inspect_backend(gm, sample_inputs):
    gm.print_readable()
    return gm.forward


torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)

out = compiled_model(x)

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[1000000]"):
        l_x_ = L_x_
        
         # File: /tmp/ipykernel_745/1194271619.py:15 in fn, code: return torch.sin(x) ** 2 + torch.cos(x) ** 2
        sin: "f32[1000000]" = torch.sin(l_x_)
        pow_1: "f32[1000000]" = sin ** 2;  sin = None
        cos: "f32[1000000]" = torch.cos(l_x_);  l_x_ = None
        pow_2: "f32[1000000]" = cos ** 2;  cos = None
        add: "f32[1000000]" = pow_1 + pow_2;  pow_1 = pow_2 = None
        return (add,)
        


In [9]:
import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch._functorch.aot_autograd import aot_module_simplified


def inspect_backend(gm, sample_inputs):
    def fw(gm, sample_inputs):
        gm.print_readable()
        # g = FxGraphDrawer(gm, "fn")
        # with open("forward.svg", "wb") as f:
        #     f.write(g.get_dot_graph().create_svg())
        return gm.forward

    def bw(gm, sample_inputs):
        gm.print_readable()
        # g = FxGraphDrawer(gm, "fn")
        # with open("backward.svg", "wb") as f:
        #     f.write(g.get_dot_graph().create_svg())
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(gm, sample_inputs, fw_compiler=fw, bw_compiler=bw)


torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)

out = compiled_model(x).sum().backward()

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[1000000]"):
         # File: /tmp/ipykernel_745/1194271619.py:15 in fn, code: return torch.sin(x) ** 2 + torch.cos(x) ** 2
        sin: "f32[1000000]" = torch.ops.aten.sin.default(primals_1)
        pow_1: "f32[1000000]" = torch.ops.aten.pow.Tensor_Scalar(sin, 2)
        cos: "f32[1000000]" = torch.ops.aten.cos.default(primals_1)
        pow_2: "f32[1000000]" = torch.ops.aten.pow.Tensor_Scalar(cos, 2)
        add: "f32[1000000]" = torch.ops.aten.add.Tensor(pow_1, pow_2);  pow_1 = pow_2 = None
        return (add, primals_1, sin, cos)
        
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[1000000]", sin: "f32[1000000]", cos: "f32[1000000]", tangents_1: "f32[1000000]"):
         # File: /tmp/ipykernel_745/1194271619.py:15 in fn, code: return torch.sin(x) ** 2 + torch.cos(x) ** 2
        pow_3: "f32[1000000]" = torch.ops.aten.pow.Tensor_Scalar(cos, 1.0);  cos = None
        mul: "f32[1000



In [8]:
def forward(
    self,
    primals_1: "f32[1000000]",
    sin: "f32[1000000]",
    cos: "f32[1000000]",
    tangents_1: "f32[1000000]",
):
    # File: /tmp/ipykernel_704917/1194271619.py:15 in fn, code: return torch.sin(x) ** 2 + torch.cos(x) ** 2
    pow_3: "f32[1000000]" = torch.ops.aten.pow.Tensor_Scalar(cos, 1.0)
    cos = None
    mul: "f32[1000000]" = torch.ops.aten.mul.Scalar(pow_3, 2.0)
    pow_3 = None
    mul_1: "f32[1000000]" = torch.ops.aten.mul.Tensor(tangents_1, mul)
    mul = None
    sin_1: "f32[1000000]" = torch.ops.aten.sin.default(primals_1)
    neg: "f32[1000000]" = torch.ops.aten.neg.default(sin_1)
    sin_1 = None
    mul_2: "f32[1000000]" = torch.ops.aten.mul.Tensor(mul_1, neg)
    mul_1 = neg = None
    pow_4: "f32[1000000]" = torch.ops.aten.pow.Tensor_Scalar(sin, 1.0)
    sin = None
    mul_3: "f32[1000000]" = torch.ops.aten.mul.Scalar(pow_4, 2.0)
    pow_4 = None
    mul_4: "f32[1000000]" = torch.ops.aten.mul.Tensor(tangents_1, mul_3)
    tangents_1 = mul_3 = None
    cos_1: "f32[1000000]" = torch.ops.aten.cos.default(primals_1)
    primals_1 = None
    mul_5: "f32[1000000]" = torch.ops.aten.mul.Tensor(mul_4, cos_1)
    mul_4 = cos_1 = None

    # File: /tmp/ipykernel_704917/1194271619.py:15 in fn, code: return torch.sin(x) ** 2 + torch.cos(x) ** 2
    add_1: "f32[1000000]" = torch.ops.aten.add.Tensor(mul_2, mul_5)
    mul_2 = mul_5 = None
    return (add_1,)

In [10]:
torch._dynamo.reset()
x = x.to(device)
compiled_model = torch.compile(
    fn,
    # backend="inductor",
    options={"trace.enabled": True, "trace.graph_diagram": True},
)

out = compiled_model(x).sum().backward()
#

W0615 23:35:22.046000 745 site-packages/torch/_inductor/debug.py:72] [0/0] draw_buffers() requires `graphviz` package
W0615 23:35:22.261000 745 site-packages/torch/_inductor/debug.py:454] [0/0] model__6_forward_11 debug trace: /code/torch_compile_debug/run_2025_06_15_23_35_22_005250-pid_745/torchinductor/model__6_forward_11.0
W0615 23:35:22.316000 745 site-packages/torch/_inductor/debug.py:72] [0/0] draw_buffers() requires `graphviz` package
W0615 23:35:22.513000 745 site-packages/torch/_inductor/debug.py:454] [0/0] model__6_backward_13 debug trace: /code/torch_compile_debug/run_2025_06_15_23_35_22_005250-pid_745/torchinductor/model__6_backward_13.1


In [11]:
device

device(type='cuda')

In [12]:
x

tensor([0.4963, 0.7682, 0.0885,  ..., 0.6331, 0.9980, 0.5057], device='cuda:0',
       grad_fn=<ToCopyBackward0>)