# Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q numpy pandas librosa soundfile xgboost lightgbm scikit-learn scipy tqdm rich

# Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torch import optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau, CosineAnnealingLR, SequentialLR, LinearLR
from torch.cuda.amp import GradScaler, autocast
from torch.utils.checkpoint import checkpoint
from torch.utils.tensorboard.writer import SummaryWriter
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
from sklearn.linear_model import LogisticRegression
import scipy, scipy.signal, scipy.stats
import librosa
import soundfile
import os, re, math, warnings, random
from rich.console import Console
# from tqdm import tqdm # 멀티쓰레딩 가끔 문제잇음 안쓰는걸로

warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [None]:
train_csv_path = "/content/drive/MyDrive/ST_KAGGLE_2/input/train.csv"
test_csv_path = "/content/drive/MyDrive/ST_KAGGLE_2/input/test.csv"
train_data_path = "/content/drive/MyDrive/ST_KAGGLE_2/input/train"
test_data_path = "/content/drive/MyDrive/ST_KAGGLE_2/input/test"

In [None]:
# 구글드라이브로 트레이닝하면 엄청 오래걸림!!! 런타임 처음 시작하면 카피 한 번만 해두기
# !unzip -q /content/drive/MyDrive/ST_KAGGLE_2/input/precomputed.zip -d /content/

In [None]:
precomputed_dir = "/content/precomputed"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

torch.backends.cuda.matmul.allow_tf32 = True # A100
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Utils

In [None]:
console = Console()

def print_start(title: str):
    console.print(f"[bold cyan]🐶 작업 시작: {title}[/bold cyan]")

def print_epoch_summary(epoch_index: int, average_loss: float):
    console.print(f"[bold blue]⚙️ Epoch {epoch_index} Summary[/bold blue]")
    console.print(f"[green]평균 Training Loss:[/green] {average_loss:.4f}")

def print_validation_accuracy(accuracy: float, min_prob: float, max_prob: float):
    console.print(f"[bold green]✅ Val Accuracy:[/bold green] {accuracy:.4f}")
    console.print(f"[dim]Probability range: {min_prob:.3f}–{max_prob:.3f}[/dim]")

def progress_bar(iterable, description: str):
    return tqdm(iterable, desc=description, ncols=120)

def print_success(message: str):
    console.print(f"[bold green]✅ {message}[/bold green]")

def print_warning(message: str):
    console.print(f"[bold yellow]⚠️ {message}[/bold yellow]")

def print_error(message: str):
    console.print(f"[bold red]❌ {message}[/bold red]")

def print_info(message: str):
    console.print(f"[bold blue]ℹ️ {message}[/bold blue]")

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    console.print(f"[bold purple]🦄 총 파라미터: {total:,}[/bold purple]")
    console.print(f"[bold purple]🦄 Trainable: {trainable:,}[/bold purple]")

# EDA

In [None]:
# with soundfile.SoundFile(train_data_path+"/steth_20190623_10_35_17_033.wav") as audio:
#     waveform_i = audio.read(dtype="float32")
#     sr = audio.samplerate
#     plt.figure(figsize=(15, 4))
#     plt.subplot(1,2,1)
#     librosa.display.waveshow(waveform_i, sr=sr)
#     plt.title('Inhale')

# with soundfile.SoundFile(train_data_path+"/steth_20190623_10_34_54_018.wav") as audio:
#     waveform_e = audio.read(dtype="float32")
#     sr = audio.samplerate
#     plt.subplot(1,2,2)
#     librosa.display.waveshow(waveform_e, sr=sr)
#     plt.title('Exhale')


In [None]:
# stft_spectrum_matrix = librosa.stft(waveform_i)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(librosa.amplitude_to_db(np.abs(stft_spectrum_matrix), ref=np.max),y_axis='log', x_axis='time')
# plt.title('Inhale STFT Power spectrogram')
# plt.colorbar(format='%+2.0f dB')
# plt.tight_layout()

# stft_spectrum_matrix = librosa.stft(waveform_e)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(librosa.amplitude_to_db(np.abs(stft_spectrum_matrix), ref=np.max),y_axis='log', x_axis='time')
# plt.title('Exhale STFT Power spectrogram')
# plt.colorbar(format='%+2.0f dB')
# plt.tight_layout()

