In [1]:
from dataclasses import asdict
from ddpm import DDPM, ModelConfig, TrainerConfig
from model import UNet



In [2]:
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

[A100 Spec](https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf)

In [3]:
A100_FP32_TFLOPS_PER_SEC = 19.5
A100_FP16_TFLOPS_PER_SEC = 312
A100_MEMBW_TB_PER_SEC = 1.555

In [4]:
BASELINE_TFLOPS_PER_SEC = 2.91

## Baseline

2.91 TFLOP/s per iteration

In [23]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig()

Files already downloaded and verified


In [24]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = DDPM(**asdict(cfg_m)).to('cuda')
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)

In [25]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

Using `torchinfo` to estimate the FLOPs. ([`torchinfo` docs](https://pypi.org/project/torchinfo/))

In [26]:
from torchinfo import summary
summary(ddpm, input_data=[x0, eps, t])

Layer (type:depth-idx)                        Output Shape              Param #
DDPM                                          [128, 3, 32, 32]          --
├─UNet: 1-1                                   [128, 3, 32, 32]          --
│    └─Sequential: 2-1                        [128, 128]                --
│    │    └─Linear: 3-1                       [128, 128]                4,224
│    │    └─GELU: 3-2                         [128, 128]                --
│    │    └─Linear: 3-3                       [128, 128]                16,512
│    └─Conv2d: 2-2                            [128, 32, 32, 32]         128
│    └─UNetDownsample: 2-3                    [128, 32, 16, 16]         --
│    │    └─TimeResNetBlock: 3-4              [128, 32, 32, 32]         26,880
│    │    └─TimeResNetBlock: 3-5              [128, 32, 32, 32]         26,880
│    │    └─GroupNorm: 3-6                    [128, 32, 32, 32]         64
│    │    └─Attention: 3-7                    [128, 32, 32, 32]         16,416


Forward pass FLOPs $= 61.22 \times 2 = 132.44 \implies$ forward + backward pass $= 61.22 \times 3 = 183.66$ GFLOPs  
[Forward and Backward Pass FLOP estimation](https://epochai.org/blog/backward-forward-FLOP-ratio)

## Profiling

- [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
- [Using CUDA events to measure time](https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964/2)

### PyTorch Profiler

In [7]:
from torch.profiler import profile, record_function
from torch.profiler import ProfilerActivity as PA

In [34]:
ddpm.train()

with profile(activities=[PA.CPU, PA.CUDA], record_shapes=True) as prof:
    for x0, _ in dataloader:
        optimizer.zero_grad()
        
        x0 = x0.to('cuda')
        eps = torch.randn_like(x0)
        t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
        
        eps_pred = ddpm(x0, eps, t)
        loss = F.smooth_l1_loss(eps, eps_pred)
        
        loss.backward()
        optimizer.step()
        break

STAGE:2024-06-18 07:27:37 1493321:1493321 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-06-18 07:27:38 1493321:1493321 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-06-18 07:27:38 1493321:1493321 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [35]:
print(prof.key_averages().table(sort_by='cuda_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: ConvolutionBack...         0.62%     855.000us         9.27%      12.747ms     169.960us       0.000us         0.00%      17.319ms     230.920us            75  
                                   ConvolutionBackward0         0.30%     406.000us         8.49%      11.674ms     155.653us       0.000us         0.00%      16.942ms     225.893us            75  
         

In [36]:
print(prof.key_averages().table(sort_by='cpu_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        34.51%      47.469ms        34.53%      47.492ms      47.492ms       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel         9.91%      13.632ms         9.91%      13.632ms       5.418us       0.000us         0.00%       0.000us       0.000us          2516  
         

### Measure Time with CUDA Events

In [14]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)

start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    loss.backward()
    optimizer.step()

end_e.record()

torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

63.03197265625


In [16]:
tflop_per_sec = (183.66 / (t / 1000)) / 1000
tflop_per_sec

2.9137593551387764

A100 has 19.5 TFLOP/s of FP32 compute ([Spec](https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf))

In [18]:
mfu = tflop_per_sec / 19.5
print(f'{mfu=:.2%}')

mfu=14.94%


Wow 14.94% MFU is pretty bad

## Optimization 1: Scale Up Batch Size

Scaling up to batch size 1024 roughly saturates the FP32 compute, but still leaves FP16 compute on the table.  
20.84 TFLOP/s per iteration (7.16x)

In [3]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=1024)

Files already downloaded and verified


In [4]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = DDPM(**asdict(cfg_m)).to('cuda')
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)

In [5]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

In [6]:
from torchinfo import summary
summary(ddpm, input_data=[x0, eps, t])

Layer (type:depth-idx)                        Output Shape              Param #
DDPM                                          [1024, 3, 32, 32]         --
├─UNet: 1-1                                   [1024, 3, 32, 32]         --
│    └─Sequential: 2-1                        [1024, 128]               --
│    │    └─Linear: 3-1                       [1024, 128]               4,224
│    │    └─GELU: 3-2                         [1024, 128]               --
│    │    └─Linear: 3-3                       [1024, 128]               16,512
│    └─Conv2d: 2-2                            [1024, 32, 32, 32]        128
│    └─UNetDownsample: 2-3                    [1024, 32, 16, 16]        --
│    │    └─TimeResNetBlock: 3-4              [1024, 32, 32, 32]        26,880
│    │    └─TimeResNetBlock: 3-5              [1024, 32, 32, 32]        26,880
│    │    └─GroupNorm: 3-6                    [1024, 32, 32, 32]        64
│    │    └─Attention: 3-7                    [1024, 32, 32, 32]        16,416


In [12]:
tflops_per_iter = (489.73 * 2 * 3) / 1000
tflops_per_iter

2.93838

In [7]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)

start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    loss.backward()
    optimizer.step()

end_e.record()

torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

141.0171875


In [13]:
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

20.837034492692602

In [18]:
mfu = tflop_per_sec / A100_FP32_TFLOPS_PER_SEC
print(f'{mfu=:.2%}')

mfu=106.86%


**All parameters are in either FP32 or INT32 format.**

In [23]:
all(p.dtype in [torch.float32, torch.int32] for p in ddpm.model.parameters())

True

Profiling shows that CPU is no longer an overhead with this batch size.

In [24]:
from torch.profiler import profile, record_function
from torch.profiler import ProfilerActivity as PA

In [28]:
ddpm.train()

with profile(activities=[PA.CPU, PA.CUDA], record_shapes=True) as prof:
    for _, (x0, _) in zip(range(100), dataloader):
        optimizer.zero_grad()
        
        x0 = x0.to('cuda')
        eps = torch.randn_like(x0)
        t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
        
        eps_pred = ddpm(x0, eps, t)
        loss = F.smooth_l1_loss(eps, eps_pred)
        
        loss.backward()
        optimizer.step()

STAGE:2024-06-18 08:00:15 1518246:1518246 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2024-06-18 08:00:30 1518246:1518246 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-06-18 08:00:30 1518246:1518246 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [29]:
print(prof.key_averages().table(sort_by='cuda_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: ConvolutionBack...         0.33%      41.929ms        19.85%        2.547s     707.427us       0.000us         0.00%        5.017s       1.394ms          3600  
                             aten::convolution_backward         2.25%     288.713ms        18.96%        2.432s     675.461us        4.157s        32.31%        4.855s       1.349ms          3600  
         

In [30]:
print(prof.key_averages().table(sort_by='cpu_time_total'))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       cudaLaunchKernel        47.03%        6.032s        47.03%        6.032s      49.851us     897.041ms         6.97%     897.054ms       7.413us        121008  
                                            aten::copy_         0.16%      20.969ms        22.50%        2.886s     178.439us     788.119ms         6.13%     806.185ms      49.838us         16176  
         

## Optimization 2 - `torch.compile`

8.3x Speedup

In [5]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=1024)

Files already downloaded and verified


In [6]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = torch.compile(DDPM(**asdict(cfg_m)).to('cuda'))
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)

In [7]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

In [7]:
eps_pred = ddpm(x0, eps, t)
loss = F.smooth_l1_loss(eps, eps_pred)    
loss.backward()



In [8]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)

start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    loss.backward()
    optimizer.step()

end_e.record()

torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

121.651767578125


In [10]:
tflops_per_iter = (489.73 * 2 * 3) / 1000
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

24.154026353237875

In [13]:
tflop_per_sec / BASELINE_TFLOPS_PER_SEC

8.300352698707172

## Optimization 3 - Use PyTorch SDPA

- [PyTorch SDPA doc](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
- [NanoGPT Usage](https://github.com/karpathy/nanoGPT/blob/master/model.py)

8.64x speedup

In [5]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=1024)

Files already downloaded and verified


In [6]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = torch.compile(DDPM(**asdict(cfg_m)).to('cuda'))
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)

In [7]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

eps_pred = ddpm(x0, eps, t)
loss = F.smooth_l1_loss(eps, eps_pred)    
loss.backward()



In [8]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)

start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    loss.backward()
    optimizer.step()

end_e.record()

torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

116.928603515625


In [9]:
tflops_per_iter = (489.73 * 2 * 3) / 1000
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

25.129693775974573

In [10]:
tflop_per_sec / BASELINE_TFLOPS_PER_SEC

8.635633599991262

## Optimization 3.5 - Use TF32 Precision

In [5]:
torch.set_float32_matmul_precision('high')

In [6]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=1024)

Files already downloaded and verified


In [7]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = torch.compile(DDPM(**asdict(cfg_m)).to('cuda'))
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)

