In [None]:
pip install wandb

## 01. Configuration For Hyperparameters

In [1]:
config = {
    # ========================
    # Dataset
    # ========================
    'dataset': 'CIFAR100',
    'num_classes': 100,
    'batch_size': 256,
    'val_ratio': 0.1,
    'num_workers': 4,
    'pin_memory': True,

    # ========================
    # Transfer Learning
    # ========================
    # use_pretrained=True이면 아래 설정들이 자동으로 결정됨:
    #   - mean/std: ImageNet 자동 적용
    #   - input_size: 모델에서 자동 결정
    #   - strong_aug: True (자동)
    #   - augmentation: bicubic resize, auto_augment 포함
    'use_pretrained': True,  # True면 Transfer Learning 모드
    'pretrained_model_name': 'vit_base_patch16_224',  # timm 모델 이름
    
    # Supported pretrained models (timm):
    #  - 'resnet50', 'resnet101': Classical ResNet
    #  - 'convnext_small', 'convnext_base': Modern ConvNet
    #  - 'vit_base_patch16_224': Vision Transformer
    #  - 'deit_base_patch16_224': Data-efficient ViT
    #  - 'vit_large_patch14_224_in21k': Larger ViT

    # ========================
    # Augmentation (use_pretrained=False일 때만 적용)
    # ========================
    'strong_aug': False,
    'flip': True,
    'erase': False,
    'color_jitter': False,

    # ========================
    # Model Architecture
    # ========================
    'model': 'convnext_s',
    'num_classes': 100,
    'dropout': 0.0,

    # ========================
    # Training
    # ========================
    'epochs': 1,
    'lr': 0.001,
    'lr_scheduler': 'cosine',  # 'step', 'cosine', 'onecycle', 'exponential'
    'lr_decay_factor': 0.9,    # Multiplicative decay when lr_scheduler='exponential'
    'weight_decay': 1e-4,

    # ========================
    # Fine-tuning Strategy
    # ========================
    # Transfer Learning에서 중요한 설정
    'freeze_backbone': True,       # 처음에 backbone freeze할지
    'freeze_epochs': 10,           # N epoch 동안 backbone freeze 유지
    'unfreeze_strategy': 'gradual', # 'gradual': epoch마다 unfreeze, 'all_at_once': freeze_epochs 후 한번에
    
    # Gradual unfreezing 관련
    'unfreeze_every_n_epochs': 5,  # 5 epoch마다 한 층씩 unfreeze
    'layer_lr_decay': 0.1,         # 뒤로 갈수록 learning rate 0.1배씩 감소

    # ========================
    # Checkpointing
    # ========================
    'checkpoint_dir': 'checkpoints',
    'resume_from_checkpoint': False,
    'checkpoint_interval': 10,

    # ========================
    # Logging
    # ========================
    'project': 'transfer_learning',
    'run_name': None,  # Auto-generated if None
    'log_interval': 50,  # Log every N batches
}

## 02. Load and Transform Dataset

In [2]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as T
from typing import Optional
import timm
from timm.data import create_transform

DEFAULT_STATS = {
    "CIFAR10":  ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    "CIFAR100": ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    "FashionMNIST": ((0.2860,), (0.3530,)),
}

def get_class_names(dataset_name, root="./data", train=True):
    """Return list of class label names for a torchvision dataset"""
    ds_class = getattr(torchvision.datasets, dataset_name)
    dataset = ds_class(root=root, train=train, download=True)
    if hasattr(dataset, "classes"):
        return dataset.classes
    elif hasattr(dataset, "labels"):
        return dataset.labels
    else:
        raise AttributeError(f"{dataset_name} has no attribute 'classes' or 'labels'")

def compute_mean_std(dataset):
    """ToTensor()만 적용된 dataset에서 채널별 mean/std 계산"""
    loader = DataLoader(dataset, batch_size=1024, shuffle=False)
    n = 0
    s1, s2 = 0.0, 0.0
    for x, _ in loader:
        b, c, h, w = x.shape
        x = x.view(b, c, -1)
        n += b * h * w
        s1 += x.sum(dim=(0, 2))
        s2 += (x ** 2).sum(dim=(0, 2))
    mean = s1 / n
    std = torch.sqrt(s2 / n - mean ** 2)
    return tuple(mean.tolist()), tuple(std.tolist())

