In [1]:
import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )

In [6]:
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

tensor([[ 1.7872e+00,  1.9614e+00,  9.0110e-04, -2.9500e-01,  7.7799e-01,
          3.5593e-02, -1.3396e-02,  6.4578e-01, -1.2947e-01,  1.8579e+00],
        [ 1.6040e+00,  1.8856e+00,  1.2606e+00,  1.7045e+00,  2.5772e-01,
          7.4184e-01,  2.9223e-01, -3.8714e-01,  1.8560e+00,  3.5174e-01],
        [-1.0209e+00,  2.4348e-02, -1.2797e-01,  1.8819e+00, -4.7803e-01,
          9.3266e-01,  1.3870e+00,  1.1974e+00, -2.0989e-01,  1.6325e+00],
        [ 8.4868e-01, -3.9857e-03,  8.0966e-01,  1.3976e+00,  1.2699e+00,
         -4.5516e-01,  4.0862e-01,  5.2027e-01,  3.0611e-02,  6.7047e-01],
        [-1.8843e+00,  1.9921e-01,  8.6538e-01,  7.9285e-01,  5.3565e-01,
         -8.6505e-01,  7.4809e-01,  1.1497e+00, -8.3933e-03, -2.2558e-01],
        [ 1.4868e+00, -8.5710e-02,  1.7295e+00,  1.5299e+00,  3.6676e-01,
          9.1303e-01,  1.6716e+00,  1.4785e+00,  9.4002e-01,  6.5890e-01],
        [ 1.7643e-01, -1.6536e-01,  9.6033e-01,  9.1788e-01,  1.0181e+00,
          1.8507e+00,  1.0062e+0

In [3]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))

tensor([[-0.0000, 0.0078, -0.0000, 0.1264, -0.0000, -0.0000, 0.1965, 1.0125, -0.0000,
         0.0904],
        [0.5082, 0.2965, -0.0000, 1.0539, -0.0000, 0.3981, -0.0000, 0.0936, -0.0000,
         -0.0000],
        [-0.0000, -0.0000, -0.0000, -0.0000, 0.5585, 0.4620, 0.1551, 0.5080, -0.0000,
         -0.0000],
        [-0.0000, -0.0000, 0.2071, -0.0000, 1.1365, 0.4965, 0.5573, -0.0000, -0.0000,
         -0.0000],
        [-0.0000, 0.2231, -0.0000, 0.4177, -0.0000, 0.3016, -0.0000, 0.3646, -0.0000,
         0.4215],
        [-0.0000, 0.6417, -0.0000, 0.7434, 0.1864, 0.4046, -0.0000, -0.0000, 0.8087,
         0.0656],
        [0.0594, 0.5815, 0.4926, -0.0000, -0.0000, -0.0000, 0.3768, -0.0000, -0.0000,
         -0.0000],
        [-0.0000, -0.0000, -0.0000, -0.0000, 1.3516, 1.1902, 0.4950, -0.0000, -0.0000,
         0.3907],
        [0.3714, -0.0000, -0.0000, -0.0000, 0.0324, 0.6633, -0.0000, 0.3417, -0.0000,
         0.7288],
        [0.0350, -0.0000, 0.4155, 0.6461, -0.0000, 0.1019, -0

In [4]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

In [7]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.021341184616088867
compile: 60.21555078125


In [8]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.027048959732055664
eager eval time 1: 0.018497535705566406
eager eval time 2: 0.016726015090942382
eager eval time 3: 0.01577676773071289
eager eval time 4: 0.01558835220336914
eager eval time 5: 0.015185919761657715
eager eval time 6: 0.015246335983276366
eager eval time 7: 0.015119359970092774
eager eval time 8: 0.015210495948791505
eager eval time 9: 0.018497535705566406
~~~~~~~~~~
compile eval time 0: 1.0827969970703124
compile eval time 1: 0.005703680038452149
compile eval time 2: 0.005573631763458252
compile eval time 3: 0.005536767959594726
compile eval time 4: 0.005467135906219483
compile eval time 5: 0.0054609918594360355
compile eval time 6: 0.005463039875030518
compile eval time 7: 0.0054579200744628905
compile eval time 8: 0.005453824043273926
compile eval time 9: 0.005475327968597412
~~~~~~~~~~
(eval) eager median: 0.015682559967041015, compile median: 0.005471231937408448, speedup: 2.8663672362004373x
~~~~~~~~~~


In [9]:
model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager train time 0: 0.5032314758300781
eager train time 1: 0.06248550415039063
eager train time 2: 0.05783039855957031
eager train time 3: 0.06603059387207032
eager train time 4: 0.05302374267578125
eager train time 5: 0.052890625
eager train time 6: 0.05248716735839844
eager train time 7: 0.060796928405761716
eager train time 8: 0.059061248779296874
eager train time 9: 0.05958553695678711
~~~~~~~~~~
compile train time 0: 172.408875
compile train time 1: 5.86939306640625
compile train time 2: 0.05554687881469727
compile train time 3: 0.0372305908203125
compile train time 4: 0.038247425079345705
compile train time 5: 0.04015513610839844
compile train time 6: 0.042434558868408204
compile train time 7: 0.03907174301147461
compile train time 8: 0.04237823867797851
compile train time 9: 0.043717632293701174
~~~~~~~~~~
(train) eager median: 0.05932339286804199, compile median: 0.042406398773193354, speedup: 1.3989255061559833x
~~~~~~~~~~
