In [None]:
!pip3 uninstall torch -y
!pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 --force

In [2]:
import torch
from torch._dynamo import optimize
from typing import *
from torch import _dynamo

All the valid backends

In [None]:
_dynamo.list_backends()

['aot_ts_nvfuser',
 'cudagraphs',
 'inductor',
 'ipex',
 'nvprims_nvfuser',
 'onnxrt',
 'tensorrt',
 'tvm']

# optimizer() usage with inductor as backend

A naive example

In [3]:
from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc = nn.Linear(128, 10)

  def forward(self, x):
    x = self.fc(x)
    x = F.relu(x)
    return x

In [4]:
from torch._inductor import config as inductor_config
from torch._dynamo import config as dynamo_config

inductor_config.debug = True
dynamo_config.verbose = True

In [5]:
foo = Net()
foo = torch.compile(foo)

In [6]:
foo

OptimizedModule(
  (_orig_mod): Net(
    (fc): Linear(in_features=128, out_features=10, bias=True)
  )
)

When enable `inductor.debug`, it could dump the python code it codegened.

In [7]:
a = torch.randn((2, 128))

foo(a)

[2023-02-24 02:00:07,038] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 0
DEBUG:filelock:Attempting to acquire lock 140315786009520 on /tmp/torchinductor_root/locks/cpqwqezuqcpct4tbqmlhjjp3hbo6hvqmumb4g7x4z3r6m7zzu7wx.lock
DEBUG:filelock:Lock 140315786009520 acquired on /tmp/torchinductor_root/locks/cpqwqezuqcpct4tbqmlhjjp3hbo6hvqmumb4g7x4z3r6m7zzu7wx.lock
DEBUG:filelock:Attempting to release lock 140315786009520 on /tmp/torchinductor_root/locks/cpqwqezuqcpct4tbqmlhjjp3hbo6hvqmumb4g7x4z3r6m7zzu7wx.lock
DEBUG:filelock:Lock 140315786009520 released on /tmp/torchinductor_root/locks/cpqwqezuqcpct4tbqmlhjjp3hbo6hvqmumb4g7x4z3r6m7zzu7wx.lock
DEBUG:filelock:Attempting to acquire lock 140315784350448 on /tmp/torchinductor_root/locks/cdpfcbmnbo2tiqauemmeoyifop6bs6dplr52qcc4drscewsrxkauciazhaeabewqhtl5n6yhs73v5bs7xy4mhm2p4fkphmc2mhuz7lpa.lock
DEBUG:filelock:Lock 140315784350448 acquired on /tmp/torchinductor_root/locks/cdpfcbmnbo2tiqauemmeoyifop6bs6dplr52qcc4d


