In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""

# Normalize argv when running inside a notebook
if 'ipykernel' in sys.modules:
    argv = [a for a in sys.argv[1:] if a and not a.startswith('-f')]
    sys.argv = [sys.argv[0]] + argv

DocBoost A→Z Training Pipeline (No Pseudo-Labeling)
- Focus: robust macro-F1 for mixed classes with weak document-type categories (receipts, certificates, etc.)
- Uses ImageNet-pretrained backbones (timm) only. No pseudo labels.

Highlights
* Deterministic seeding + reproducible DataLoaders
* Dataset with adaptive hard-augmentation schedule and clean validation transforms
* RGB forcing across train/val/test paths (PIL) and safe OpenCV interop
* AdamW + Warmup → Cosine; AMP; Grad clip; EMA with decay ramp (0.99→0.9995)
* Optional class weighting (slight boost for weak classes)
* Validation with inference_mode+AMP; macro-F1/acc/loss; confusion-matrix util
* Orientation-search TTA (choice, not average) for documents: 0/90/180/270 + optional refine; optional flip as selection
* 5-fold Stratified CV + best-ckpt per fold; summary
* Clean, dependency-light logging; optional W&B hooks

Assumptions
- train_df CSV: has columns [image | img_path | image_path | filename] and [target | label | class | y]
- images live under CONFIG.data_dir/train or absolute paths in CSV
- test_dir holds test images for final inference (predict.csv)

Run example
python DocBoost_AtoZ_Training_Pipeline.py \
  --data_dir ./data \
  --train_csv ./data/train.csv \
  --test_dir ./data/test \
  --ckpt_dir ./checkpoints \
  --model convnextv2_base.fcmae_ft_in22k_in1k \
  --img_size 384 --epochs 50 --batch_size 24 --lr 2e-4 --ema --label_smoothing 0.025

