In [None]:
import sys
sys.path.append("../../")

import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import transforms

from utils.train import train_one_epoch, evaluate

In [None]:
def timed_epoch(fn, device):
    start = time.perf_counter()
    fn()
    if device.type == "cuda":
        torch.cuda.synchronize()
    return time.perf_counter() - start

def peak_memory_mb(device):
    if device.type != "cuda":
        return None
    return torch.cuda.max_memory_allocated(device) / 1024**2

## Gradient Accumulation

Gradient accumulation simulates a larger batch size by:
- running multiple forward/backward passes
- not updating weights every step
- accumulating gradients across several mini-batches
- performing one optimizer step after N batches

<p align="center">
  <img src="../../assets/img/efficiency/grad_accum.png" width="400">
</p>

In [None]:
def train_one_epoch_accum(
    model,
    dataloader,
    optimizer,
    loss_fn,
    device,
    accum_steps=4,
):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()

    for step, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # normalize loss
        loss = loss / accum_steps
        loss.backward()

        if (step + 1) % accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        running_loss += loss.item() * images.size(0) * accum_steps

    return running_loss / len(dataloader.dataset)

In [None]:
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std  = (0.2023, 0.1994, 0.2010)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

train_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True, 
    transform=train_transforms
)
val_dataset = CIFAR10(
    root="../../assets/cifar10", 
    train=False, 
    download=True, 
    transform=val_transforms
)

# comparing two batch sizes
train_loader_bs256 = DataLoader(train_dataset, batch_size=256, shuffle=True)
train_loader_bs64 = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


In [None]:
def make_model(device):
    model = nn.Sequential(
        # feature extractor
        nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),

        nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),

        # classifier
        nn.Flatten(),
        nn.Linear(64 * 8 * 8, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999))
    
    return model, optimizer

In [None]:
# for reproducibility
torch.manual_seed(0)
loss_fn = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# baseline model and optimizer
model, optimizer = make_model(device)

# baseline
baseline_time = timed_epoch(
    lambda: train_one_epoch(
        model, train_loader_bs256, optimizer, loss_fn, device
    ),
    device,
)

_, baseline_acc = evaluate(
    model, val_loader, loss_fn, device
)

# reset model and optimizer
model, optimizer = make_model(device)

# gradient accumulation
accum_time = timed_epoch(
    lambda: train_one_epoch_accum(
        model, train_loader_bs64, optimizer, loss_fn, device, accum_steps=4
    ),
    device,
)

_, accum_acc = evaluate(
    model, val_loader, loss_fn, device
)

print(f"Baseline:\n\ttime: {baseline_time:.2f}s  |  acc: {baseline_acc*100:.2f}%")
print(f"Grad Accum:\n\ttime: {accum_time:.2f}s  |  acc: {accum_acc*100:.2f}%")