# Let's try coding custom backend for PyTorch Compile

In [2]:
from typing import List
import torch
from torchvision.models import densenet121

device = None

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward


# Reset since we are using a different backend.
torch._dynamo.reset()


def init_model():
    return densenet121().to(torch.float32).to(device)


# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).to(device),
        torch.randint(1000, (b,)).to(device),
    )


opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])

Using device: mps
custom backend called with FX graph:
opcode         name                                                                                                         target                                                                                                       args                                                                                                                                                                                                                                                                                                                                                                                                                                                     kwargs
-------------  -----------------------------------------------------------------------------------------------------------  -----------------------------------------------------------------------------------------------------------  -------------------------------

tensor([[-0.1761,  0.2459, -1.0244,  ...,  0.0164,  0.2590,  0.0744],
        [-0.2247,  0.2604, -0.9793,  ...,  0.0730,  0.1597, -0.1108],
        [-0.0173,  0.2761, -0.7957,  ...,  0.0842,  0.2744, -0.0740],
        ...,
        [-0.2606,  0.2906, -0.8976,  ...,  0.1672,  0.1921, -0.1895],
        [-0.3237,  0.1724, -0.9114,  ..., -0.0017,  0.2648, -0.0521],
        [-0.2970,  0.2771, -0.9087,  ...,  0.1074,  0.1609,  0.0564]],
       device='mps:0', grad_fn=<LinearBackward0>)

In [3]:
def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b


opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

custom backend called with FX graph:
opcode         name    target                                               args         kwargs
-------------  ------  ---------------------------------------------------  -----------  --------
placeholder    l_a_    L_a_                                                 ()           {}
placeholder    l_b_    L_b_                                                 ()           {}
call_function  abs_1   <built-in method abs of type object at 0x124e20088>  (l_a_,)      {}
call_function  add     <built-in function add>                              (abs_1, 1)   {}
call_function  x       <built-in function truediv>                          (l_a_, add)  {}
call_method    sum_1   sum                                                  (l_b_,)      {}
call_function  lt      <built-in function lt>                               (sum_1, 0)   {}
output         output  output                                               ((x, lt),)   {}
custom backend called with FX gra

tensor([-0.3239, -0.0289,  0.1068, -0.1800,  0.0932,  0.0667,  0.7168,  0.6543,
         0.2256, -0.7977])

---

The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above):

1. x = a / (torch.abs(a) + 1)

2. b = b * -1; return x * b

3. return x * b

When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph.

Let’s investigate by example how TorchDynamo would step through bar. If b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if not b.sum() < 0, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using torch._dynamo.explain:

In [4]:
# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)

Graph Count: 2
Graph Break Count: 1
Op Count: 5
Break Reasons:
  Break Reason 1:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file /var/folders/62/sqk699ld69v13ny7hdx8n4980000gn/T/ipykernel_86204/966290996.py, line 3 in bar>
Ops per Graph:
  Ops 1:
    <built-in method abs of type object at 0x124e20088>
    <built-in function add>
    <built-in function truediv>
    <built-in function lt>
  Ops 2:
    <built-in function mul>
Out Guards:
  Guard 1:
    Name: "G['torch']"
    Source: global
    Create Function: FUNCTION_MATCH
    Guard Types: ['ID_MATCH']
    Code List: ["___check_obj_id(G['torch'], 4356027104)"]
    Object Weakref: None
    Guarded Class Weakref: <weakref at 0x100a8fbf0; to 'type' at 0x10083a488 (module)>
  Guard 2:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3

In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using `fullgraph=True`:

In [6]:
opt_bar = torch.compile(bar, fullgraph=True)

opt_bar(torch.randn(10), torch.randn(10))


UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/var/folders/62/sqk699ld69v13ny7hdx8n4980000gn/T/ipykernel_86204/966290996.py", line 3, in bar
    if b.sum() < 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups.

In [11]:
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(10, 2)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))


opt_model = torch.compile(MyModule(), fullgraph=True)
inp = torch.randn(8, 10)

print(opt_model(inp))

tensor([[0.0969, 0.0000],
        [0.1884, 0.0357],
        [1.3537, 0.0000],
        [0.0000, 0.5696],
        [0.0000, 0.4626],
        [0.0000, 0.0000],
        [0.0000, 0.2022],
        [0.0000, 0.1956]], grad_fn=<CompiledFunctionBackward>)


---

## Torch.Export

- We can use torch.export (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program.
- The exported graph is intended to be **run on different (i.e. Python-less) environments**. 
- One important restriction is that the **torch.export does not support graph breaks**.