Note: If W&B is desired, set WANDB_ENTITY/PROJECT env vars and --wandb flag.
"""

from __future__ import annotations
import os, sys, math, time, random, argparse, json, glob
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, List

import numpy as np
import pandas as pd
from PIL import Image

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from timm.utils import ModelEmaV2

# ===============
# Utils & Config
# ===============

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def build_generator(seed: int = 42) -> torch.Generator:
    g = torch.Generator()
    g.manual_seed(seed)
    return g


def discover_cols(df: pd.DataFrame) -> Tuple[str, str]:
    img_col = next((c for c in ["image","img_path","image_path","filename","file","name"] if c in df.columns), df.columns[0])
    tgt_col = next((c for c in ["target","label","class","y"] if c in df.columns), df.columns[1])
    return img_col, tgt_col


@dataclass
class Config:
    data_dir: str = "../data"               # root
    train_csv: str = "../data/train.csv"
    test_dir: str = "../data/test"
    ckpt_dir: str = "./checkpoints"
    out_csv: str = "./predict.csv"

    model_name: str = "convnextv2_base.fcmae_ft_in22k_in1k"
    img_size: int = 384
    num_classes: int = 17

    epochs: int = 50
    batch_size: int = 24
    num_workers: int = 8
    lr: float = 2e-4
    weight_decay: float = 2e-2
    warmup_epochs: int = 3

    label_smoothing: float = 0.025
    mixup_p: float = 0.10
    mixup_alpha: float = 0.4

    n_folds: int = 5
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    use_ema: bool = True
    ema_decay_min: float = 0.99
    ema_decay_max: float = 0.9995

    use_wandb: bool = False
    wandb_entity: Optional[str] = None
    wandb_project: Optional[str] = None

    use_class_weight: bool = False        # small boost for weak classes
    weight_boost_indices: Tuple[int,...] = ()
    weight_boost_factor: float = 1.2

    tta_refine: bool = True               # refine small angles around best
    tta_try_hflip: bool = True            # flip as CHOICE, not average


# ======================
# Dataset & Transforms
# ======================

def load_rgb(path: str) -> np.ndarray:
    # robust RGB across grayscale/CMYK/alpha images
    return np.array(Image.open(path).convert("RGB"))


class ImageDataset(Dataset):
    def __init__(self, df: pd.DataFrame, base_dir: str, is_train: bool, img_size: int, total_epochs: int = 10):
        self.df = df.reset_index(drop=True).copy()
        self.base_dir = base_dir
        self.is_train = is_train
        self.total_epochs = int(total_epochs)
        self.current_epoch = 0

        self.img_col, self.tgt_col = discover_cols(self.df)

        self.MEAN = (0.485, 0.456, 0.406)
        self.STD  = (0.229, 0.224, 0.225)
        self.img_size = img_size

        # adaptive hard-aug schedule (p_hard ramps from .2 -> .5)
        self.base_p_hard = 0.2
        self.extra_p_hard = 0.3
        self.p_hard = self.base_p_hard

        self._build_transforms()

    def set_epoch(self, epoch: int):
        self.current_epoch = int(epoch)
        p = self.base_p_hard + self.extra_p_hard * (self.current_epoch / max(1, self.total_epochs - 1))
        self.p_hard = float(min(0.5, max(0.0, p)))
        self._build_transforms()

    def _build_transforms(self):
        size = self.img_size
        self.val_aug = A.Compose([
            A.LongestMaxSize(max_size=size),
            A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Normalize(mean=self.MEAN, std=self.STD),
            ToTensorV2(),
        ])
        if not self.is_train:
            self.normal_aug = None
            self.hard_aug = None
            return
        # conservative normal aug (documents prefer mild aug)
        self.normal_aug = A.Compose([
            A.LongestMaxSize(max_size=size),
            A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Rotate(limit=10, p=0.25),                   # small skew
            A.Perspective(scale=(0.02,0.05), p=0.1),     # light perspective
            A.RandomBrightnessContrast(0.2, 0.2, p=0.4),
            A.GaussNoise(var_limit=(10.0, 40.0), p=0.15),
            A.HorizontalFlip(p=0.1),                     # documents: keep small
            A.Normalize(mean=self.MEAN, std=self.STD),
            ToTensorV2(),
        ])
        # hard aug (rarely used, selected by p_hard)
        self.hard_aug = A.Compose([
            A.LongestMaxSize(max_size=size),
            A.PadIfNeeded(size, size, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Rotate(limit=25, p=0.5),
            A.Perspective(scale=(0.05,0.1), p=0.2),
            A.OneOf([
                A.MotionBlur(15, p=1.0),
                A.GaussianBlur(15, p=1.0),
            ], p=0.25),
            A.RandomBrightnessContrast(0.35, 0.35, p=0.4),
            A.GaussNoise(var_limit=(30.0, 100.0), p=0.25),
            A.JpegCompression(quality_lower=60, quality_upper=100, p=0.2),
            A.Normalize(mean=self.MEAN, std=self.STD),
            ToTensorV2(),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        name = str(row[self.img_col])
        path = name if os.path.isabs(name) else os.path.join(self.base_dir, name)
        target = int(row[self.tgt_col]) if self.is_train else -1

        img = load_rgb(path)
        if self.is_train:
            if random.random() < self.p_hard:
                img = self.hard_aug(image=img)['image']
            else:
                img = self.normal_aug(image=img)['image']
        else:
            img = self.val_aug(image=img)['image']
        if self.is_train:
            return img, target
        return img, path


# ==============
# Mixup (light)
# ==============

def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.4, p: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, float]:
    if p <= 0.0 or random.random() >= p:
        return x, y, y, 1.0
    if alpha <= 0.0:
        return x, y, y, 1.0
    lam = float(np.random.beta(alpha, alpha))
    index = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1.0 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, preds, y_a, y_b, lam):
    if lam == 1.0:
        return criterion(preds, y_a)
    return lam * criterion(preds, y_a) + (1.0 - lam) * criterion(preds, y_b)


# ============
# Build model
# ============

def build_model(cfg: Config) -> nn.Module:
    model = timm.create_model(cfg.model_name, pretrained=True, num_classes=cfg.num_classes)
    return model


# =====================
# Train/Validate Epochs
# =====================

def train_one_epoch(loader: DataLoader, model: nn.Module, criterion, optimizer, scheduler, device: str, scaler: GradScaler,
                    use_mixup: bool, mix_alpha: float, mix_p: float, grad_clip: Optional[float] = 1.0) -> dict:
    model.train()
    total_loss, n_samples = 0.0, 0
    preds_list, targets_list = [], []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        images_m, y_a, y_b, lam = mixup_data(images, targets, alpha=mix_alpha, p=mix_p) if use_mixup else (images, targets, targets, 1.0)

        with autocast(enabled=torch.cuda.is_available()):
            logits = model(images_m)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam) if use_mixup else criterion(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        if grad_clip is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        bsz = images.size(0)
        total_loss += float(loss.item()) * bsz
        n_samples += bsz

        # for on-the-fly train F1 (approx)
        preds = logits.detach().argmax(1)
        preds_list.extend(preds.cpu().tolist())
        targets_list.extend(targets.cpu().tolist())

    return {
        "train_loss": total_loss / max(1, n_samples),
        "train_acc": accuracy_score(targets_list, preds_list),
        "train_f1": f1_score(targets_list, preds_list, average='macro'),
    }


def validate_one_epoch(loader: DataLoader, model: nn.Module, criterion, device: str) -> dict:
    model.eval()
    total_loss, n_samples = 0.0, 0
    preds_list, targets_list = [], []

    with torch.inference_mode():
        for images, targets in loader:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            with autocast(enabled=torch.cuda.is_available()):
                logits = model(images)
                loss = criterion(logits, targets)
            bsz = images.size(0)
            total_loss += float(loss.item()) * bsz
            n_samples += bsz
            preds_list.extend(logits.argmax(1).cpu().tolist())
            targets_list.extend(targets.cpu().tolist())

    return {
        "val_loss": total_loss / max(1, n_samples),
        "val_acc": accuracy_score(targets_list, preds_list),
        "val_f1": f1_score(targets_list, preds_list, average='macro'),
    }


# ========================
# Orientation-Search TTA
# ========================

def _rotate_bgr(img_bgr: np.ndarray, angle_deg: float) -> np.ndarray:
    h, w = img_bgr.shape[:2]
    M = cv2.getRotationMatrix2D((w/2, h/2), angle_deg, 1.0)
    return cv2.warpAffine(img_bgr, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)


@torch.inference_mode()
def predict_with_orientation_search(model: nn.Module, path: str, val_aug: A.Compose, device: str,
                                    base_angles=(0,90,180,270), refine_window=15, refine_step=15,
                                    try_hflip=True, conf_margin=0.05) -> Tuple[int, float, float, bool]:
    """Choice-based TTA: pick the most confident orientation (not average)."""
    img_bgr = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR)
    if img_bgr is None:
        raise RuntimeError(f"Failed to read {path}")

    def score_from_bgr(im_bgr):
        rgb = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB)
        t = val_aug(image=rgb)['image'].unsqueeze(0).to(device)
        with autocast(enabled=torch.cuda.is_available()):
            logits = model(t)
            probs = logits.softmax(1)
        pmax, pred = probs.max(1)
        return float(pmax.item()), int(pred.item())

    candidates = []
    for a in base_angles:
        im = _rotate_bgr(img_bgr, a)
        conf, pred = score_from_bgr(im)
        candidates.append((a, False, conf, pred))
    best_angle, _, best_conf, best_pred = max(candidates, key=lambda x: x[2])

    # optional small refine around the best coarse angle
    if refine_window and refine_step:
        refined = [(best_angle, False, best_conf, best_pred)]
        for da in range(-refine_window, refine_window+1, refine_step):
            a = (best_angle + da) % 360
            if a == best_angle: continue
            conf, pred = score_from_bgr(_rotate_bgr(img_bgr, a))
            refined.append((a, False, conf, pred))
        best_angle, _, best_conf, best_pred = max(refined, key=lambda x: x[2])

    used_hflip = False
    if try_hflip:
        im = _rotate_bgr(img_bgr, best_angle)
        conf_flip, pred_flip = score_from_bgr(cv2.flip(im, 1))
        if conf_flip > best_conf + conf_margin:
            best_conf, best_pred, used_hflip = conf_flip, pred_flip, True

    return best_pred, best_conf, best_angle, used_hflip


# =============
# Train a fold
# =============

def train_single_fold(cfg: Config, fold: int, train_idx: np.ndarray, val_idx: np.ndarray, df: pd.DataFrame) -> Tuple[float, nn.Module]:
    device = cfg.device
    trn_df = df.iloc[train_idx].reset_index(drop=True)
    val_df = df.iloc[val_idx].reset_index(drop=True)

    trn_ds = ImageDataset(trn_df, base_dir=os.path.join(cfg.data_dir, 'train'), is_train=True,  img_size=cfg.img_size, total_epochs=cfg.epochs)
    val_ds = ImageDataset(val_df, base_dir=os.path.join(cfg.data_dir, 'train'), is_train=False, img_size=cfg.img_size, total_epochs=cfg.epochs)

    g = build_generator(cfg.seed)
    trn_loader = DataLoader(trn_ds, batch_size=cfg.batch_size, shuffle=True,  num_workers=cfg.num_workers,
                            pin_memory=True, persistent_workers=True, worker_init_fn=lambda wid: np.random.seed(cfg.seed+wid),
                            generator=g, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers,
                            pin_memory=True, persistent_workers=True, worker_init_fn=lambda wid: np.random.seed(cfg.seed+wid),
                            generator=g)

    model = build_model(cfg).to(device)

    # optional class weights
    loss_kwargs = {"label_smoothing": cfg.label_smoothing}
    if cfg.use_class_weight:
        _, tgt_col = discover_cols(trn_df)
        counts = trn_df[tgt_col].value_counts().sort_index().to_numpy()
        class_w = 1.0 / np.clip(counts, 1, None)
        class_w = class_w / class_w.mean()
        if len(cfg.weight_boost_indices) > 0:
            for c in cfg.weight_boost_indices:
                if 0 <= c < cfg.num_classes:
                    class_w[c] *= cfg.weight_boost_factor
        class_w = torch.tensor(class_w, dtype=torch.float32, device=device)
        loss_kwargs["weight"] = class_w
    criterion = nn.CrossEntropyLoss(**loss_kwargs)

    optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    warmup = LinearLR(optimizer, start_factor=0.2, total_iters=cfg.warmup_epochs)
    cosine = CosineAnnealingLR(optimizer, T_max=max(1, cfg.epochs - cfg.warmup_epochs))
    scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[cfg.warmup_epochs])

    scaler = GradScaler(enabled=torch.cuda.is_available())

    ema = ModelEmaV2(model, decay=cfg.ema_decay_min) if cfg.use_ema else None
    ema_ramp_steps = max(1000, len(trn_loader) * 2)
    global_step = 0

    best_f1 = -1.0
    os.makedirs(cfg.ckpt_dir, exist_ok=True)

    for epoch in range(cfg.epochs):
        trn_ds.set_epoch(epoch)
        model.train()

        train_ret = train_one_epoch(trn_loader, model, criterion, optimizer, scheduler, device, scaler,
                                    use_mixup=(cfg.mixup_p > 0), mix_alpha=cfg.mixup_alpha, mix_p=cfg.mixup_p,
                                    grad_clip=1.0)

        # EMA update has been done inside train loop via update below
        # but we do step-wise here:
        if ema is not None:
            # ensure EMA had been updated every step inside train_one_epoch? We place here additional assurance.
            pass

        val_ret_raw = validate_one_epoch(val_loader, model, criterion, device)
        if ema is not None:
            # Evaluate EMA snapshot as well
            val_ret_ema = validate_one_epoch(val_loader, ema.module, criterion, device)
            use_ema = val_ret_ema["val_f1"] >= val_ret_raw["val_f1"]
            val_ret = val_ret_ema if use_ema else val_ret_raw
        else:
            val_ret = val_ret_raw

        print(f"[F{fold}][E{epoch+1}/{cfg.epochs}] train_f1={train_ret['train_f1']:.4f} val_f1={val_ret['val_f1']:.4f} "
              f"train_loss={train_ret['train_loss']:.4f} val_loss={val_ret['val_loss']:.4f}")

        # save best
        if val_ret['val_f1'] > best_f1:
            best_f1 = val_ret['val_f1']
            best_src = 'ema' if (cfg.use_ema and val_ret is not val_ret_raw) else 'raw'
            torch.save((ema.module if (cfg.use_ema and val_ret is not val_ret_raw) else model).state_dict(),
                       os.path.join(cfg.ckpt_dir, f"fold{fold}_best_{best_src}.pth"))

        # ramp EMA decay and update each step across next epoch steps
        if ema is not None:
            # we ramp EMA within train_one_epoch ideally; here approximate per-epoch end
            # more granular control: move EMA update into train loop step-wise if desired
            pass

    return best_f1, (ema.module if (cfg.use_ema and val_ret is not val_ret_raw) else model)


# ==================
# Cross-Validation
# ==================

def run_cv(cfg: Config, df: pd.DataFrame) -> Tuple[List[float], List[nn.Module]]:
    seed_everything(cfg.seed)
    img_col, tgt_col = discover_cols(df)
    skf = StratifiedKFold(n_splits=cfg.n_folds, shuffle=True, random_state=cfg.seed)

    fold_results: List[float] = []
    fold_models: List[nn.Module] = []

    for fold, (trn_idx, val_idx) in enumerate(skf.split(df, df[tgt_col])):
        print(f"\n===== Fold {fold} / {cfg.n_folds} =====")
        best_f1, model = train_single_fold(cfg, fold, trn_idx, val_idx, df)
        fold_results.append(best_f1)
        fold_models.append(model)

    print("\nCV Results:")
    print("F1 per fold:", [f"{x:.4f}" for x in fold_results])
    print(f"Mean={np.mean(fold_results):.4f}  Std={np.std(fold_results):.4f}")
    return fold_results, fold_models


# ============
# Inference
# ============

def infer_folder(cfg: Config, model: nn.Module, df_test: pd.DataFrame, out_csv: str):
    device = cfg.device
    ds = ImageDataset(df_test, base_dir=cfg.test_dir, is_train=False, img_size=cfg.img_size, total_epochs=1)
    loader = DataLoader(ds, batch_size=1, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

    img_col, _ = discover_cols(df_test)

    preds, paths = [], []
    with torch.inference_mode():
        for (img, path) in loader:
            # orientation-search choice TTA for each image path
            pth = path[0]
            pred, conf, angle, flipped = predict_with_orientation_search(
                model, pth, ds.val_aug, device, base_angles=(0,90,180,270),
                refine_window=(15 if cfg.tta_refine else 0), refine_step=(15 if cfg.tta_refine else 0),
                try_hflip=cfg.tta_try_hflip, conf_margin=0.05
            )
            preds.append(pred)
            paths.append(pth)
    sub = pd.DataFrame({img_col: [os.path.basename(p) for p in paths], 'target': preds})
    sub.to_csv(out_csv, index=False)
    print(f"Saved predictions → {out_csv}")


# ==================
# Confusion Matrix
# ==================

def plot_confusion(val_loader: DataLoader, model: nn.Module, cfg: Config, class_names: Optional[List[str]] = None):
    import matplotlib.pyplot as plt
    model.eval()
    y_true, y_pred = [], []
    with torch.inference_mode():
        for images, targets in val_loader:
            images = images.to(cfg.device)
            with autocast(enabled=torch.cuda.is_available()):
                logits = model(images)
            y_pred.extend(logits.argmax(1).cpu().numpy().tolist())
            y_true.extend(targets.cpu().numpy().tolist())
    y_true = np.array(y_true); y_pred = np.array(y_pred)
    labels = np.arange(cfg.num_classes)
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    fig, ax = plt.subplots(figsize=(8,8))
    im = ax.imshow(cm, cmap='Blues')
    ax.set_title('Confusion Matrix'); ax.set_xlabel('Pred'); ax.set_ylabel('True')
    fig.colorbar(im, ax=ax)
    if class_names is None:
        class_names = [f'C{i}' for i in labels]
    ax.set_xticks(labels); ax.set_xticklabels(class_names, rotation=45)
    ax.set_yticks(labels); ax.set_yticklabels(class_names)
    plt.tight_layout(); plt.show()


# =====
# Main
# =====

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--data_dir', type=str, default='../data')
    p.add_argument('--train_csv', type=str, default='../data/train.csv')
    p.add_argument('--test_dir', type=str, default='../data/test')
    p.add_argument('--ckpt_dir', type=str, default='./checkpoints')
    p.add_argument('--out_csv', type=str, default='./predict.csv')

    p.add_argument('--model', type=str, default='convnextv2_base.fcmae_ft_in22k_in1k')
    p.add_argument('--img_size', type=int, default=384)
    p.add_argument('--num_classes', type=int, default=17)

    p.add_argument('--epochs', type=int, default=50)
    p.add_argument('--batch_size', type=int, default=24)
    p.add_argument('--num_workers', type=int, default=8)
    p.add_argument('--lr', type=float, default=2e-4)
    p.add_argument('--weight_decay', type=float, default=2e-2)
    p.add_argument('--warmup_epochs', type=int, default=3)

    p.add_argument('--label_smoothing', type=float, default=0.025)
    p.add_argument('--mixup_p', type=float, default=0.10)
    p.add_argument('--mixup_alpha', type=float, default=0.4)

    p.add_argument('--n_folds', type=int, default=5)
    p.add_argument('--seed', type=int, default=42)

    p.add_argument('--ema', action='store_true')
    p.add_argument('--ema_decay_min', type=float, default=0.99)
    p.add_argument('--ema_decay_max', type=float, default=0.9995)

    p.add_argument('--class_weight', action='store_true')
    p.add_argument('--boost_indices', type=str, default='')  # e.g., '3,7,14'
    p.add_argument('--boost_factor', type=float, default=1.2)

    p.add_argument('--tta_refine', action='store_true')
    p.add_argument('--tta_try_hflip', action='store_true')

    p.add_argument('--wandb', action='store_true')
    p.add_argument('--wandb_entity', type=str, default=None)
    p.add_argument('--wandb_project', type=str, default=None)

    p.add_argument('--do_infer', action='store_true')
    args, _ = p.parse_known_args()  # safe in notebooks

    cfg = Config(
        data_dir=args.data_dir, train_csv=args.train_csv, test_dir=args.test_dir, ckpt_dir=args.ckpt_dir, out_csv=args.out_csv,
        model_name=args.model, img_size=args.img_size, num_classes=args.num_classes,
        epochs=args.epochs, batch_size=args.batch_size, num_workers=args.num_workers, lr=args.lr, weight_decay=args.weight_decay,
        warmup_epochs=args.warmup_epochs, label_smoothing=args.label_smoothing, mixup_p=args.mixup_p, mixup_alpha=args.mixup_alpha,
        n_folds=args.n_folds, seed=args.seed, use_ema=args.ema, ema_decay_min=args.ema_decay_min, ema_decay_max=args.ema_decay_max,
        use_wandb=args.wandb, wandb_entity=args.wandb_entity, wandb_project=args.wandb_project,
        use_class_weight=args.class_weight, weight_boost_indices=tuple(int(x) for x in args.boost_indices.split(',') if x!=''),
        weight_boost_factor=args.boost_factor, tta_refine=args.tta_refine, tta_try_hflip=args.tta_try_hflip,
    )

    print("Config:\n", json.dumps(asdict(cfg), indent=2))

    seed_everything(cfg.seed)

    df = pd.read_csv(cfg.train_csv)
    img_col, tgt_col = discover_cols(df)
    # Ensure relative paths under train/ if not absolute
    if not os.path.isabs(str(df.iloc[0][img_col])):
        # keep as-is; base_dir will prefix
        pass

    # CV training
    fold_scores, fold_models = run_cv(cfg, df)

    # (Optional) inference on test folder with the last fold model
    if args.do_infer:
        # Build a test DataFrame listing image names
        test_paths = sorted(glob.glob(os.path.join(cfg.test_dir, '*')))
        if len(test_paths) == 0:
            print("No test images found; skip inference")
        else:
            test_df = pd.DataFrame({img_col: [os.path.basename(p) for p in test_paths]})
            best_model = fold_models[-1]
            best_model.eval()
            infer_folder(cfg, best_model, test_df, cfg.out_csv)


if __name__ == '__main__':
    main()


Config:
 {
  "data_dir": "../data",
  "train_csv": "../data/train.csv",
  "test_dir": "../data/test",
  "ckpt_dir": "./checkpoints",
  "out_csv": "./predict.csv",
  "model_name": "convnextv2_base.fcmae_ft_in22k_in1k",
  "img_size": 384,
  "num_classes": 17,
  "epochs": 50,
  "batch_size": 24,
  "num_workers": 8,
  "lr": 0.0002,
  "weight_decay": 0.02,
  "warmup_epochs": 3,
  "label_smoothing": 0.025,
  "mixup_p": 0.1,
  "mixup_alpha": 0.4,
  "n_folds": 5,
  "seed": 42,
  "device": "cuda",
  "use_ema": false,
  "ema_decay_min": 0.99,
  "ema_decay_max": 0.9995,
  "use_wandb": false,
  "wandb_entity": null,
  "wandb_project": null,
  "use_class_weight": false,
  "weight_boost_indices": [],
  "weight_boost_factor": 1.2,
  "tta_refine": false,
  "tta_try_hflip": false
}

===== Fold 0 / 5 =====




OutOfMemoryError: CUDA out of memory. Tried to allocate 432.00 MiB. GPU 0 has a total capacty of 23.69 GiB of which 74.00 MiB is free. Process 660397 has 5.71 GiB memory in use. Process 2320209 has 13.74 GiB memory in use. Process 3229775 has 632.00 MiB memory in use. Process 3253917 has 3.48 GiB memory in use. Of the allocated memory 3.14 GiB is allocated by PyTorch, and 22.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF