In [None]:
!pip install torch torchvision timm albumentations opencv-python tqdm



In [None]:
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = "/content/drive/MyDrive/datasets/tiny-imagenet-200"

!mkdir -p $DATA_DIR
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip -O /content/drive/MyDrive/datasets/tiny-imagenet-200.zip
!unzip -q /content/drive/MyDrive/datasets/tiny-imagenet-200.zip -d /content/drive/MyDrive/datasets/


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
URL transformed to HTTPS due to an HSTS policy
--2025-10-16 04:08:50--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘/content/drive/MyDrive/datasets/tiny-imagenet-200.zip’


2025-10-16 04:09:13 (10.7 MB/s) - ‘/content/drive/MyDrive/datasets/tiny-imagenet-200.zip’ saved [248100043/248100043]



In [None]:
%%writefile train.py
import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from timm.models import create_model
from timm.data.mixup import Mixup
from timm.utils import accuracy, AverageMeter, ModelEmaV2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt

# ✅ Enable cuDNN auto-tuning for faster convolutions
torch.backends.cudnn.benchmark = True


# ===================== DATASET WRAPPER =====================
class AlbumentationsImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = datasets.ImageFolder(root)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        path, label = self.dataset.imgs[idx]
        image = np.array(self.dataset.loader(path).convert('RGB'))
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, label


# ===================== TRAINING FUNCTION =====================
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch, mixup_fn=None, ema=None):
    start_time = time.time()
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()

    # ✅ tqdm progress bar per batch
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", ncols=100, leave=False)
    for images, targets in progress_bar:
        images, targets = images.to(device), targets.to(device)
        if mixup_fn is not None:
            images, targets = mixup_fn(images, targets)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        if ema:
            ema.update(model)

        # Accuracy calc
        acc_targets = torch.argmax(targets, dim=1) if targets.ndim > 1 else targets
        acc1, _ = accuracy(output, acc_targets, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        progress_bar.set_postfix(loss=f"{losses.avg:.3f}", acc=f"{top1.avg:.2f}")

    epoch_time = time.time() - start_time
    print(f"\nTrain Epoch {epoch}: Loss {losses.avg:.4f}, Acc@1 {top1.avg:.2f}, Time {epoch_time/60:.2f} min")
    return losses.avg, top1.avg, epoch_time


# ===================== VALIDATION FUNCTION =====================
@torch.no_grad()
def validate(model, dataloader, criterion, device, ema=None):
    model.eval()
    if ema:
        model = ema.module
    losses = AverageMeter()
    top1 = AverageMeter()

    for images, targets in tqdm(dataloader, desc="Validating", ncols=100, leave=False):
        images, targets = images.to(device), targets.to(device)
        output = model(images)
        loss = criterion(output, targets)
        acc1, _ = accuracy(output, targets, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))

    print(f"Validation: Loss {losses.avg:.4f}, Acc@1 {top1.avg:.2f}")
    return losses.avg, top1.avg


# ===================== MAIN =====================
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='./tiny-imagenet-200')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=0.05)
    parser.add_argument('--num-classes', type=int, default=200)
    parser.add_argument('--mixup', type=float, default=0.3)
    parser.add_argument('--cutmix', type=float, default=0.3)
    parser.add_argument('--ema-decay', type=float, default=0.9998)
    parser.add_argument('--label-smoothing', type=float, default=0.1)
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--save', type=str, default='./checkpoints')
    args, _ = parser.parse_known_args()

    os.makedirs(args.save, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    MEAN = (0.485, 0.456, 0.406)
    COARSE_DROPOUT_FILL = tuple([int(x * 255) for x in MEAN])

    train_tfms = A.Compose([
        A.RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.CoarseDropout(max_holes=1, max_height=8, max_width=8, fill_value=COARSE_DROPOUT_FILL, p=0.3),
        A.Normalize(mean=MEAN, std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    val_tfms = A.Compose([
        A.Resize(height=256, width=256),
        A.CenterCrop(height=224, width=224),
        A.Normalize(mean=MEAN, std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    train_dataset = AlbumentationsImageDataset(os.path.join(args.data, 'train'), transform=train_tfms)
    val_dataset = AlbumentationsImageDataset(os.path.join(args.data, 'val'), transform=val_tfms)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True, persistent_workers=True)

    model = create_model('resnet50', pretrained=False, num_classes=args.num_classes).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
    mixup_fn = Mixup(mixup_alpha=args.mixup, cutmix_alpha=args.cutmix,
                     label_smoothing=args.label_smoothing, num_classes=args.num_classes)
    ema = ModelEmaV2(model, decay=args.ema_decay, device=device)

    best_acc = 0
    train_losses, val_losses, train_accs, val_accs, epoch_times = [], [], [], [], []

    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc, epoch_time = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, mixup_fn, ema)
        val_loss, val_acc = validate(model, val_loader, criterion, device, ema)
        scheduler.step()

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        epoch_times.append(epoch_time)

        # ✅ Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save({
                'model': model.state_dict(),
                'ema': ema.module.state_dict(),
                'epoch': epoch,
                'val_acc': val_acc
            }, os.path.join(args.save, 'best.pth'))

        # ✅ Save checkpoint every 10 epochs
        if epoch % 10 == 0 or epoch == args.epochs:
            ckpt_path = os.path.join(args.save, f'checkpoint_epoch{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'ema': ema.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'val_acc': val_acc,
            }, ckpt_path)
            print(f"💾 Saved checkpoint: {ckpt_path}")

            # Plot progress
            plt.figure(figsize=(8, 5))
            plt.plot(range(1, len(train_accs)+1), train_accs, label='Train Acc', marker='o')
            plt.plot(range(1, len(val_accs)+1), val_accs, label='Val Acc', marker='o')
            plt.title('Accuracy vs Epochs')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy (%)')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(args.save, f'plot_epoch{epoch}.png'))
            plt.close()

        print(f"✅ Epoch {epoch} done. Best Acc so far: {best_acc:.2f}%")

    total_time = sum(epoch_times)
    print(f"🏁 Training complete! Best Top-1: {best_acc:.2f}%, Total Time: {total_time/60:.2f} min")


if __name__ == "__main__":
    main()


Overwriting train.py


## Sanity Check

In [None]:
# !python train.py --data ./tiny-imagenet-200 --epochs 2 --batch_size 32

## Imagenet-200 testing

In [None]:
!python train.py \
  --data /content/drive/MyDrive/datasets/tiny-imagenet-200 \
  --epochs 50 \
  --batch_size 128 \
  --lr 1e-3 \
  --workers 2 \
  --save ./runs/tiny_albu


  A.CoarseDropout(max_holes=1, max_height=8, max_width=8, fill_value=COARSE_DROPOUT_FILL, p=0.3),
  with torch.cuda.amp.autocast():
                                                                                                    
Train Epoch 1: Loss 5.1943, Acc@1 2.71, Time 9.35 min
Validation: Loss 5.3322, Acc@1 0.00
✅ Epoch 1 done. Best Acc so far: 0.00%
                                                                                                    
Train Epoch 2: Loss 5.1012, Acc@1 4.12, Time 8.88 min
Validation: Loss 5.3587, Acc@1 0.00
✅ Epoch 2 done. Best Acc so far: 0.00%
                                                                                                    
Train Epoch 3: Loss 5.0633, Acc@1 4.67, Time 8.91 min
Validation: Loss 5.3908, Acc@1 0.00
✅ Epoch 3 done. Best Acc so far: 0.00%
                                                                                                    
Train Epoch 4: Loss 5.0415, Acc@1 5.11, Time 8.85 min
Validation: Loss 5.4273