In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet50

torch.set_float32_matmul_precision('high')

class DDQN(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.online = nn.Sequential(
            resnet50(),
            nn.Linear(1000, 100)
        )

    def forward(self, input):
        return self.online(input)

In [2]:
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

def generate_data():
    return torch.randn(1, 3, 64, 64).to(torch.float32).cuda()

def evaluate(mod, inp):
    return mod(inp)

In [3]:
model = DDQN().cuda()

iters = 10

elapsed_times = []
for i in range(iters):
    inp = generate_data()
    _, eager_time = timed(lambda: evaluate(model, inp))
    print(f"normal eval time {i}: {eager_time}")


#compiledModel = torch.compile(model)
newModel =  torch.compile(DDQN().cuda(), mode="reduce-overhead")

iters = 10

elapsed_times = []
for i in range(iters):
    inp = generate_data()
    _, eager_time = timed(lambda: evaluate(newModel, inp))
    print(f"compiled eval time {i}: {eager_time}")



normal eval time 0: 2.298642333984375
normal eval time 1: 0.00820531177520752
normal eval time 2: 0.0076605439186096195
normal eval time 3: 0.007527423858642578
normal eval time 4: 0.008177663803100586
normal eval time 5: 0.006677504062652588
normal eval time 6: 0.00637440013885498
normal eval time 7: 0.006554624080657959
normal eval time 8: 0.0064542717933654785
normal eval time 9: 0.0066344962120056155
compiled eval time 0: 14.5871083984375
compiled eval time 1: 0.0030658559799194335
compiled eval time 2: 0.0025661439895629883
compiled eval time 3: 0.002099200010299683
compiled eval time 4: 0.002085887908935547
compiled eval time 5: 0.002091007947921753
compiled eval time 6: 0.0020797441005706786
compiled eval time 7: 0.0021975040435791016
compiled eval time 8: 0.002401279926300049
compiled eval time 9: 0.004622335910797119