In [8]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

eps_pred = ddpm(x0, eps, t)
loss = F.smooth_l1_loss(eps, eps_pred)    
loss.backward()

In [9]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)

start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    loss.backward()
    optimizer.step()

end_e.record()

torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

116.965390625


In [10]:
tflops_per_iter = (489.73 * 2 * 3) / 1000
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

25.121790166295185

In [11]:
tflop_per_sec / BASELINE_TFLOPS_PER_SEC

8.632917582919307

## Optimization 4 - Mixed-Precision Training

Lowering to 16-bit precision lowers the memory requirements, so we further increase the batch size.  
At batch size 1024, it achieves 10.99x speedup.  
At batch size 3584, it achieves 11.92x speedup.

- [Tutorial 1](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/)
- [Tutorial 2](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)

In [5]:
img2tensor = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)  # [0, 1] -> [-1, 1]
])
ds = CIFAR10('./cifar10', train=True, transform=img2tensor, download=True)
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=3072)

Files already downloaded and verified


In [6]:
dataloader = DataLoader(ds, batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True)
ddpm = torch.compile(DDPM(**asdict(cfg_m)).to('cuda'))
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)
scaler = torch.cuda.amp.GradScaler()

In [8]:
from torchinfo import summary
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
summary(DDPM(**asdict(cfg_m)).to('cuda'), input_data=[x0, eps, t])

