In [1]:
# standard imports
import torch
from shark.iree_utils import get_iree_compiled_module

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# torch dynamo related imports
import torchdynamo
from torchdynamo.optimizations.backends import create_backend
from torchdynamo.optimizations.subgraph import SubGraph

# torch-mlir imports for compiling
from torch_mlir import compile, OutputType

[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends.

In [3]:
def toy_example(*args):
    a, b = args

    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [4]:
# compiler that lowers fx_graph to through MLIR
def __torch_mlir(fx_graph, *args, **kwargs):
    assert isinstance(
        fx_graph, torch.fx.GraphModule
    ), "Model must be an FX GraphModule."

    def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):
        """Replace tuple with tuple element in functions that return one-element tuples."""

        for node in fx_g.graph.nodes:
            if node.op == "output":
                assert len(node.args) == 1, "Output node must have a single argument"
                node_arg = node.args[0]
                if isinstance(node_arg, tuple) and len(node_arg) == 1:
                    node.args = (node_arg[0],)
        fx_g.graph.lint()
        fx_g.recompile()
        return fx_g

    fx_graph = _unwrap_single_tuple_return(fx_graph)
    ts_graph = torch.jit.script(fx_graph)

    # torchdynamo does munges the args differently depending on whether you use
    # the @torchdynamo.optimize decorator or the context manager
    if isinstance(args, tuple):
        args = list(args)
    assert isinstance(args, list)
    if len(args) == 1 and isinstance(args[0], list):
        args = args[0]

    linalg_module = compile(ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS)
    callable, _ = get_iree_compiled_module(linalg_module, "cuda", func_name="forward")

    def forward(*inputs):
        return callable(*inputs)

    return forward

Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:

In [5]:
with torchdynamo.optimize(__torch_mlir):
    for _ in range(10):
        print(toy_example(torch.randn(10), torch.randn(10)))

Found 1 device(s).
Device: 0
  Name: NVIDIA GeForce RTX 3080
  Compute Capability: 8.6
[ 0.40569967 -0.09795299  0.15944454 -0.11522183  0.13940483  0.6483943
  0.04897427  0.20021795  0.4110793  -0.01060459]
[-2.36014053e-01  1.02099776e-01 -8.32196176e-02  5.48950136e-01
 -1.22762606e-01  6.83019171e-05 -1.87891126e-01  2.95851409e-01
 -7.94005573e-01 -7.86187351e-02]
[-0.08547013 -0.03790672 -0.67750883  0.07134506  0.48344284 -0.04401336
  0.5358189   0.19252774  0.01672608  0.16548733]
[ 0.16950132 -0.14072983  0.0850194   0.51586574  0.6814878   0.09228899
  0.00628967  0.04618661  0.33402687  0.0672036 ]
[-3.0006915e-01 -5.6649814e-03  1.0971012e-02  6.7839026e-01
  1.4477329e-01  7.1921291e-05 -1.2694100e-01 -1.0598335e-01
  4.5776103e-02 -3.7474141e-02]
[ 0.1161904   0.11104004  0.03108321 -0.01897361 -0.2773486  -0.1210255
 -0.10480757  0.15325065  0.07355037  0.43414077]
[-0.15983443  0.18079512 -0.05479247  0.06110435  0.12209348  0.12046977
  0.20978567 -0.43570745 -0.9095

It can also be used through a decorator:

In [6]:
@create_backend
def torch_mlir(subgraph, *args, **kwargs):
    assert isinstance(subgraph, SubGraph), "Model must be a dynamo SubGraph."
    return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))

@torchdynamo.optimize("torch_mlir")
def toy_example2(*args):
    a, b = args

    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [7]:
for _ in range(10):
    print(toy_example2(torch.randn(10), torch.randn(10)))

Found 1 device(s).
Device: 0
  Name: NVIDIA GeForce RTX 3080
  Compute Capability: 8.6
[-0.18768834 -0.13050991  0.16192573 -0.08606607 -0.7383352   0.21919324
  0.1471572   0.1957912   0.78911537  0.3079384 ]
[-0.6180763   0.3083624   0.30140203 -0.13603576  0.13938724  0.06829462
 -0.5264937   0.47425482  0.631954    0.07812975]
[ 0.02927845 -0.21328457  0.03714288 -0.04158077  0.00811315  0.06475347
  0.16807333 -0.2835701   0.07122174 -0.25891435]
[-0.7808262  -0.00184676 -0.42828155  0.02376047  0.23778288 -0.2332218
  0.35119227 -0.45859754 -0.16244853  0.08230756]
[ 0.65825784 -0.43039966  0.49089798  0.16756855  0.17000133  0.1523097
  0.00477562  0.05351321 -0.16297375  0.42369154]
[-0.35411894  0.34467864  0.19818862 -0.26733887 -0.36235648  1.570275
  0.08005163 -0.00406713 -0.7041876   0.2678179 ]
[-0.8275841   0.01846866  0.27031392  0.07428868 -0.3472132   0.72673005
  0.04936812 -0.0711254  -0.15575752 -0.1963107 ]
[0.2804986  0.2999007  0.59554356 0.00538198 0.06104416 