In [1]:
import torch
import torch._dynamo
from torch import nn

In [2]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32, 64)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.nn.functional.gelu(x)
        return x

model = MLP()

batch_size = 8
input = torch.randn(batch_size, 32)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

### Invoke `torch.compile` produces a fx graph in Torch IR

In [4]:
def toy_backend(gm, sample_inputs):
    print("Dynamo produced a fx Graph in Torch IR:")
    gm.print_readable()

    print("Notice that sample_inputs is a list of flattened FakeTensor:")
    print(sample_inputs)
    return gm.forward

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

Dynamo produced a fx Graph in Torch IR:
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)
        self_fc1 = self.self_fc1(x);  x = None
        
        # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)
        gelu = torch._C._nn.gelu(self_fc1);  self_fc1 = None
        return (gelu,)
        
Notice that sample_inputs is a list of flattened FakeTensor:
[FakeTensor(FakeTensor(..., device='meta', size=(s0, 32)), cpu)]


## Invoke AOTAutograd, produces forward + backward FX graph in Aten IR
* Captures forward + backwards
* Lowering from Torch IR to Aten/Prims IR

### Core Aten IR (https://pytorch.org/docs/master/ir.html#core-aten-ir)

* A strict subset of aten operators (< 250) after decompositions
* Purely functional (no inputs mutations）
* Guaranteed metadata information, e.g. dtype and shape propagation

In [5]:
import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified

def toy_backend(gm, sample_inputs): 
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("AOTAutograd produced a fx Graph in Aten IR:")
        gm.print_readable()
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

AOTAutograd produced a fx Graph in Aten IR:
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[s0, 32]):
        # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)
        t: f32[32, 64] = torch.ops.aten.t.default(primals_1);  primals_1 = None
        addmm: f32[s0, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, t);  primals_2 = t = None
        
        # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)
        gelu: f32[s0, 64] = torch.ops.aten.gelu.default(addmm)
        return [gelu, addmm, primals_3]
        




In [6]:
from torch._inductor.decomposition import decompositions as default_decompositions

decompositions = default_decompositions.copy()

def toy_backend(gm, sample_inputs):
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("Decomposed fx Graph in Aten IR:")
        gm.print_readable()
        return gm

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        decompositions=decompositions,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

Decomposed fx Graph in Aten IR:
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[s0, 32]):
        # File: /tmp/ipykernel_8670/1490842273.py:7, code: x = self.fc1(x)
        permute: f32[32, 64] = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: f32[s0, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, permute);  primals_2 = permute = None
        
        # File: /tmp/ipykernel_8670/1490842273.py:8, code: x = torch.nn.functional.gelu(x)
        mul: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.5)
        mul_1: f32[s0, 64] = torch.ops.aten.mul.Tensor(addmm, 0.7071067811865476)
        erf: f32[s0, 64] = torch.ops.aten.erf.default(mul_1);  mul_1 = None
        add: f32[s0, 64] = torch.ops.aten.add.Tensor(erf, 1);  erf = None
        mul_2: f32[s0, 64] = torch.ops.aten.mul.Tensor(mul, add);  mul = add = None
        return [mul_2, addmm, primals_3]
        




### Prims IR (https://pytorch.org/docs/master/ir.html#prims-ir)

* Explicit type promotion and broadcasting
* prims.convert_element_type
* prims.broadcast_in_dim
* For backends with powerful compiler that can reclaim the performance by fusion, e.g. nvFuser

In [7]:
prims_decomp = torch._decomp.get_decompositions([
    torch.ops.aten.add,
    torch.ops.aten.expand.default,
])

def fn(a, b):
    return a + b

def toy_backend(gm, sample_inputs):
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("Further decomposed fx Graph in Prims IR:")
        gm.print_readable()
        return gm

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        decompositions=prims_decomp,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
fn = torch.compile(backend=toy_backend)(fn)
out = fn(torch.rand(3, dtype=torch.float), torch.rand(3, 3, dtype=torch.half))

Further decomposed fx Graph in Prims IR:
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[3], arg1_1: f16[3, 3]):
        # File: /tmp/ipykernel_8670/2178452752.py:7, code: return a + b
        _to_copy: f32[3, 3] = torch.ops.aten._to_copy.default(arg1_1, dtype = torch.float32);  arg1_1 = None
        broadcast_in_dim: f32[3, 3] = torch.ops.prims.broadcast_in_dim.default(arg0_1, [3, 3], [1]);  arg0_1 = None
        add: f32[3, 3] = torch.ops.prims.add.default(broadcast_in_dim, _to_copy);  broadcast_in_dim = _to_copy = None
        return (add,)
        
