In [1]:
from typing import TYPE_CHECKING

import torch
from functorch.compile import aot_function, aot_module
from torch import Tensor, nn
from torch._subclasses import FakeTensorMode

In [2]:
# Display all cell outputs
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all" # type: ignore  # noqa: PGH003

In [3]:
dtype = torch.float32

In [4]:
class Layer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weights = nn.Parameter(torch.tensor([[-1, 1], [-2, 2]], dtype=dtype))

    def forward(self, x:Tensor) -> Tensor:
        return torch.matmul(x, self.weights)

    if TYPE_CHECKING:
        def __call__(self, x:Tensor) -> Tensor: ...

In [5]:
data = torch.tensor([[1, 2]], dtype=dtype)
target = torch.tensor([[0, 1]], dtype=dtype)

model = Layer()

In [6]:
output = model(data)
loss = (output - target).sum()
print(f"loss: {loss}")

loss.backward()

print(output)
print(model.weights.data)
print(model.weights.grad)

loss: -1.0
tensor([[-5.,  5.]], grad_fn=<MmBackward0>)
tensor([[-1.,  1.],
        [-2.,  2.]])
tensor([[1., 1.],
        [2., 2.]])


In [7]:
# The compiler_fn is called after the forward and backward graphs are extracted.
# Here, we just print the code in the compiler_fn. Return of this function is a callable.
def compiler_fn(fx_module: torch.fx.GraphModule, _):
    print(fx_module.code)
    return fx_module

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
#aot_print_fn = aot_function(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

# Run the aot_print_fn once to trigger the compilation and print the graphs
cloned_inputs = data.clone().detach().requires_grad_(True)
#with FakeTensorMode(allow_non_fake_inputs=True):
res = aot_print_fn(cloned_inputs)
res.sum().backward()




def forward(self, primals_1, primals_2):
    mm = torch.ops.aten.mm.default(primals_2, primals_1)
    return [mm, primals_1, primals_2]
    



def forward(self, primals_1, primals_2, tangents_1):
    t = torch.ops.aten.t.default(primals_2);  primals_2 = None
    mm_1 = torch.ops.aten.mm.default(t, tangents_1);  t = None
    t_1 = torch.ops.aten.t.default(primals_1);  primals_1 = None
    mm_2 = torch.ops.aten.mm.default(tangents_1, t_1);  tangents_1 = t_1 = None
    return [mm_1, mm_2]
    