def get_transforms(image_size, mean, std, *, 
                  strong_aug=True, flip=True,
                  erase=False, pretrained_input_size=None, use_pretrained=False):
    """
    Transform 생성.
    
    Args:
        use_pretrained: True면 timm의 최적화된 transform 사용 (ImageNet 기준)
        pretrained_input_size: Pretrained 모델의 input size
    """
    
    # Pretrained 모델용 최적화된 transform (timm)
    if use_pretrained and pretrained_input_size is not None:
        transform_train = create_transform(
            input_size=pretrained_input_size,
            is_training=True,
            mean=mean,
            std=std,
            auto_augment='rand-m9-mstd0.5' if strong_aug else None,
            interpolation='bicubic',
            hflip=flip,
            re_prob=0.25 if erase else 0.0,
        )
        transform_test = create_transform(
            input_size=pretrained_input_size,
            is_training=False,
            mean=mean,
            std=std,
            interpolation='bicubic',
        )
        return transform_train, transform_test
    
    # 일반 학습용 transform
    train_trans = []
    test_trans = []
    
    if pretrained_input_size is not None:
        train_trans.append(T.Resize(pretrained_input_size, interpolation=T.InterpolationMode.BICUBIC))
        test_trans.append(T.Resize(pretrained_input_size, interpolation=T.InterpolationMode.BICUBIC))
    else:
        train_trans.append(T.RandomCrop(image_size, padding=4))
    
    if flip:
        train_trans.append(T.RandomHorizontalFlip())
    if strong_aug:
        train_trans.insert(0, T.TrivialAugmentWide())
    
    train_trans += [T.ToTensor(), T.Normalize(mean, std)]
    test_trans += [T.ToTensor(), T.Normalize(mean, std)]
    
    if erase:
        train_trans.append(T.RandomErasing(p=0.5))
    
    return T.Compose(train_trans), T.Compose(test_trans)

def get_dataloaders(
    dataset_name: str = "CIFAR10",
    batch_size: int = 128,
    data_root: str = "./data",
    val_ratio: float = 0.1,
    num_workers: int = 4,
    pin_memory: bool = True,
    use_pretrained: bool = False,
    pretrained_model_name: Optional[str] = None,
):
    """
    train_loader, val_loader, test_loader 자동 구성.
    
    Transfer Learning 사용 시:
        use_pretrained=True 로 설정하면 모든 설정이 자동으로 적용됨
        - ImageNet 통계 (mean/std)
        - timm 최적화 transform (bicubic resize, augmentation)
        - 적절한 augmentation (강함)
    
    Args:
        use_pretrained: Pretrained 모델 사용 여부
        pretrained_model_name: timm 모델 이름 (예: 'resnet50', 'vit_base_patch16_224')
    """
    
    # ========================
    # Transfer Learning 설정
    # ========================
    if use_pretrained:
        # timm에서 모델 정보 가져오기
        try:
            model_info = timm.get_model_pretrained_cfg(pretrained_model_name)
            pretrained_input_size = model_info.input_size[-1]  # (3, 224, 224) → 224
        except:
            # 모델 정보를 못 가져온 경우 기본값
            pretrained_input_size = 224
        
        # ImageNet 통계 자동 적용
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        
        # Pretrained는 항상 강한 augmentation 적용
        strong_aug = True
        flip = True
        erase = True
        
        print(f"\n{'='*60}")
        print(f"[Transfer Learning Mode]")
        print(f"Model: {pretrained_model_name}")
        print(f"Input size: {pretrained_input_size}x{pretrained_input_size}")
        print(f"Statistics: ImageNet (mean={mean}, std={std})")
        print(f"Augmentation: Strong (auto_augment, flip, erase)")
        print(f"{'='*60}\n")
    
    # ========================
    # 일반 학습 설정
    # ========================
    else:
        pretrained_input_size = None
        strong_aug = True
        flip = True
        erase = False
        
        if dataset_name in DEFAULT_STATS:
            mean, std = DEFAULT_STATS[dataset_name]
        else:
            tmp = getattr(torchvision.datasets, dataset_name)(
                root=data_root, train=True, download=True, transform=T.ToTensor()
            )
            mean, std = compute_mean_std(tmp)
        
        print(f"\n[Standard Training Mode] Dataset: {dataset_name}")
    
    # Transform 생성
    image_size = pretrained_input_size if pretrained_input_size else 32
    transform_train, transform_test = get_transforms(
        image_size, mean, std,
        strong_aug=strong_aug,
        flip=flip,
        erase=erase,
        pretrained_input_size=pretrained_input_size,
        use_pretrained=config['use_pretrained'],
    )
    
    # Dataset 로드
    ds_class = getattr(torchvision.datasets, dataset_name)
    full_train = ds_class(root=data_root, train=True, download=True, transform=transform_train)
    test_set = ds_class(root=data_root, train=False, download=True, transform=transform_test)
    
    # Validation split
    val_size = int(len(full_train) * val_ratio)
    train_size = len(full_train) - val_size
    train_set, val_set = random_split(full_train, [train_size, val_size])
    
    # DataLoader 생성
    train_loader = DataLoader(
        train_set, batch_size=config['batch_size'], shuffle=True,
        num_workers=config['num_workers'], pin_memory=config['pin_memory'], persistent_workers=True if num_workers > 0 else False
    )
    val_loader = DataLoader(
        val_set, batch_size=config['batch_size'], shuffle=False,
        num_workers=config['num_workers'], pin_memory=config['pin_memory'], persistent_workers=True if num_workers > 0 else False
    )
    test_loader = DataLoader(
        test_set, batch_size=config['batch_size'], shuffle=False,
        num_workers=config['num_workers'], pin_memory=config['pin_memory'], persistent_workers=True if num_workers > 0 else False
    )
    
    print(f"Dataset: {dataset_name}")
    print(f"Mean/Std: {mean} / {std}")
    print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")
    print(f"Batch size: {batch_size}, Num workers: {num_workers}, Pin memory: {pin_memory}\n")
    
    return train_loader, val_loader, test_loader



