In [2]:
import torch
import torch._dynamo
from torch import nn

In [6]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32, 64)

    def forward(self, x):
        x = self.fc1(x)
        x = x.luis_add(x)
        x = torch.nn.functional.gelu(x)
        return x

model = MLP()

batch_size = 8
input = torch.randn(batch_size, 32)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
print(torch.__version__)

out = model(input)

print(out)

2.1.0a0+gitd70ddad
tensor([[ 6.1200e-02,  6.1260e-01,  4.9810e-01, -8.8650e-02, -1.2920e-01,
          2.2570e-02,  2.9933e-01,  1.0666e-01,  3.2522e-01,  9.4256e-01,
          3.9640e-01,  1.2687e+00, -2.3221e-02,  1.3839e-01, -9.4085e-02,
         -4.0286e-02, -1.5851e-01,  4.7221e-01,  4.5965e-01, -1.6994e-01,
          6.4940e-01,  1.3770e+00, -4.4846e-02, -1.1805e-01,  4.3350e-01,
          1.2265e-01, -5.2011e-02, -8.6966e-02,  8.9819e-01,  1.8202e+00,
          4.9735e-01, -3.1347e-02, -1.4244e-01, -7.9711e-02,  2.9998e-01,
         -1.1166e-01, -1.4799e-01,  2.9165e-02,  1.0200e+00,  3.6100e-01,
         -1.2993e-01,  1.2405e+00, -9.4801e-02,  4.6645e-01, -1.5835e-01,
          1.7928e+00,  3.1570e+00, -8.8806e-02,  3.2276e-02, -8.4513e-02,
          2.2140e+00,  2.1544e-02, -4.7591e-02,  1.8657e+00,  2.5779e-01,
         -8.6051e-02,  2.8156e-01,  1.0803e-01, -4.0472e-02,  6.8307e-01,
          3.2200e+00,  7.2199e-01, -1.2213e-01, -8.5889e-04],
        [-1.5169e-01,  2.2543e+

### Invoke `torch.compile` produces a fx graph in Torch IR

In [9]:
def toy_backend(gm, sample_inputs):
    print("Dynamo produced a fx Graph in Torch IR:")
    gm.print_readable()

    print("Notice that sample_inputs is a list of flattened FakeTensor:")
    print(sample_inputs)
    return gm.forward

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

Dynamo produced a fx Graph in Torch IR:
class GraphModule(torch.nn.Module):
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_
        
        # File: /tmp/ipykernel_7672/678154743.py:7, code: x = self.fc1(x)
        l__self___fc1 = self.L__self___fc1(l_x_);  l_x_ = None
        
        # File: /tmp/ipykernel_7672/678154743.py:8, code: x = x.luis_add(x)
        luis_add = l__self___fc1.luis_add(l__self___fc1);  l__self___fc1 = None
        
        # File: /tmp/ipykernel_7672/678154743.py:9, code: x = torch.nn.functional.gelu(x)
        gelu = torch._C._nn.gelu(luis_add);  luis_add = None
        return (gelu,)
        
Notice that sample_inputs is a list of flattened FakeTensor:
[tensor([[-0.1829,  0.5798, -0.9840,  0.8186, -1.6522,  0.4352, -0.6053,  0.1342,
          0.8324,  0.3900, -1.8516,  0.2012,  0.3902, -1.2492,  1.0215,  1.0482,
          2.3499, -1.7583,  0.9609,  0.8368,  1.7870,  0.6308,  0.9022, -0.9496,
         -0.4872,  1.0258,  1.4198,  0.0269,  0.4601

## Invoke AOTAutograd, produces forward + backward FX graph in Aten IR
* Captures forward + backwards
* Lowering from Torch IR to Aten/Prims IR

### Core Aten IR (https://pytorch.org/docs/master/ir.html#core-aten-ir)

* A strict subset of aten operators (< 250) after decompositions
* Purely functional (no inputs mutations）
* Guaranteed metadata information, e.g. dtype and shape propagation

In [10]:
import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified

def toy_backend(gm, sample_inputs): 
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("AOTAutograd produced a fx Graph in Aten IR:")
        gm.print_readable()
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

AOTAutograd produced a fx Graph in Aten IR:
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[8, 32]):
        # File: /tmp/ipykernel_7672/678154743.py:7, code: x = self.fc1(x)
        t: f32[32, 64] = torch.ops.aten.t.default(primals_1);  primals_1 = None
        addmm: f32[8, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, t);  primals_2 = t = None
        
        # File: /tmp/ipykernel_7672/678154743.py:8, code: x = x.luis_add(x)
        luis_add: f32[8, 64] = torch.ops.aten.luis_add.Tensor(addmm, addmm);  addmm = None
        
        # File: /tmp/ipykernel_7672/678154743.py:9, code: x = torch.nn.functional.gelu(x)
        gelu: f32[8, 64] = torch.ops.aten.gelu.default(luis_add)
        return [gelu, primals_3, luis_add]
        




In [11]:
from torch._inductor.decomposition import decompositions as default_decompositions

decompositions = default_decompositions.copy()

def toy_backend(gm, sample_inputs):
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("Decomposed fx Graph in Aten IR:")
        gm.print_readable()
        return gm

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        decompositions=decompositions,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
cmodel = torch.compile(model, backend=toy_backend, dynamic=True)

# triggers compilation of forward graph on the first run
out = cmodel(input)

Decomposed fx Graph in Aten IR:
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[64, 32], primals_2: f32[64], primals_3: f32[8, 32]):
        # File: /tmp/ipykernel_7672/678154743.py:7, code: x = self.fc1(x)
        permute: f32[32, 64] = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: f32[8, 64] = torch.ops.aten.addmm.default(primals_2, primals_3, permute);  primals_2 = permute = None
        
        # File: /tmp/ipykernel_7672/678154743.py:8, code: x = x.luis_add(x)
        luis_add: f32[8, 64] = torch.ops.aten.luis_add.Tensor(addmm, addmm);  addmm = None
        
        # File: /tmp/ipykernel_7672/678154743.py:9, code: x = torch.nn.functional.gelu(x)
        mul: f32[8, 64] = torch.ops.aten.mul.Tensor(luis_add, 0.5)
        mul_1: f32[8, 64] = torch.ops.aten.mul.Tensor(luis_add, 0.7071067811865476)
        erf: f32[8, 64] = torch.ops.aten.erf.default(mul_1);  mul_1 = None
        add: f32[8, 64] = torch.ops.aten.add.Tenso



### Prims IR (https://pytorch.org/docs/master/ir.html#prims-ir)

* Explicit type promotion and broadcasting
* prims.convert_element_type
* prims.broadcast_in_dim
* For backends with powerful compiler that can reclaim the performance by fusion, e.g. nvFuser

In [15]:
prims_decomp = torch._decomp.get_decompositions([
    torch.ops.aten.luis_add,
    torch.ops.aten.expand.default,
])

def fn(a, b):
    return a.luis_add(b)

def toy_backend(gm, sample_inputs):
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("Further decomposed fx Graph in Prims IR:")
        gm.print_readable()
        return gm

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        decompositions=prims_decomp,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
fn = torch.compile(backend=toy_backend)(fn)
out = fn(torch.rand(3, dtype=torch.float), torch.rand(3, 3, dtype=torch.half))

Further decomposed fx Graph in Prims IR:
class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: f32[3], arg1_1: f16[3, 3]):
        # File: /tmp/ipykernel_2581/2846203934.py:7, code: return a.luis_add(b)
        _to_copy: f32[3, 3] = torch.ops.aten._to_copy.default(arg1_1, dtype = torch.float32);  arg1_1 = None
        broadcast_in_dim: f32[3, 3] = torch.ops.prims.broadcast_in_dim.default(arg0_1, [3, 3], [1]);  arg0_1 = None
        luis_add: f32[3, 3] = torch.ops.prims.luis_add.default(broadcast_in_dim, _to_copy);  broadcast_in_dim = _to_copy = None
        return (luis_add,)
        


In [31]:
from typing import List
import torch
from torch._dynamo.backends.registry import _BACKENDS as BACKENDS
from torch._decomp import core_aten_decompositions
from torch._functorch.aot_autograd import aot_module_simplified

print(torch._dynamo.list_backends())

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    try:
        trt_compiled = BACKENDS['cudagraphs'](gm, example_inputs) #BACKENDS["dynamo_minifier_backend"](gm, example_inputs, "onnxrt")
        if trt_compiled is not None:
            return trt_compiled
    except Exception as e:
        print(f'Failed to compile with backend=cudagraphs, err={str(e)}')
    pass
    
     # first backend failed, try something else...
    try:      
        inductor_compiled = BACKENDS['inductor'](gm, example_inputs) #BACKENDS["dynamo_minifier_backend"](gm, example_inputs, "inductor")
        if inductor_compiled is not None:
            return inductor_compiled
    except Exception as e:
        print(f'Failed to compile with backend=inductor, err={str(e)}')
        pass
    
    gm.print_readable()
    return gm.forward

decompositions = core_aten_decompositions()
decompositions.update(
    torch._decomp.get_decompositions([
        torch.ops.aten.addmm,
    ])
)

def toy_backend_2(gm, sample_inputs):
    def my_compiler_2(gm, sample_inputs):
        try:
            trt_compiled = BACKENDS['onnxrt'](gm, sample_inputs) #BACKENDS["dynamo_minifier_backend"](gm, sample_inputs, "onnxrt")
            if trt_compiled is not None:
                print('Going with onnxrt...')
                return trt_compiled
        except Exception as e:
            print(f'Failed to compile with backend=onnxrt, err={str(e)}')
        pass

         # first backend failed, try something else...
        try:      
            inductor_compiled = BACKENDS['inductor'](gm, sample_inputs)  #BACKENDS["dynamo_minifier_backend"](gm, sample_inputs, "inductor")
            if inductor_compiled is not None:
                print('Going with inductor...')
                return inductor_compiled
        except Exception as e:
            print(f'Failed to compile with backend=inductor, err={str(e)}')
            pass

        gm.print_readable()
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        decompositions=decompositions,
        fw_compiler=my_compiler_2
    )


['aot_ts_nvfuser', 'cudagraphs', 'inductor', 'ipex', 'nvprims_nvfuser', 'onnxrt', 'tvm']


In [32]:
torch._dynamo.reset()


@torch.compile(backend=toy_backend_2)
#@torch.compile(backend=my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

a = torch.randn(10)
for _ in range(100):
    a += toy_example(torch.randn(10), torch.randn(10))

print(a)

Failed to compile with backend=onnxrt, err=
attribute lookup is not defined on builtin:
  File "<eval_with_key>.726", line 5
def forward(self, arg0_1, arg1_1):
    abs_1 = torch.ops.aten.abs.default(arg0_1)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    add = torch.ops.aten.add.Tensor(abs_1, 1);  abs_1 = None
    div = torch.ops.aten.div.Tensor(arg0_1, add);  arg0_1 = add = None

Going with inductor...
Failed to compile with backend=onnxrt, err=
attribute lookup is not defined on builtin:
  File "<eval_with_key>.733", line 5
def forward(self, arg0_1, arg1_1):
    mul = torch.ops.aten.mul.Tensor(arg1_1, arg0_1);  arg1_1 = arg0_1 = None
          ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return (mul,)

Going with inductor...
Failed to compile with backend=onnxrt, err=
attribute lookup is not defined on builtin:
  File "<eval_with_key>.740", line 5
def forward(self, arg0_1, arg1_1):
    mul = torch.ops.aten.mul.Tensor(arg0_1, -1);  arg0_1 = None
          ~~~~~~~~~~~~~~~~~~~~~~~~~ <--

