# TorchDynamo Demo

This notebook shows examples of how to use TorchDynamo with Torch-MLIR+IREE.

In [1]:
import torchdynamo
import torch
from utils import make_torch_mlir_compiler
import torch_mlir
import iree_torch

In [2]:
import warnings, logging
warnings.simplefilter("ignore")
torchdynamo.config.log_level = logging.ERROR

# Current Steps for Compiling in Torch-MLIR

In order to see some of the benefits one gets with TorchDynamo, we will first take a look at the current compilation process that users of Torch-MLIR have to go through to compile PyTorch code. We will then look at some limitations of this approach that TorchDynamo allows us to overcome.

To compile PyTorch code using Torch-MLIR+IREE, you must:

- Create a `torch.nn.Module`
- Ensure that module is scriptable or traceable (could require code changes)
- Compile module using `torch_mlir.compile` + `iree_torch.compile_to_vmfb`

In [3]:
def compile_and_load(module: torch.nn.Module, example_inputs):
    linalg_module = torch_mlir.compile(module, example_inputs, 
                                       output_type="linalg-on-tensors")
    compiled_module = iree_torch.compile_to_vmfb(linalg_module)
    return iree_torch.load_vmfb(compiled_module)

In [4]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, t):
        return 2 * t

In [5]:
myModule = MyModule()
compiled_module = compile_and_load(myModule, torch.ones((2, 3)))
compiled_module.forward(torch.ones((2, 3)))

tensor([[2., 2., 2.],
        [2., 2., 2.]])

## Limitations of Current Approach

- Input to `torch_mlir.compile` must be a `torch.nn.Module`
- Module must be scriptable or traceable
- Torch-MLIR is expected to support all of TorchScript (loops, control flow, etc)

Note: Both TorchScript and Torch-MLIR support single function workloads, but it requires a [different path in Torch-MLIR](https://github.com/llvm/torch-mlir/blob/4d47f1671a6020ed43af6e71631e932ac56b1f46/lib/Dialect/Torch/Transforms/Passes.cpp#L30) and the Python API currently does not support it.

# Steps for Compiling with TorchDynamo

- Add `torchdynamo.optimize` decorator (not limited to `torch.nn.Module`s)

In [6]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(t):
    return 2 * t

In [7]:
foo(torch.ones((2, 3)))

Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %mul : [#users=1] = call_function[target=torch.ops.aten.mul](args = (%arg0_1, 2), kwargs = {})
    return mul


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
    %int2 = torch.constant.int 2
    %0 = torch.aten.mul.Scalar %arg0, %int2 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    return %0 : !torch.vtensor<[2,3],f32>
  }
}



tensor([[2., 2., 2.],
        [2., 2., 2.]])

The verbose output above shows what is happening inside the Torch-MLIR compiler passed to TorchDynamo. TorchDynamo feeds to the Torch-MLIR compiler the [`torch.fx.GraphModule`](https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule) representing the computation performed in `foo`. This graph is then imported into Torch-MLIR, where several simplification passes are performed to reach the [backend contract](https://github.com/llvm/torch-mlir/blob/main/docs/architecture.md#the-backend-contract), producing the second graph. The MLIR graph is then compiled further.

Because TorchDynamo remembers when a computation has been compiled, running `foo` again will no longer result in the verbose compilation output.

In [8]:
foo(torch.ones((2, 3)))

tensor([[2., 2., 2.],
        [2., 2., 2.]])

## Graph Breaks

TorchDynamo will automatically insert graph breaks to separate code that is expected to run on the backend from code that is expected to run at the Python level. This means Torch-MLIR does not have to worry about things like print statements or data dependent control flow.

### Graph Break Example: Print statements

In [9]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(a, b):
    print("Hello!")
    return a + b

In [10]:
foo(torch.ones((2, 3)), torch.ones((2, 3)))

Hello!
Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %add : [#users=1] = call_function[target=torch.ops.aten.add](args = (%arg0_1, %arg1_1), kwargs = {})
    return add


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
    %int1 = torch.constant.int 1
    %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    return %0 : !torch.vtensor<[2,3],f32>
  }
}



