# U-Net Throughput Benchmark

This notebook measures the training throughput (images/second) of the real and complex U-Nets as a function of batch size and model width. Synthetic inputs are used so the results isolate compute efficiency on the current hardware.

In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from pathlib import Path
import time
import itertools

import torch
import pandas as pd

PROJECT_ROOT = Path.cwd().resolve()
if PROJECT_ROOT.name == 'notebooks':
    PROJECT_ROOT = PROJECT_ROOT.parent
SRC_ROOT = PROJECT_ROOT / 'src'
import sys
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from models.real_unet import RealUnet
from models.cx_unet import ComplexUnet

torch.manual_seed(0)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [2]:
BATCH_SIZES = [1, 2, 4]
FEATURE_CONFIGS = {
    'tiny': [8, 16, 32, 64, 128],
    'base': [16, 32, 64, 128, 256],
    #'large': [32, 64, 128, 256, 512],
}
BENCH_STEPS = 10  # iterations measured
WARMUP_STEPS = 5
INPUT_SHAPE = (1, 320, 256)  # (channels, H, W)
print(f"Device: {DEVICE}")
print(f"Benchmark steps: {BENCH_STEPS}, warmup: {WARMUP_STEPS}")

Device: cuda
Benchmark steps: 10, warmup: 5


In [3]:
def make_model(kind: str, features: list[int]):
    if kind == 'real':
        return RealUnet(in_channels=1, out_channels=1, features=features, width_scale=1.0).to(DEVICE)
    if kind == 'complex':
        return ComplexUnet(in_channels=1, out_channels=1, features=features).to(DEVICE)
    raise ValueError(f"Unknown model kind: {kind}")


def make_batch(batch_size: int):
    shape = (batch_size, *INPUT_SHAPE)
    real = torch.randn(shape, dtype=torch.float32, device=DEVICE)
    imag = torch.randn(shape, dtype=torch.float32, device=DEVICE)
    batch = torch.complex(real, imag)
    return batch


def benchmark_step(model, optimizer, batch_size: int):
    inputs = make_batch(batch_size)
    targets = make_batch(batch_size)
    # simple L1 loss to match training behavior
    criterion = lambda pred, tgt: torch.mean(torch.abs(pred - tgt))

    # warmup
    for _ in range(WARMUP_STEPS):
        optimizer.zero_grad(set_to_none=True)
        pred = model(inputs)
        loss = criterion(pred, targets)
        loss.backward()
        optimizer.step()

    torch.cuda.synchronize() if DEVICE.type == 'cuda' else None
    start = time.perf_counter()
    for _ in range(BENCH_STEPS):
        optimizer.zero_grad(set_to_none=True)
        pred = model(inputs)
        loss = criterion(pred, targets)
        loss.backward()
        optimizer.step()
    torch.cuda.synchronize() if DEVICE.type == 'cuda' else None
    elapsed = time.perf_counter() - start
    images = batch_size * BENCH_STEPS
    return images / elapsed

In [4]:
results = []
for model_kind, config_name in itertools.product(['real', 'complex'], FEATURE_CONFIGS.keys()):
    features = FEATURE_CONFIGS[config_name]
    print(f"Benchmarking {model_kind} | config {config_name}")
    model = make_model(model_kind, features)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    for batch_size in BATCH_SIZES:
        try:
            imgs_per_sec = benchmark_step(model, optimizer, batch_size)
        except RuntimeError as err:
            imgs_per_sec = float('nan')
            print(f"  batch {batch_size} OOM/error: {err}")
        results.append({
            'model': model_kind,
            'config': config_name,
            'features': features,
            'batch_size': batch_size,
            'images_per_sec': imgs_per_sec,
        })
    del model, optimizer
    torch.cuda.empty_cache() if DEVICE.type == 'cuda' else None

df = pd.DataFrame(results)
df

Benchmarking real | config tiny
Benchmarking real | config base
Benchmarking complex | config tiny
Benchmarking complex | config base


Unnamed: 0,model,config,features,batch_size,images_per_sec
0,real,tiny,"[8, 16, 32, 64, 128]",1,179.051148
1,real,tiny,"[8, 16, 32, 64, 128]",2,268.967635
2,real,tiny,"[8, 16, 32, 64, 128]",4,306.878584
3,real,base,"[16, 32, 64, 128, 256]",1,186.811581
4,real,base,"[16, 32, 64, 128, 256]",2,205.689029
5,real,base,"[16, 32, 64, 128, 256]",4,221.419786
6,complex,tiny,"[8, 16, 32, 64, 128]",1,17.029726
7,complex,tiny,"[8, 16, 32, 64, 128]",2,21.04433
8,complex,tiny,"[8, 16, 32, 64, 128]",4,22.588525
9,complex,base,"[16, 32, 64, 128, 256]",1,13.31492


In [5]:
df.pivot_table(index=['model', 'config'], columns='batch_size', values='images_per_sec')

Unnamed: 0_level_0,batch_size,1,2,4
model,config,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
complex,tiny,1.008742,1.040376,1.076109
real,tiny,34.318817,64.119999,89.939659