In [3]:
import torch
import torch.nn as nn
import torchvision

# WideResNet (from scratch)
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate=0.0, stride=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.shortcut = (
            nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False)
            if (stride != 1 or in_ch != out_ch) else nn.Identity()
        )

    def forward(self, x):
        res = self.shortcut(x)
        x = self.conv1(self.relu1(self.bn1(x)))
        x = self.dropout(x)
        x = self.conv2(self.relu2(self.bn2(x)))
        return x + res

class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=10):
        super().__init__()
        assert ((depth - 4) % 6 == 0), "depth should be 6n+4"
        N = (depth - 4) // 6
        K = widen_factor

        self.stem = nn.Conv2d(3, 16, 3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_group(16, 16*K, N, dropout_rate, stride=1)
        self.layer2 = self._make_group(16*K, 32*K, N, dropout_rate, stride=2)
        self.layer3 = self._make_group(32*K, 64*K, N, dropout_rate, stride=2)

        self.head = nn.Sequential(
            nn.BatchNorm2d(64*K),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64*K, num_classes)
        )

    def _make_group(self, in_ch, out_ch, N, dropout_rate, stride):
        layers = [Block(in_ch, out_ch, dropout_rate, stride)]
        for _ in range(1, N):
            layers.append(Block(out_ch, out_ch, dropout_rate, 1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return self.head(x)


def _make_resnet18_cifar(num_classes=10, pretrained=False, dropout=0.0):
    if pretrained:
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
    else:
        weights = None
    m = torchvision.models.resnet18(weights=weights)
    m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    m.maxpool = nn.Identity()
    in_feat = m.fc.in_features
    head = [nn.Linear(in_feat, num_classes)]
    if dropout > 0:
        # Create a list of layers for the head
        head = [nn.Dropout(dropout), nn.Linear(in_feat, num_classes)]
    # Use nn.Sequential to combine the layers in the head
    m.fc = nn.Sequential(*head)
    return m


def get_model(model_name, num_classes=10, dropout=0.0, pretrained=True):
    """
    Supported models:
      - 'wrn28_10': good for cifar-10/100 (only when pixels are 32x32 or 64x64)
      - 'convnext_s': small gpu, fast baseline
      - 'vit_b16_aug': good accuracy, moderate gpu
      - 'deit_b16d': good accuracy, moderate gpu
      - 'eva_b14': optional heavy
      - 'resnet18_cifar': classic baseline for cifar-10/100
    """
    name = model_name.lower()
    timm_map = {
        'convnext_s':  'convnext_small.fb_in22k_ft_in1k',
        'deit_b16d':   'deit_base_distilled_patch16_224',
        'vit_b16_aug': 'vit_base_patch16_224.augreg_in21k_ft_in1k',
        'eva_b14':     'eva02_base_patch14_224',
    }

    if name in timm_map:
        # timm은 drop_rate, drop_path_rate 둘 다 지원. 커스텀 데이터셋은 dpr 0.1~0.2 추천.
        return timm.create_model(
            timm_map[name],
            pretrained=pretrained,
            num_classes=num_classes,
            drop_rate=dropout,
            drop_path_rate=0.1
        )
    if name == 'wrn28_10':
        return WideResNet(depth=28, widen_factor=10, dropout_rate=dropout, num_classes=num_classes)

    if name == 'resnet18_cifar':
        return _make_resnet18_cifar(num_classes=num_classes, pretrained=pretrained, dropout=dropout)

    raise ValueError(f"Unknown model: {model_name}. Choose from "
                     f"['wrn28_10','convnext_s','deit_b16d','vit_b16_aug','eva_b14','resnet18_cifar'].")

### 3-1. Helper Functions for Transfer Learning

In [4]:
import torch.nn as nn
from torchvision import transforms

# 1) output fc layer label 맞추기
def replace_classifier(model, num_classes: int) -> None:
    """
    모델의 최종 분류 헤드를 num_classes에 맞게 교체.
    torchvision(fc), timm(classifier/head) 모두 대응.
    """
    # torchvision: resnet 등
    if hasattr(model, "fc") and isinstance(model.fc, nn.Linear):
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, num_classes)
        return
    # timm: efficientnet 등
    if hasattr(model, "classifier") and isinstance(model.classifier, nn.Linear):
        in_f = model.classifier.in_features
        model.classifier = nn.Linear(in_f, num_classes)
        return
    # timm: convnext 등
    if hasattr(model, "head") and isinstance(model.head, nn.Linear):
        in_f = model.head.in_features
        model.head = nn.Linear(in_f, num_classes)
        return
    raise RuntimeError("No replaceable classifier head found (fc/classifier/head).")

# 2) freeze 정도 (백본만 얼리기 or 해제)
def apply_freeze(model, freeze_backbone: bool, head_names=("fc","classifier","head")) -> None:
    """
    freeze_backbone=True면 헤드만 학습, False면 전부 학습.
    """
    for name, p in model.named_parameters():
        p.requires_grad = (not freeze_backbone) or any(h in name for h in head_names)

# 3) lr (head) 변화: head(학습하는 부분)/backbone(고정하는 부분) 서로 다른 LR용 파라미터 그룹 만들기
def make_param_groups(model, lr_head: float, lr_backbone: float, weight_decay: float = 5e-4, momentum: float = 0.9,
                      head_names=("fc","classifier","head")):
    """
    optimizer에 바로 넣을 수 있는 param_groups 반환.
    SGD/AdamW 등 어떤 옵티마이저에도 그대로 사용 가능.
    """
    head_params, backbone_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (head_params if any(h in n for h in head_names) else backbone_params).append(p)

    pg = []
    if backbone_params:
        pg.append(dict(params=backbone_params, lr=lr_backbone, weight_decay=config['weight_decay'], momentum=momentum))
    if head_params:
        pg.append(dict(params=head_params, lr=lr_head, weight_decay=config['weight_decay'], momentum=momentum))
    return pg

### 03-2. Save and Load Utility Functions

In [5]:
"""
Contains various utility functions for PyTorch model training and saving.
"""
import torch
from pathlib import Path

def save_model(model: torch.nn.Module,
               target_dir: str,
               model_name: str):
  """Saves a PyTorch model to a target directory.

  Args:
    model: A target PyTorch model to save.
    target_dir: A directory for saving the model to.
    model_name: A filename for the saved model. Should include
      either ".pth" or ".pt" as the file extension.

  Example usage:
    save_model(model=model_0,
               target_dir="models",
               model_name="05_going_modular_tingvgg_model.pth")
  """
  # Create target directory
  target_dir_path = Path(target_dir)
  target_dir_path.mkdir(parents=True,
                        exist_ok=True)

  # Create model save path
  assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
  model_save_path = target_dir_path / model_name

  # Save the model state_dict()
  print(f"[INFO] Saving model to: {model_save_path}")
  torch.save(obj=model.state_dict(),
             f=model_save_path)


def save_checkpoint(model: torch.nn.Module,
                    optimizer: torch.optim.Optimizer,
                    epoch: int,
                    target_dir: str,
                    checkpoint_name: str,
                    scheduler=None,
                    best_acc: float = 0.0):
  """Saves a complete training checkpoint including model, optimizer, and epoch.

  Args:
    model: PyTorch model to save.
    optimizer: Optimizer state to save.
    epoch: Current epoch number.
    target_dir: Directory for saving the checkpoint.
    checkpoint_name: Filename for the checkpoint. Should include ".pth" or ".pt".
    scheduler: Optional learning rate scheduler to save.
    best_acc: Best validation accuracy achieved so far.

  Example usage:
    save_checkpoint(model=model, optimizer=optimizer, epoch=20,
                    target_dir="checkpoints", checkpoint_name="checkpoint_epoch_20.pth")
  """
  # Create target directory
  target_dir_path = Path(target_dir)
  target_dir_path.mkdir(parents=True, exist_ok=True)

  # Create checkpoint save path
  assert checkpoint_name.endswith(".pth") or checkpoint_name.endswith(".pt"), \
      "checkpoint_name should end with '.pt' or '.pth'"
  checkpoint_path = target_dir_path / checkpoint_name

  # Create checkpoint dictionary
  checkpoint = {
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'best_acc': best_acc
  }

  if scheduler is not None:
      checkpoint['scheduler_state_dict'] = scheduler.state_dict()

  # Save checkpoint
  print(f"[INFO] Saving checkpoint to: {checkpoint_path}")
  torch.save(checkpoint, checkpoint_path)


def load_checkpoint(model: torch.nn.Module,
                    optimizer: torch.optim.Optimizer,
                    checkpoint_path: str,
                    scheduler=None,
                    device: str = 'cpu'):
  """Loads a training checkpoint to resume training.

  Args:
    model: PyTorch model to load weights into.
    optimizer: Optimizer to load state into.
    checkpoint_path: Path to the checkpoint file.
    scheduler: Optional learning rate scheduler to load state into.
    device: Device to load the model to.

  Returns:
    epoch: The epoch number from the checkpoint.
    best_acc: The best validation accuracy from the checkpoint.

  Example usage:
    epoch, best_acc = load_checkpoint(model=model, optimizer=optimizer,
                                       checkpoint_path="checkpoints/checkpoint_epoch_20.pth")
  """
  checkpoint_path = Path(checkpoint_path)

  if not checkpoint_path.exists():
      print(f"[WARNING] Checkpoint not found at: {checkpoint_path}")
      return 0, 0.0

  print(f"[INFO] Loading checkpoint from: {checkpoint_path}")
  checkpoint = torch.load(checkpoint_path, map_location=device)

  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  if scheduler is not None and 'scheduler_state_dict' in checkpoint:
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

  epoch = checkpoint.get('epoch', 0)
  best_acc = checkpoint.get('best_acc', 0.0)

  print(f"[INFO] Resumed from epoch {epoch} with best accuracy: {best_acc:.2f}%")
  return epoch, best_acc

## 04. Train a Model

In [6]:
"""
Training engine for model training and validation.
This defines the training and validation loops, logging, and checkpointing.
"""
import torch
import torch.nn as nn
from tqdm.auto import tqdm
import wandb
from pathlib import Path

def run_epoch(model, loader, criterion, optimizer, device, is_train=True, scheduler=None, step_per_batch=False):
    """
    Single epoch for train or validation.

    Args:
        model: PyTorch model to train or validate
        loader: DataLoader for training or validation set
        criterion: Loss function
        optimizer: Optimizer
        device: torch.device for training or validation
        is_train: Whether to train the model
        scheduler: Learning rate scheduler
        step_per_batch: Whether to step the scheduler per batch

    Returns:
        running_loss: Average loss for the epoch
        correct: Number of correct predictions
        total: Total number of samples
        accuracy: Accuracy of the model
    """
    model.train() if is_train else model.eval()
    running_loss, correct, total = 0.0, 0, 0

    desc = "Train" if is_train else "Val"
    with torch.set_grad_enabled(is_train):
        for inputs, targets in tqdm(loader, desc=desc):
            inputs, targets = inputs.to(device), targets.to(device)

            if is_train:
                optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            if is_train:
                loss.backward()
                optimizer.step()
                if scheduler and step_per_batch:
                    scheduler.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return running_loss / len(loader), 100. * correct / total

def train_engine(model, train_loader, val_loader, device, epochs=100, lr=0.05,
                 weight_decay=5e-4, lr_scheduler='onecycle',
                 freeze_backbone=False, unfreeze_epoch=0,
                 project_name="hackathon", run_name=None,
                 checkpoint_dir="checkpoints", resume_from_checkpoint=True,
                 checkpoint_interval=20):
    """
    Main training loop with configurable scheduler.

    Args:
        model: PyTorch model to train
        train_loader: DataLoader for training set (e.g., 45k images from CIFAR train split)
        val_loader: DataLoader for validation set (e.g., 5k images from CIFAR train split)
                    Used for monitoring, model selection, and hyperparameter tuning
        device: torch.device for training
        epochs: Number of training epochs
        lr: Learning rate
        weight_decay: Weight decay for regularization
        lr_scheduler: Learning rate scheduler type ('onecycle', 'cosine', 'step', None)
        freeze_backbone: Whether to freeze backbone initially
        unfreeze_epoch: Epoch to unfreeze backbone (if freeze_backbone=True)
        project_name: WandB project name
        run_name: WandB run name
        checkpoint_dir: Directory to save/load checkpoints
        resume_from_checkpoint: Whether to resume from checkpoint if available
        checkpoint_interval: Save checkpoint every N epochs (default: 20)

    Returns:
        best_acc: Best validation accuracy achieved during training

    Note:
        - This function uses val_loader for validation during training
        - The best model is saved based on validation performance
        - DO NOT pass test_loader here - keep test set for final evaluation only!
    """

    wandb.init(project=project_name, name=run_name, config=locals())

    if freeze_backbone:
        for name, param in model.named_parameters():
            if 'fc' not in name:  # Assuming final layer is named 'fc'
                param.requires_grad = False
        print("✓ Backbone frozen for initial training.")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9, weight_decay=weight_decay)

    # Scheduler selection
    if lr_scheduler == 'onecycle':
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=config['lr'], epochs=config['epochs'], steps_per_epoch=len(train_loader)
        )
        step_per_batch = True
    elif lr_scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        step_per_batch = False
    elif lr_scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        step_per_batch = False
    else:
        scheduler = None
        step_per_batch = False

    # Initialize checkpoint directory
    checkpoint_path = Path(checkpoint_dir)
    checkpoint_path.mkdir(parents=True, exist_ok=True)

    # Try to resume from checkpoint
    start_epoch = 1
    best_acc = 0.0
    latest_checkpoint = checkpoint_path / "latest_checkpoint.pth"

    if resume_from_checkpoint and latest_checkpoint.exists():
        start_epoch, best_acc = load_checkpoint(
            model=model,
            optimizer=optimizer,
            checkpoint_path=str(latest_checkpoint),
            scheduler=scheduler,
            device=device
        )
        start_epoch += 1  # Start from next epoch
        print(f"[INFO] Resuming training from epoch {start_epoch}")

    for epoch in range(start_epoch, epochs + 1):
        print(f"\nEpoch {epoch}/{epochs}")
        # Unfreeze backbone if specified
        if unfreeze_epoch > 0 and epoch == unfreeze_epoch:
            for param in model.parameters():
                param.requires_grad = True
            optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9, weight_decay=weight_decay)
            print(f"Backbone unfrozen at epoch {epoch} with reduced Learning Rate.")
            wandb.log({"lr": optimizer.param_groups[0]['lr']})
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer, device, True, scheduler, step_per_batch)
        val_loss, val_acc = run_epoch(model, val_loader, criterion, None, device, False)

        wandb.log({
            'train/loss': train_loss, 'train/acc': train_acc,
            'val/loss': val_loss, 'val/acc': val_acc,
            'lr': optimizer.param_groups[0]['lr']
        })

        # Step scheduler per epoch if not per batch
        if scheduler and not step_per_batch:
            scheduler.step()

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Best model saved: {val_acc:.2f}%")

        # Save checkpoint every checkpoint_interval epochs
        if epoch % checkpoint_interval == 0:
            save_checkpoint(
                model=model,
                optimizer=optimizer,
                epoch=epoch,
                target_dir=checkpoint_dir,
                checkpoint_name=f"checkpoint_epoch_{epoch}.pth",
                scheduler=scheduler,
                best_acc=best_acc
            )
            # Also save as latest checkpoint for easy resuming
            save_checkpoint(
                model=model,
                optimizer=optimizer,
                epoch=epoch,
                target_dir=checkpoint_dir,
                checkpoint_name="latest_checkpoint.pth",
                scheduler=scheduler,
                best_acc=best_acc
            )

    wandb.finish()
    return best_acc

