https://depyf.readthedocs.io/en/latest/opt_tutorial.html

In [1]:
import torch


class F(torch.nn.Module):
    def __init__(self, i):
        super().__init__()
        self.i = i

    def forward(self, x):
        return x + self.i


class Mod(torch.nn.Module):
    def __init__(self, n: int):
        super().__init__()
        self.fs = torch.nn.ModuleList([F(i) for i in range(n)])

    @torch.compile
    def forward(self, x):
        for f in self.fs:
            x = f(x)
        return x


total_time = 0
import time

mod = Mod(100)
mod(torch.tensor([1]))  # Compile the function

x = torch.tensor([2])  # Create input tensor
start = time.time()
for i in range(10000):
    y = mod(x)
    # do something with y
end = time.time()
total_time += end - start
print(total_time)

0.09114861488342285


### import depyf

In [2]:
!pip install depyf -q

In [3]:
import torch
import depyf


class F(torch.nn.Module):
    def __init__(self, i):
        super().__init__()
        self.i = i

    def forward(self, x):
        return x + self.i


class Mod(torch.nn.Module):
    def __init__(self, n: int):
        super().__init__()
        self.fs = torch.nn.ModuleList([F(i) for i in range(n)])

    @torch.compile
    def forward(self, x):
        for f in self.fs:
            x = f(x)
        return x


total_time = 0
import time

mod = Mod(100)
with depyf.prepare_debug("pytorch/torch_compile/dump_src_dir/"):
    mod(torch.tensor([1]))


x = torch.tensor([2])  # Create input tensor
start = time.time()
for i in range(10000):
    y = mod(x)
    # do something with y
end = time.time()
total_time += end - start
print(total_time)



0.09391260147094727


In [5]:
import torch
import depyf
from depyf.optimization import TorchCompileWrapperWithCustomDispatcher


class F(torch.nn.Module):
    def __init__(self, i):
        super().__init__()
        self.i = i

    def forward(self, x):
        return x + self.i


class Mod(TorchCompileWrapperWithCustomDispatcher):
    def __init__(self, n: int):
        self.fs = torch.nn.ModuleList([F(i) for i in range(n)])
        compiled_callable = torch.compile(self.forward)
        super().__init__(compiled_callable)

    def forward(self, x):
        for f in self.fs:
            x = f(x)
        return x

    def __call__(self, x):
        if len(self.compiled_codes) == 1:
            with self.dispatch_to_code(0):
                return self.forward(x)
        else:
            return self.compiled_callable(x)


total_time = 0
import time

mod = Mod(100)
mod(torch.tensor([1]))  # Compile

x = torch.tensor([2])  # Input tensor
start = time.time()
for i in range(10000):
    y = mod(x)
end = time.time()
total_time += end - start
print(total_time)

0.026301145553588867


In this code, the TorchCompileWrapperWithCustomDispatcher class is used to bypass the checks. By doing this, the execution time drops to about 0.05 seconds, compared to the original 0.7 seconds. This shows that the checks were responsible for most of the overhead.

Реальный пример, когда убирают guards в vllm и производительность растет
- https://github.com/vllm-project/vllm/pull/7898

официальные доки https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html