# Columnformer training playground

## Installation

If working on colab, you need to install the project. Skip these steps if working from a local installation.

In [1]:
%%bash
INSTALL=false

if [[ $INSTALL == true ]]; then
    git clone https://github.com/clane9/columnformers.git
    cd columnformers

    pip install -U pip
    pip install -r requirements.txt
    pip install -e .
fi

## Setup

In [2]:
import math
import time
from dataclasses import dataclass
from typing import Literal, Optional, Tuple, Union

import torch
import yaml
from fvcore.nn import FlopCountAnalysis
from timm.utils import AverageMeter, random_seed
from torch.utils.data import DataLoader

import columnformers.utils as ut
from columnformers.data import create_dataset, create_loader, list_datasets
from columnformers.inspection.metrics import Accuracy
from columnformers.models import create_model, list_models
from columnformers.tasks import ImageClassification, Task

In [3]:
@dataclass
class Args:
    # Model
    model: str = "vision_columnformer_r_tiny_patch16_128"
    num_heads: Optional[int] = None
    mlp_ratio: Optional[int] = None
    untied: Optional[Union[bool, Tuple[bool, bool, bool]]] = None
    skip_attn: Optional[bool] = None
    attn_bias: Optional[bool] = None
    qk_head_dim: Optional[int] = None
    no_vp: Optional[bool] = None
    init_local_attn: Optional[bool] = None
    global_pool: Literal["avg", "spatial"] = "avg"
    pos_embed: bool = True
    drop_rate: float = 0.0
    proj_drop_rate: float = 0.0
    attn_drop_rate: float = 0.0
    wiring_lambd: float = 0.0
    # Dataset
    dataset: str = "imagenet100"
    crop_min_scale: float = 1.0
    hflip: float = 0.01
    color_jitter: Optional[float] = None
    keep_in_memory: bool = False
    workers: int = 4
    # Optimization
    epochs: int = 100
    batch_size: int = 256
    lr: float = 6e-4
    decay_lr: bool = True
    warmup_fraction: float = 0.1
    weight_decay: float = 0.05
    clip_grad: Optional[float] = 1.0
    # Logistics
    use_cuda: bool = True
    log_interval: int = 10
    debug: bool = False
    seed: int = 42

In [4]:
print("Available models:", "\n" + "\n".join(list_models()), "\n")
print("Available datasets:", "\n" + "\n".join(list_datasets()))

Available models: 
vision_transformer_tiny_patch16_128
vision_columnformer_ff_tiny_patch16_128
vision_columnformer_r_tiny_patch16_128 

Available datasets: 
imagenet-100
micro-imagenet-100
debug-100


In [5]:
args = Args(
    model="vision_transformer_tiny_patch16_128",
    dataset="micro-imagenet-100",
    keep_in_memory=True,
    epochs=3,
    lr=1e-3,
    decay_lr=False,
    warmup_fraction=0.0,
)

print(yaml.safe_dump(args.__dict__, sort_keys=False))

model: vision_transformer_tiny_patch16_128
num_heads: null
mlp_ratio: null
untied: null
skip_attn: null
attn_bias: null
qk_head_dim: null
no_vp: null
init_local_attn: null
global_pool: avg
pos_embed: true
drop_rate: 0.0
proj_drop_rate: 0.0
attn_drop_rate: 0.0
wiring_lambd: 0.0
dataset: micro-imagenet-100
crop_min_scale: 1.0
hflip: 0.01
color_jitter: null
keep_in_memory: true
workers: 4
epochs: 3
batch_size: 256
lr: 0.001
decay_lr: false
warmup_fraction: 0.0
weight_decay: 0.05
clip_grad: 1.0
use_cuda: true
log_interval: 10
debug: false
seed: 42



In [6]:
device = torch.device("cuda" if args.use_cuda and torch.cuda.is_available() else "cpu")
print("Running on:", device)

Running on: cuda


In [7]:
random_seed(args.seed)

## Datasets

In [8]:
dataset = create_dataset(
    args.dataset,
    min_scale=args.crop_min_scale,
    hflip=args.hflip,
    color_jitter=args.color_jitter,
    keep_in_memory=args.keep_in_memory,
)
print(dataset)

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 120000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 5000
    })
})


In [9]:
loaders = {}
for split, ds in dataset.items():
    loaders[split] = create_loader(
        ds,
        shuffle=True,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        device=device,
    )

In [10]:
first_batch = next(iter(loaders["train"]))

print({k: (v.shape, v.dtype, v.device) for k, v in first_batch.items()})

