In [6]:
import os
import psutil
import gc

import torch
import torch._dynamo
from torchvision import transforms, models
from torchdrive.models.semantic import BDD100KSemSeg

device = torch.device('cuda')

compile_fns = [
    ("inductor", lambda f: torch.compile(f, backend="inductor")),
    ("cudagraphs", lambda f: torch.compile(f, backend="cudagraphs")),
    ("aot_eager", lambda f: torch.compile(f, backend="aot_eager")),
    ("none", lambda f: f),
]

def get_mem():
    gc.collect()
    process = psutil.Process(os.getpid())
    return process.memory_info().rss


def f_as_tensor(x):
    """
    leaks 32bytes
    """
    mean = torch.as_tensor((1,2,3,6,7,8,9,10), dtype=x.dtype, device=x.device)
    return mean

def f_mul(x):
    """
    leaks 64bytes
    """
    return x * 10

def f_view(x):
    """
    no leak
    """
    return x.view(-1)

for name, compile_fn in compile_fns:
    gc.collect()
    print("###############")
    print("## compile_fn", name)
    torch._dynamo.reset()
    
    # leaks 22 kbytes
    #m = BDD100KSemSeg(device=device, compile_fn=compile_fn)
    #m = compile_fn(transforms.Normalize(mean=(1, 2, 3), std=(4,5,6)))
    #m = compile_fn(f_as_tensor)
    
    # leaks ~11kb/it
    m = compile_fn(models.resnet18().to(device))
    inp = torch.rand(2, 3, 240, 320, device=device) 
    
    # warmup
    m(inp)
    m(inp)
    
    start_mem = get_mem()
    
    N = 10000
    for i in range(N):
        m(inp)

        if i % (N//10) == 0:
            print(i, get_mem())
            
    end_mem = get_mem()
    total_diff = (end_mem-start_mem)
    print("bytes/it", total_diff/N)
    print("bytes", total_diff)


###############
## compile_fn inductor
0 3857846272
1000 3864399872
2000 3875803136
3000 3887206400
4000 3898609664
5000 3910012928
6000 3921416192
7000 3932950528
8000 3944353792
9000 3955888128
bytes/it 10944.512
bytes 109445120
###############
## compile_fn cudagraphs




0 3995996160
1000 3995996160
2000 3995996160
3000 3995996160
4000 3995996160
5000 3995996160
6000 3995996160
7000 3995996160
8000 3995996160
9000 3995996160
bytes/it 0.0
bytes 0
###############
## compile_fn aot_eager
0 3995996160
1000 3995996160
2000 3995996160
3000 3995996160
4000 3995996160
5000 3995996160
6000 3995996160
7000 3995996160
8000 3995996160
9000 3995996160
bytes/it 0.0
bytes 0
###############
## compile_fn none
0 4010414080
1000 4010414080
2000 4010414080
3000 4010414080
4000 4010414080
5000 4010414080
6000 4010414080
7000 4010414080
8000 4010414080
9000 4010414080
bytes/it 0.0
bytes 0


In [2]:
import torch._dynamo
torch._dynamo.list_backends()

['aot_ts_nvfuser',
 'cudagraphs',
 'inductor',
 'ipex',
 'nvprims_nvfuser',
 'onnxrt',
 'tvm']