In [None]:
# from matplotlib.colors import Normalize
# mfc_coefficients = librosa.feature.mfcc(y=waveform_i, sr=sr, n_mfcc=32)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(mfc_coefficients, x_axis='time',norm=Normalize(vmin=-30,vmax=30))
# plt.colorbar()
# plt.yticks(())
# plt.ylabel('MFC Coefficient')
# plt.title('Inhale MFC Coefficients')
# plt.tight_layout()

# mfc_coefficients = librosa.feature.mfcc(y=waveform_e, sr=sr, n_mfcc=32)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(mfc_coefficients, x_axis='time',norm=Normalize(vmin=-30,vmax=30))
# plt.colorbar()
# plt.yticks(())
# plt.ylabel('MFC Coefficient')
# plt.title('Exhale MFC Coefficients')
# plt.tight_layout()

In [None]:
# melspectrogram_i = librosa.feature.melspectrogram(y=waveform_i, sr=sr, n_mels=128, fmax=4000)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(librosa.power_to_db(S=melspectrogram_i, ref=np.mean),y_axis='mel',fmax=4000, x_axis='time', norm=Normalize(vmin=-20,vmax=20))
# plt.colorbar(format='%+2.0f dB',label='Amplitude')
# plt.ylabel('Mels')
# plt.title('Inhale Mel spectrogram')
# plt.tight_layout()

# melspectrogram_e = librosa.feature.melspectrogram(y=waveform_e, sr=sr, n_mels=128, fmax=4000)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(librosa.power_to_db(S=melspectrogram_e, ref=np.mean),y_axis='mel',fmax=4000, x_axis='time', norm=Normalize(vmin=-20,vmax=20))
# plt.colorbar(format='%+2.0f dB',label='Amplitude')
# plt.ylabel('Mels')
# plt.title('Exhale Mel spectrogram')
# plt.tight_layout()

In [None]:
# chromagram = librosa.feature.chroma_stft(y=waveform_i, sr=sr)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(chromagram, y_axis='chroma', x_axis='time')
# plt.colorbar(label='Relative Intensity')
# plt.title('Inhale Chromagram')
# plt.tight_layout()

# chromagram = librosa.feature.chroma_stft(y=waveform_e, sr=sr)
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(chromagram, y_axis='chroma', x_axis='time')
# plt.colorbar(label='Relative Intensity')
# plt.title('Exhale Chromagram')
# plt.tight_layout()

In [None]:
# def feature_chromagram(waveform, sample_rate):
#     stft_spectrogram=np.abs(librosa.stft(waveform))
#     chromagram=np.mean(librosa.feature.chroma_stft(S=stft_spectrogram, sr=sample_rate).T,axis=0)
#     return chromagram

# def feature_melspectrogram(waveform, sample_rate):
#     melspectrogram=np.mean(librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=4000).T,axis=0)
#     return melspectrogram

# def feature_mfcc(waveform, sample_rate):
#     mfc_coefficients=np.mean(librosa.feature.mfcc(y=waveform, sr=sample_rate, n_mfcc=40).T, axis=0)
#     return mfc_coefficients

# def get_features(file):
#     with soundfile.SoundFile(file) as audio:
#         waveform = audio.read(dtype="float32")
#         sample_rate = audio.samplerate

#     chromagram = feature_chromagram(waveform, sample_rate) # 12,
#     melspectrogram = feature_melspectrogram(waveform, sample_rate) # 128,
#     mfc_coefficients = feature_mfcc(waveform, sample_rate) # 40,

#     feature_vector = np.hstack((chromagram, melspectrogram, mfc_coefficients)).astype(np.float32)

#     return feature_vector # 12 + 128 + 40 = 180,

In [None]:
# inhale_matrix = get_features(train_data_path+"/steth_20190623_10_35_17_033.wav")
# exhale_matrix = get_features(train_data_path+"/steth_20190623_10_34_54_018.wav")

# Dataset

