In [1]:
import torchdynamo
import torch
from utils import make_torch_mlir_compiler
import torch_mlir
import iree_torch

In [2]:
import warnings, logging
warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR

# Current Steps for Compiling in Torch-MLIR

- Create a `torch.nn.Module`
- Ensure that module is scriptable or traceable (could require code changes)
- Compile module using `torch_mlir.compile` + `iree_torch.compile_to_vmfb`

In [3]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, t):
        return 2 * t

In [4]:
example_input = torch.rand((2, 3))
example_input

tensor([[0.5207, 0.9647, 0.7091],
        [0.8046, 0.4800, 0.7829]])

In [5]:
myModule = MyModule()
linalg_module = torch_mlir.compile(myModule, example_input, 
                                   output_type="linalg-on-tensors")
compiled_module = iree_torch.compile_to_vmfb(linalg_module)
loaded_module = iree_torch.load_vmfb(compiled_module)
loaded_module.forward(example_input)

tensor([[1.0413, 1.9294, 1.4181],
        [1.6091, 0.9600, 1.5658]])

## Limitations

- Input to `torch_mlir.compile` must be a `torch.nn.Module`
- Module must be scriptable or traceable
- Torch-MLIR is expected to support all of TorchScript (loops, control flow, etc)

Note: TorchScript and Torch-MLIR do support single function workloads, but it requires a different path in Torch-MLIR and the API currently does not support it

# Steps using TorchDynamo

- Add `torchdynamo.optimize` decorator (not limited to `torch.nn.Module`s)

In [6]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(t):
    return 2 * t

In [7]:
foo(example_input)

Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %mul : [#users=1] = call_function[target=torch.ops.aten.mul](args = (%arg0_1, 2), kwargs = {})
    return mul


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
    %int2 = torch.constant.int 2
    %0 = torch.aten.mul.Scalar %arg0, %int2 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    return %0 : !torch.vtensor<[2,3],f32>
  }
}



tensor([[1.0413, 1.9294, 1.4181],
        [1.6091, 0.9600, 1.5658]])

## Graph Breaks

Graph breaks allow running modules and functions with a mix of code expected to run on the backend and code expected to run at the Python level. This means Torch-MLIR does not have to worry about things like data dependent control flow.

### Example 1: Print statements

In [8]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(a, b):
    print("Hello!")
    return a + b

In [9]:
foo(torch.rand((2, 3)), torch.rand((2, 3)))

Hello!
Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %add : [#users=1] = call_function[target=torch.ops.aten.add](args = (%arg0_1, %arg1_1), kwargs = {})
    return add


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
    %int1 = torch.constant.int 1
    %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    return %0 : !torch.vtensor<[2,3],f32>
  }
}



tensor([[1.5328, 1.4332, 1.1009],
        [1.0809, 0.7956, 0.9636]])

This would not work in the current compilation flow from Torch-MLIR.

### Example 2: Control flow

In [10]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(a, b):
    x = a / (a + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [11]:
foo(torch.rand((2, 3)), -torch.rand((2, 3)))

Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=2] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %add : [#users=1] = call_function[target=torch.ops.aten.add](args = (%arg0_1, 1), kwargs = {})
    %div : [#users=1] = call_function[target=torch.ops.aten.div](args = (%arg0_1, %add), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.ops.aten.sum](args = (%arg1_1,), kwargs = {})
    %lt : [#users=1] = call_function[target=torch.ops.aten.lt](args = (%sum_1, 0), kwargs = {})
    return (div, lt)


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[],i1>) {
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %none = torch.constant.none
    %0 = torch.aten.add.Scalar %arg0, %int1, %int1 : !torch.vtensor<[2,3],f32>, !torch.int, !

tensor([[0.0182, 0.2515, 0.0553],
        [0.0088, 0.2397, 0.0663]])

# Using `torch.fx` to Handle Torch-MLIR Limitations

- Returning a single element tuple vs. returning a single tensor
- Functionalizing in-place reshapes
- Decomposing complex ops at the Python level with a few lines of code