In [11]:
def train_full(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Get data
    train_loader, val_loader, test_loader = get_dataloaders(
        dataset_name=config['dataset'],
        batch_size=config['batch_size'],
        val_ratio=config['val_ratio'],
        num_workers=config['num_workers'],
        pin_memory=config['pin_memory'],
        use_pretrained=config['use_pretrained'],
        pretrained_model_name=config['pretrained_model_name']
    )

    # Get model
    model = get_model(config['model'], config['num_classes'], config['dropout'], config['use_pretrained'])
        # 🔍 DEBUG: Check model output shape and labels
    print("\n" + "="*60)
    print("DEBUG INFO")
    print("="*60)
    sample_batch = next(iter(train_loader))
    sample_inputs, sample_labels = sample_batch
    print(f"Sample input shape: {sample_inputs.shape}")
    print(f"Sample labels shape: {sample_labels.shape}")
    print(f"Label range: [{sample_labels.min()}, {sample_labels.max()}]")
    print(f"Config num_classes: {config['num_classes']}")

    # Test forward pass on CPU first
    model = get_model(config['model'], config['num_classes'], config['dropout'], config['use_pretrained'])
    test_output = model_cpu(sample_inputs[:2])  # Just 2 samples
    print(f"Model output shape: {test_output.shape}")
    print(f"Expected: (2, {config['num_classes']})")
    print("="*60 + "\n")
    model = model.to(device)


    # Train
    best_acc = train_engine(
        model, train_loader, val_loader, device,
        epochs=config['epochs'],
        lr=config['lr'],
        lr_scheduler=config['lr_scheduler'],
        weight_decay=config['weight_decay'],
        project_name=config['project'],
        run_name=config['run_name'],

        freeze_backbone=config['freeze_backbone'],
        unfreeze_epoch=config['freeze_epochs'],

        checkpoint_dir=config['checkpoint_dir'],
        resume_from_checkpoint=config['resume_from_checkpoint'],
        checkpoint_interval=config['checkpoint_interval']
    )

    print(f"Training complete! Best accuracy: {best_acc:.2f}%")
    return model

In [None]:
model = train_full(config)

Device: cuda

[Transfer Learning Mode]
Model: vit_base_patch16_224
Input size: 224x224
Statistics: ImageNet (mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
Augmentation: Strong (auto_augment, flip, erase)

Dataset: CIFAR100
Mean/Std: (0.485, 0.456, 0.406) / (0.229, 0.224, 0.225)
Train: 45000, Val: 5000, Test: 10000
Batch size: 256, Num workers: 4, Pin memory: True


DEBUG INFO
Sample input shape: torch.Size([256, 3, 224, 224])
Sample labels shape: torch.Size([256])
Label range: [0, 99]
Config num_classes: 100


## 05. Evaluate w Test Datset

In [None]:
"""
Model evaluation with essential metrics
"""
import torch
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import numpy as np

def evaluate(model, test_loader, device, class_names=None):
    """
    Evaluate model and return comprehensive metrics.

    Args:
        model: Trained PyTorch model to evaluate
        test_loader: DataLoader for test set (use ONLY for final evaluation!)
        device: torch.device for evaluation
        class_names: List of class names for classification report

    Returns:
        dict: {'accuracy',           # Overall accuracy
               'f1_macro',           # f1 score for macro average
               'f1_weighted',        # f1 score for weighted average
               'confusion_matrix',   # Confusion matrix
               'per_class_acc',      # Per-class accuracy
               'classification_report',
               'predictions',        # Predicted labels
               'targets',            # True labels

    Note:
        - Use this function ONLY for final test set evaluation
        - Load the best model checkpoint before calling this
        - Do NOT use this during training/validation - it should be called once at the end
    """
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.numpy())

    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)

    # Calculate metrics
    accuracy = 100. * (all_preds == all_targets).sum() / len(all_targets)
    f1_macro = f1_score(all_targets, all_preds, average='macro')
    f1_weighted = f1_score(all_targets, all_preds, average='weighted')
    cm = confusion_matrix(all_targets, all_preds)

    # Per-class accuracy
    per_class_acc = cm.diagonal() / cm.sum(axis=1)

    # Classification report
    report = classification_report(
        all_targets, all_preds,
        target_names=class_names,
        digits=3
    )

    print(f"\n{'='*60}")
    print(f"EVALUATION RESULTS")
    print(f"{'='*60}")
    print(f"Overall Accuracy: {accuracy:.2f}%")
    print(f"F1 Score (Macro): {f1_macro:.3f}")
    print(f"F1 Score (Weighted): {f1_weighted:.3f}")
    print(f"\n{report}")
    print(f"{'='*60}\n")

    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'confusion_matrix': cm,
        'per_class_acc': per_class_acc,
        'classification_report': report,
        'predictions': all_preds,
        'targets': all_targets
    }

