In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch import Tensor
from torch.profiler import profile, record_function, ProfilerActivity
device = "cuda" if torch.cuda.is_available() else "cpu"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def fn1(x: Tensor, a: Tensor) -> Tensor:
    mu = 1.0
    x = torch.exp(mu * x)
    a = torch.exp(mu * a)
    y = F.linear(x, a)
    y = torch.log(y) / mu
    return y


def fn2(x: Tensor, a: Tensor) -> Tensor:
    return checkpoint(fn1, x, a)


x = torch.randn(10000, 1000, requires_grad=True).to(device)
a = torch.randn(500, 1000, requires_grad=True).to(device)
grad_y = torch.rand(10000, 500).to(device)

In [3]:
%%timeit -n100 -r1 
x1=x.detach().requires_grad_(True)
a1=a.detach().requires_grad_(True)

y1 = fn1(x1, a1)
y1.backward(grad_y)
torch.cuda.synchronize()

5.34 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 100 loops each)


In [4]:
%%timeit -n100 -r1
x2=x.detach().requires_grad_(True)
a2=a.detach().requires_grad_(True)

y2 = fn2(x2, a2)
y2.backward(grad_y)
torch.cuda.synchronize()

6.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 100 loops each)


In [7]:
y1 = fn1(x.detach(), a.detach())
y2 = fn2(x.detach(), a.detach())
assert y1.allclose(y2)



In [13]:
def tropical1(x:Tensor, a:Tensor) -> Tensor:
    return torch.max(x.unsqueeze(-2) + a, dim=-1)[0]

def tropical2(x: Tensor, a: Tensor) -> Tensor:
    return checkpoint(tropical1, x, a)

x = torch.randn(100, 1000, requires_grad=True).to(device)
a = torch.randn(500, 1000, requires_grad=True).to(device)
grad_y = torch.rand(100, 500).to(device)

In [14]:
%%timeit -n100 -r1 
x1=x.detach().requires_grad_(True)
a1=a.detach().requires_grad_(True)

y1 = tropical1(x1, a1)
y1.backward(grad_y)
torch.cuda.synchronize()


3.22 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 100 loops each)


In [15]:
%%timeit -n100 -r1 
x1=x.detach().requires_grad_(True)
a1=a.detach().requires_grad_(True)

y1 = tropical2(x1, a1)
y1.backward(grad_y)
torch.cuda.synchronize()

4.86 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 100 loops each)


In [16]:
x = torch.randn(100, 1000, requires_grad=True).to(device)
a = torch.randn(500, 1000, requires_grad=True).to(device)
grad_y = torch.rand(100, 500).to(device)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True) as prof:
    y = tropical1(x, a)

print(prof.key_averages())

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        aten::unsqueeze         0.63%      35.000us         0.74%      41.000us      41.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 

STAGE:2023-07-12 15:38:10 250074:250074 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-07-12 15:38:10 250074:250074 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-12 15:38:10 250074:250074 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [17]:
x = torch.randn(100, 1000, requires_grad=True).to(device)
a = torch.randn(500, 1000, requires_grad=True).to(device)
grad_y = torch.rand(100, 500).to(device)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True) as prof:
    y = tropical2(x, a)

print(prof.key_averages())

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                     CheckpointFunction        28.30%       2.568ms        30.14%       2.735ms       2.735ms       0.000us         0.00%       6.803ms       6.803ms       5.73 Kb       5.73 Kb     195.50 Kb    -191.12 M

STAGE:2023-07-12 15:38:32 250074:250074 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
[W CPUAllocator.cpp:235] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
STAGE:2023-07-12 15:38:32 250074:250074 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-12 15:38:32 250074:250074 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