In [None]:
class DS(Dataset):
    EXCLUDED_KEYS = {'scalars', 'sr', 'hop_length', 'n_fft'}

    def __init__(self, data_frame: pd.DataFrame, feature_dir: str, is_training: bool):
        self.df = data_frame.reset_index(drop=True)
        self.feature_dir = feature_dir
        self.is_training = is_training

        self._detect_features()

    def _detect_features(self):
        if len(self.df) == 0:
            raise ValueError

        first_id = self.df.iloc[0]["ID"]
        npz_path = os.path.join(self.feature_dir, first_id + ".npz")

        with np.load(npz_path) as data:
            self.feature_names = [k for k in data.keys() if k not in self.EXCLUDED_KEYS]
            self.feature_names.sort()

            self.n_features = len(self.feature_names)
            self.scalar_dim = data['scalars'].shape[0]

        print(f"#Features: {self.n_features} - {', '.join(self.feature_names)}")
        print(f"#Scalars: {self.scalar_dim}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_id = row["ID"]
        npz_path = os.path.join(self.feature_dir, file_id + ".npz")

        data = np.load(npz_path)

        features_list = []
        for feat_name in self.feature_names:
            features_list.append(data[feat_name])

        features = np.stack(features_list, axis=0).astype(np.float32)
        features = torch.from_numpy(features)

        scalars = torch.from_numpy(data['scalars'].astype(np.float32))

        if self.is_training:
            label = 1.0 if row["Target"] == "E" else 0.0
            return features, scalars, torch.tensor(label, dtype=torch.float32)
        else:
            return features, scalars, file_id

In [None]:
def collate_fn(batch):
    feats, scals, labs_or_ids = [], [], []
    for f, s, y in batch:
        feats.append(f)
        scals.append(s)
        labs_or_ids.append(y)

    features = torch.stack(feats, dim=0)
    scalars  = torch.stack(scals, dim=0) # 이거 개수대로 모델에 num scalar features 지정하기

    if isinstance(labs_or_ids[0], torch.Tensor):
        labels = torch.stack(labs_or_ids, dim=0)
        return features, scalars, labels
    else:
        return features, scalars, labs_or_ids

# Model

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        pt = torch.where(targets == 1, p, 1 - p)
        weight = self.alpha * (1 - pt) ** self.gamma
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        return (weight * bce).mean()

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        reduced_channels = max(channels // reduction, 16)

        self.fc = nn.Sequential(
            nn.Linear(channels, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

## VGG

In [None]:
class VGG(nn.Module):
    def __init__(self, in_channels=9, num_scalar_features=39, dropout_rate=0.2):
        super().__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Dropout2d(dropout_rate * 0.5)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(dropout_rate)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(dropout_rate)
        )

        self.block4_conv = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.Conv2d(512, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.GELU(),
            nn.Dropout2d(dropout_rate)
        )

        self.block4_residual = nn.Sequential(
            nn.Conv2d(256, 512, 1, bias=False),
            nn.BatchNorm2d(512)
        )

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.scalar_net = nn.Sequential(
            nn.Linear(num_scalar_features, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.GELU()
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 + 64, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128, bias=False),
            nn.BatchNorm1d(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, features, scalars):
        x = self.block1(features)
        x = self.block2(x)
        x = self.block3(x)
        residual = self.block4_residual(x)
        x = self.block4_conv(x) + residual
        x = self.global_pool(x).view(x.size(0), -1)
        s = self.scalar_net(scalars)
        combined = torch.cat([x, s], dim=1)
        return self.classifier(combined).squeeze(1)

## CNN 8 Layers

In [None]:
class CNN8(nn.Module):
    def __init__(self, in_channels=9, num_scalar_features=39, dropout_rate=0.3):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.Dropout2d(dropout_rate),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.scalar_net = nn.Sequential(
            nn.Linear(num_scalar_features, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(dropout_rate),

            nn.Linear(64, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
        )

        self.classifier = nn.Sequential(
            nn.Linear(256 + 64, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(dropout_rate),

            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),

            nn.Linear(128, 1)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, features, scalars):
        x = self.cnn(features)
        x = x.view(x.size(0), -1)
        s = self.scalar_net(scalars)
        combined = torch.cat([x, s], dim=1)
        return self.classifier(combined).squeeze(1)

# Train and Ensemble

In [None]:
SR = 16000
DURATION = 1.0
EXPECTED_LEN = int(SR * DURATION)
N_MELS = 128
N_MFCC = 40
HOP_LENGTH = 256
N_FFT = 512
FMAX = 4500
BATCH_SIZE = 64
NUM_WORKERS = 4
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
SWA_START_RATIO = 0.75
SEED_LIST = [0, 1, 2]
T_FIXED = (EXPECTED_LEN // HOP_LENGTH) + 1

## Train Functions

## Data augmentation

In [None]:
def cutmix_data(features, labels, alpha=1.0, device='cuda'):
    batch_size = features.size(0)
    indices = torch.randperm(batch_size).to(device)

    lam = np.random.beta(alpha, alpha)

    W = features.size(3)
    H = features.size(2)

    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int32(W * cut_rat)
    cut_h = np.int32(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    features_mixed = features.clone()
    features_mixed[:, :, bby1:bby2, bbx1:bbx2] = features[indices, :, bby1:bby2, bbx1:bbx2]

    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))

    labels_mixed = lam * labels + (1 - lam) * labels[indices]

    return features_mixed, labels_mixed, indices, lam


def mixup_data(features, labels, alpha=1.0, device='cuda'):
    batch_size = features.size(0)
    indices = torch.randperm(batch_size).to(device)

    lam = np.random.beta(alpha, alpha)

    mixed_features = lam * features + (1 - lam) * features[indices]
    mixed_labels = lam * labels + (1 - lam) * labels[indices]

    return mixed_features, mixed_labels, indices, lam

## Train Model

In [None]:
def train_model(
    model,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    save_dir: str,
    num_epochs: int = 30,
    base_lr: float = 1e-3,
    weight_decay: float = 1e-4,
    patience: int = 15,
    min_delta: float = 1e-4,
    monitor: str = "val_acc",
    restore_best_weights: bool = True,
    use_cutmix: bool = True,
    use_mixup: bool = True,
    cutmix_prob: float = 0.5,
    mixup_prob: float = 0.5,
    cutmix_alpha: float = 1.0,
    mixup_alpha: float = 0.2,
    warmup_epochs: int = 5,
):
    os.makedirs(save_dir, exist_ok=True)
    model = model.to(device)
    count_parameters(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)

    total_steps = len(train_loader) * num_epochs
    warmup_steps = int(0.05 * total_steps)
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps),
            CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=1e-6),
        ],
        milestones=[warmup_steps],
    )

    criterion = nn.BCEWithLogitsLoss()
    scaler = GradScaler(enabled=(device.type != "cpu"))

    best_val_acc = 0.0
    best_val_loss = float("inf")
    best_ckpt = None
    best_weights = None
    early_stop_counter = 0

    print_start(f"학습 레츠고~ (CutMix: {use_cutmix}, MixUp: {use_mixup})\n")

    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, total = 0.0, 0, 0

        use_aug = epoch >= warmup_epochs

        for batch_idx, (features, scalars, labels) in enumerate(train_loader):
            features, scalars, labels = map(lambda x: x.to(device, non_blocking=True), (features, scalars, labels))
            # non_blocking True쓰면 개빠릅니다

            original_labels = labels.clone()
            mixed = False

            if use_aug and (use_cutmix or use_mixup):
                r = np.random.rand()

                if use_cutmix and r < cutmix_prob:
                    features, labels, _, lam = cutmix_data(features, labels, cutmix_alpha, device.type)
                    mixed = True
                elif use_mixup and r < (cutmix_prob + mixup_prob):
                    indices = torch.randperm(features.size(0)).to(device)
                    lam = np.random.beta(mixup_alpha, mixup_alpha)

                    features = lam * features + (1 - lam) * features[indices]
                    scalars = lam * scalars + (1 - lam) * scalars[indices]
                    labels = lam * labels + (1 - lam) * labels[indices]
                    mixed = True

            optimizer.zero_grad(set_to_none=True)
            with autocast(enabled=(device.type != "cpu")):
                logits = model(features, scalars)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            if not mixed:
                preds = (logits > 0.0).float()
                train_correct += (preds == labels).sum().item()
                total += labels.size(0)
            else:
                with torch.no_grad():
                    preds = (logits > 0.0).float()
                    train_correct += (preds == original_labels).sum().item()
                    total += original_labels.size(0)

            train_loss += loss.item()

        train_acc = train_correct / total
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for features, scalars, labels in val_loader:
                features, scalars, labels = map(lambda x: x.to(device, non_blocking=True), (features, scalars, labels))
                with autocast(enabled=(device.type != "cpu")):
                    logits = model(features, scalars)
                    loss = criterion(logits, labels)

                preds = (logits > 0.0).float()
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
                val_loss += loss.item()

        val_acc = val_correct / val_total
        avg_val_loss = val_loss / len(val_loader)

        aug_status = f" [Aug: {'ON' if use_aug else 'OFF'}]" if use_cutmix or use_mixup else ""
        print(
            f"[Epoch {epoch+1:02d}]{aug_status} "
            f"Train Loss: {avg_train_loss:.6} | Train Acc: {train_acc:.6f} || "
            f"Val Loss: {avg_val_loss:.6f} | Val Acc: {val_acc:.6f}"
        )

        if monitor == "val_acc":
            metric = val_acc
            best_metric = best_val_acc
        else:
            metric = -avg_val_loss
            best_metric = -best_val_loss

        if metric - best_metric > min_delta:
            best_val_acc = val_acc
            best_val_loss = avg_val_loss
            best_ckpt = os.path.join(save_dir, f"best_epoch{epoch+1:02d}.pth")
            best_weights = model.state_dict() if restore_best_weights else None
            early_stop_counter = 0
            torch.save({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "val_acc": val_acc,
                "val_loss": avg_val_loss,
                "epoch": epoch + 1,
                "cutmix_used": use_cutmix,
                "mixup_used": use_mixup
            }, best_ckpt)
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print_warning("끝;;;;")
                if restore_best_weights and best_weights is not None:
                    model.load_state_dict(best_weights)
                break

    return best_ckpt, best_val_acc

## Load Model

In [None]:
def load_model(ckpt_path: str, arch: str, num_scalar_features: int, device: torch.device):
    if arch == 'vgg':
        model = VGG(num_scalar_features=num_scalar_features).to(device)
    elif arch == 'cnn8':
        model = CNN8(num_scalar_features=num_scalar_features).to(device)
    else:
        raise ValueError(f"이게머임? {arch}")

    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt)
    model.eval()
    return model

## Ensemble Functions

In [None]:
def average_ensemble(
    ckpt_paths: list[str],
    archs:      list[str],
    test_loader: DataLoader,
    device:      torch.device,
    num_scalar_features: int,
):
    assert len(ckpt_paths) == len(archs), "제대로 하자."

    models = []
    for idx, (path, arch) in enumerate(zip(ckpt_paths, archs), 1):
        m = load_model(path, arch, num_scalar_features=num_scalar_features, device=device)
        models.append(m)
        print_info(f"모델: {idx}/{len(ckpt_paths)} @ '{path}'")

    all_ids = []
    per_batch_probs = []

    with torch.no_grad():
        for batch_idx, (feats, scals, ids) in enumerate(test_loader, 1):
            feats = feats.to(device, non_blocking=True)
            scals = scals.to(device, non_blocking=True)

            logits_stack = torch.stack([m(feats, scals).view(-1) for m in models], dim=0)
            probs = torch.sigmoid(logits_stack)
            avg_probs = probs.mean(dim=0)

            per_batch_probs.append(avg_probs.cpu().numpy())
            all_ids.extend(ids)

    final_probs = np.concatenate(per_batch_probs, axis=0)
    print_success("완료~")
    return all_ids, final_probs

In [None]:
def weighted_ensemble(
    ckpt_paths: list[str],
    archs: list[str],
    test_loader: torch.utils.data.DataLoader,
    device: torch.device,
    num_scalar_features: int,
    val_scores: list[float],
    use_softmax_weights: bool = True
):
    assert len(ckpt_paths) == len(archs) == len(val_scores), "제대로하자"

    weights = torch.tensor(val_scores, dtype=torch.float32)
    weights = torch.softmax(weights, dim=0) if use_softmax_weights else weights / weights.sum()

    models = []
    for idx, (path, arch) in enumerate(zip(ckpt_paths, archs), 1):
        model = load_model(path, arch, num_scalar_features=num_scalar_features, device=device)
        models.append(model)

    all_ids = []
    all_probs = []

    with torch.no_grad():
        for feats, scalars, ids in test_loader:
            feats, scalars = feats.to(device), scalars.to(device)

            logits_stack = torch.stack([model(feats, scalars).view(-1) for model in models])
            probs_stack = torch.sigmoid(logits_stack)

            weighted_avg = (weights[:, None].to(device) * probs_stack).sum(dim=0)
            all_probs.append(weighted_avg.cpu().numpy())
            all_ids.extend(ids)

    return all_ids, np.concatenate(all_probs)

## DataLoaders

In [None]:
train_df = pd.read_csv(train_csv_path)
test_df = pd.read_csv(test_csv_path)

train_df_split, val_df_split = train_test_split(train_df, test_size=0.20, shuffle=True, random_state=42)

train_dataset = DS(data_frame=train_df_split, feature_dir=precomputed_dir, is_training=True)
val_dataset = DS(data_frame=val_df_split, feature_dir=precomputed_dir, is_training=True)
test_dataset = DS(data_frame=test_df, feature_dir=precomputed_dir, is_training=False)

batch_size = 512
num_workers = 8
prefetch = 2

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=prefetch,
    collate_fn=collate_fn,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size * 2,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=prefetch,
    collate_fn=collate_fn,
    drop_last=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size * 2,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=prefetch,
    collate_fn=collate_fn
)

