In [14]:
!rm -rf torch_compile_debug
!rm *.svg

rm: cannot remove '*.svg': No such file or directory


In [15]:
import torch
import math
import os
import matplotlib.pyplot as plt
import torch._dynamo
from torchvision import models
from torch.fx.passes.graph_drawer import FxGraphDrawer
from IPython.display import Markdown as md

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [16]:
def f(x):
    return torch.sin(x)**2 + torch.cos(x)**2

md('''
# $ y = f(x) = sin^2(x) + cos^2(x)$
''')


# $ y = f(x) = sin^2(x) + cos^2(x)$


In [17]:
md('''
## Optimization problem:
### find $w^*$ that $\displaystyle \min_{w} f(w)$

## Gradient update
### $ w_{i+1} = w_{i} - g(\\nabla f(w))$
## For SGD
### $  g(\\nabla f(w)) = \\alpha*\\nabla f(w)$
## Which makes the update for SGD:
### $ w_{i+1} = w_{i} - \\alpha*\\nabla f(w)$

## Loss function
### $loss(w): loss(model(w,batch_{inputs}), batch_{outputs})$ 
''')


## Optimization problem:
### find $w^*$ that $\displaystyle \min_{w} f(w)$

## Gradient update
### $ w_{i+1} = w_{i} - g(\nabla f(w))$
## For SGD
### $  g(\nabla f(w)) = \alpha*\nabla f(w)$
## Which makes the update for SGD:
### $ w_{i+1} = w_{i} - \alpha*\nabla f(w)$

## Loss function
### $loss(w): loss(model(w,batch_{inputs}), batch_{outputs})$ 


In [18]:
md('''
# **Forward graph:** $f(x) = sin^2(x)+cos^2(x)$ \n
# **Backward graph:** $\\frac {df(x)}{d\\vec{w}} = f\'(x) = 2sin(x)cos(x) + 2cos(x)(-sin(x))$
''')


# **Forward graph:** $f(x) = sin^2(x)+cos^2(x)$ 

# **Backward graph:** $\frac {df(x)}{d\vec{w}} = f'(x) = 2sin(x)cos(x) + 2cos(x)(-sin(x))$


In [19]:
torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
torch.nn.functional.mse_loss(f(x),torch.ones_like(x)) < 1e-10

tensor(True, device='cuda:0')

In [20]:
torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)

compiled_f = torch.compile(f)
torch.nn.functional.mse_loss(compiled_f(x),torch.ones_like(x)) < 1e-10



tensor(True, device='cuda:0')

In [21]:
def inspect_backend(gm, sample_inputs):
    code = gm.print_readable()
    with open("forward.svg", "wb") as file:
        file.write(FxGraphDrawer(gm,'f').get_dot_graph().create_svg())
    return gm.forward

torch._dynamo.reset()
compiled_f = torch.compile(f, backend=inspect_backend)

x = torch.rand(1000, requires_grad=True).to(device)
out = compiled_f(x)

md(f'''
### Graph
![]({'forward.svg'})
''')

class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: /tmp/ipykernel_3344/1097326487.py:2, code: return torch.sin(x)**2 + torch.cos(x)**2
        sin = torch.sin(x)
        pow_1 = sin ** 2;  sin = None
        cos = torch.cos(x);  x = None
        pow_2 = cos ** 2;  cos = None
        add = pow_1 + pow_2;  pow_1 = pow_2 = None
        return (add,)
        


BackendCompilerFailed: inspect_backend raised RuntimeError: FXGraphDrawer requires the pydot package to be installed. Please install pydot through your favorite Python package manager.

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


# AOTAutograd and Aten IR

In [22]:
import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified

def f(x):
    return torch.sin(x)**2 + torch.cos(x)**2

def inspect_backend(gm, sample_inputs): 
    # Forward compiler capture
    def fw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("forward_aot.svg", "wb") as file:
            file.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)
    
    # Backward compiler capture
    def bw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("backward_aot.svg", "wb") as file:
            file.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)
    
    # Call AOTAutograd
    gm_forward = aot_module_simplified(gm,sample_inputs,
                                       fw_compiler=fw,
                                       bw_compiler=bw)

    return gm_forward

torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
y = torch.ones_like(x)

torch._dynamo.reset()
compiled_f = torch.compile(f, backend=inspect_backend)
out = torch.nn.functional.mse_loss(compiled_f(x), y).backward()

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[1000]):
        # File: /tmp/ipykernel_3344/2663716573.py:7, code: return torch.sin(x)**2 + torch.cos(x)**2
        sin: f32[1000] = torch.ops.aten.sin.default(primals_1)
        pow_1: f32[1000] = torch.ops.aten.pow.Tensor_Scalar(sin, 2)
        cos: f32[1000] = torch.ops.aten.cos.default(primals_1)
        pow_2: f32[1000] = torch.ops.aten.pow.Tensor_Scalar(cos, 2)
        add: f32[1000] = torch.ops.aten.add.Tensor(pow_1, pow_2);  pow_1 = pow_2 = None
        return [add, primals_1, sin, cos]
        


BackendCompilerFailed: inspect_backend raised RuntimeError: FXGraphDrawer requires the pydot package to be installed. Please install pydot through your favorite Python package manager.

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


