# Torch Compile

torch.compile is the latest method to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes.

In [1]:
import torch
import warnings

device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"using {device=}")

using device='mps'


---

## torch.compile basic usage

Arbitrary Python functions can be optimized by passing the callable to torch.compile.

We can then call the returned optimized function in place of the original function.

In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)

print(opt_foo1(torch.randn(2, 2), torch.randn(2, 2)))

tensor([[ 1.3567,  0.1011],
        [-1.6241,  0.3115]])


- We can also use decorator to compile a function

In [3]:
t1 = torch.randn(2, 2)
t2 = torch.randn(2, 2)


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(t1, t2))

tensor([[ 0.5304, -0.2428],
        [ 1.2837,  1.6413]])


- We can also optimize torch.nn.Module instances.

In [4]:
t = torch.randn(10, 100)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))


mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(t))

tensor([[0.4956, 0.7377, 0.2008, 0.0000, 0.2288, 0.6336, 1.4262, 0.0000, 0.0000,
         0.4541],
        [0.8629, 0.1461, 0.5833, 0.0000, 0.0544, 0.0000, 0.4984, 0.5596, 0.0000,
         0.4246],
        [0.5437, 0.0811, 0.0756, 0.0083, 0.8152, 0.2054, 0.2765, 0.0000, 0.0000,
         0.0901],
        [0.4975, 0.2585, 0.0000, 0.2944, 0.6465, 0.0000, 0.5874, 0.0357, 0.1657,
         0.0000],
        [0.2115, 0.2578, 0.0000, 0.4508, 0.0000, 0.0000, 0.4244, 0.0000, 0.0000,
         0.3522],
        [0.0000, 0.1485, 0.0000, 0.0000, 0.0000, 0.0000, 0.5496, 1.2266, 0.0000,
         0.0000],
        [0.0000, 0.7922, 0.2276, 0.0351, 0.0000, 0.0142, 0.0000, 0.4307, 0.4509,
         0.0000],
        [0.3721, 0.9828, 0.0000, 0.3258, 0.0000, 0.0000, 0.5848, 0.0000, 0.0000,
         0.2040],
        [0.0000, 0.0000, 0.0000, 0.5841, 0.0000, 0.0000, 0.0000, 0.9089, 0.0982,
         0.2046],
        [0.1742, 0.0000, 0.0000, 0.0000, 0.5193, 0.0000, 0.1036, 0.1628, 0.0000,
         0.0000]], grad_fn=<

---

## torch.compile and Nested Calls

Nested function calls within the decorated function will also be compiled.

In [5]:
def nested_function(x):
    return torch.sin(x)


@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b


print(outer_function(t1, t2))

tensor([[ 0.5304, -0.2428],
        [ 1.2837,  1.6413]])


- In the same fashion, when compiling a module all sub-modules and methods within it, that are not in a skip list, are also compiled.

In [6]:
class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))


outer_mod = OuterModule()
opt_outer_mod = torch.compile(outer_mod)
print(opt_outer_mod(t))

tensor([[0.2111, 0.0000],
        [0.3138, 0.0000],
        [0.0000, 0.0000],
        [0.2170, 0.0000],
        [0.1764, 0.0000],
        [0.0862, 0.1966],
        [0.0000, 0.3254],
        [0.0000, 0.3184],
        [0.0594, 0.0000],
        [0.2973, 0.0000]], grad_fn=<CompiledFunctionBackward>)


We can also disable some functions from being compiled by using torch.compiler.disable. Suppose you want to disable the tracing on just the complex_function function, but want to continue the tracing back in complex_conjugate. In this case, you can use torch.compiler.disable(recursive=False) option. Otherwise, the default is recursive=True.

In [7]:
def complex_conjugate(z):
    return torch.conj(z)


@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)


def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)


# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

tensor([4.4721, 5.8310])