from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_root/zt/cztcl2vp5yqlnhofzpqfficjcxgyict6e3xhfdd7sdbkipp4p44x.h"
extern "C" void kernel(float* __restrict__ in_out_ptr0,
                       bool* __restrict__ out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long i0=0; i0<20; i0+=1)
        {
            auto tmp0 = in_out_ptr0[i0];
            auto tmp1 = tmp0 * (tmp0>0);
            auto tmp2 = static_cast<float>(0);
            auto tmp3 = tmp1 <

DEBUG:filelock:Attempting to release lock 140315784350448 on /tmp/torchinductor_root/locks/cdpfcbmnbo2tiqauemmeoyifop6bs6dplr52qcc4drscewsrxkauciazhaeabewqhtl5n6yhs73v5bs7xy4mhm2p4fkphmc2mhuz7lpa.lock
DEBUG:filelock:Lock 140315784350448 released on /tmp/torchinductor_root/locks/cdpfcbmnbo2tiqauemmeoyifop6bs6dplr52qcc4drscewsrxkauciazhaeabewqhtl5n6yhs73v5bs7xy4mhm2p4fkphmc2mhuz7lpa.lock
[2023-02-24 02:00:14,842] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 0


tensor([[0.1216, 0.0000, 0.0000, 0.4786, 0.0000, 0.0000, 0.6883, 0.2427, 0.0000,
         0.0000],
        [0.8512, 0.0000, 0.0000, 0.7956, 0.0000, 0.2226, -0.0000, -0.0000, 0.2839,
         -0.0000]], grad_fn=<CompiledFunctionBackward>)

# Dive into dynamo

According to the definition of `_dynamo.optimize`: 

```python
def optimize(
    backend="inductor",
    *,
    nopython=False,
    guard_export_fn=None,
    guard_fail_fn=None,
    disable=False,
    dynamic=False,
):
```

The `backend` argument could be either a `str` or a `callable`.
Let's hack it with a custom callable to dump something.

In [33]:
my_graph_id = 0
def my_compiler( 
        gm: torch.fx.GraphModule,
        inputs: List[torch.Tensor]):
    global my_graph_id
    print(f"my_compiler() called with FX graph-{my_graph_id}:")
    my_graph_id += 1
    gm.print_readable()
    print()
    #print("tabular:")
    #gm.graph.print_tabular(); print()
    #print(f"code: {gm.graph.python_code()}")
    return gm.forward  # python callable

## Example 1

In [9]:
def foo1(a:torch.tensor, b:torch.tensor):
  x = a + b
  if b.sum() < 0:
    x = x * -1
  return x

foo1_ = optimize(my_compiler)(foo1)

Note that, this kernel contains a `if` `if b.sum() < 0`, since the `b.sum()` is determined by its value(dynamic), so it should break the graph into two cases:

The first, when the condition is true:

```python
x = a + b
x = x * -1
return x
```

The second, when the condition is false:

```python
x = a + b
return x
```

In [16]:
torch._dynamo.reset() # reset all che compilation cache

a = torch.randn((2, 3))
b = torch.randn((2, 3))

# It should tigger both cases of the if-else
foo1_(a, b)
foo1_(a, -b)

my_compiler() called with FX graph:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: <ipython-input-9-dd0c51f20b11>:2, code: x = a + b
        add = a + b;  a = None
        
        # File: <ipython-input-9-dd0c51f20b11>:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

tabular:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    b       b                        ()            {}
call_function  add     <built-in function add>  (a, b)        {}
call_method    sum_1   sum                      (b,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((add, lt),)  {}

my_compiler() called with FX graph:
class G

tensor([[ 0.8036, -0.5172, -0.2097],
        [ 0.0542,  2.3458,  0.0871]])

In the exaple above, it do break into two graphs, but not from expected:

- graph1: the expressions before the if, with the condition computation
- graph2: the expressions after the if

## Example 2

### Execute once case

In [43]:
def foo2(a:torch.tensor, b:torch.tensor):
  x = a + b
  if b.sum() < 0:
    x = x * -1
  if a.sum() < 0:
    x = x * -1
  x = 2 * x
  return x

foo2_ = optimize(my_compiler)(foo2)

In [45]:
torch._dynamo.reset() # reset all che compilation cache
my_graph_id = 0

a = torch.ones((2, 3))
b = torch.ones((2, 3))

# It should tigger only one case of the if-else
foo2_(a, b)

my_compiler() called with FX graph-0:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:2, code: x = a + b
        add = a + b;  a = None
        
        # File: <ipython-input-43-f6e4dc936826>:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

my_compiler() called with FX graph-1:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:5, code: if a.sum() < 0:
        sum_1 = a.sum();  a = None
        lt = sum_1 < 0;  sum_1 = None
        return (lt,)
        

my_compiler() called with FX graph-2:
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:7, code: x = 2 * x
        mul = 2 * x;  x = None
        return (mul,)
        



tensor([[4., 4., 4.],
        [4., 4., 4.]])

### Exectue all the cases

In [46]:
torch._dynamo.reset() # reset all che compilation cache
my_graph_id = 0

# It should tigger all the four combinations of the if-conditions
foo2_(a, b)
foo2_(a, -b)
foo2_(-a, b)
foo2_(-a, -b)

my_compiler() called with FX graph-0:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:2, code: x = a + b
        add = a + b;  a = None
        
        # File: <ipython-input-43-f6e4dc936826>:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

my_compiler() called with FX graph-1:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:5, code: if a.sum() < 0:
        sum_1 = a.sum();  a = None
        lt = sum_1 < 0;  sum_1 = None
        return (lt,)
        

my_compiler() called with FX graph-2:
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: <ipython-input-43-f6e4dc936826>:7, code: x = 2 * x
        mul = 2 * x;  x = None
        return (mul,)
        

my_compiler() called with FX graph-3:
clas

tensor([[-4., -4., -4.],
        [-4., -4., -4.]])

In [None]:
torch._dynamo.reset() # reset all che compilation cache
my_graph_id = 0

a = torch.randn((2, 3))
b = torch.randn((2, 3))

# It should tigger only one case of the if-else
foo2_(a, b)

my_compiler() called with FX graph-0:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: <ipython-input-24-9cf32d8775c0>:2, code: x = a + b
        add = a + b;  a = None
        
        # File: <ipython-input-24-9cf32d8775c0>:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

tabular:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    b       b                        ()            {}
call_function  add     <built-in function add>  (a, b)        {}
call_method    sum_1   sum                      (b,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((add, lt),)  {}

my_compiler() called with FX graph-1:
c

tensor([[-1.3537, -0.0531,  0.9512],
        [-1.3276,  1.3997, -0.3662]])

## Non-torch function call

In [37]:
import scipy

In [41]:
def func(a, b):
    import numpy as np
    aa = np.randn((2,3))
    sum = a + b
    return sum.numpy() + aa
    return aa

  func(a, b)

IndentationError: ignored

In [None]:
torch._dynamo.reset()
torch._dynamo.config.verbose=True
func = optimize(my_compiler)(draw_example)

In [None]:
func(a, b)

InternalTorchDynamoError: 