Layer (type:depth-idx)                        Output Shape              Param #
DDPM                                          [3072, 3, 32, 32]         --
├─UNet: 1-1                                   [3072, 3, 32, 32]         --
│    └─Sequential: 2-1                        [3072, 128]               --
│    │    └─Linear: 3-1                       [3072, 128]               4,224
│    │    └─GELU: 3-2                         [3072, 128]               --
│    │    └─Linear: 3-3                       [3072, 128]               16,512
│    └─Conv2d: 2-2                            [3072, 32, 32, 32]        128
│    └─UNetDownsample: 2-3                    [3072, 32, 16, 16]        --
│    │    └─TimeResNetBlock: 3-4              [3072, 32, 32, 32]        26,880
│    │    └─TimeResNetBlock: 3-5              [3072, 32, 32, 32]        26,880
│    │    └─GroupNorm: 3-6                    [3072, 32, 32, 32]        64
│    │    └─Attention: 3-7                    [3072, 32, 32, 32]        16,416


In [9]:
tflops_per_iter = 1.47 * 2 * 3
tflops_per_iter

8.82

In [None]:
x0, _ = next(iter(dataloader))
x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

with torch.cuda.amp.autocast():
    eps_pred = ddpm(x0, eps, t)

loss = F.smooth_l1_loss(eps, eps_pred)
scaler.scale(loss).backward()

In [19]:
import numpy as np

ddpm.train()

n_steps = 16
starts = [torch.cuda.Event(enable_timing=True) for _ in range(n_steps)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(n_steps)]

for i, (x0, _) in zip(range(n_steps), dataloader):
    starts[i].record()

    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    with torch.cuda.amp.autocast():
        eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    ends[i].record()

torch.cuda.synchronize()
t = np.mean([s.elapsed_time(e) for s, e in zip(starts, ends)])

In [22]:
[s.elapsed_time(e) for s, e in zip(starts, ends)]

[479.6488037109375,
 460.52862548828125,
 460.4241943359375,
 461.0775146484375,
 460.2716064453125,
 460.305419921875,
 460.3299865722656,
 460.33099365234375,
 460.2429504394531,
 460.2245178222656,
 460.9413146972656,
 460.6423034667969,
 460.0217590332031,
 460.24090576171875,
 460.5327453613281,
 460.4999694824219]

In [11]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)
start_e.record()

