In [1]:
import torch
# torch.set_float32_matmul_precision('high')

import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )




In [2]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

In [3]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 2.996556884765625




compile: 35.974046875


In [4]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.019986431121826173
eager eval time 1: 0.01766307258605957
eager eval time 2: 0.017738752365112305
eager eval time 3: 0.01589964771270752
eager eval time 4: 0.016094207763671875
eager eval time 5: 0.015523615837097169
eager eval time 6: 0.015485695838928223
eager eval time 7: 0.01773353576660156
eager eval time 8: 0.01754649543762207
eager eval time 9: 0.017763328552246094
~~~~~~~~~~
compile eval time 0: 0.01704960060119629
compile eval time 1: 0.017118207931518553
compile eval time 2: 0.01782067108154297
compile eval time 3: 0.01662156867980957
compile eval time 4: 0.01719910430908203
compile eval time 5: 0.017386335372924805
compile eval time 6: 0.019425344467163087
compile eval time 7: 0.01880166435241699
compile eval time 8: 0.019574783325195313
compile eval time 9: 0.016857088088989256
~~~~~~~~~~
(eval) eager median: 0.01760478401184082, compile median: 0.017292719841003418, speedup: 1.0180459854613184x
~~~~~~~~~~