{'image': (torch.Size([256, 3, 128, 128]), torch.float32, device(type='cuda', index=0)), 'label': (torch.Size([256]), torch.int64, device(type='cuda', index=0))}


## Model

In [11]:
num_classes = dataset["train"].features["label"].num_classes

In [12]:
model = create_model(
    args.model,
    num_heads=args.num_heads,
    mlp_ratio=args.mlp_ratio,
    untied=args.untied,
    skip_attn=args.skip_attn,
    attn_bias=args.attn_bias,
    qk_head_dim=args.qk_head_dim,
    no_vp=args.no_vp,
    init_local_attn=args.init_local_attn,
    num_classes=num_classes,
    pos_embed=args.pos_embed,
    global_pool=args.global_pool,
    drop_rate=args.drop_rate,
    proj_drop_rate=args.proj_drop_rate,
    attn_drop_rate=args.attn_drop_rate,
)
model = model.to(device)
print(model)

VisionColumnformer(
  pos_embed=True
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (encoder): Columnformer(
    depth=6, recurrent=False, geometry=(64, 64), init_local_attn=False, local_attn_sigma=2.0
    (blocks): ModuleList(
      (0-5): 6 x Block(
        skip_attn=True
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          bias=False
          (q): Linear(in_features=384, out_features=384, bias=True)
          (k): Linear(in_features=384, out_features=384, bias=True)
          (v): Linear(in_features=384, out_features=384, bias=True)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_fe

In [13]:
task = ImageClassification(wiring_lambd=args.wiring_lambd)
task = task.to(device)
print(task)

ImageClassification(wiring_lambd=0.0)


In [14]:
param_count = sum(p.numel() for p in model.parameters())
flop_count = flops = FlopCountAnalysis(
    model, first_batch["image"][:1]
).total()

print(f"Params: {param_count / 1e6:.0f}M, FLOPs: {flop_count / 1e6:.0f}M")

Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::mul encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::mean encountered 1 time(s)


Params: 11M, FLOPs: 719M


## Optimizer

In [15]:
no_decay_keys = ut.get_no_decay_keys(model)
optimizer = ut.create_optimizer(
    model,
    no_decay_keys=no_decay_keys,
    lr=args.lr,
    weight_decay=args.weight_decay,
)
epoch_steps = len(loaders["train"])
lr_schedule = ut.CosineDecaySchedule(
    base_lr=args.lr,
    total_steps=args.epochs * epoch_steps,
    do_decay=args.decay_lr,
    warmup_fraction=args.warmup_fraction,
)
print(optimizer)
print("No decay keys:", no_decay_keys)

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.05

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0
)
No decay keys: ['pos_embed', 'patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.q.bias', 'encoder.blocks.0.attn.k.bias', 'encoder.blocks.0.attn.v.bias', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.q.bias', 'encoder.blocks.1.attn.k.bias', 'encoder.blocks.1.attn.v.bias', 'encoder.blocks.1.attn.

## Training

In [16]:
def train_one_epoch(
    *,
    args: Args,
    epoch: int,
    model: torch.nn.Module,
    task: Task,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    lr_schedule: ut.LRSchedule,
    device: torch.device,
):
    model.train()
    task.train()
    optimizer.zero_grad()
    
    is_cuda = device.type == "cuda"
    if is_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    
    accuracy = Accuracy()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()
    acc_m = AverageMeter()

    epoch_batches = len(train_loader)
    first_step = epoch * epoch_batches

    end = time.monotonic()
    for batch_idx, batch in enumerate(train_loader):
        step = first_step + batch_idx
        is_last_batch = batch_idx + 1 == epoch_batches
        batch_size = len(batch["image"])
        data_time = time.monotonic() - end

        # forward pass
        loss, state = task.forward(model, batch)
        loss_item = loss.item()

        if math.isnan(loss_item) or math.isinf(loss_item):
            raise RuntimeError("NaN/Inf loss encountered on step %d; exiting", step)

        # update lr
        lr = lr_schedule(step)
        ut.update_lr_(optimizer, lr)

        # backward and optimization step
        total_norm = ut.backward_step(loss, optimizer, max_grad_norm=args.clip_grad)

        # end of iteration timing
        if is_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end
        
        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)
        
        acc_item = accuracy(state)
        acc_m.update(acc_item, batch_size)

        if step % args.log_interval == 0 or is_last_batch or args.debug:
            tput = args.batch_size / step_time_m.avg
            if is_cuda:
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Train: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}][{step:>6d}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Acc: {acc_m.val:#.3g} ({acc_m.avg:#.3g})"
                f"  LR: {lr:.3e}"
                f"  Grad: {total_norm:.3e}"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        # Restart timer for next iteration
        end = time.monotonic()

        if args.debug:
            break

In [17]:
@torch.no_grad()
def validate(
    *,
    args: Args,
    epoch: int,
    model: torch.nn.Module,
    task: Task,
    val_loader: DataLoader,
    device: torch.device,
) -> float:
    model.eval()
    task.eval()

    is_cuda = device.type == "cuda"
    if is_cuda:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    
    accuracy = Accuracy()

    loss_m = AverageMeter()
    data_time_m = AverageMeter()
    step_time_m = AverageMeter()
    acc_m = AverageMeter()

    epoch_batches = len(val_loader)
    end = time.monotonic()
    for batch_idx, batch in enumerate(val_loader):
        batch_size = len(batch["image"])
        data_time = time.monotonic() - end

        loss, state = task.forward(model, batch)
        loss_item = loss.item()

        # end of iteration timing
        if is_cuda:
            torch.cuda.synchronize()
        step_time = time.monotonic() - end

        loss_m.update(loss_item, batch_size)
        data_time_m.update(data_time, batch_size)
        step_time_m.update(step_time, batch_size)
        
        acc_item = accuracy(state)
        acc_m.update(acc_item, batch_size)

        if (
            batch_idx % args.log_interval == 0
            or batch_idx + 1 == epoch_batches
            or args.debug
        ):
            tput = args.batch_size / step_time_m.avg
            if is_cuda: 
                alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
                res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
            else:
                alloc_mem_gb = res_mem_gb = 0.0

            print(
                f"Val: {epoch:>3d} [{batch_idx:>3d}/{epoch_batches}]"
                f"  Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
                f"  Acc: {acc_m.val:#.3g} ({acc_m.avg:#.3g})"
                f"  Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
                f"  Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
            )

        if args.debug:
            break

        # Reset timer
        end = time.monotonic()

    return loss_m.avg

In [18]:
start_time = time.monotonic()

for epoch in range(args.epochs):
    print(f"Starting epoch {epoch:d}")

    train_one_epoch(
        args=args,
        epoch=epoch,
        model=model,
        task=task,
        train_loader=loaders["train"],
        optimizer=optimizer,
        lr_schedule=lr_schedule,
        device=device,
    )

    metric = validate(
        args=args,
        epoch=epoch,
        model=model,
        task=task,
        val_loader=loaders["validation"],
        device=device,
    )

    if args.debug:
        break

print(f"Done! Run time: {time.monotonic() - start_time:.0f}s")
print(f"*** Final metric: {metric:.3f}")

Starting epoch 0
Train:   0 [  0/469][     0]  Loss: 4.65 (4.65)  Acc: 0.00 (0.00)  LR: 1.000e-03  Grad: 1.440e+00  Time: 1.047,1.365 187/s  Mem: 3.26,3.43 GB
Train:   0 [ 10/469][    10]  Loss: 4.48 (4.55)  Acc: 1.95 (2.38)  LR: 1.000e-03  Grad: 9.749e-01  Time: 0.115,0.248 1031/s  Mem: 3.60,3.78 GB
Train:   0 [ 20/469][    20]  Loss: 4.32 (4.48)  Acc: 3.52 (3.09)  LR: 1.000e-03  Grad: 9.567e-01  Time: 0.071,0.196 1307/s  Mem: 3.60,3.78 GB
Train:   0 [ 30/469][    30]  Loss: 4.43 (4.44)  Acc: 3.91 (3.19)  LR: 1.000e-03  Grad: 8.931e-01  Time: 0.057,0.178 1438/s  Mem: 3.60,3.78 GB
Train:   0 [ 40/469][    40]  Loss: 4.30 (4.42)  Acc: 4.69 (3.54)  LR: 1.000e-03  Grad: 9.022e-01  Time: 0.050,0.170 1509/s  Mem: 3.60,3.78 GB
Train:   0 [ 50/469][    50]  Loss: 4.32 (4.40)  Acc: 5.86 (3.84)  LR: 1.000e-03  Grad: 9.025e-01  Time: 0.045,0.164 1560/s  Mem: 3.60,3.78 GB
Train:   0 [ 60/469][    60]  Loss: 4.32 (4.38)  Acc: 4.30 (3.97)  LR: 1.000e-03  Grad: 9.220e-01  Time: 0.042,0.161 1593/s  M