In [None]:
eval_dict

In [None]:
_, _, test_loader = get_dataloaders(
        dataset_name=config['dataset'],
        batch_size=config['batch_size'],
        val_ratio=config['val_ratio'],
        num_workers=config['num_workers'],
        pin_memory=config['pin_memory'],
        use_pretrained=config['use_pretrained'],
        pretrained_model_name=config['pretrained_model_name']
    )

eval_dict = evaluate(model, test_loader, device='cuda')

## 06. Visualize

In [None]:
"""
Visualization utilities using wandb
"""
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import wandb
import pandas as pd

def _require_run():
    if wandb.run is None:
        raise RuntimeError("No active W&B run. Call wandb.init(...) before logging.")

def log_confusion_matrix_from_matrix(cm, class_names, title="Confusion Matrix"):
    _require_run()
    df = pd.DataFrame(cm, index=class_names, columns=class_names)
    wandb.log({
        "confusion_matrix": wandb.plot.heatmap(
            df,
            x_label="Predicted",
            y_label="True",
            title=title
        )
    })

def log_per_class_metrics(per_class_acc, class_names):
    """
    Log per-class accuracy as wandb bar chart

    Args:
        per_class_acc: array of per-class accuracies
        class_names: list of class names
    """
    _require_run()
    # Create bar chart data
    data = [[class_names[i], per_class_acc[i] * 100] for i in range(len(class_names))]
    table = wandb.Table(data=data, columns=["Class", "Accuracy (%)"])

    wandb.log({
        "per_class_accuracy": wandb.plot.bar(
            table, "Class", "Accuracy (%)",
            title="Per-Class Accuracy"
        )
    })
    print("✓ Per-class accuracy logged to wandb")