In [11]:
md(f'''
|![]({'forward_aot.svg'}) | < Forward graph <br><br><br> Backward graph >|![]({'backward_aot.svg'})|
|---|---|---|
''')


|![](forward_aot.svg) | < Forward graph <br><br><br> Backward graph >|![](backward_aot.svg)|
|---|---|---|


# Decomposition to Core Aten IR

In [13]:
import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
from torch._decomp import core_aten_decompositions

def f_loss(x, y):
    f_x = torch.sin(x)**2 + torch.cos(x)**2
    return torch.nn.functional.mse_loss(f_x, y)

# decompositions = core_aten_decompositions() # Use decomposition to Core Aten IR
decompositions = {} # Don't use decomposition to Core Aten IR

def inspect_backend(gm, sample_inputs): 
    def fw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("forward_decomp.svg", "wb") as file:
            file.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)
    
    def bw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("backward_decomp.svg", "wb") as file:
            file.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=fw,
        bw_compiler=bw,
        decompositions=decompositions
    )

torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
y = torch.ones_like(x)

torch._dynamo.reset()
compiled_f = torch.compile(f_loss, backend=inspect_backend)
out = compiled_f(x,y).backward()


md('''
# $MSE = (\\frac{1}{n})(\\vec{y}-\\vec{x})^2$

''')

class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[1000], primals_2: f32[1000]):
        # File: /tmp/ipykernel_3344/1876138810.py:8, code: f_x = torch.sin(x)**2 + torch.cos(x)**2
        sin: f32[1000] = torch.ops.aten.sin.default(primals_1)
        pow_1: f32[1000] = torch.ops.aten.pow.Tensor_Scalar(sin, 2)
        cos: f32[1000] = torch.ops.aten.cos.default(primals_1)
        pow_2: f32[1000] = torch.ops.aten.pow.Tensor_Scalar(cos, 2)
        add: f32[1000] = torch.ops.aten.add.Tensor(pow_1, pow_2);  pow_1 = pow_2 = None
        
        # File: /tmp/ipykernel_3344/1876138810.py:9, code: return torch.nn.functional.mse_loss(f_x, y)
        mse_loss: f32[] = torch.ops.aten.mse_loss.default(add, primals_2)
        return [mse_loss, sin, primals_2, add, cos, primals_1]
        


BackendCompilerFailed: inspect_backend raised RuntimeError: FXGraphDrawer requires the pydot package to be installed. Please install pydot through your favorite Python package manager.

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


In [None]:
md(f'''
|![]({'forward_decomp.svg'}) | < Forward graph <br><br><br> Backward graph >|![]({'backward_decomp.svg'})|
|---|---|---|
''')

# Decomposition to prim IR

In [None]:
import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
from torch._decomp import core_aten_decompositions

def f_loss(x, y):
    f_x = torch.sin(x)**2 + torch.cos(x)**2
    return torch.nn.functional.mse_loss(f_x, y)

decompositions = core_aten_decompositions()
decompositions.update(
    torch._decomp.get_decompositions([
        torch.ops.aten.sin,
        torch.ops.aten.cos,
        torch.ops.aten.add,
        torch.ops.aten.sub,
        torch.ops.aten.mul,
        torch.ops.aten.sum,
        torch.ops.aten.mean,
        torch.ops.aten.pow.Tensor_Scalar,
    ])
)

def inspect_backend(gm, sample_inputs): 
    def fw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("forward_decomp_prims.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)
    
    def bw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("backward_decomp_prims.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())
        return make_boxed_func(gm.forward)

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=fw,
        bw_compiler=bw,
        decompositions=decompositions
    )

torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
y = torch.ones_like(x)

torch._dynamo.reset()
compiled_f = torch.compile(f_loss, backend=inspect_backend)
out = compiled_f(x,y).backward()


In [None]:
md(f'''
|![]({'forward_decomp_prims.svg'}) | < Forward graph <br><br><br> Backward graph >|![]({'backward_decomp_prims.svg'})|
|---|---|---|
''')

In [25]:

import torch
import torch._dynamo
def f(x):
    return torch.sin(x)**2 + torch.cos(x)**2 

torch._dynamo.reset()
compiled_f = torch.compile(f, backend='inductor',
                              options={'trace.enabled':True,
                                       'trace.graph_diagram':True})


# device = 'cpu'
device = 'cuda'

torch.manual_seed(0)
x = torch.rand(1000, requires_grad=True).to(device)
y = torch.ones_like(x)

out = torch.nn.functional.mse_loss(compiled_f(x),y).backward()



Writing FX graph to file: /workspaces/devspace/pytorch-compile-blogpost/torch_compile_debug/run_2023_06_21_21_31_46_238774-pid_3344/aot_torchinductor/model__7_forward_10.2/graph_diagram.svg


BackendCompilerFailed: debug_wrapper raised RuntimeError: FXGraphDrawer requires the pydot package to be installed. Please install pydot through your favorite Python package manager.

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


In [None]:
import glob
fwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*forward*/graph_diagram.svg')[-1]
bwd = glob.glob('torch_compile_debug/run_*/aot_torchinductor/*backward*/graph_diagram.svg')[-1]

md(f'''
|![]({fwd}) | < Forward graph <br><br><br> Backward graph >|![]({bwd})|
|---|---|---|
''')