- https://docs.pytorch.org/docs/stable/torch.compiler_profiling_torch_compile.html

In [1]:
import torch
from torchvision.models import resnet18

torch.profiler._utils._init_for_cuda_graphs()

device = "cuda"  # or 'cpu', 'xpu', etc.
model = resnet18().to(device)

inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]

model_c = torch.compile(model)


def fwd_bwd(inp):
    out = model_c(inp)
    out.sum().backward()


# warm up
fwd_bwd(inputs[0])

with torch.profiler.profile() as prof:
    for i in range(1, 4):
        fwd_bwd(inputs[i])
        prof.step()

prof.export_chrome_trace("trace.json")

In [2]:
import torch
from torchvision.models import resnet18

# user can switch between cuda and xpu
device = "cuda"
model = resnet18().to(device)
inputs = [torch.randn((5, 3, 224, 224), device=device) for _ in range(10)]

model_c = torch.compile(model)


def fwd_bwd(inp):
    out = model_c(inp)
    out.sum().backward()


def warmup_compile():
    def fn(x):
        return x.sin().relu()

    x = torch.rand((2, 2), device=device, requires_grad=True)
    fn_c = torch.compile(fn)
    out = fn_c(x)
    out.sum().backward()


with torch.profiler.profile() as prof:
    with torch.profiler.record_function("warmup compile"):
        warmup_compile()

    with torch.profiler.record_function("resnet18 compile"):
        fwd_bwd(inputs[0])

prof.export_chrome_trace("trace_compile.json")

Although there are logging tools for identifying graph breaks, the profiler provides a quick visual method of identifying graph breaks. There are two profiler events to look for: Torch-Compiled Region and CompiledFunction.

### example of break

In [3]:
import torch
import torch._dynamo

# user can switch between cuda and xpu
device = "cuda"


class ModelWithBreaks(torch.nn.Module):
    def __init__(self):
        super().__init__()

        def create_sequential():
            return torch.nn.Sequential(
                torch.nn.Linear(128, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 128),
                torch.nn.ReLU(),
            )

        self.mod1 = create_sequential()
        self.mod2 = create_sequential()
        self.mod3 = create_sequential()
        self.mod4 = create_sequential()

    def forward(self, inp):
        mod1 = self.mod1(inp)
        torch._dynamo.graph_break()
        mod2 = self.mod2(mod1)
        torch._dynamo.graph_break()
        mod3 = self.mod3(mod2)
        torch._dynamo.graph_break()
        mod4 = self.mod4(mod3)
        return mod4


model = ModelWithBreaks().to(device)
inputs = [torch.randn((128, 128), device=device) for _ in range(10)]

model_c = torch.compile(model)


def fwd_bwd(inp):
    out = model_c(inp)
    out.sum().backward()


# warm up
fwd_bwd(inputs[0])

with torch.profiler.profile() as prof:
    for i in range(1, 4):
        fwd_bwd(inputs[i])
        prof.step()

prof.export_chrome_trace("trace_break.json")



![alt text](image.png)

### Launch overhead

![alt text](image-1.png)