## Train

In [None]:
num_scalars = 36
epochs = 100

In [None]:
cnn8_model = CNN8(num_scalar_features=num_scalars)
cnn8_ckpt, cnn8_val_acc = train_model(
    model      = cnn8_model,
    train_loader=train_loader,
    val_loader  = val_loader,
    device      = device,
    save_dir    = "./checkpoints/cnn8",
    num_epochs  = epochs,
    base_lr=4e-4,
    weight_decay=1e-4,
    use_cutmix=True,
    use_mixup=True,
    cutmix_prob=0.6,
    mixup_prob=0.4,
    patience=25,
    warmup_epochs=4,
)
print(f"CNN8 best‐val‐acc = {cnn8_val_acc:.4f}, saved to {cnn8_ckpt}")

In [None]:
vgg_model = VGG(num_scalar_features=num_scalars)
vgg_ckpt, vgg_val_acc = train_model(
    model      = vgg_model,
    train_loader=train_loader,
    val_loader  = val_loader,
    device      = device,
    save_dir    = "./checkpoints/vgg",
    num_epochs  = 140,
    patience=55,
)
print(f"VGG best‐val‐acc = {vgg_val_acc:.4f}, saved to {vgg_ckpt}")

## Inference

In [None]:
# ckpt_paths = [cnn8_ckpt, resnet_ckpt, convgru_ckpt, vgg_ckpt, vgg_large_ckpt, convnext_ckpt, resnet_se_stochdepth_ckpt]
# raw_scores = [cnn8_val_acc, resnet_val_acc, convgru_val_acc, vgg_val_acc, vgg_large_val_acc, convnext_val_acc, resnet_se_stochdepth_val_acc]
# archs = ["cnn8", "resnet", "convgru", "vgg", "vgg_large", "convnext", "resnet_se_stochdepth"]

In [None]:
ckpt_paths = [cnn8_ckpt, vgg_ckpt]
raw_scores = [cnn8_val_acc, vgg_val_acc]
archs = ["cnn8", "vgg"]

In [None]:
all_ids, avg_probs = weighted_ensemble(
    ckpt_paths=ckpt_paths,
    archs=archs,
    test_loader=test_loader,
    device=device,
    num_scalar_features=num_scalars,
    val_scores=raw_scores
)

predictions = (avg_probs > 0.5).astype(int)
final_labels = ["E" if p == 1 else "I" for p in predictions]

submission_df = pd.DataFrame({
    "ID": all_ids,
    "Target": final_labels
})
submission_df.to_csv("./ensemble_submission.csv", index=False)
print(submission_df.head(10))