tensor([[2., 2., 2.],
        [2., 2., 2.]])

Note the lack of a print statement in the compiled graph thanks to the graph break. This computation does not work in the current compilation flow from Torch-MLIR.

In [11]:
class PrintStatementModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, a, b):
        print("Hello!")
        return a + b
    
printStatementModule = PrintStatementModule()
compiled_module = compile_and_load(printStatementModule, [torch.ones((2, 3)), torch.ones((2, 3))])
compiled_module.forward(torch.ones((2, 3)), torch.ones((2, 3)))

TorchMlirCompilerError: Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.str'
note: see current operation: %0 = "torch.constant.str"() {value = "Hello!"} : () -> !torch.str
error: Module does not conform to the linalg-on-tensors backend contract. See dialect conversion legality information above.


Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torch-backend-to-linalg-on-tensors-backend-pipeline' /tmp/PrintStatementModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.


### Graph Break Example: Control flow

In [12]:
torchdynamo.reset()
@torchdynamo.optimize(make_torch_mlir_compiler(use_tracing=False, device="cpu", verbose=True))
def foo(a, b):
    x = a / 2
    if b.max() < 0:
        b = b * -1
    return x * b

In [13]:
foo(torch.ones((2, 3)), torch.ones((2, 3)))

Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %div : [#users=1] = call_function[target=torch.ops.aten.div](args = (%arg0_1, 2), kwargs = {})
    %max_1 : [#users=1] = call_function[target=torch.ops.aten.max](args = (%arg1_1,), kwargs = {})
    %lt : [#users=1] = call_function[target=torch.ops.aten.lt](args = (%max_1, 0), kwargs = {})
    return (div, lt)


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[],i1>) {
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %0 = torch.aten.div.Scalar %arg0, %int2 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    %1 = torch.aten.max %arg1 : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[],f32>
    %2 = torch.aten.lt.Scalar 

tensor([[0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000]])

Notice that because the body of the `if` statement is not used, it is not compiled. If we change the inputs so that the body of the `if` statement is needed, then TorchDynamo will reuse the compiled graphs for the other parts of the function and only compile the body of the `if` statement.

In [14]:
foo(torch.ones((2, 3)), -torch.ones((2, 3)))

Compiling graph...
torch.fx graph:
graph():
    %arg0_1 : [#users=1] = placeholder[target=arg0_1]
    %arg1_1 : [#users=1] = placeholder[target=arg1_1]
    %mul : [#users=1] = call_function[target=torch.ops.aten.mul](args = (%arg0_1, -1), kwargs = {})
    %mul_1 : [#users=1] = call_function[target=torch.ops.aten.mul](args = (%arg1_1, %mul), kwargs = {})
    return mul_1


torch-mlir backend contract graph:
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> {
    %int-1 = torch.constant.int -1
    %0 = torch.aten.mul.Scalar %arg0, %int-1 : !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
    %1 = torch.aten.mul.Tensor %arg1, %0 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32>
    return %1 : !torch.vtensor<[2,3],f32>
  }
}



tensor([[0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000]])

This type of control flow would not work in Torch-MLIR using the current compilation flow.

In [15]:
class ControlFlowModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, a, b):
        x = a / 2
        if b.max() < 0:
            b = b * -1
        return x * b
    
controlFlowModule = ControlFlowModule()
compiled_module = compile_and_load(controlFlowModule, [torch.ones((2, 3)), torch.ones((2, 3))])
compiled_module.forward(torch.ones((2, 3)), torch.ones((2, 3)))

TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: 'torch.copy.to_vtensor' op failed to verify that operand is corresponding !torch.tensor
note: see current operation: %22 = "torch.copy.to_vtensor"(%21) : (!torch.tensor<*,f32>) -> !torch.vtensor


Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints}' /tmp/ControlFlowModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