def log_sample_predictions(model, test_loader, device, class_names, num_samples=16):
    """
    Log sample predictions to wandb

    Args:
        model: trained model
        test_loader: test dataloader
        device: torch device
        class_names: list of class names
        num_samples: number of samples to log
    """
    model.eval()

    # Get one batch
    images, labels = next(iter(test_loader))
    images, labels = images[:num_samples].to(device), labels[:num_samples]

    with torch.no_grad():
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)

    # Create wandb images with predictions
    _require_run()
    wandb_images = []
    for idx in range(num_samples):
        img = images[idx].cpu()
        true_label = class_names[labels[idx]]
        pred_label = class_names[predicted[idx].cpu()]
        pred_prob = probs[idx][predicted[idx]].cpu().item()

        caption = f"True: {true_label} | Pred: {pred_label} ({pred_prob:.2%})"
        wandb_images.append(wandb.Image(img, caption=caption))

    wandb.log({"predictions": wandb_images})
    print(f"✓ {num_samples} sample predictions logged to wandb")

def log_all_metrics(results, class_names):
    """
    Log all evaluation metrics to wandb

    Args:
        results: dict from evaluate() function
        class_names: list of class names
    """
    _require_run()
    # Log summary metrics
    wandb.log({
        "test/accuracy": results['accuracy'],
        "test/f1_macro": results['f1_macro'],
        "test/f1_weighted": results['f1_weighted'],
    })

    # Log per-class accuracies
    for i, class_name in enumerate(class_names):
        wandb.log({f"test/accuracy_{class_name}": results['per_class_acc'][i] * 100})

    print("✓ All metrics logged to wandb")