# Import

In [None]:
import os
import random

import pandas as pd
import numpy as np

from PIL import Image
from tqdm import tqdm 

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch import nn, optim

from sklearn.metrics import log_loss
from sklearn.model_selection import StratifiedKFold


if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


# Hyperparameter Setting

In [2]:
CFG = {
    'IMG_SIZE': 384,
    'BATCH_SIZE': 16,
    'EPOCHS': 15,
    'LEARNING_RATE': 1e-4,
    'SEED' : 42,
    'N_SPLITS': 3
}

# Fixed RandomSeed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(CFG['SEED'])

# CustomDataset

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        self.root_dir = root_dir
        self.transform = transform
        self.is_test = is_test
        self.samples = []

        if is_test:
            for fname in sorted(os.listdir(root_dir)):
                if fname.lower().endswith(('.jpg')):
                    img_path = os.path.join(root_dir, fname)
                    self.samples.append((img_path,))
        else:
            self.classes = sorted(os.listdir(root_dir))
            self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

            for cls_name in self.classes:
                cls_folder = os.path.join(root_dir, cls_name)

                if not os.path.isdir(cls_folder):
                    continue
                
                for fname in os.listdir(cls_folder):
                    if fname.lower().endswith(('.jpg')):
                        img_path = os.path.join(cls_folder, fname)
                        label = self.class_to_idx[cls_name]
                        self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.is_test:
            img_path = self.samples[idx][0]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image
        else:
            img_path, label = self.samples[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label

# Data Load

In [5]:
train_root = './filtered_train'
test_root = './test'

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((CFG['IMG_SIZE'] + 32, CFG['IMG_SIZE'] + 32)),
    transforms.RandomResizedCrop(CFG['IMG_SIZE'], scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((CFG['IMG_SIZE'], CFG['IMG_SIZE'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [None]:
full_dataset = CustomImageDataset(train_root, transform=None)
print(f"총 이미지 수: {len(full_dataset)}")

targets = [label for _, label in full_dataset.samples]
class_names = full_dataset.classes

train_idx, val_idx = train_test_split(
    range(len(targets)), test_size=0.2, stratify=targets, random_state=42
)

train_dataset = Subset(CustomImageDataset(train_root, transform=train_transform), train_idx)
val_dataset = Subset(CustomImageDataset(train_root, transform=val_transform), val_idx)
print(f'train 이미지 수: {len(train_dataset)}, valid 이미지 수: {len(val_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)


총 이미지 수: 33131
train 이미지 수: 26504, valid 이미지 수: 6627


# ConvNeXt

In [None]:
import torch.nn as nn
import timm

class BaseModel(nn.Module):
    def __init__(self, num_classes):
        super(BaseModel, self).__init__()
        self.backbone = timm.create_model('convnext_base_384_in22ft1k', pretrained=True, features_only=False)
        self.feature_dim = self.backbone.head.in_features
        self.backbone.head = nn.Identity()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.head = nn.Linear(self.feature_dim, num_classes)

    def forward(self, x):
        x = self.backbone.forward_features(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def mixup_data(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size, _, H, W = x.size()
    index = torch.randperm(batch_size).to(x.device)

    cx = np.random.randint(W)
    cy = np.random.randint(H)
    cut_w = int(W * np.sqrt(1 - lam))
    cut_h = int(H * np.sqrt(1 - lam))

    x1 = np.clip(cx - cut_w // 2, 0, W)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    y2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    y_a, y_b = y, y[index]
    lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
    return x, y_a, y_b, lam


def apply_mixup_or_cutmix(x, y, mix_prob=0.5, mixup_alpha=0.2, cutmix_alpha=1.0):
    if np.random.rand() < mix_prob:
        return mixup_data(x, y, alpha=mixup_alpha)
    else:
        return cutmix_data(x, y, alpha=cutmix_alpha)

# Train/ Validation

In [None]:
def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    best_logloss = checkpoint.get('best_logloss', float('inf'))
    best_acc = checkpoint.get('best_acc', 0.0)
    best_ce_loss = checkpoint.get('best_ce_loss', float('inf'))
    return start_epoch, best_logloss, best_acc, best_ce_loss

resume = True
resume_ckpt = {
    0: "final_model_fold1/checkpoint_epoch_005.pth",
    1: None,
    2: None,
    3: None,
    4: None
}

skf = StratifiedKFold(n_splits=CFG['N_SPLITS'], shuffle=True, random_state=42)
targets = [label for _, label in full_dataset.samples]
class_names = full_dataset.classes

for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
    print(f"\n📂 Fold {fold+1}/{CFG['N_SPLITS']}")

    train_dataset = Subset(CustomImageDataset(train_root, transform=train_transform), train_idx)
    val_dataset = Subset(CustomImageDataset(train_root, transform=val_transform), val_idx)
    train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)


    model = BaseModel(num_classes=len(class_names)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=CFG['LEARNING_RATE'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['EPOCHS'], eta_min=1e-6)
    criterion = nn.CrossEntropyLoss()

    save_dir = f"final_model_fold{fold+1}"
    os.makedirs(save_dir, exist_ok=True)

    start_epoch = 0
    best_logloss = float('inf')
    best_acc = 0.0
    best_ce_loss = float('inf')
    if resume and resume_ckpt.get(fold):
        print(f"🔄 Resuming fold {fold+1} from {resume_ckpt[fold]}")
        start_epoch, best_logloss, best_acc, best_ce_loss = load_checkpoint(
            model, optimizer, scheduler, resume_ckpt[fold]
        )
        print(f"Resume: start_epoch={start_epoch}, best_logloss={best_logloss}, best_acc={best_acc}, best_ce_loss={best_ce_loss}")

    for epoch in range(start_epoch, CFG['EPOCHS']):
        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"[Fold {fold+1}][Epoch {epoch+1}/{CFG['EPOCHS']}] Training"):
            images, labels = images.to(device), labels.to(device)
            inputs, targets_a, targets_b, lam = apply_mixup_or_cutmix(
                images, labels, mix_prob=0.5, mixup_alpha=0.2, cutmix_alpha=1.0
            )

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_probs = []
        all_labels = []

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"[Fold {fold+1}][Epoch {epoch+1}/{CFG['EPOCHS']}] Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

                probs = F.softmax(outputs, dim=1)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        val_logloss = log_loss(all_labels, all_probs, labels=list(range(len(class_names))))
        scheduler.step()

        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | LogLoss: {val_logloss:.4f}")

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_logloss': best_logloss,
            'best_acc': best_acc,
            'best_ce_loss': best_ce_loss,
        }
        torch.save(checkpoint, f"{save_dir}/checkpoint_epoch_{epoch+1:03d}.pth")

        if val_logloss < best_logloss:
            best_logloss = val_logloss
            torch.save(model.state_dict(), f"{save_dir}/best_logloss.pth")

        if val_accuracy > best_acc:
            best_acc = val_accuracy
            torch.save(model.state_dict(), f"{save_dir}/best_acc.pth")

        if avg_val_loss < best_ce_loss:
            best_ce_loss = avg_val_loss
            torch.save(model.state_dict(), f"{save_dir}/best_loss.pth")



📂 Fold 1/3


  model = create_fn(


🔄 Resuming fold 1 from final_model_fold1/checkpoint_epoch_005.pth
Resume: start_epoch=5, best_logloss=0.15354506109472799, best_acc=95.19195943498733, best_ce_loss=0.15338741486953428


[Fold 1][Epoch 6/15] Training: 100%|██████████| 1381/1381 [38:40<00:00,  1.68s/it]
[Fold 1][Epoch 6/15] Validation: 100%|██████████| 691/691 [05:29<00:00,  2.10it/s]


Train Loss: 0.9392 | Val Loss: 0.1390 | Val Acc: 95.93% | LogLoss: 0.1392


[Fold 1][Epoch 7/15] Training: 100%|██████████| 1381/1381 [38:14<00:00,  1.66s/it]
[Fold 1][Epoch 7/15] Validation: 100%|██████████| 691/691 [05:26<00:00,  2.11it/s]


Train Loss: 0.8840 | Val Loss: 0.1104 | Val Acc: 96.44% | LogLoss: 0.1105


[Fold 1][Epoch 8/15] Training: 100%|██████████| 1381/1381 [37:51<00:00,  1.65s/it]
[Fold 1][Epoch 8/15] Validation: 100%|██████████| 691/691 [05:22<00:00,  2.14it/s]


Train Loss: 0.8386 | Val Loss: 0.1110 | Val Acc: 96.57% | LogLoss: 0.1111


[Fold 1][Epoch 9/15] Training: 100%|██████████| 1381/1381 [37:34<00:00,  1.63s/it]
[Fold 1][Epoch 9/15] Validation: 100%|██████████| 691/691 [05:21<00:00,  2.15it/s]


Train Loss: 0.7874 | Val Loss: 0.1048 | Val Acc: 96.90% | LogLoss: 0.1050


[Fold 1][Epoch 10/15] Training: 100%|██████████| 1381/1381 [37:50<00:00,  1.64s/it]
[Fold 1][Epoch 10/15] Validation: 100%|██████████| 691/691 [05:26<00:00,  2.11it/s]


Train Loss: 0.7301 | Val Loss: 0.0960 | Val Acc: 97.08% | LogLoss: 0.0961


[Fold 1][Epoch 11/15] Training: 100%|██████████| 1381/1381 [38:53<00:00,  1.69s/it]
[Fold 1][Epoch 11/15] Validation: 100%|██████████| 691/691 [05:41<00:00,  2.02it/s]


Train Loss: 0.7042 | Val Loss: 0.0884 | Val Acc: 97.32% | LogLoss: 0.0885


[Fold 1][Epoch 12/15] Training: 100%|██████████| 1381/1381 [39:01<00:00,  1.70s/it]
[Fold 1][Epoch 12/15] Validation: 100%|██████████| 691/691 [05:37<00:00,  2.05it/s]


Train Loss: 0.6850 | Val Loss: 0.0866 | Val Acc: 97.35% | LogLoss: 0.0867


[Fold 1][Epoch 13/15] Training: 100%|██████████| 1381/1381 [37:51<00:00,  1.64s/it]
[Fold 1][Epoch 13/15] Validation: 100%|██████████| 691/691 [05:23<00:00,  2.14it/s]


Train Loss: 0.6648 | Val Loss: 0.0858 | Val Acc: 97.41% | LogLoss: 0.0859


[Fold 1][Epoch 14/15] Training: 100%|██████████| 1381/1381 [37:33<00:00,  1.63s/it]
[Fold 1][Epoch 14/15] Validation: 100%|██████████| 691/691 [05:21<00:00,  2.15it/s]


Train Loss: 0.6607 | Val Loss: 0.0840 | Val Acc: 97.40% | LogLoss: 0.0841


[Fold 1][Epoch 15/15] Training: 100%|██████████| 1381/1381 [37:22<00:00,  1.62s/it]
[Fold 1][Epoch 15/15] Validation: 100%|██████████| 691/691 [05:17<00:00,  2.18it/s]


Train Loss: 0.6621 | Val Loss: 0.0841 | Val Acc: 97.41% | LogLoss: 0.0842

📂 Fold 2/3


  model = create_fn(
[Fold 2][Epoch 1/15] Training: 100%|██████████| 1381/1381 [38:58<00:00,  1.69s/it]
[Fold 2][Epoch 1/15] Validation: 100%|██████████| 691/691 [05:17<00:00,  2.18it/s]


Train Loss: 4.0099 | Val Loss: 0.5538 | Val Acc: 84.76% | LogLoss: 0.5544


[Fold 2][Epoch 2/15] Training: 100%|██████████| 1381/1381 [41:37<00:00,  1.81s/it]
[Fold 2][Epoch 2/15] Validation: 100%|██████████| 691/691 [05:44<00:00,  2.01it/s]


Train Loss: 1.5564 | Val Loss: 0.2685 | Val Acc: 91.32% | LogLoss: 0.2688


[Fold 2][Epoch 3/15] Training: 100%|██████████| 1381/1381 [42:35<00:00,  1.85s/it]
[Fold 2][Epoch 3/15] Validation: 100%|██████████| 691/691 [05:49<00:00,  1.98it/s]


Train Loss: 1.2858 | Val Loss: 0.2244 | Val Acc: 93.63% | LogLoss: 0.2246


[Fold 2][Epoch 4/15] Training: 100%|██████████| 1381/1381 [42:22<00:00,  1.84s/it]
[Fold 2][Epoch 4/15] Validation: 100%|██████████| 691/691 [05:46<00:00,  2.00it/s]


Train Loss: 1.1111 | Val Loss: 0.1727 | Val Acc: 94.62% | LogLoss: 0.1728


[Fold 2][Epoch 5/15] Training: 100%|██████████| 1381/1381 [41:51<00:00,  1.82s/it]
[Fold 2][Epoch 5/15] Validation: 100%|██████████| 691/691 [05:51<00:00,  1.97it/s]


Train Loss: 1.0052 | Val Loss: 0.1490 | Val Acc: 95.39% | LogLoss: 0.1492


[Fold 2][Epoch 6/15] Training: 100%|██████████| 1381/1381 [42:25<00:00,  1.84s/it]
[Fold 2][Epoch 6/15] Validation: 100%|██████████| 691/691 [05:50<00:00,  1.97it/s]


Train Loss: 0.9495 | Val Loss: 0.1268 | Val Acc: 96.05% | LogLoss: 0.1269


[Fold 2][Epoch 7/15] Training: 100%|██████████| 1381/1381 [42:01<00:00,  1.83s/it]
[Fold 2][Epoch 7/15] Validation: 100%|██████████| 691/691 [05:42<00:00,  2.02it/s]


Train Loss: 0.9023 | Val Loss: 0.1156 | Val Acc: 96.35% | LogLoss: 0.1157


[Fold 2][Epoch 8/15] Training: 100%|██████████| 1381/1381 [42:25<00:00,  1.84s/it]
[Fold 2][Epoch 8/15] Validation: 100%|██████████| 691/691 [05:55<00:00,  1.94it/s]


Train Loss: 0.8302 | Val Loss: 0.1117 | Val Acc: 96.65% | LogLoss: 0.1118


[Fold 2][Epoch 9/15] Training: 100%|██████████| 1381/1381 [43:31<00:00,  1.89s/it]
[Fold 2][Epoch 9/15] Validation: 100%|██████████| 691/691 [05:49<00:00,  1.97it/s]


Train Loss: 0.7989 | Val Loss: 0.1017 | Val Acc: 96.89% | LogLoss: 0.1018


[Fold 2][Epoch 10/15] Training: 100%|██████████| 1381/1381 [43:00<00:00,  1.87s/it]
[Fold 2][Epoch 10/15] Validation: 100%|██████████| 691/691 [05:57<00:00,  1.93it/s]


Train Loss: 0.7254 | Val Loss: 0.0975 | Val Acc: 97.08% | LogLoss: 0.0976


[Fold 2][Epoch 11/15] Training: 100%|██████████| 1381/1381 [42:46<00:00,  1.86s/it]
[Fold 2][Epoch 11/15] Validation: 100%|██████████| 691/691 [05:52<00:00,  1.96it/s]


Train Loss: 0.7333 | Val Loss: 0.0891 | Val Acc: 97.21% | LogLoss: 0.0892


[Fold 2][Epoch 12/15] Training: 100%|██████████| 1381/1381 [42:28<00:00,  1.85s/it]
[Fold 2][Epoch 12/15] Validation: 100%|██████████| 691/691 [05:52<00:00,  1.96it/s]


Train Loss: 0.6956 | Val Loss: 0.0855 | Val Acc: 97.49% | LogLoss: 0.0856


[Fold 2][Epoch 13/15] Training: 100%|██████████| 1381/1381 [43:11<00:00,  1.88s/it]
[Fold 2][Epoch 13/15] Validation: 100%|██████████| 691/691 [05:45<00:00,  2.00it/s]


Train Loss: 0.6735 | Val Loss: 0.0820 | Val Acc: 97.43% | LogLoss: 0.0820


[Fold 2][Epoch 14/15] Training: 100%|██████████| 1381/1381 [41:50<00:00,  1.82s/it]
[Fold 2][Epoch 14/15] Validation: 100%|██████████| 691/691 [05:45<00:00,  2.00it/s]


Train Loss: 0.6844 | Val Loss: 0.0809 | Val Acc: 97.69% | LogLoss: 0.0810


[Fold 2][Epoch 15/15] Training: 100%|██████████| 1381/1381 [42:18<00:00,  1.84s/it]
[Fold 2][Epoch 15/15] Validation: 100%|██████████| 691/691 [05:57<00:00,  1.93it/s]


Train Loss: 0.6549 | Val Loss: 0.0805 | Val Acc: 97.65% | LogLoss: 0.0806

📂 Fold 3/3


  model = create_fn(
[Fold 3][Epoch 1/15] Training: 100%|██████████| 1381/1381 [44:12<00:00,  1.92s/it] 
[Fold 3][Epoch 1/15] Validation: 100%|██████████| 691/691 [05:18<00:00,  2.17it/s]


Train Loss: 4.0209 | Val Loss: 0.6242 | Val Acc: 83.41% | LogLoss: 0.6248


[Fold 3][Epoch 2/15] Training:   5%|▌         | 70/1381 [36:45<9:18:59, 25.58s/it] 

In [None]:
def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    best_logloss = checkpoint.get('best_logloss', float('inf'))
    best_acc = checkpoint.get('best_acc', 0.0)
    best_ce_loss = checkpoint.get('best_ce_loss', float('inf'))
    return start_epoch, best_logloss, best_acc, best_ce_loss

resume = True
resume_ckpt = {
    0: None,
    1: None,
    2: "final_model_fold3/checkpoint_epoch_001.pth",
    3: None,
    4: None
}

skf = StratifiedKFold(n_splits=CFG['N_SPLITS'], shuffle=True, random_state=42)
targets = [label for _, label in full_dataset.samples]
class_names = full_dataset.classes

for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
    if fold != 2:
        continue
    
    print(f"\n📂 Fold {fold+1}/{CFG['N_SPLITS']}")

    train_dataset = Subset(CustomImageDataset(train_root, transform=train_transform), train_idx)
    val_dataset = Subset(CustomImageDataset(train_root, transform=val_transform), val_idx)
    train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)

    model = BaseModel(num_classes=len(class_names)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=CFG['LEARNING_RATE'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['EPOCHS'], eta_min=1e-6)
    criterion = nn.CrossEntropyLoss()

    save_dir = f"final_model_fold{fold+1}"
    os.makedirs(save_dir, exist_ok=True)

    start_epoch = 0
    best_logloss = float('inf')
    best_acc = 0.0
    best_ce_loss = float('inf')
    if resume and resume_ckpt.get(fold):
        print(f"🔄 Resuming fold {fold+1} from {resume_ckpt[fold]}")
        start_epoch, best_logloss, best_acc, best_ce_loss = load_checkpoint(
            model, optimizer, scheduler, resume_ckpt[fold]
        )
        print(f"Resume: start_epoch={start_epoch}, best_logloss={best_logloss}, best_acc={best_acc}, best_ce_loss={best_ce_loss}")

    for epoch in range(start_epoch, CFG['EPOCHS']):
        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"[Fold {fold+1}][Epoch {epoch+1}/{CFG['EPOCHS']}] Training"):
            images, labels = images.to(device), labels.to(device)
            inputs, targets_a, targets_b, lam = apply_mixup_or_cutmix(
                images, labels, mix_prob=0.5, mixup_alpha=0.2, cutmix_alpha=1.0
            )

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_probs = []
        all_labels = []

        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"[Fold {fold+1}][Epoch {epoch+1}/{CFG['EPOCHS']}] Validation"):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

                probs = F.softmax(outputs, dim=1)
                all_probs.extend(probs.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * correct / total
        val_logloss = log_loss(all_labels, all_probs, labels=list(range(len(class_names))))
        scheduler.step()

        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}% | LogLoss: {val_logloss:.4f}")

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_logloss': best_logloss,
            'best_acc': best_acc,
            'best_ce_loss': best_ce_loss,
        }
        torch.save(checkpoint, f"{save_dir}/checkpoint_epoch_{epoch+1:03d}.pth")

        if val_logloss < best_logloss:
            best_logloss = val_logloss
            torch.save(model.state_dict(), f"{save_dir}/best_logloss.pth")

        if val_accuracy > best_acc:
            best_acc = val_accuracy
            torch.save(model.state_dict(), f"{save_dir}/best_acc.pth")

        if avg_val_loss < best_ce_loss:
            best_ce_loss = avg_val_loss
            torch.save(model.state_dict(), f"{save_dir}/best_loss.pth")


📂 Fold 3/3


  model = create_fn(


🔄 Resuming fold 3 from final_model_fold3/checkpoint_epoch_001.pth
Resume: start_epoch=1, best_logloss=inf, best_acc=0.0, best_ce_loss=inf


[Fold 3][Epoch 2/15] Training: 100%|██████████| 1381/1381 [38:17<00:00,  1.66s/it]
[Fold 3][Epoch 2/15] Validation: 100%|██████████| 691/691 [05:13<00:00,  2.21it/s]


Train Loss: 1.5755 | Val Loss: 0.2542 | Val Acc: 92.61% | LogLoss: 0.2544


[Fold 3][Epoch 3/15] Training: 100%|██████████| 1381/1381 [43:29<00:00,  1.89s/it]
[Fold 3][Epoch 3/15] Validation: 100%|██████████| 691/691 [05:14<00:00,  2.20it/s]


Train Loss: 1.2462 | Val Loss: 0.2089 | Val Acc: 93.93% | LogLoss: 0.2091


[Fold 3][Epoch 4/15] Training: 100%|██████████| 1381/1381 [46:48<00:00,  2.03s/it] 
[Fold 3][Epoch 4/15] Validation: 100%|██████████| 691/691 [05:16<00:00,  2.18it/s]


Train Loss: 1.1290 | Val Loss: 0.1709 | Val Acc: 95.28% | LogLoss: 0.1711


[Fold 3][Epoch 5/15] Training: 100%|██████████| 1381/1381 [45:53<00:00,  1.99s/it]
[Fold 3][Epoch 5/15] Validation: 100%|██████████| 691/691 [05:20<00:00,  2.16it/s]


Train Loss: 1.0111 | Val Loss: 0.1387 | Val Acc: 95.78% | LogLoss: 0.1388


[Fold 3][Epoch 6/15] Training: 100%|██████████| 1381/1381 [45:10<00:00,  1.96s/it]
[Fold 3][Epoch 6/15] Validation: 100%|██████████| 691/691 [05:13<00:00,  2.21it/s]


Train Loss: 0.9136 | Val Loss: 0.1272 | Val Acc: 96.06% | LogLoss: 0.1274


[Fold 3][Epoch 7/15] Training: 100%|██████████| 1381/1381 [43:59<00:00,  1.91s/it]
[Fold 3][Epoch 7/15] Validation: 100%|██████████| 691/691 [05:15<00:00,  2.19it/s]


Train Loss: 0.8666 | Val Loss: 0.1188 | Val Acc: 96.08% | LogLoss: 0.1189


[Fold 3][Epoch 8/15] Training: 100%|██████████| 1381/1381 [44:44<00:00,  1.94s/it]
[Fold 3][Epoch 8/15] Validation: 100%|██████████| 691/691 [05:16<00:00,  2.18it/s]


Train Loss: 0.8142 | Val Loss: 0.1171 | Val Acc: 96.58% | LogLoss: 0.1172


[Fold 3][Epoch 9/15] Training: 100%|██████████| 1381/1381 [47:37<00:00,  2.07s/it] 
[Fold 3][Epoch 9/15] Validation: 100%|██████████| 691/691 [05:32<00:00,  2.08it/s]


Train Loss: 0.7764 | Val Loss: 0.1087 | Val Acc: 96.70% | LogLoss: 0.1089


[Fold 3][Epoch 10/15] Training: 100%|██████████| 1381/1381 [50:03<00:00,  2.18s/it] 
[Fold 3][Epoch 10/15] Validation: 100%|██████████| 691/691 [05:35<00:00,  2.06it/s]


Train Loss: 0.7401 | Val Loss: 0.0982 | Val Acc: 97.06% | LogLoss: 0.0983


[Fold 3][Epoch 11/15] Training: 100%|██████████| 1381/1381 [46:47<00:00,  2.03s/it]
[Fold 3][Epoch 11/15] Validation: 100%|██████████| 691/691 [05:29<00:00,  2.10it/s]


Train Loss: 0.7176 | Val Loss: 0.0936 | Val Acc: 97.11% | LogLoss: 0.0937


[Fold 3][Epoch 12/15] Training: 100%|██████████| 1381/1381 [49:50<00:00,  2.17s/it] 
[Fold 3][Epoch 12/15] Validation: 100%|██████████| 691/691 [05:09<00:00,  2.24it/s]


Train Loss: 0.7031 | Val Loss: 0.0884 | Val Acc: 97.22% | LogLoss: 0.0885


[Fold 3][Epoch 13/15] Training: 100%|██████████| 1381/1381 [44:13<00:00,  1.92s/it]
[Fold 3][Epoch 13/15] Validation: 100%|██████████| 691/691 [05:09<00:00,  2.23it/s]


Train Loss: 0.6732 | Val Loss: 0.0866 | Val Acc: 97.23% | LogLoss: 0.0867


[Fold 3][Epoch 14/15] Training: 100%|██████████| 1381/1381 [44:01<00:00,  1.91s/it]
[Fold 3][Epoch 14/15] Validation: 100%|██████████| 691/691 [05:08<00:00,  2.24it/s]


Train Loss: 0.6783 | Val Loss: 0.0844 | Val Acc: 97.41% | LogLoss: 0.0845


[Fold 3][Epoch 15/15] Training: 100%|██████████| 1381/1381 [43:51<00:00,  1.91s/it]
[Fold 3][Epoch 15/15] Validation: 100%|██████████| 691/691 [05:08<00:00,  2.24it/s]


Train Loss: 0.6560 | Val Loss: 0.0838 | Val Acc: 97.39% | LogLoss: 0.0839


# Inference

In [15]:
test_dataset = CustomImageDataset(test_root, transform=None, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False)

In [None]:
IMG_SIZE = CFG['IMG_SIZE']
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
BATCH_SIZE = CFG['BATCH_SIZE']


def get_tta_transforms():
    return [
        transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
        transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomRotation(degrees=5),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]),
    ]

tta_transforms = get_tta_transforms()

fold_list = [1, 2, 3]
model_paths = [
    f'/Users/hyun/dev_ws/dacon_car/final_model_fold{fold}/checkpoint_epoch_015.pth'
    for fold in fold_list
]

ensemble_probs = []

for fold_idx, path in enumerate(model_paths):
    model = BaseModel(num_classes=len(class_names))
    checkpoint = torch.load(path, map_location=device)

    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()

    with torch.no_grad():
        for tta_idx, tta_tf in enumerate(tta_transforms):
            probs_list = []

            for images in tqdm(test_loader, desc=f"Fold {fold_idx+1}, TTA {tta_idx+1}"):
                transformed = [tta_tf(img) for img in images]
                transformed = torch.stack(transformed).to(device)

                outputs = model(transformed)
                probs = F.softmax(outputs, dim=1)
                probs_list.append(probs.cpu())

            all_probs = torch.cat(probs_list, dim=0)
            ensemble_probs.append(all_probs)

avg_probs = torch.stack(ensemble_probs).mean(dim=0)

results = [
    {class_names[i]: prob[i].item() for i in range(len(class_names))}
    for prob in avg_probs
]
pred = pd.DataFrame(results)

pred['label'] = pred[class_names].idxmax(axis=1)
pred['ID'] = [f'TEST_{i:05d}' for i in range(len(pred))]

submission = pred[['ID', 'label']]
submission.to_csv("ensemble_3fold_3tta_submission.csv", index=False)
print("✅ 3-Fold × 3-TTA soft voting 완료 및 저장 완료")

Fold 1, TTA 1: 100%|██████████| 517/517 [03:58<00:00,  2.17it/s]
Fold 1, TTA 2: 100%|██████████| 517/517 [03:58<00:00,  2.17it/s]
Fold 1, TTA 3: 100%|██████████| 517/517 [04:05<00:00,  2.11it/s]
Fold 2, TTA 1: 100%|██████████| 517/517 [04:02<00:00,  2.13it/s]
Fold 2, TTA 2: 100%|██████████| 517/517 [04:02<00:00,  2.13it/s]
Fold 2, TTA 3: 100%|██████████| 517/517 [04:06<00:00,  2.09it/s]
Fold 3, TTA 1: 100%|██████████| 517/517 [04:07<00:00,  2.09it/s]
Fold 3, TTA 2: 100%|██████████| 517/517 [04:06<00:00,  2.10it/s]
Fold 3, TTA 3: 100%|██████████| 517/517 [04:07<00:00,  2.09it/s]


✅ 3-Fold × 3-TTA soft voting 완료 및 저장 완료


# Submission

In [None]:
submission = pd.read_csv('./sample_submission.csv', encoding='utf-8-sig')

class_columns = submission.columns[1:]
pred = pred[class_columns]

submission[class_columns] = pred.values
submission.to_csv('test36_submission.csv', index=False, encoding='utf-8-sig')