for _, (x0, _) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)
    
    with torch.cuda.amp.autocast():
        eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

end_e.record()
torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

90.327724609375


In [20]:
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

19.105735651364622

In [21]:
tflop_per_sec / BASELINE_TFLOPS_PER_SEC

6.565544897376158

## Optimization 5 - Faster Data Loading FFCV

In [17]:
from ffcv.writer import DatasetWriter
from ffcv.fields import RGBImageField

ds = CIFAR10('./cifar10', train=True, download=True)
writer = DatasetWriter('./cifar10.beton', {'image': RGBImageField(max_resolution=32)})
writer.from_indexed_dataset(ds)

Files already downloaded and verified







  0%|                                                                                                                               | 0/50000 [00:00<?, ?it/s][A[A[A[A[A




100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 124044.86it/s][A[A[A[A[A


In [None]:
import numpy as np
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import SimpleRGBImageDecoder
from ffcv.transforms import ToTensor, ToTorchImage, ToDevice, NormalizeImage, RandomHorizontalFlip, Convert

img_tsfms = [
    SimpleRGBImageDecoder(),
    RandomHorizontalFlip(),
    ToTensor(),
    ToDevice(torch.device('cuda')),
    ToTorchImage(),
    NormalizeImage(
        mean=np.array([127.5, 127.5, 127.5]), std=np.array([127.5, 127.5, 127.5]),  # [0, 255] -> [-1, 1]
        type=np.float32
    )
]
dataloader = Loader(
    './cifar10.beton', batch_size=1, num_workers=4, # os_cache=True,
    order=OrderOption.RANDOM, pipelines={'image': img_tsfms}
)
x0 = next(iter(dataloader))

In [14]:
x0[0].dtype, x0[0].size(), x0[0].min(), x0[0].max()

(torch.float32,
 torch.Size([1, 3, 32, 32]),
 tensor(-0.8824, device='cuda:0'),
 tensor(0.8980, device='cuda:0'))

### Start Here

In [5]:
cfg_m = ModelConfig()
cfg_t = TrainerConfig(bs=3072)

In [None]:
import numpy as np
from ffcv.loader import Loader, OrderOption
from ffcv.fields.decoders import SimpleRGBImageDecoder
from ffcv.transforms import ToTensor, ToTorchImage, ToDevice, NormalizeImage, RandomHorizontalFlip, Convert

img_tsfms = [
    SimpleRGBImageDecoder(),
    RandomHorizontalFlip(),
    ToTensor(),
    ToDevice(torch.device(cfg_t.device)),
    ToTorchImage(),
    NormalizeImage(
        mean=np.array([127.5, 127.5, 127.5]), std=np.array([127.5, 127.5, 127.5]),  # [0, 255] -> [-1, 1]
        type=np.float32
    )
]
dataloader = Loader(
    './cifar10.beton', batch_size=cfg_t.bs, num_workers=cfg_t.num_workers, drop_last=True, os_cache=True,
    order=OrderOption.RANDOM, pipelines={'image': img_tsfms}
)

In [7]:
ddpm = torch.compile(DDPM(**asdict(cfg_m)).to('cuda'))
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=cfg_t.lr)
scaler = torch.cuda.amp.GradScaler()

In [None]:
x0, = next(iter(dataloader))
# x0 = x0.to('cuda')
eps = torch.randn_like(x0)
t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

with torch.cuda.amp.autocast():
    eps_pred = ddpm(x0, eps, t)
    loss = F.smooth_l1_loss(eps, eps_pred)
scaler.scale(loss).backward()

In [12]:
ddpm.train()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)
start_e.record()

for _, (x0, ) in zip(range(100), dataloader):
    optimizer.zero_grad()
    
    # x0 = x0.to('cuda')
    eps = torch.randn_like(x0)
    t = torch.randint(0, cfg_m.nT, [cfg_t.bs], device=x0.device)

    with torch.cuda.amp.autocast():
        eps_pred = ddpm(x0, eps, t)
        loss = F.smooth_l1_loss(eps, eps_pred)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

end_e.record()
torch.cuda.synchronize()

t = start_e.elapsed_time(end_e) / 100.0
print(t)

72.9346142578125


In [13]:
tflops_per_iter = (489.73 * 2 * 3) / 1000
tflop_per_sec = tflops_per_iter / (t / 1000)
tflop_per_sec

40.28786646643916

In [14]:
tflop_per_sec / BASELINE_TFLOPS_PER_SEC

13.844627651697305