In [2]:
import os, random, time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from tqdm import tqdm
from torch.amp import autocast
from torch.cuda.amp import autocast, GradScaler

# --- Í≤ΩÎ°ú ÏÑ§Ï†ï Î∞è ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ---
BASE_DIR     = f"./data/covid19-xray-severity-scoring/"
CSV_PATH     = str(Path(BASE_DIR) / "Brixia.csv")
IMAGE_DIR    = str(Path(BASE_DIR) / "segmented_png")

OUT_DIR      = "./runs_severity_classification"
BEST_PATH    = str(Path(OUT_DIR) / "best_efficientnet_b0_classification.pth")
PHASE2_PATH  = str(Path(OUT_DIR) / "phase2_weighted_classification.pth")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED         = 42
IMG_SIZE     = 224
BATCH_SIZE   = 128  # üîÑ Î∂ÑÎ•òÎäî Î∞∞Ïπò ÌÅ¨Í∏∞ Ï§ÑÏûÑ
NUM_CLASSES  = 4   # ‚úÖ NEW: 0, 1, 2, 3
EPOCHS_PHASE1 = 30
EPOCHS_PHASE2 = 50
LR           = 1e-4
WEIGHT_DECAY = 5e-4
AMP          = True
EARLY_STOP_ACC = 0.75  # üîÑ MAE ‚Üí Accuracy
DROP_RATIO   = 0.3
AUG_RATIO    = 0.5
MIXUP_ALPHA  = 0.2
LABEL_SMOOTHING = 0.1  # ‚úÖ NEW: Label smoothing

# --- ÏãúÎìú Í≥†Ï†ï ---
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
set_seed(SEED)

def make_transform_with_label(train: bool, img_size: int = IMG_SIZE, aug_ratio=AUG_RATIO):
    """Brixia ScoreÏùò Ï¢åÏö∞ Íµ¨Ï°∞Î•º Í≥†Î†§Ìïú transform"""
    def _tfm(img: Image.Image, label: torch.Tensor = None):
        img = img.convert('RGB')
        img = TF.resize(
            img, 
            [img_size, img_size], 
            interpolation=TF.InterpolationMode.BILINEAR,
            antialias=True
        )
        
        if train:
            # 1. Horizontal Flip (ÎùºÎ≤®ÎèÑ Ìï®Íªò flip)
            if random.random() < aug_ratio:
                img = TF.hflip(img)
                if label is not None:
                    label = label[[3, 4, 5, 0, 1, 2]]
            
            # 2. ÏïΩÌïú ÌöåÏ†Ñ (¬±5ÎèÑ)
            if random.random() < aug_ratio:
                angle = float(torch.empty(1).uniform_(-5, 5))
                img = TF.rotate(
                    img, 
                    angle, 
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 3. ÏïΩÌïú Translation
            if random.random() < aug_ratio:
                max_dx = 0.05 * img_size
                max_dy = 0.05 * img_size
                translations = (
                    float(torch.empty(1).uniform_(-max_dx, max_dx)),
                    float(torch.empty(1).uniform_(-max_dy, max_dy))
                )
                img = TF.affine(
                    img,
                    angle=0,
                    translate=translations,
                    scale=1.0,
                    shear=0,
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 4. Brightness & Contrast
            if random.random() < aug_ratio:
                brightness_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_brightness(img, brightness_factor)
            
            if random.random() < aug_ratio:
                contrast_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_contrast(img, contrast_factor)
            
            # 5. Gamma Correction
            if random.random() < 0.3:
                gamma = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_gamma(img, gamma)
        
        # Tensor Î≥ÄÌôò
        img = TF.to_tensor(img)
        
        # Gaussian Noise (train only)
        if train and random.random() < 0.2:
            noise = torch.randn_like(img) * 0.01
            img = img + noise
            img = torch.clamp(img, 0, 1)
        
        # Ï†ïÍ∑úÌôî
        img = TF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        if label is not None:
            return img, label
        return img
    
    return _tfm

def load_and_split_brixia(csv_path, val_ratio=0.2, seed=SEED):
    df = pd.read_csv(csv_path, dtype={'BrixiaScore': str})
    df = df.dropna(subset=['BrixiaScore'])
    df = df[df['BrixiaScore'] != 'nan']
    df = df[df['BrixiaScore'].str.len() == 6].copy()
    
    print(f"Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: {len(df)}Í∞ú")
    
    if 'ConsensusTestset' in df.columns:
        test_df = df[df['ConsensusTestset'] == 1].copy()
        train_val_df = df[df['ConsensusTestset'] == 0].copy()
    else:
        test_df = pd.DataFrame()
        train_val_df = df.copy()
    
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    train_idx, val_idx = next(gss.split(
        train_val_df, 
        groups=train_val_df['StudyId']
    ))
    
    tr_df = train_val_df.iloc[train_idx].copy()
    val_df = train_val_df.iloc[val_idx].copy()
    
    print(f"Train: {len(tr_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    validate_split(tr_df, val_df, test_df)
    
    return tr_df, val_df, test_df

def validate_split(tr_df, val_df, tt_df):
    train_studies = set(tr_df['StudyId'])
    val_studies = set(val_df['StudyId'])
    test_studies = set(tt_df['StudyId']) if len(tt_df) > 0 else set()
    
    assert len(train_studies & val_studies) == 0, "Train-Val Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(train_studies & test_studies) == 0, "Train-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(val_studies & test_studies) == 0, "Val-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    
    for name, data in [('Train', tr_df), ('Val', val_df), ('Test', tt_df)]:
        if len(data) > 0:
            scores = data['BrixiaScore'].apply(lambda x: sum(int(c) for c in x))
            print(f"{name} - Mean: {scores.mean():.2f}, Std: {scores.std():.2f}")
    
    return True

# ============================================================
# Dataset
# ============================================================
class BrixiaDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.img_col = "Filename"
        self.label_col = "BrixiaScore"
        self._validate_data()
    
    def _validate_data(self):
        assert self.img_col in self.df.columns
        assert self.label_col in self.df.columns
        
        invalid_scores = self.df[self.df[self.label_col].str.len() != 6]
        if len(invalid_scores) > 0:
            print(f"‚ö†Ô∏è Í≤ΩÍ≥†: {len(invalid_scores)}Í∞úÏùò ÏûòÎ™ªÎêú BrixiaScore Î∞úÍ≤¨")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_name_from_csv = row[self.img_col]
        img_name = img_name_from_csv.replace('.dcm', '.png')
        img_path = os.path.join(self.img_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"‚ùå Ïù¥ÎØ∏ÏßÄ Î°úÎìú Ïò§Î•ò: {img_path}")
            raise
        
        scores_str = row[self.label_col]
        scores_list = [int(c) for c in scores_str]
        labels = torch.tensor(scores_list, dtype=torch.long)  # üîÑ longÏúºÎ°ú Î≥ÄÍ≤Ω
        
        if self.transform:
            image, labels = self.transform(image, labels)
        else:
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, labels

def create_dataloaders(tr_df, val_df, tt_df, img_dir, 
                       batch_size=32, img_size=224, num_workers=4):
    train_transform = make_transform_with_label(train=True, img_size=img_size)
    val_transform = make_transform_with_label(train=False, img_size=img_size)
    
    tr_ds = BrixiaDataset(tr_df, img_dir, transform=train_transform)
    val_ds = BrixiaDataset(val_df, img_dir, transform=val_transform)
    tt_ds = BrixiaDataset(tt_df, img_dir, transform=val_transform)
    
    tr_loader = DataLoader(
        tr_ds, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    tt_loader = DataLoader(
        tt_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    print("‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å")
    print(f"   Train: {len(tr_ds)} samples, {len(tr_loader)} batches")
    print(f"   Val:   {len(val_ds)} samples, {len(val_loader)} batches")
    print(f"   Test:  {len(tt_ds)} samples, {len(tt_loader)} batches")
    
    return tr_loader, val_loader, tt_loader

# ============================================================
# Loss Function (Î∂ÑÎ•òÏö©ÏúºÎ°ú Î≥ÄÍ≤Ω)
# ============================================================
def calculate_class_weights(labels, num_classes=4, method='sqrt_inverse'):
    """ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ìï¥Í≤∞ÏùÑ ÏúÑÌïú Í∞ÄÏ§ëÏπò Í≥ÑÏÇ∞"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    labels_flat = labels.flatten()
    counts = np.bincount(labels_flat.astype(int), minlength=num_classes)
    
    if method == 'sqrt_inverse':
        weights = 1.0 / (np.sqrt(counts) + 1e-6)
    elif method == 'inverse':
        weights = 1.0 / (counts + 1e-6)
    else:
        total = len(labels_flat)
        weights = total / (num_classes * (counts + 1e-6))
    
    weights = weights / weights.mean()
    return torch.FloatTensor(weights)

def print_class_distribution(labels):
    """ÌÅ¥ÎûòÏä§ Î∂ÑÌè¨ ÏãúÍ∞ÅÌôî"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    print("=" * 60)
    print("Class Distribution Analysis")
    print("=" * 60)
    
    region_names = ['A', 'B', 'C', 'D', 'E', 'F']
    for idx, name in enumerate(region_names):
        region_labels = labels[:, idx]
        counts = np.bincount(region_labels.astype(int), minlength=4)
        total = counts.sum()
        
        print(f"\n{name}:")
        for cls in range(4):
            pct = 100 * counts[cls] / total if total > 0 else 0
            bar = '‚ñà' * int(pct / 2)
            print(f"  Class {cls}: {counts[cls]:4d} ({pct:5.1f}%) {bar}")
    
    print("=" * 60)

class AdaptiveClassificationLoss(nn.Module):
    """
    ‚úÖ NEW: Îã§Ï§ë ÌÅ¥ÎûòÏä§ Î∂ÑÎ•òÎ•º ÏúÑÌïú ÏÜêÏã§Ìï®Ïàò
    - CrossEntropyLoss Í∏∞Î∞ò
    - ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    - Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    - Label Smoothing ÏßÄÏõê
    """
    
    def __init__(self, train_labels, num_classes=4, use_class_weights=True, 
                 part_weights=None, label_smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.label_smoothing = label_smoothing
        
        if isinstance(train_labels, torch.Tensor):
            train_labels_np = train_labels.cpu().numpy()
        else:
            train_labels_np = train_labels
        
        print_class_distribution(train_labels_np)
        
        # ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
        if use_class_weights:
            class_weights = calculate_class_weights(train_labels_np, num_classes=num_classes)
            self.register_buffer('class_weights', class_weights)
            print(f"‚úÖ Class weights: {class_weights.numpy()}")
        else:
            self.class_weights = None
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò
        if part_weights is None:
            self.part_weights = torch.ones(6)
        else:
            self.part_weights = torch.tensor(part_weights, dtype=torch.float32)
        self.register_buffer('part_weights_buf', self.part_weights)
        print(f"‚úÖ Part weights: {self.part_weights.numpy()}")

    def forward(self, pred_logits, target, use_mixup=False):
        """
        pred_logits: [B, 6, 4] - Í∞Å Î∂ÄÏúÑÎ≥Ñ 4Í∞ú ÌÅ¥ÎûòÏä§Ïóê ÎåÄÌïú logits
        target: [B, 6] - Í∞Å Î∂ÄÏúÑÎ≥Ñ ÌÅ¥ÎûòÏä§ Î†àÏù¥Î∏î (0~3)
        """
        B, num_regions, num_classes = pred_logits.shape
        
        # Reshape for loss computation
        pred_logits_flat = pred_logits.view(B * num_regions, num_classes)  # [B*6, 4]
        target_flat = target.view(B * num_regions)  # [B*6]
        
        # CrossEntropyLoss with label smoothing
        if self.class_weights is not None and not use_mixup:
            criterion = nn.CrossEntropyLoss(
                weight=self.class_weights.to(pred_logits.device),
                label_smoothing=self.label_smoothing,
                reduction='none'
            )
        else:
            criterion = nn.CrossEntropyLoss(
                label_smoothing=self.label_smoothing,
                reduction='none'
            )
        
        loss = criterion(pred_logits_flat, target_flat)  # [B*6]
        loss = loss.view(B, num_regions)  # [B, 6]
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
        part_weights = self.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)  # [B, 6]
        
        # Î∂ÄÏúÑÎ≥Ñ ÌèâÍ∑† ÏÜêÏã§ (Î™®ÎãàÌÑ∞ÎßÅÏö©)
        part_losses = loss.mean(dim=0)  # [6]
        
        return loss.mean(), part_losses

# ============================================================
# Mixup (Î∂ÑÎ•òÏö©)
# ============================================================
def mixup_data_classification(x, y, alpha=MIXUP_ALPHA):
    """
    ‚úÖ NEW: Î∂ÑÎ•òÏö© Mixup
    yÎäî one-hotÏúºÎ°ú Î≥ÄÌôò ÌõÑ mixup
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    
    # One-hot encoding
    y_onehot = F.one_hot(y, num_classes=NUM_CLASSES).float()  # [B, 6, 4]
    y_onehot_shuffled = y_onehot[index]
    
    mixed_y = lam * y_onehot + (1 - lam) * y_onehot_shuffled  # [B, 6, 4]
    
    return mixed_x, mixed_y, lam

def mixup_criterion(criterion, pred_logits, y_mixed, lam):
    """MixupÏùÑ ÏúÑÌïú ÏÜêÏã§ Í≥ÑÏÇ∞"""
    # y_mixedÎäî [B, 6, 4] soft labels
    # pred_logitsÎäî [B, 6, 4]
    
    B, num_regions, num_classes = pred_logits.shape
    
    # Reshape
    pred_flat = pred_logits.view(B * num_regions, num_classes)  # [B*6, 4]
    target_flat = y_mixed.view(B * num_regions, num_classes)  # [B*6, 4]
    
    # Soft target loss
    log_probs = F.log_softmax(pred_flat, dim=1)
    loss = -(target_flat * log_probs).sum(dim=1)  # [B*6]
    loss = loss.view(B, num_regions)  # [B, 6]
    
    # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö© (criterionÏóêÏÑú Í∞ÄÏ†∏Ïò¥)
    if hasattr(criterion, 'part_weights_buf'):
        part_weights = criterion.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)
    
    part_losses = loss.mean(dim=0)
    return loss.mean(), part_losses

# ============================================================
# Metrics (Î∂ÑÎ•òÏö©)
# ============================================================
@torch.no_grad()
def calculate_classification_metrics(pred_logits, labels):
    """
    ‚úÖ NEW: Î∂ÑÎ•ò ÏßÄÌëú Í≥ÑÏÇ∞
    pred_logits: [B, 6, 4]
    labels: [B, 6]
    """
    # ÏòàÏ∏° ÌÅ¥ÎûòÏä§
    preds = pred_logits.argmax(dim=-1)  # [B, 6]
    
    # Exact match accuracy
    exact_acc = (preds == labels).float().mean().item()
    
    # Off-by-1 accuracy (Ïù∏Ï†ë ÌÅ¥ÎûòÏä§ ÌóàÏö©)
    off_by_1 = (torch.abs(preds - labels) <= 1).float().mean().item()
    
    # Per-region accuracy
    region_acc = (preds == labels).float().mean(dim=0)  # [6]
    
    # MAE (Ï∞∏Í≥†Ïö©)
    mae = torch.abs(preds.float() - labels.float()).mean().item()
    
    return exact_acc, off_by_1, mae, region_acc

# ============================================================
# Model (Î∂ÑÎ•òÏö©ÏúºÎ°ú Î≥ÄÍ≤Ω)
# ============================================================
class EfficientNetB0Classification(nn.Module):
    """
    ‚úÖ NEW: Î∂ÑÎ•òÎ•º ÏúÑÌïú Î™®Îç∏
    Í∞Å Î∂ÄÏúÑ(A~F)ÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§(0~3) ÏòàÏ∏°
    """
    def __init__(self, pretrained=True, drop=0.3, num_regions=6, num_classes=4):
        super().__init__()
        
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = efficientnet_b0(weights=weights)
        
        self.features = backbone.features
        in_feat = 1280
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, 49, in_feat))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Region queries
        self.region_queries = nn.Parameter(torch.randn(num_regions, in_feat))
        nn.init.xavier_uniform_(self.region_queries)
        
        # Cross attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(in_feat)
        self.norm2 = nn.LayerNorm(in_feat)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(in_feat, in_feat * 2),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(in_feat * 2, in_feat),
            nn.Dropout(drop)
        )
        
        # üîÑ Classification heads: Í∞Å Î∂ÄÏúÑÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§ ÏòàÏ∏°
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_feat, 256),
                nn.LayerNorm(256),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(256, 128),
                nn.LayerNorm(128),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(128, num_classes)  # üîÑ 4Í∞ú ÌÅ¥ÎûòÏä§ Ï∂úÎ†•
            ) for _ in range(num_regions)
        ])
    
    def forward(self, x):
        B = x.size(0)
        
        # Feature extraction
        feat = self.features(x)  # [B, 1280, 7, 7]
        feat = feat.flatten(2).transpose(1, 2)  # [B, 49, 1280]
        feat = feat + self.pos_embed
        
        # Region queries
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)  # [B, 6, 1280]
        
        # Cross attention
        attn_out, _ = self.cross_attention(
            query=queries,
            key=feat,
            value=feat
        )
        
        attn_out = self.norm1(attn_out + queries)
        ffn_out = self.ffn(attn_out)
        attn_out = self.norm2(attn_out + ffn_out)  # [B, 6, 1280]
        
        # üîÑ Í∞Å Î∂ÄÏúÑÎ≥Ñ classification head ÌÜµÍ≥º
        outputs = []
        for i in range(len(self.heads)):
            region_feat = attn_out[:, i, :]  # [B, 1280]
            logits = self.heads[i](region_feat)  # [B, 4]
            outputs.append(logits)
        
        out = torch.stack(outputs, dim=1)  # [B, 6, 4]
        return out

# ============================================================
# Training Functions (Î∂ÑÎ•òÏö© ÏàòÏ†ï)
# ============================================================
def train_epoch(model, tr_loader, criterion, optimizer, scaler, device, 
                amp=True, use_mixup=True):
    model.train()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    
    pbar = tqdm(tr_loader, desc="Train", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)  # [B, 6]
        
        # Mixup Ï†ÅÏö©
        is_mixup = use_mixup and (random.random() < 0.5)
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=amp):
            pred_logits = model(imgs)  # [B, 6, 4]
            
            if is_mixup:
                imgs_mixed, labels_mixed, lam = mixup_data_classification(imgs, labels)
                pred_logits = model(imgs_mixed)
                loss, part_losses = mixup_criterion(criterion, pred_logits, labels_mixed, lam)
            else:
                loss, part_losses = criterion(pred_logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Metrics (ÏõêÎ≥∏ labelsÎ°ú Í≥ÑÏÇ∞)
        exact_acc, off_by_1, mae, _ = calculate_classification_metrics(pred_logits.detach(), labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}",
            mae=f"{run_mae/n:.4f}"
        )
    
    part_losses_avg = part_losses_sum / n
    
    return run_loss/n, run_acc/n, run_off1/n, run_mae/n, part_losses_avg

@torch.no_grad()
def evaluate(model, val_loader, criterion, device, split='val'):
    model.eval()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    region_acc_sum = torch.zeros(6)
    
    pbar = tqdm(val_loader, desc=f"{split.capitalize()}", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        pred_logits = model(imgs)  # [B, 6, 4]
        loss, part_losses = criterion(pred_logits, labels)
        
        exact_acc, off_by_1, mae, region_acc = calculate_classification_metrics(pred_logits, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        region_acc_sum += region_acc.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}"
        )
    
    avg_loss = run_loss/n
    avg_acc = run_acc/n
    avg_off1 = run_off1/n
    avg_mae = run_mae/n
    part_losses_avg = part_losses_sum / n
    region_acc_avg = region_acc_sum / n
    
    print(f"[{split}] loss:{avg_loss:.4f} acc:{avg_acc:.4f} "
          f"off1:{avg_off1:.4f} mae:{avg_mae:.4f}")
    print(f"  Region Acc: {region_acc_avg.numpy().round(3)}")
    
    return avg_loss, avg_acc, avg_off1, avg_mae, part_losses_avg, region_acc_avg

def get_lrs(optimizer):
    return [pg['lr'] for pg in optimizer.param_groups]

# ============================================================
# Main Function
# ============================================================
def main():
    print("\n" + "="*70)
    print("üöÄ Two-Phase Classification Training")
    print("="*70)
    
    # Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
    print("\nüìÇ Loading data...")
    tr_df, val_df, tt_df = load_and_split_brixia(CSV_PATH)
    
    print("\nüì¶ Creating DataLoaders...")
    tr_loader, val_loader, tt_loader = create_dataloaders(
        tr_df, val_df, tt_df, img_dir=IMAGE_DIR, 
        batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=4
    )
    
    # Train labels Ï∂îÏ∂ú
    train_labels = torch.cat([labels for _, labels in tr_loader], dim=0)
    
    # ========================================
    # Phase 1: Í∑†Îì± Í∞ÄÏ§ëÏπò ÌïôÏäµ
    # ========================================
    print("\n" + "="*70)
    print("üìç PHASE 1: Training with Uniform Weights")
    print("="*70)
    
    criterion_phase1 = AdaptiveClassificationLoss(
        train_labels, 
        num_classes=NUM_CLASSES,
        use_class_weights=True,
        part_weights=None,
        label_smoothing=LABEL_SMOOTHING
    )
    
    model = EfficientNetB0Classification(
        pretrained=True, 
        drop=DROP_RATIO,
        num_regions=6,
        num_classes=NUM_CLASSES
    ).to(DEVICE)
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6
    )
    scaler = GradScaler(enabled=AMP)
    
    best_acc_phase1 = 0.0
    patience_counter = 0
    max_patience = 10
    
    # Phase 1 ÌïôÏäµ
    phase1_part_losses = []
    
    for ep in range(1, EPOCHS_PHASE1 + 1):
        t0 = time.time()
        
        tr_loss, tr_acc, tr_off1, tr_mae, tr_part_losses = train_epoch(
            model, tr_loader, criterion_phase1, optimizer, scaler, DEVICE, AMP
        )
        
        val_loss, val_acc, val_off1, val_mae, val_part_losses, val_region_acc = evaluate(
            model, val_loader, criterion_phase1, DEVICE, split='val'
        )
        
        scheduler.step(val_acc)
        
        # Î∂ÄÏúÑÎ≥Ñ ÏÜêÏã§ ÎàÑÏ†Å
        phase1_part_losses.append(val_part_losses.numpy())
        
        # Save best
        if val_acc > best_acc_phase1:
            best_acc_phase1 = val_acc
            patience_counter = 0
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'val_acc': best_acc_phase1,
                'val_mae': val_mae,
            }, BEST_PATH)
            print(f"‚úÖ Phase 1 Best (Acc={best_acc_phase1:.4f}, MAE={val_mae:.4f})")
        else:
            patience_counter += 1
        
        if patience_counter >= max_patience:
            print(f"\n‚èπÔ∏è Phase 1 Early stopping at epoch {ep}")
            break
        
        elapsed = time.time() - t0
        print(f"\n[Phase1 Epoch {ep:02d}/{EPOCHS_PHASE1}]")
        print(f"  Train - acc:{tr_acc:.4f} off1:{tr_off1:.4f} mae:{tr_mae:.4f}")
        print(f"  Val   - acc:{val_acc:.4f} off1:{val_off1:.4f} mae:{val_mae:.4f}")
        print(f"  Part losses (Val): {val_part_losses.numpy().round(3)}")
        print(f"  LR:{get_lrs(optimizer)[0]:.2e} | {elapsed:.1f}s | Pat:{patience_counter}/{max_patience}")
        print("-" * 70)
    
    # Phase 1 Î∂ÄÏúÑÎ≥Ñ ÌèâÍ∑† ÏÜêÏã§ Í≥ÑÏÇ∞
    avg_part_losses = np.mean(phase1_part_losses, axis=0)
    print("\n" + "="*70)
    print("üìä Phase 1 Average Part Losses:")
    region_names = ['A', 'B', 'C', 'D', 'E', 'F']
    for i, name in enumerate(region_names):
        print(f"   {name}: {avg_part_losses[i]:.4f}")
    print("="*70)
    
    # ========================================
    # Phase 2: ÎÇúÏù¥ÎèÑ Í∏∞Î∞ò Í∞ÄÏ§ëÏπò
    # ========================================
    print("\n" + "="*70)
    print("üìç PHASE 2: Training with Difficulty-Based Weights")
    print("="*70)
    
    # ÎÇúÏù¥ÎèÑ Í∏∞Î∞ò Í∞ÄÏ§ëÏπò Í≥ÑÏÇ∞
    normalized_losses = avg_part_losses / avg_part_losses.mean()
    part_weights_phase2 = normalized_losses ** 0.5
    part_weights_phase2 = part_weights_phase2.tolist()
    
    print(f"\nüéØ Calculated Part Weights:")
    for i, name in enumerate(region_names):
        print(f"   {name}: {part_weights_phase2[i]:.3f} (loss: {avg_part_losses[i]:.4f})")
    
    # Phase 2 criterion
    criterion_phase2 = AdaptiveClassificationLoss(
        train_labels,
        num_classes=NUM_CLASSES,
        use_class_weights=True,
        part_weights=part_weights_phase2,
        label_smoothing=LABEL_SMOOTHING
    )
    
    # Phase 1 best model Î°úÎìú
    model.load_state_dict(torch.load(BEST_PATH)['model_state_dict'])
    
    # ÏÉàÎ°úÏö¥ optimizer & scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR*0.5, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-7
    )
    scaler = GradScaler(enabled=AMP)
    
    best_acc_phase2 = 0.0
    patience_counter = 0
    
    # Phase 2 ÌïôÏäµ
    for ep in range(1, EPOCHS_PHASE2 + 1):
        t0 = time.time()
        
        tr_loss, tr_acc, tr_off1, tr_mae, tr_part_losses = train_epoch(
            model, tr_loader, criterion_phase2, optimizer, scaler, DEVICE, AMP
        )
        
        val_loss, val_acc, val_off1, val_mae, val_part_losses, val_region_acc = evaluate(
            model, val_loader, criterion_phase2, DEVICE, split='val'
        )
        
        scheduler.step(val_acc)
        
        # Early stopping check
        if val_acc >= EARLY_STOP_ACC:
            print(f"\n‚úÖ Target Acc {EARLY_STOP_ACC} reached at epoch {ep}!")
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'val_acc': val_acc,
                'val_mae': val_mae,
                'part_weights': part_weights_phase2
            }, PHASE2_PATH)
            break
        
        # Save best
        if val_acc > best_acc_phase2:
            best_acc_phase2 = val_acc
            patience_counter = 0
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'val_acc': best_acc_phase2,
                'val_mae': val_mae,
                'part_weights': part_weights_phase2
            }, PHASE2_PATH)
            print(f"‚úÖ Phase 2 Best (Acc={best_acc_phase2:.4f}, MAE={val_mae:.4f})")
        else:
            patience_counter += 1
        
        if patience_counter >= max_patience:
            print(f"\n‚èπÔ∏è Phase 2 Early stopping at epoch {ep}")
            break
        
        elapsed = time.time() - t0
        print(f"\n[Phase2 Epoch {ep:02d}/{EPOCHS_PHASE2}]")
        print(f"  Train - acc:{tr_acc:.4f} off1:{tr_off1:.4f} mae:{tr_mae:.4f}")
        print(f"  Val   - acc:{val_acc:.4f} off1:{val_off1:.4f} mae:{val_mae:.4f}")
        print(f"  Part losses (Val): {val_part_losses.numpy().round(3)}")
        print(f"  LR:{get_lrs(optimizer)[0]:.2e} | {elapsed:.1f}s | Pat:{patience_counter}/{max_patience}")
        print("-" * 70)
    
    # ========================================
    # Test Evaluation
    # ========================================
    print("\n" + "="*70)
    print("üéâ Training Finished!")
    print("="*70)
    
    if len(tt_loader) > 0:
        print("\nüìä Test evaluation with Phase 2 model...")
        model.load_state_dict(torch.load(PHASE2_PATH)['model_state_dict'])
        tt_loss, tt_acc, tt_off1, tt_mae, tt_part_losses, tt_region_acc = evaluate(
            model, tt_loader, criterion_phase2, DEVICE, split='test'
        )
        print(f"\nüèÜ Test Results:")
        print(f"   Accuracy: {tt_acc:.4f}")
        print(f"   Off-by-1: {tt_off1:.4f}")
        print(f"   MAE: {tt_mae:.4f}")
        print(f"   Region Acc: {tt_region_acc.numpy().round(3)}")
        print(f"   Part losses: {tt_part_losses.numpy().round(3)}")
    
    print(f"\nüíæ Phase 1 model saved: {BEST_PATH}")
    print(f"üíæ Phase 2 model saved: {PHASE2_PATH}")
    
    print("\nüìà Summary:")
    print(f"   Phase 1 Best Acc: {best_acc_phase1:.4f}")
    print(f"   Phase 2 Best Acc: {best_acc_phase2:.4f}")
    print(f"   Improvement: {(best_acc_phase2 - best_acc_phase1):.4f}")

if __name__ == "__main__":
    main()


üöÄ Two-Phase Classification Training

üìÇ Loading data...
Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: 4695Í∞ú
Train: 3637, Val: 912, Test: 146
Train - Mean: 8.31, Std: 4.26
Val - Mean: 8.35, Std: 4.15
Test - Mean: 7.78, Std: 4.20

üì¶ Creating DataLoaders...
‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å
   Train: 3637 samples, 28 batches
   Val:   912 samples, 8 batches
   Test:  146 samples, 2 batches

üìç PHASE 1: Training with Uniform Weights
Class Distribution Analysis

A:
  Class 0: 1791 ( 50.0%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1: 1122 ( 31.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2:  457 ( 12.8%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  214 (  6.0%) ‚ñà‚ñà

B:
  Class 0:  740 ( 20.6%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1:  912 ( 25.4%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2: 1159 ( 32.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  773 ( 21.6%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà

C:
  Class 0:  392 ( 10.9%) ‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1:

                                                                                           

[val] loss:1.2036 acc:0.4492 off1:0.8598 mae:0.7092
  Region Acc: [0.49  0.394 0.404 0.593 0.404 0.411]
‚úÖ Phase 1 Best (Acc=0.4492, MAE=0.7092)

[Phase1 Epoch 01/30]
  Train - acc:0.3878 off1:0.8030 mae:0.8513
  Val   - acc:0.4492 off1:0.8598 mae:0.7092
  Part losses (Val): [1.228 1.262 1.242 1.015 1.245 1.231]
  LR:1.00e-04 | 18.5s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.1401 acc:0.4960 off1:0.9077 mae:0.6043
  Region Acc: [0.537 0.461 0.458 0.603 0.459 0.457]
‚úÖ Phase 1 Best (Acc=0.4960, MAE=0.6043)

[Phase1 Epoch 02/30]
  Train - acc:0.4324 off1:0.8436 mae:0.7551
  Val   - acc:0.4960 off1:0.9077 mae:0.6043
  Part losses (Val): [1.134 1.195 1.191 0.991 1.15  1.182]
  LR:1.00e-04 | 13.1s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.1196 acc:0.5217 off1:0.9092 mae:0.5740
  Region Acc: [0.579 0.473 0.481 0.638 0.495 0.465]
‚úÖ Phase 1 Best (Acc=0.5217, MAE=0.5740)

[Phase1 Epoch 03/30]
  Train - acc:0.4648 off1:0.8708 mae:0.6828
  Val   - acc:0.5217 off1:0.9092 mae:0.5740
  Part losses (Val): [1.111 1.165 1.172 0.969 1.135 1.166]
  LR:1.00e-04 | 13.5s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.1032 acc:0.5256 off1:0.9174 mae:0.5610
  Region Acc: [0.559 0.496 0.486 0.598 0.518 0.498]
‚úÖ Phase 1 Best (Acc=0.5256, MAE=0.5610)

[Phase1 Epoch 04/30]
  Train - acc:0.4704 off1:0.8710 mae:0.6797
  Val   - acc:0.5256 off1:0.9174 mae:0.5610
  Part losses (Val): [1.102 1.132 1.162 0.964 1.103 1.156]
  LR:1.00e-04 | 14.1s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0906 acc:0.5417 off1:0.9280 mae:0.5351
  Region Acc: [0.58  0.522 0.507 0.629 0.519 0.493]
‚úÖ Phase 1 Best (Acc=0.5417, MAE=0.5351)

[Phase1 Epoch 05/30]
  Train - acc:0.4817 off1:0.8811 mae:0.6576
  Val   - acc:0.5417 off1:0.9280 mae:0.5351
  Part losses (Val): [1.082 1.113 1.152 0.959 1.084 1.154]
  LR:1.00e-04 | 13.7s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0746 acc:0.5431 off1:0.9355 mae:0.5258
  Region Acc: [0.586 0.518 0.5   0.643 0.524 0.489]
‚úÖ Phase 1 Best (Acc=0.5431, MAE=0.5258)

[Phase1 Epoch 06/30]
  Train - acc:0.4834 off1:0.8760 mae:0.6637
  Val   - acc:0.5431 off1:0.9355 mae:0.5258
  Part losses (Val): [1.042 1.089 1.125 0.936 1.082 1.175]
  LR:1.00e-04 | 13.8s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0613 acc:0.5609 off1:0.9386 mae:0.5038
  Region Acc: [0.595 0.524 0.524 0.673 0.541 0.508]
‚úÖ Phase 1 Best (Acc=0.5609, MAE=0.5038)

[Phase1 Epoch 07/30]
  Train - acc:0.4972 off1:0.8841 mae:0.6400
  Val   - acc:0.5609 off1:0.9386 mae:0.5038
  Part losses (Val): [1.052 1.106 1.106 0.914 1.064 1.126]
  LR:1.00e-04 | 14.1s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0670 acc:0.5598 off1:0.9360 mae:0.5095
  Region Acc: [0.598 0.532 0.52  0.664 0.533 0.512]

[Phase1 Epoch 08/30]
  Train - acc:0.5209 off1:0.9062 mae:0.5871
  Val   - acc:0.5598 off1:0.9360 mae:0.5095
  Part losses (Val): [1.057 1.098 1.108 0.916 1.093 1.131]
  LR:1.00e-04 | 13.6s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0437 acc:0.5694 off1:0.9516 mae:0.4828
  Region Acc: [0.592 0.546 0.518 0.68  0.571 0.51 ]
‚úÖ Phase 1 Best (Acc=0.5694, MAE=0.4828)

[Phase1 Epoch 09/30]
  Train - acc:0.5169 off1:0.8982 mae:0.6026
  Val   - acc:0.5694 off1:0.9516 mae:0.4828
  Part losses (Val): [1.018 1.064 1.107 0.903 1.048 1.122]
  LR:1.00e-04 | 13.8s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0549 acc:0.5718 off1:0.9441 mae:0.4874
  Region Acc: [0.602 0.544 0.519 0.677 0.579 0.511]
‚úÖ Phase 1 Best (Acc=0.5718, MAE=0.4874)

[Phase1 Epoch 10/30]
  Train - acc:0.5225 off1:0.9066 mae:0.5855
  Val   - acc:0.5718 off1:0.9441 mae:0.4874
  Part losses (Val): [1.071 1.086 1.096 0.907 1.044 1.126]
  LR:1.00e-04 | 13.8s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0444 acc:0.5694 off1:0.9503 mae:0.4832
  Region Acc: [0.599 0.553 0.531 0.643 0.568 0.524]

[Phase1 Epoch 11/30]
  Train - acc:0.5349 off1:0.9080 mae:0.5714
  Val   - acc:0.5694 off1:0.9503 mae:0.4832
  Part losses (Val): [1.014 1.068 1.108 0.91  1.043 1.123]
  LR:1.00e-04 | 13.8s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0523 acc:0.5696 off1:0.9384 mae:0.4976
  Region Acc: [0.596 0.549 0.542 0.679 0.556 0.496]

[Phase1 Epoch 12/30]
  Train - acc:0.5412 off1:0.9136 mae:0.5590
  Val   - acc:0.5696 off1:0.9384 mae:0.4976
  Part losses (Val): [1.05  1.068 1.087 0.906 1.076 1.127]
  LR:1.00e-04 | 13.8s | Pat:2/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0441 acc:0.5755 off1:0.9488 mae:0.4792
  Region Acc: [0.613 0.548 0.536 0.664 0.568 0.523]
‚úÖ Phase 1 Best (Acc=0.5755, MAE=0.4792)

[Phase1 Epoch 13/30]
  Train - acc:0.5194 off1:0.8928 mae:0.6110
  Val   - acc:0.5755 off1:0.9488 mae:0.4792
  Part losses (Val): [1.023 1.069 1.094 0.913 1.037 1.13 ]
  LR:1.00e-04 | 13.7s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0639 acc:0.5603 off1:0.9412 mae:0.5020
  Region Acc: [0.588 0.559 0.542 0.641 0.534 0.498]

[Phase1 Epoch 14/30]
  Train - acc:0.5144 off1:0.8913 mae:0.6160
  Val   - acc:0.5603 off1:0.9412 mae:0.5020
  Part losses (Val): [1.02  1.08  1.132 0.922 1.08  1.15 ]
  LR:1.00e-04 | 13.3s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0581 acc:0.5742 off1:0.9463 mae:0.4845
  Region Acc: [0.607 0.546 0.549 0.674 0.572 0.496]

[Phase1 Epoch 15/30]
  Train - acc:0.5000 off1:0.8802 mae:0.6452
  Val   - acc:0.5742 off1:0.9463 mae:0.4845
  Part losses (Val): [1.072 1.08  1.075 0.915 1.062 1.144]
  LR:1.00e-04 | 13.0s | Pat:2/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0419 acc:0.5810 off1:0.9530 mae:0.4693
  Region Acc: [0.598 0.561 0.546 0.666 0.594 0.521]
‚úÖ Phase 1 Best (Acc=0.5810, MAE=0.4693)

[Phase1 Epoch 16/30]
  Train - acc:0.5437 off1:0.9072 mae:0.5673
  Val   - acc:0.5810 off1:0.9530 mae:0.4693
  Part losses (Val): [1.013 1.064 1.098 0.904 1.043 1.13 ]
  LR:1.00e-04 | 12.8s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0391 acc:0.5786 off1:0.9492 mae:0.4761
  Region Acc: [0.6   0.555 0.542 0.669 0.578 0.529]

[Phase1 Epoch 17/30]
  Train - acc:0.5112 off1:0.8800 mae:0.6331
  Val   - acc:0.5786 off1:0.9492 mae:0.4761
  Part losses (Val): [1.011 1.058 1.1   0.906 1.038 1.121]
  LR:1.00e-04 | 13.3s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0531 acc:0.5853 off1:0.9543 mae:0.4645
  Region Acc: [0.6   0.581 0.56  0.673 0.584 0.513]
‚úÖ Phase 1 Best (Acc=0.5853, MAE=0.4645)

[Phase1 Epoch 18/30]
  Train - acc:0.5915 off1:0.9363 mae:0.4816
  Val   - acc:0.5853 off1:0.9543 mae:0.4645
  Part losses (Val): [1.043 1.058 1.091 0.938 1.046 1.143]
  LR:1.00e-04 | 13.6s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0398 acc:0.5762 off1:0.9485 mae:0.4793
  Region Acc: [0.593 0.568 0.548 0.662 0.576 0.51 ]

[Phase1 Epoch 19/30]
  Train - acc:0.5342 off1:0.8969 mae:0.5895
  Val   - acc:0.5762 off1:0.9485 mae:0.4793
  Part losses (Val): [1.013 1.05  1.084 0.916 1.048 1.128]
  LR:1.00e-04 | 13.0s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0535 acc:0.5777 off1:0.9521 mae:0.4731
  Region Acc: [0.588 0.564 0.552 0.652 0.587 0.524]

[Phase1 Epoch 20/30]
  Train - acc:0.5758 off1:0.9170 mae:0.5226
  Val   - acc:0.5777 off1:0.9521 mae:0.4731
  Part losses (Val): [1.036 1.066 1.091 0.935 1.055 1.139]
  LR:1.00e-04 | 13.0s | Pat:2/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0484 acc:0.5711 off1:0.9476 mae:0.4852
  Region Acc: [0.609 0.556 0.541 0.65  0.575 0.497]

[Phase1 Epoch 21/30]
  Train - acc:0.5525 off1:0.9025 mae:0.5646
  Val   - acc:0.5711 off1:0.9476 mae:0.4852
  Part losses (Val): [1.035 1.069 1.088 0.93  1.036 1.132]
  LR:1.00e-04 | 13.2s | Pat:3/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0643 acc:0.5707 off1:0.9556 mae:0.4766
  Region Acc: [0.589 0.567 0.555 0.649 0.57  0.495]

[Phase1 Epoch 22/30]
  Train - acc:0.5213 off1:0.8723 mae:0.6318
  Val   - acc:0.5707 off1:0.9556 mae:0.4766
  Part losses (Val): [1.025 1.094 1.098 0.929 1.07  1.171]
  LR:1.00e-04 | 13.2s | Pat:4/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0811 acc:0.5689 off1:0.9565 mae:0.4783
  Region Acc: [0.566 0.575 0.553 0.637 0.578 0.505]

[Phase1 Epoch 23/30]
  Train - acc:0.5458 off1:0.8896 mae:0.5896
  Val   - acc:0.5689 off1:0.9565 mae:0.4783
  Part losses (Val): [1.054 1.088 1.108 0.969 1.076 1.19 ]
  LR:1.00e-04 | 12.9s | Pat:5/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0769 acc:0.5711 off1:0.9518 mae:0.4806
  Region Acc: [0.61  0.565 0.541 0.641 0.558 0.512]

[Phase1 Epoch 24/30]
  Train - acc:0.5398 off1:0.8823 mae:0.6072
  Val   - acc:0.5711 off1:0.9518 mae:0.4806
  Part losses (Val): [1.053 1.087 1.132 0.95  1.068 1.172]
  LR:5.00e-05 | 12.8s | Pat:6/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0737 acc:0.5711 off1:0.9485 mae:0.4841
  Region Acc: [0.594 0.567 0.552 0.638 0.573 0.502]

[Phase1 Epoch 25/30]
  Train - acc:0.5903 off1:0.9162 mae:0.5099
  Val   - acc:0.5711 off1:0.9485 mae:0.4841
  Part losses (Val): [1.034 1.092 1.114 0.963 1.067 1.172]
  LR:5.00e-05 | 12.9s | Pat:7/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0888 acc:0.5705 off1:0.9455 mae:0.4874
  Region Acc: [0.589 0.557 0.558 0.638 0.579 0.502]

[Phase1 Epoch 26/30]
  Train - acc:0.5603 off1:0.8902 mae:0.5746
  Val   - acc:0.5705 off1:0.9455 mae:0.4874
  Part losses (Val): [1.064 1.112 1.12  0.985 1.075 1.177]
  LR:5.00e-05 | 13.0s | Pat:8/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0848 acc:0.5735 off1:0.9490 mae:0.4814
  Region Acc: [0.598 0.559 0.558 0.651 0.566 0.509]

[Phase1 Epoch 27/30]
  Train - acc:0.5799 off1:0.9055 mae:0.5341
  Val   - acc:0.5735 off1:0.9490 mae:0.4814
  Part losses (Val): [1.06  1.098 1.128 0.963 1.089 1.172]
  LR:5.00e-05 | 13.3s | Pat:9/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0911 acc:0.5711 off1:0.9466 mae:0.4861
  Region Acc: [0.596 0.56  0.545 0.656 0.56  0.509]

‚èπÔ∏è Phase 1 Early stopping at epoch 28

üìä Phase 1 Average Part Losses:
   A: 1.0567
   B: 1.0975
   C: 1.1190
   D: 0.9391
   E: 1.0765
   F: 1.1522

üìç PHASE 2: Training with Difficulty-Based Weights

üéØ Calculated Part Weights:
   A: 0.992 (loss: 1.0567)
   B: 1.011 (loss: 1.0975)
   C: 1.021 (loss: 1.1190)
   D: 0.935 (loss: 0.9391)
   E: 1.001 (loss: 1.0765)
   F: 1.036 (loss: 1.1522)
Class Distribution Analysis

A:
  Class 0: 1791 ( 50.0%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1: 1122 ( 31.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2:  457 ( 12.8%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  214 (  6.0%) ‚ñà‚ñà

B:
  Class 0:  740 ( 20.6%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1:  912 ( 25.4%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2: 1159 ( 32.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  773 ( 21

                                                                                           

[val] loss:1.0394 acc:0.5837 off1:0.9571 mae:0.4625
  Region Acc: [0.598 0.583 0.562 0.649 0.593 0.516]
‚úÖ Phase 2 Best (Acc=0.5837, MAE=0.4625)

[Phase2 Epoch 01/50]
  Train - acc:0.5283 off1:0.8868 mae:0.6114
  Val   - acc:0.5837 off1:0.9571 mae:0.4625
  Part losses (Val): [0.999 1.061 1.107 0.85  1.047 1.173]
  LR:5.00e-05 | 13.3s | Pat:0/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0511 acc:0.5676 off1:0.9501 mae:0.4867
  Region Acc: [0.592 0.537 0.544 0.663 0.56  0.509]

[Phase2 Epoch 02/50]
  Train - acc:0.5843 off1:0.9285 mae:0.5007
  Val   - acc:0.5676 off1:0.9501 mae:0.4867
  Part losses (Val): [1.028 1.078 1.105 0.861 1.068 1.167]
  LR:5.00e-05 | 12.7s | Pat:1/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0558 acc:0.5819 off1:0.9554 mae:0.4662
  Region Acc: [0.606 0.569 0.553 0.66  0.578 0.525]

[Phase2 Epoch 03/50]
  Train - acc:0.5724 off1:0.9119 mae:0.5344
  Val   - acc:0.5819 off1:0.9554 mae:0.4662
  Part losses (Val): [1.036 1.08  1.112 0.869 1.052 1.185]
  LR:5.00e-05 | 13.5s | Pat:2/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0629 acc:0.5804 off1:0.9545 mae:0.4691
  Region Acc: [0.592 0.577 0.555 0.65  0.58  0.529]

[Phase2 Epoch 04/50]
  Train - acc:0.5764 off1:0.9159 mae:0.5247
  Val   - acc:0.5804 off1:0.9545 mae:0.4691
  Part losses (Val): [1.039 1.084 1.119 0.883 1.06  1.191]
  LR:5.00e-05 | 12.8s | Pat:3/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0633 acc:0.5757 off1:0.9505 mae:0.4777
  Region Acc: [0.589 0.572 0.56  0.635 0.59  0.508]

[Phase2 Epoch 05/50]
  Train - acc:0.5916 off1:0.9206 mae:0.5029
  Val   - acc:0.5757 off1:0.9505 mae:0.4777
  Part losses (Val): [1.029 1.083 1.132 0.872 1.056 1.208]
  LR:5.00e-05 | 13.0s | Pat:4/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0599 acc:0.5769 off1:0.9541 mae:0.4730
  Region Acc: [0.586 0.578 0.569 0.636 0.571 0.522]

[Phase2 Epoch 06/50]
  Train - acc:0.5774 off1:0.9153 mae:0.5219
  Val   - acc:0.5769 off1:0.9541 mae:0.4730
  Part losses (Val): [1.022 1.087 1.126 0.873 1.064 1.187]
  LR:5.00e-05 | 12.8s | Pat:5/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0627 acc:0.5819 off1:0.9525 mae:0.4693
  Region Acc: [0.598 0.569 0.565 0.66  0.58  0.52 ]

[Phase2 Epoch 07/50]
  Train - acc:0.5836 off1:0.9172 mae:0.5154
  Val   - acc:0.5819 off1:0.9525 mae:0.4693
  Part losses (Val): [1.041 1.079 1.114 0.874 1.074 1.194]
  LR:2.50e-05 | 13.2s | Pat:6/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0649 acc:0.5806 off1:0.9538 mae:0.4693
  Region Acc: [0.589 0.569 0.561 0.652 0.588 0.524]

[Phase2 Epoch 08/50]
  Train - acc:0.5572 off1:0.9007 mae:0.5642
  Val   - acc:0.5806 off1:0.9538 mae:0.4693
  Part losses (Val): [1.043 1.085 1.118 0.882 1.062 1.199]
  LR:2.50e-05 | 12.7s | Pat:7/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0759 acc:0.5821 off1:0.9545 mae:0.4667
  Region Acc: [0.601 0.561 0.565 0.651 0.587 0.527]

[Phase2 Epoch 09/50]
  Train - acc:0.5859 off1:0.9128 mae:0.5196
  Val   - acc:0.5821 off1:0.9545 mae:0.4667
  Part losses (Val): [1.056 1.096 1.127 0.888 1.078 1.21 ]
  LR:2.50e-05 | 12.9s | Pat:8/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0793 acc:0.5777 off1:0.9525 mae:0.4733
  Region Acc: [0.598 0.567 0.561 0.649 0.58  0.511]

[Phase2 Epoch 10/50]
  Train - acc:0.6305 off1:0.9409 mae:0.4383
  Val   - acc:0.5777 off1:0.9525 mae:0.4733
  Part losses (Val): [1.062 1.096 1.131 0.895 1.076 1.216]
  LR:2.50e-05 | 13.2s | Pat:9/10
----------------------------------------------------------------------


                                                                                           

[val] loss:1.0792 acc:0.5760 off1:0.9518 mae:0.4762
  Region Acc: [0.603 0.557 0.56  0.654 0.566 0.516]

‚èπÔ∏è Phase 2 Early stopping at epoch 11

üéâ Training Finished!

üìä Test evaluation with Phase 2 model...


                                                                            

[test] loss:1.0002 acc:0.5936 off1:0.9658 mae:0.4418
  Region Acc: [0.623 0.589 0.521 0.726 0.568 0.534]

üèÜ Test Results:
   Accuracy: 0.5936
   Off-by-1: 0.9658
   MAE: 0.4418
   Region Acc: [0.623 0.589 0.521 0.726 0.568 0.534]
   Part losses: [0.91  1.081 1.075 0.787 0.997 1.15 ]

üíæ Phase 1 model saved: runs_severity_classification/best_efficientnet_b0_classification.pth
üíæ Phase 2 model saved: runs_severity_classification/phase2_weighted_classification.pth

üìà Summary:
   Phase 1 Best Acc: 0.5853
   Phase 2 Best Acc: 0.5837
   Improvement: -0.0016




In [3]:
import os, random, time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from tqdm import tqdm
from torch.amp import autocast
from torch.cuda.amp import autocast, GradScaler

# --- Í≤ΩÎ°ú ÏÑ§Ï†ï Î∞è ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ---
BASE_DIR     = f"./data/covid19-xray-severity-scoring/"
CSV_PATH     = str(Path(BASE_DIR) / "Brixia.csv")
IMAGE_DIR    = str(Path(BASE_DIR) / "segmented_png")

OUT_DIR      = "./runs_severity_classification"
BEST_PATH    = str(Path(OUT_DIR) / "best_efficientnet_b0_classification.pth")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED         = 42
IMG_SIZE     = 224
BATCH_SIZE   = 32
NUM_CLASSES  = 4   # 0, 1, 2, 3
EPOCHS       = 100  # Single phase training
LR           = 1e-4
WEIGHT_DECAY = 5e-4
AMP          = True
EARLY_STOP_ACC = 0.75  # üîÑ MAE ‚Üí Accuracy
DROP_RATIO   = 0.3
AUG_RATIO    = 0.5
MIXUP_ALPHA  = 0.2
LABEL_SMOOTHING = 0.1  # ‚úÖ NEW: Label smoothing

# --- ÏãúÎìú Í≥†Ï†ï ---
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
set_seed(SEED)

def make_transform_with_label(train: bool, img_size: int = IMG_SIZE, aug_ratio=AUG_RATIO):
    """Brixia ScoreÏùò Ï¢åÏö∞ Íµ¨Ï°∞Î•º Í≥†Î†§Ìïú transform"""
    def _tfm(img: Image.Image, label: torch.Tensor = None):
        img = img.convert('RGB')
        img = TF.resize(
            img, 
            [img_size, img_size], 
            interpolation=TF.InterpolationMode.BILINEAR,
            antialias=True
        )
        
        if train:
            # 1. Horizontal Flip (Ï¢åÏö∞ Î∞òÏ†Ñ: ABC ‚Üî DEF)
            if random.random() < aug_ratio:
                img = TF.hflip(img)
                if label is not None:
                    # [A, B, C, D, E, F] ‚Üí [D, E, F, A, B, C]
                    label = label[[3, 4, 5, 0, 1, 2]]
            
            # 2. ÏïΩÌïú ÌöåÏ†Ñ (¬±5ÎèÑ)
            if random.random() < aug_ratio:
                angle = float(torch.empty(1).uniform_(-5, 5))
                img = TF.rotate(
                    img, 
                    angle, 
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 3. ÏïΩÌïú Translation
            if random.random() < aug_ratio:
                max_dx = 0.05 * img_size
                max_dy = 0.05 * img_size
                translations = (
                    float(torch.empty(1).uniform_(-max_dx, max_dx)),
                    float(torch.empty(1).uniform_(-max_dy, max_dy))
                )
                img = TF.affine(
                    img,
                    angle=0,
                    translate=translations,
                    scale=1.0,
                    shear=0,
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 4. Brightness & Contrast
            if random.random() < aug_ratio:
                brightness_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_brightness(img, brightness_factor)
            
            if random.random() < aug_ratio:
                contrast_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_contrast(img, contrast_factor)
            
            # 5. Gamma Correction
            if random.random() < 0.3:
                gamma = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_gamma(img, gamma)
        
        # Tensor Î≥ÄÌôò
        img = TF.to_tensor(img)
        
        # Gaussian Noise (train only)
        if train and random.random() < 0.2:
            noise = torch.randn_like(img) * 0.01
            img = img + noise
            img = torch.clamp(img, 0, 1)
        
        # Ï†ïÍ∑úÌôî
        img = TF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        if label is not None:
            return img, label
        return img
    
    return _tfm

def load_and_split_brixia(csv_path, val_ratio=0.2, seed=SEED):
    df = pd.read_csv(csv_path, dtype={'BrixiaScore': str})
    df = df.dropna(subset=['BrixiaScore'])
    df = df[df['BrixiaScore'] != 'nan']
    df = df[df['BrixiaScore'].str.len() == 6].copy()
    
    print(f"Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: {len(df)}Í∞ú")
    
    if 'ConsensusTestset' in df.columns:
        test_df = df[df['ConsensusTestset'] == 1].copy()
        train_val_df = df[df['ConsensusTestset'] == 0].copy()
    else:
        test_df = pd.DataFrame()
        train_val_df = df.copy()
    
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    train_idx, val_idx = next(gss.split(
        train_val_df, 
        groups=train_val_df['StudyId']
    ))
    
    tr_df = train_val_df.iloc[train_idx].copy()
    val_df = train_val_df.iloc[val_idx].copy()
    
    print(f"Train: {len(tr_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    validate_split(tr_df, val_df, test_df)
    
    return tr_df, val_df, test_df

def validate_split(tr_df, val_df, tt_df):
    train_studies = set(tr_df['StudyId'])
    val_studies = set(val_df['StudyId'])
    test_studies = set(tt_df['StudyId']) if len(tt_df) > 0 else set()
    
    assert len(train_studies & val_studies) == 0, "Train-Val Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(train_studies & test_studies) == 0, "Train-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(val_studies & test_studies) == 0, "Val-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    
    for name, data in [('Train', tr_df), ('Val', val_df), ('Test', tt_df)]:
        if len(data) > 0:
            scores = data['BrixiaScore'].apply(lambda x: sum(int(c) for c in x))
            print(f"{name} - Mean: {scores.mean():.2f}, Std: {scores.std():.2f}")
    
    return True

# ============================================================
# Dataset
# ============================================================
class BrixiaDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.img_col = "Filename"
        self.label_col = "BrixiaScore"
        self._validate_data()
    
    def _validate_data(self):
        assert self.img_col in self.df.columns
        assert self.label_col in self.df.columns
        
        invalid_scores = self.df[self.df[self.label_col].str.len() != 6]
        if len(invalid_scores) > 0:
            print(f"‚ö†Ô∏è Í≤ΩÍ≥†: {len(invalid_scores)}Í∞úÏùò ÏûòÎ™ªÎêú BrixiaScore Î∞úÍ≤¨")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_name_from_csv = row[self.img_col]
        img_name = img_name_from_csv.replace('.dcm', '.png')
        img_path = os.path.join(self.img_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"‚ùå Ïù¥ÎØ∏ÏßÄ Î°úÎìú Ïò§Î•ò: {img_path}")
            raise
        
        scores_str = row[self.label_col]
        scores_list = [int(c) for c in scores_str]
        labels = torch.tensor(scores_list, dtype=torch.long)  # üîÑ longÏúºÎ°ú Î≥ÄÍ≤Ω
        
        if self.transform:
            image, labels = self.transform(image, labels)
        else:
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, labels

def create_dataloaders(tr_df, val_df, tt_df, img_dir, 
                       batch_size=32, img_size=224, num_workers=4):
    train_transform = make_transform_with_label(train=True, img_size=img_size)
    val_transform = make_transform_with_label(train=False, img_size=img_size)
    
    tr_ds = BrixiaDataset(tr_df, img_dir, transform=train_transform)
    val_ds = BrixiaDataset(val_df, img_dir, transform=val_transform)
    tt_ds = BrixiaDataset(tt_df, img_dir, transform=val_transform)
    
    tr_loader = DataLoader(
        tr_ds, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    tt_loader = DataLoader(
        tt_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    print("‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å")
    print(f"   Train: {len(tr_ds)} samples, {len(tr_loader)} batches")
    print(f"   Val:   {len(val_ds)} samples, {len(val_loader)} batches")
    print(f"   Test:  {len(tt_ds)} samples, {len(tt_loader)} batches")
    
    return tr_loader, val_loader, tt_loader

# ============================================================
# Loss Function (Î∂ÑÎ•òÏö©ÏúºÎ°ú Î≥ÄÍ≤Ω)
# ============================================================
def calculate_class_weights(labels, num_classes=4, method='sqrt_inverse'):
    """ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ìï¥Í≤∞ÏùÑ ÏúÑÌïú Í∞ÄÏ§ëÏπò Í≥ÑÏÇ∞"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    labels_flat = labels.flatten()
    counts = np.bincount(labels_flat.astype(int), minlength=num_classes)
    
    if method == 'sqrt_inverse':
        weights = 1.0 / (np.sqrt(counts) + 1e-6)
    elif method == 'inverse':
        weights = 1.0 / (counts + 1e-6)
    else:
        total = len(labels_flat)
        weights = total / (num_classes * (counts + 1e-6))
    
    weights = weights / weights.mean()
    return torch.FloatTensor(weights)

def print_class_distribution(labels):
    """ÌÅ¥ÎûòÏä§ Î∂ÑÌè¨ ÏãúÍ∞ÅÌôî"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    print("=" * 60)
    print("Class Distribution Analysis")
    print("=" * 60)
    
    region_names = ['A', 'B', 'C', 'D', 'E', 'F']
    for idx, name in enumerate(region_names):
        region_labels = labels[:, idx]
        counts = np.bincount(region_labels.astype(int), minlength=4)
        total = counts.sum()
        
        print(f"\n{name}:")
        for cls in range(4):
            pct = 100 * counts[cls] / total if total > 0 else 0
            bar = '‚ñà' * int(pct / 2)
            print(f"  Class {cls}: {counts[cls]:4d} ({pct:5.1f}%) {bar}")
    
    print("=" * 60)

class AdaptiveClassificationLoss(nn.Module):
    """
    ‚úÖ NEW: Îã§Ï§ë ÌÅ¥ÎûòÏä§ Î∂ÑÎ•òÎ•º ÏúÑÌïú ÏÜêÏã§Ìï®Ïàò
    - CrossEntropyLoss Í∏∞Î∞ò
    - ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    - Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    - Label Smoothing ÏßÄÏõê
    """
    
    def __init__(self, train_labels, num_classes=4, use_class_weights=True, 
                 part_weights=None, label_smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.label_smoothing = label_smoothing
        
        if isinstance(train_labels, torch.Tensor):
            train_labels_np = train_labels.cpu().numpy()
        else:
            train_labels_np = train_labels
        
        print_class_distribution(train_labels_np)
        
        # ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
        if use_class_weights:
            class_weights = calculate_class_weights(train_labels_np, num_classes=num_classes)
            self.register_buffer('class_weights', class_weights)
            print(f"‚úÖ Class weights: {class_weights.numpy()}")
        else:
            self.class_weights = None
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò
        if part_weights is None:
            self.part_weights = torch.ones(6)
        else:
            self.part_weights = torch.tensor(part_weights, dtype=torch.float32)
        self.register_buffer('part_weights_buf', self.part_weights)
        print(f"‚úÖ Part weights: {self.part_weights.numpy()}")

    def forward(self, pred_logits, target, use_mixup=False):
        """
        pred_logits: [B, 6, 4] - Í∞Å Î∂ÄÏúÑÎ≥Ñ 4Í∞ú ÌÅ¥ÎûòÏä§Ïóê ÎåÄÌïú logits
        target: [B, 6] - Í∞Å Î∂ÄÏúÑÎ≥Ñ ÌÅ¥ÎûòÏä§ Î†àÏù¥Î∏î (0~3)
        """
        B, num_regions, num_classes = pred_logits.shape
        
        # Reshape for loss computation
        pred_logits_flat = pred_logits.view(B * num_regions, num_classes)  # [B*6, 4]
        target_flat = target.view(B * num_regions)  # [B*6]
        
        # CrossEntropyLoss with label smoothing
        if self.class_weights is not None and not use_mixup:
            criterion = nn.CrossEntropyLoss(
                weight=self.class_weights.to(pred_logits.device),
                label_smoothing=self.label_smoothing,
                reduction='none'
            )
        else:
            criterion = nn.CrossEntropyLoss(
                label_smoothing=self.label_smoothing,
                reduction='none'
            )
        
        loss = criterion(pred_logits_flat, target_flat)  # [B*6]
        loss = loss.view(B, num_regions)  # [B, 6]
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
        part_weights = self.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)  # [B, 6]
        
        # Î∂ÄÏúÑÎ≥Ñ ÌèâÍ∑† ÏÜêÏã§ (Î™®ÎãàÌÑ∞ÎßÅÏö©)
        part_losses = loss.mean(dim=0)  # [6]
        
        return loss.mean(), part_losses

# ============================================================
# Mixup (Î∂ÑÎ•òÏö©)
# ============================================================
def mixup_data_classification(x, y, alpha=MIXUP_ALPHA):
    """
    ‚úÖ NEW: Î∂ÑÎ•òÏö© Mixup
    yÎäî one-hotÏúºÎ°ú Î≥ÄÌôò ÌõÑ mixup
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    
    # One-hot encoding
    y_onehot = F.one_hot(y, num_classes=NUM_CLASSES).float()  # [B, 6, 4]
    y_onehot_shuffled = y_onehot[index]
    
    mixed_y = lam * y_onehot + (1 - lam) * y_onehot_shuffled  # [B, 6, 4]
    
    return mixed_x, mixed_y, lam

def mixup_criterion(criterion, pred_logits, y_mixed, lam):
    """MixupÏùÑ ÏúÑÌïú ÏÜêÏã§ Í≥ÑÏÇ∞"""
    # y_mixedÎäî [B, 6, 4] soft labels
    # pred_logitsÎäî [B, 6, 4]
    
    B, num_regions, num_classes = pred_logits.shape
    
    # Reshape
    pred_flat = pred_logits.view(B * num_regions, num_classes)  # [B*6, 4]
    target_flat = y_mixed.view(B * num_regions, num_classes)  # [B*6, 4]
    
    # Soft target loss
    log_probs = F.log_softmax(pred_flat, dim=1)
    loss = -(target_flat * log_probs).sum(dim=1)  # [B*6]
    loss = loss.view(B, num_regions)  # [B, 6]
    
    # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö© (criterionÏóêÏÑú Í∞ÄÏ†∏Ïò¥)
    if hasattr(criterion, 'part_weights_buf'):
        part_weights = criterion.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)
    
    part_losses = loss.mean(dim=0)
    return loss.mean(), part_losses

# ============================================================
# Metrics (Î∂ÑÎ•òÏö©)
# ============================================================
@torch.no_grad()
def calculate_classification_metrics(pred_logits, labels):
    """
    ‚úÖ NEW: Î∂ÑÎ•ò ÏßÄÌëú Í≥ÑÏÇ∞
    pred_logits: [B, 6, 4]
    labels: [B, 6]
    """
    # ÏòàÏ∏° ÌÅ¥ÎûòÏä§
    preds = pred_logits.argmax(dim=-1)  # [B, 6]
    
    # Exact match accuracy
    exact_acc = (preds == labels).float().mean().item()
    
    # Off-by-1 accuracy (Ïù∏Ï†ë ÌÅ¥ÎûòÏä§ ÌóàÏö©)
    off_by_1 = (torch.abs(preds - labels) <= 1).float().mean().item()
    
    # Per-region accuracy
    region_acc = (preds == labels).float().mean(dim=0)  # [6]
    
    # MAE (Ï∞∏Í≥†Ïö©)
    mae = torch.abs(preds.float() - labels.float()).mean().item()
    
    return exact_acc, off_by_1, mae, region_acc

# ============================================================
# Model (Î∂ÑÎ•òÏö©ÏúºÎ°ú Î≥ÄÍ≤Ω)
# ============================================================
class EfficientNetB0Classification(nn.Module):
    """
    ‚úÖ Î∂ÑÎ•òÎ•º ÏúÑÌïú Î™®Îç∏ (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ Í≥†Î†§)
    Í∞Å Î∂ÄÏúÑ(A~F)ÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§(0~3) ÏòàÏ∏°
    + Self-AttentionÏúºÎ°ú Î∂ÄÏúÑ Í∞Ñ Í¥ÄÍ≥Ñ ÌïôÏäµ
    """
    def __init__(self, pretrained=True, drop=0.3, num_regions=6, num_classes=4):
        super().__init__()
        
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = efficientnet_b0(weights=weights)
        
        self.features = backbone.features
        in_feat = 1280
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, 49, in_feat))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Region queries
        self.region_queries = nn.Parameter(torch.randn(num_regions, in_feat))
        nn.init.xavier_uniform_(self.region_queries)
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ÌäπÏßï ‚Üí Î∂ÄÏúÑÎ≥Ñ ÌäπÏßï)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(in_feat)
        
        # ‚úÖ NEW: Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ)
        self.self_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm2 = nn.LayerNorm(in_feat)
        self.norm3 = nn.LayerNorm(in_feat)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(in_feat, in_feat * 2),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(in_feat * 2, in_feat),
            nn.Dropout(drop)
        )
        
        self.norm4 = nn.LayerNorm(in_feat)
        
        # Classification heads: Í∞Å Î∂ÄÏúÑÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§ ÏòàÏ∏°
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_feat, 256),
                nn.LayerNorm(256),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(256, 128),
                nn.LayerNorm(128),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(128, num_classes)
            ) for _ in range(num_regions)
        ])
    
    def forward(self, x):
        B = x.size(0)
        
        # Feature extraction
        feat = self.features(x)  # [B, 1280, 7, 7]
        feat = feat.flatten(2).transpose(1, 2)  # [B, 49, 1280]
        feat = feat + self.pos_embed
        
        # Region queries
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)  # [B, 6, 1280]
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ‚Üí Î∂ÄÏúÑ)
        attn_out, _ = self.cross_attention(
            query=queries,
            key=feat,
            value=feat
        )
        attn_out = self.norm1(attn_out + queries)  # [B, 6, 1280]
        
        # ‚úÖ Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ)
        # A, B, CÍ∞Ä ÏÑúÎ°ú ÏòÅÌñ•ÏùÑ Ï£ºÍ≥†, D, E, FÎèÑ ÏÑúÎ°ú ÏòÅÌñ•
        self_attn_out, attn_weights = self.self_attention(
            query=attn_out,
            key=attn_out,
            value=attn_out
        )
        attn_out = self.norm2(attn_out + self_attn_out)  # [B, 6, 1280]
        
        # FFN
        ffn_out = self.ffn(attn_out)
        attn_out = self.norm3(attn_out + ffn_out)  # [B, 6, 1280]
        
        # Í∞Å Î∂ÄÏúÑÎ≥Ñ classification head ÌÜµÍ≥º
        outputs = []
        for i in range(len(self.heads)):
            region_feat = attn_out[:, i, :]  # [B, 1280]
            logits = self.heads[i](region_feat)  # [B, 4]
            outputs.append(logits)
        
        out = torch.stack(outputs, dim=1)  # [B, 6, 4]
        return out, attn_weights  # ‚úÖ attention weights Î∞òÌôò (Grad-CAMÏö©)

# ============================================================
# Training Functions (Î∂ÑÎ•òÏö© ÏàòÏ†ï)
# ============================================================
def train_epoch(model, tr_loader, criterion, optimizer, scaler, device, 
                amp=True, use_mixup=True):
    model.train()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    
    pbar = tqdm(tr_loader, desc="Train", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)  # [B, 6]
        
        # Mixup Ï†ÅÏö©
        is_mixup = use_mixup and (random.random() < 0.5)
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=amp):
            if is_mixup:
                imgs_mixed, labels_mixed, lam = mixup_data_classification(imgs, labels)
                pred_logits, _ = model(imgs_mixed)  # attention weights Î¨¥Ïãú
                loss, part_losses = mixup_criterion(criterion, pred_logits, labels_mixed, lam)
            else:
                pred_logits, _ = model(imgs)  # attention weights Î¨¥Ïãú
                loss, part_losses = criterion(pred_logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Metrics (ÏõêÎ≥∏ labelsÎ°ú Í≥ÑÏÇ∞)
        exact_acc, off_by_1, mae, _ = calculate_classification_metrics(pred_logits.detach(), labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}",
            mae=f"{run_mae/n:.4f}"
        )
    
    part_losses_avg = part_losses_sum / n
    
    return run_loss/n, run_acc/n, run_off1/n, run_mae/n, part_losses_avg

@torch.no_grad()
def evaluate(model, val_loader, criterion, device, split='val'):
    model.eval()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    region_acc_sum = torch.zeros(6)
    
    pbar = tqdm(val_loader, desc=f"{split.capitalize()}", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        pred_logits, _ = model(imgs)  # [B, 6, 4], attention weights Î¨¥Ïãú
        loss, part_losses = criterion(pred_logits, labels)
        
        exact_acc, off_by_1, mae, region_acc = calculate_classification_metrics(pred_logits, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        region_acc_sum += region_acc.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}"
        )
    
    avg_loss = run_loss/n
    avg_acc = run_acc/n
    avg_off1 = run_off1/n
    avg_mae = run_mae/n
    part_losses_avg = part_losses_sum / n
    region_acc_avg = region_acc_sum / n
    
    print(f"[{split}] loss:{avg_loss:.4f} acc:{avg_acc:.4f} "
          f"off1:{avg_off1:.4f} mae:{avg_mae:.4f}")
    print(f"  Region Acc: {region_acc_avg.numpy().round(3)}")
    
    return avg_loss, avg_acc, avg_off1, avg_mae, part_losses_avg, region_acc_avg

def get_lrs(optimizer):
    return [pg['lr'] for pg in optimizer.param_groups]

# ============================================================
# Main Function
# ============================================================
def main():
    print("\n" + "="*70)
    print("üöÄ Brixia COVID-19 Classification Training")
    print("   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)")
    print("   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)")
    print("="*70)
    
    # Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
    print("\nüìÇ Loading data...")
    tr_df, val_df, tt_df = load_and_split_brixia(CSV_PATH)
    
    print("\nüì¶ Creating DataLoaders...")
    tr_loader, val_loader, tt_loader = create_dataloaders(
        tr_df, val_df, tt_df, img_dir=IMAGE_DIR, 
        batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=4
    )
    
    # Train labels Ï∂îÏ∂ú
    train_labels = torch.cat([labels for _, labels in tr_loader], dim=0)
    
    # ========================================
    # ÌïôÏäµ Ï§ÄÎπÑ
    # ========================================
    print("\n" + "="*70)
    print("üìç Training Setup")
    print("="*70)
    
    # ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï ÏûêÎèô Ï≤òÎ¶¨
    criterion = AdaptiveClassificationLoss(
        train_labels, 
        num_classes=NUM_CLASSES,
        use_class_weights=True,  # ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò ÏûêÎèô Í≥ÑÏÇ∞
        part_weights=None,  # Î∂ÄÏúÑÎ≥Ñ Í∑†Îì± Í∞ÄÏ§ëÏπò
        label_smoothing=LABEL_SMOOTHING
    )
    
    # ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥ÑÎ•º Í≥†Î†§ÌïòÎäî Î™®Îç∏
    model = EfficientNetB0Classification(
        pretrained=True, 
        drop=DROP_RATIO,
        num_regions=6,
        num_classes=NUM_CLASSES
    ).to(DEVICE)
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6
    )
    scaler = GradScaler(enabled=AMP)
    
    best_acc = 0.0
    best_mae = float('inf')
    patience_counter = 0
    max_patience = 15
    
    # ========================================
    # ÌïôÏäµ Î£®ÌîÑ
    # ========================================
    print("\n" + "="*70)
    print("üèãÔ∏è Training Start")
    print("="*70)
    
    for ep in range(1, EPOCHS + 1):
        t0 = time.time()
        
        tr_loss, tr_acc, tr_off1, tr_mae, tr_part_losses = train_epoch(
            model, tr_loader, criterion, optimizer, scaler, DEVICE, AMP
        )
        
        val_loss, val_acc, val_off1, val_mae, val_part_losses, val_region_acc = evaluate(
            model, val_loader, criterion, DEVICE, split='val'
        )
        
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_mae = val_mae
            patience_counter = 0
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': best_acc,
                'val_mae': best_mae,
                'val_off1': val_off1,
                'region_acc': val_region_acc,
            }, BEST_PATH)
            print(f"‚úÖ New Best! (Acc={best_acc:.4f}, MAE={best_mae:.4f})")
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= max_patience:
            print(f"\n‚èπÔ∏è Early stopping at epoch {ep}")
            break
        
        elapsed = time.time() - t0
        print(f"\n[Epoch {ep:03d}/{EPOCHS}]")
        print(f"  Train - loss:{tr_loss:.4f} acc:{tr_acc:.4f} off1:{tr_off1:.4f} mae:{tr_mae:.4f}")
        print(f"  Val   - loss:{val_loss:.4f} acc:{val_acc:.4f} off1:{val_off1:.4f} mae:{val_mae:.4f}")
        print(f"  Part losses (Val): {val_part_losses.numpy().round(3)}")
        print(f"  Region Acc (Val): {val_region_acc.numpy().round(3)}")
        print(f"  LR:{get_lrs(optimizer)[0]:.2e} | {elapsed:.1f}s | Patience:{patience_counter}/{max_patience}")
        print("-" * 70)
    
    # ========================================
    # Test Evaluation
    # ========================================
    print("\n" + "="*70)
    print("üéâ Training Finished!")
    print("="*70)
    
    if len(tt_loader) > 0:
        print("\nüìä Test evaluation with best model...")
        checkpoint = torch.load(BEST_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        tt_loss, tt_acc, tt_off1, tt_mae, tt_part_losses, tt_region_acc = evaluate(
            model, tt_loader, criterion, DEVICE, split='test'
        )
        
        print(f"\nüèÜ Test Results:")
        print(f"   Accuracy: {tt_acc:.4f}")
        print(f"   Off-by-1: {tt_off1:.4f}")
        print(f"   MAE: {tt_mae:.4f}")
        print(f"   Region Acc: {tt_region_acc.numpy().round(3)}")
        print(f"   Part losses: {tt_part_losses.numpy().round(3)}")
    
    print(f"\nüíæ Best model saved: {BEST_PATH}")
    print(f"üìà Best Validation Accuracy: {best_acc:.4f}")
    print(f"üìâ Best Validation MAE: {best_mae:.4f}")
    print("\n‚úÖ Ïù¥Ï†ú gradcam_inference.pyÎ•º Ïã§ÌñâÌïòÏó¨ Í≤∞Í≥ºÎ•º ÏãúÍ∞ÅÌôîÌïòÏÑ∏Ïöî!")
    print("="*70)

if __name__ == "__main__":
    main()


üöÄ Brixia COVID-19 Classification Training
   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)
   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)

üìÇ Loading data...
Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: 4695Í∞ú
Train: 3637, Val: 912, Test: 146
Train - Mean: 8.31, Std: 4.26
Val - Mean: 8.35, Std: 4.15
Test - Mean: 7.78, Std: 4.20

üì¶ Creating DataLoaders...
‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å
   Train: 3637 samples, 113 batches
   Val:   912 samples, 29 batches
   Test:  146 samples, 5 batches

üìç Training Setup
Class Distribution Analysis

A:
  Class 0: 1810 ( 50.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1: 1126 ( 31.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2:  446 ( 12.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  234 (  6.5%) ‚ñà‚ñà‚ñà

B:
  Class 0:  721 ( 19.9%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1:  949 ( 26.2%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2: 1171 ( 32.4%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  7

                                                                                             

[val] loss:1.1561 acc:0.4856 off1:0.8688 mae:0.6603
  Region Acc: [0.542 0.44  0.405 0.617 0.467 0.443]
‚úÖ New Best! (Acc=0.4856, MAE=0.6603)

[Epoch 001/100]
  Train - loss:1.2178 acc:0.4151 off1:0.8294 mae:0.7886
  Val   - loss:1.1561 acc:0.4856 off1:0.8688 mae:0.6603
  Part losses (Val): [1.141 1.203 1.239 0.984 1.166 1.203]
  Region Acc (Val): [0.542 0.44  0.405 0.617 0.467 0.443]
  LR:1.00e-04 | 15.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.1078 acc:0.5143 off1:0.9123 mae:0.5780
  Region Acc: [0.546 0.489 0.463 0.612 0.505 0.47 ]
‚úÖ New Best! (Acc=0.5143, MAE=0.5780)

[Epoch 002/100]
  Train - loss:1.1414 acc:0.4534 off1:0.8569 mae:0.7148
  Val   - loss:1.1078 acc:0.5143 off1:0.9123 mae:0.5780
  Part losses (Val): [1.073 1.14  1.175 0.969 1.119 1.171]
  Region Acc (Val): [0.546 0.489 0.463 0.612 0.505 0.47 ]
  LR:1.00e-04 | 14.9s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0821 acc:0.5521 off1:0.9240 mae:0.5307
  Region Acc: [0.582 0.518 0.514 0.654 0.534 0.511]
‚úÖ New Best! (Acc=0.5521, MAE=0.5307)

[Epoch 003/100]
  Train - loss:1.1075 acc:0.4791 off1:0.8797 mae:0.6592
  Val   - loss:1.0821 acc:0.5521 off1:0.9240 mae:0.5307
  Part losses (Val): [1.057 1.125 1.132 0.941 1.096 1.141]
  Region Acc (Val): [0.582 0.518 0.514 0.654 0.534 0.511]
  LR:1.00e-04 | 15.0s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0551 acc:0.5523 off1:0.9390 mae:0.5143
  Region Acc: [0.557 0.541 0.534 0.621 0.545 0.516]
‚úÖ New Best! (Acc=0.5523, MAE=0.5143)

[Epoch 004/100]
  Train - loss:1.0675 acc:0.4775 off1:0.8786 mae:0.6667
  Val   - loss:1.0551 acc:0.5523 off1:0.9390 mae:0.5143
  Part losses (Val): [1.04  1.072 1.093 0.944 1.059 1.122]
  Region Acc (Val): [0.557 0.541 0.534 0.621 0.545 0.516]
  LR:1.00e-04 | 15.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0389 acc:0.5698 off1:0.9527 mae:0.4790
  Region Acc: [0.601 0.554 0.523 0.656 0.569 0.516]
‚úÖ New Best! (Acc=0.5698, MAE=0.4790)

[Epoch 005/100]
  Train - loss:1.0458 acc:0.4950 off1:0.8903 mae:0.6343
  Val   - loss:1.0389 acc:0.5698 off1:0.9527 mae:0.4790
  Part losses (Val): [1.004 1.059 1.1   0.91  1.044 1.117]
  Region Acc (Val): [0.601 0.554 0.523 0.656 0.569 0.516]
  LR:1.00e-04 | 14.7s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0407 acc:0.5720 off1:0.9432 mae:0.4889
  Region Acc: [0.599 0.558 0.539 0.661 0.575 0.5  ]
‚úÖ New Best! (Acc=0.5720, MAE=0.4889)

[Epoch 006/100]
  Train - loss:1.0339 acc:0.4810 off1:0.8760 mae:0.6656
  Val   - loss:1.0407 acc:0.5720 off1:0.9432 mae:0.4889
  Part losses (Val): [1.03  1.044 1.077 0.898 1.063 1.132]
  Region Acc (Val): [0.599 0.558 0.539 0.661 0.575 0.5  ]
  LR:1.00e-04 | 15.1s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0404 acc:0.5623 off1:0.9507 mae:0.4905
  Region Acc: [0.59  0.543 0.53  0.664 0.559 0.488]

[Epoch 007/100]
  Train - loss:1.0283 acc:0.5035 off1:0.8903 mae:0.6269
  Val   - loss:1.0404 acc:0.5623 off1:0.9507 mae:0.4905
  Part losses (Val): [1.001 1.062 1.099 0.896 1.053 1.132]
  Region Acc (Val): [0.59  0.543 0.53  0.664 0.559 0.488]
  LR:1.00e-04 | 14.2s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0244 acc:0.5788 off1:0.9512 mae:0.4739
  Region Acc: [0.61  0.547 0.542 0.677 0.598 0.5  ]
‚úÖ New Best! (Acc=0.5788, MAE=0.4739)

[Epoch 008/100]
  Train - loss:1.0099 acc:0.5132 off1:0.9006 mae:0.6046
  Val   - loss:1.0244 acc:0.5788 off1:0.9512 mae:0.4739
  Part losses (Val): [1.011 1.056 1.066 0.884 1.018 1.111]
  Region Acc (Val): [0.61  0.547 0.542 0.677 0.598 0.5  ]
  LR:1.00e-04 | 14.9s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0607 acc:0.5667 off1:0.9388 mae:0.4996
  Region Acc: [0.609 0.545 0.527 0.671 0.557 0.491]

[Epoch 009/100]
  Train - loss:1.0068 acc:0.5140 off1:0.8972 mae:0.6103
  Val   - loss:1.0607 acc:0.5667 off1:0.9388 mae:0.4996
  Part losses (Val): [1.01  1.117 1.112 0.904 1.088 1.133]
  Region Acc (Val): [0.609 0.545 0.527 0.671 0.557 0.491]
  LR:1.00e-04 | 13.4s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0276 acc:0.5757 off1:0.9512 mae:0.4762
  Region Acc: [0.605 0.539 0.566 0.651 0.566 0.526]

[Epoch 010/100]
  Train - loss:1.0012 acc:0.5422 off1:0.9163 mae:0.5549
  Val   - loss:1.0276 acc:0.5757 off1:0.9512 mae:0.4762
  Part losses (Val): [0.974 1.062 1.06  0.896 1.067 1.106]
  Region Acc (Val): [0.605 0.539 0.566 0.651 0.566 0.526]
  LR:1.00e-04 | 14.8s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0291 acc:0.5872 off1:0.9527 mae:0.4629
  Region Acc: [0.629 0.575 0.546 0.675 0.57  0.527]
‚úÖ New Best! (Acc=0.5872, MAE=0.4629)

[Epoch 011/100]
  Train - loss:0.9737 acc:0.5294 off1:0.9037 mae:0.5845
  Val   - loss:1.0291 acc:0.5872 off1:0.9527 mae:0.4629
  Part losses (Val): [0.982 1.049 1.078 0.89  1.045 1.131]
  Region Acc (Val): [0.629 0.575 0.546 0.675 0.57  0.527]
  LR:1.00e-04 | 14.9s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0252 acc:0.5870 off1:0.9582 mae:0.4578
  Region Acc: [0.617 0.589 0.55  0.685 0.57  0.51 ]

[Epoch 012/100]
  Train - loss:0.9708 acc:0.5355 off1:0.9004 mae:0.5849
  Val   - loss:1.0252 acc:0.5870 off1:0.9582 mae:0.4578
  Part losses (Val): [0.985 1.028 1.077 0.887 1.051 1.123]
  Region Acc (Val): [0.617 0.589 0.55  0.685 0.57  0.51 ]
  LR:1.00e-04 | 14.3s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0349 acc:0.5755 off1:0.9631 mae:0.4635
  Region Acc: [0.603 0.544 0.553 0.677 0.562 0.514]

[Epoch 013/100]
  Train - loss:0.9796 acc:0.5336 off1:0.9043 mae:0.5797
  Val   - loss:1.0349 acc:0.5755 off1:0.9631 mae:0.4635
  Part losses (Val): [1.003 1.052 1.092 0.889 1.043 1.13 ]
  Region Acc (Val): [0.603 0.544 0.553 0.677 0.562 0.514]
  LR:1.00e-04 | 14.6s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0449 acc:0.5810 off1:0.9518 mae:0.4702
  Region Acc: [0.601 0.571 0.553 0.68  0.568 0.513]

[Epoch 014/100]
  Train - loss:0.9560 acc:0.5643 off1:0.9207 mae:0.5293
  Val   - loss:1.0449 acc:0.5810 off1:0.9518 mae:0.4702
  Part losses (Val): [1.045 1.043 1.076 0.918 1.047 1.141]
  Region Acc (Val): [0.601 0.571 0.553 0.68  0.568 0.513]
  LR:1.00e-04 | 14.7s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0410 acc:0.5789 off1:0.9594 mae:0.4636
  Region Acc: [0.606 0.561 0.55  0.663 0.569 0.523]

[Epoch 015/100]
  Train - loss:0.9371 acc:0.5409 off1:0.9002 mae:0.5806
  Val   - loss:1.0410 acc:0.5789 off1:0.9594 mae:0.4636
  Part losses (Val): [1.011 1.064 1.088 0.895 1.052 1.136]
  Region Acc (Val): [0.606 0.561 0.55  0.663 0.569 0.523]
  LR:1.00e-04 | 14.1s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0541 acc:0.5789 off1:0.9492 mae:0.4753
  Region Acc: [0.623 0.555 0.541 0.691 0.561 0.503]

[Epoch 016/100]
  Train - loss:0.9381 acc:0.5517 off1:0.9065 mae:0.5589
  Val   - loss:1.0541 acc:0.5789 off1:0.9492 mae:0.4753
  Part losses (Val): [1.016 1.069 1.103 0.914 1.064 1.16 ]
  Region Acc (Val): [0.623 0.555 0.541 0.691 0.561 0.503]
  LR:1.00e-04 | 13.7s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0446 acc:0.5835 off1:0.9538 mae:0.4651
  Region Acc: [0.606 0.591 0.537 0.681 0.567 0.519]

[Epoch 017/100]
  Train - loss:0.9321 acc:0.5712 off1:0.9179 mae:0.5278
  Val   - loss:1.0446 acc:0.5835 off1:0.9538 mae:0.4651
  Part losses (Val): [1.037 1.047 1.093 0.907 1.042 1.142]
  Region Acc (Val): [0.606 0.591 0.537 0.681 0.567 0.519]
  LR:5.00e-05 | 14.8s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0328 acc:0.5859 off1:0.9552 mae:0.4618
  Region Acc: [0.612 0.573 0.564 0.667 0.57  0.53 ]

[Epoch 018/100]
  Train - loss:0.9015 acc:0.5732 off1:0.9140 mae:0.5310
  Val   - loss:1.0328 acc:0.5859 off1:0.9552 mae:0.4618
  Part losses (Val): [0.988 1.04  1.066 0.905 1.048 1.15 ]
  Region Acc (Val): [0.612 0.573 0.564 0.667 0.57  0.53 ]
  LR:5.00e-05 | 14.0s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0522 acc:0.5870 off1:0.9556 mae:0.4607
  Region Acc: [0.609 0.582 0.561 0.678 0.58  0.512]

[Epoch 019/100]
  Train - loss:0.8709 acc:0.5857 off1:0.9109 mae:0.5229
  Val   - loss:1.0522 acc:0.5870 off1:0.9556 mae:0.4607
  Part losses (Val): [1.015 1.057 1.089 0.932 1.049 1.17 ]
  Region Acc (Val): [0.609 0.582 0.561 0.678 0.58  0.512]
  LR:5.00e-05 | 15.0s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0585 acc:0.5866 off1:0.9576 mae:0.4585
  Region Acc: [0.607 0.567 0.568 0.673 0.572 0.532]

[Epoch 020/100]
  Train - loss:0.8829 acc:0.5706 off1:0.9075 mae:0.5420
  Val   - loss:1.0585 acc:0.5866 off1:0.9576 mae:0.4585
  Part losses (Val): [1.037 1.068 1.08  0.951 1.069 1.147]
  Region Acc (Val): [0.607 0.567 0.568 0.673 0.572 0.532]
  LR:5.00e-05 | 14.1s | Patience:9/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0761 acc:0.5850 off1:0.9552 mae:0.4625
  Region Acc: [0.62  0.572 0.55  0.672 0.578 0.518]

[Epoch 021/100]
  Train - loss:0.8491 acc:0.6151 off1:0.9310 mae:0.4679
  Val   - loss:1.0761 acc:0.5850 off1:0.9552 mae:0.4625
  Part losses (Val): [1.066 1.064 1.118 0.948 1.081 1.179]
  Region Acc (Val): [0.62  0.572 0.55  0.672 0.578 0.518]
  LR:5.00e-05 | 14.9s | Patience:10/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0741 acc:0.5779 off1:0.9554 mae:0.4695
  Region Acc: [0.602 0.57  0.556 0.657 0.555 0.527]

[Epoch 022/100]
  Train - loss:0.8681 acc:0.5972 off1:0.9173 mae:0.5017
  Val   - loss:1.0741 acc:0.5779 off1:0.9554 mae:0.4695
  Part losses (Val): [1.043 1.08  1.112 0.949 1.085 1.174]
  Region Acc (Val): [0.602 0.57  0.556 0.657 0.555 0.527]
  LR:5.00e-05 | 14.8s | Patience:11/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0966 acc:0.5830 off1:0.9545 mae:0.4660
  Region Acc: [0.615 0.56  0.559 0.679 0.573 0.511]

[Epoch 023/100]
  Train - loss:0.8471 acc:0.5946 off1:0.9120 mae:0.5131
  Val   - loss:1.0966 acc:0.5830 off1:0.9545 mae:0.4660
  Part losses (Val): [1.094 1.089 1.113 0.989 1.104 1.191]
  Region Acc (Val): [0.615 0.56  0.559 0.679 0.573 0.511]
  LR:2.50e-05 | 13.9s | Patience:12/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0951 acc:0.5808 off1:0.9547 mae:0.4678
  Region Acc: [0.616 0.566 0.565 0.664 0.565 0.509]

[Epoch 024/100]
  Train - loss:0.8297 acc:0.6269 off1:0.9249 mae:0.4644
  Val   - loss:1.0951 acc:0.5808 off1:0.9547 mae:0.4678
  Part losses (Val): [1.074 1.094 1.131 0.968 1.106 1.198]
  Region Acc (Val): [0.616 0.566 0.565 0.664 0.565 0.509]
  LR:2.50e-05 | 15.0s | Patience:13/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0956 acc:0.5810 off1:0.9576 mae:0.4653
  Region Acc: [0.606 0.562 0.557 0.673 0.575 0.512]

[Epoch 025/100]
  Train - loss:0.7912 acc:0.6228 off1:0.9186 mae:0.4758
  Val   - loss:1.0956 acc:0.5810 off1:0.9576 mae:0.4653
  Part losses (Val): [1.064 1.108 1.137 0.971 1.089 1.205]
  Region Acc (Val): [0.606 0.562 0.557 0.673 0.575 0.512]
  LR:2.50e-05 | 14.4s | Patience:14/15
----------------------------------------------------------------------


                                                                                             

[val] loss:1.0993 acc:0.5833 off1:0.9543 mae:0.4656
  Region Acc: [0.621 0.562 0.557 0.673 0.562 0.524]

‚èπÔ∏è Early stopping at epoch 26

üéâ Training Finished!

üìä Test evaluation with best model...


                                                                            

[test] loss:1.0082 acc:0.5868 off1:0.9646 mae:0.4521
  Region Acc: [0.616 0.521 0.596 0.671 0.555 0.562]

üèÜ Test Results:
   Accuracy: 0.5868
   Off-by-1: 0.9646
   MAE: 0.4521
   Region Acc: [0.616 0.521 0.596 0.671 0.555 0.562]
   Part losses: [0.937 1.083 1.057 0.859 1.013 1.1  ]

üíæ Best model saved: runs_severity_classification/best_efficientnet_b0_classification.pth
üìà Best Validation Accuracy: 0.5872
üìâ Best Validation MAE: 0.4629

‚úÖ Ïù¥Ï†ú gradcam_inference.pyÎ•º Ïã§ÌñâÌïòÏó¨ Í≤∞Í≥ºÎ•º ÏãúÍ∞ÅÌôîÌïòÏÑ∏Ïöî!


In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

# ============================================================
# ÏÑ§Ï†ï
# ============================================================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_SIZE = 224
NUM_CLASSES = 4
NUM_REGIONS = 6

MODEL_PATH = "./runs_severity_classification/best_efficientnet_b0_classification.pth"
OUTPUT_DIR = "./gradcam_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

REGION_NAMES = ['A (Ï¢åÏÉÅ)', 'B (Ï¢åÏ§ë)', 'C (Ï¢åÌïò)', 'D (Ïö∞ÏÉÅ)', 'E (Ïö∞Ï§ë)', 'F (Ïö∞Ìïò)']
CLASS_NAMES = ['Ï†ïÏÉÅ (0)', 'Í≤ΩÏ¶ù (1)', 'Ï§ëÎì±ÎèÑ (2)', 'Ï§ëÏ¶ù (3)']

# Î∂ÄÏúÑÎ≥Ñ ÏÉâÏÉÅ (ÏãúÍ∞ÅÌôîÏö©)
REGION_COLORS = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#F7DC6F']

# ============================================================
# Î™®Îç∏ Ï†ïÏùò (ÌïôÏäµ ÏΩîÎìúÏôÄ ÎèôÏùº)
# ============================================================
class EfficientNetB0Classification(nn.Module):
    def __init__(self, pretrained=False, drop=0.3, num_regions=6, num_classes=4):
        super().__init__()
        
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = efficientnet_b0(weights=weights)
        
        self.features = backbone.features
        in_feat = 1280
        
        self.pos_embed = nn.Parameter(torch.randn(1, 49, in_feat))
        self.region_queries = nn.Parameter(torch.randn(num_regions, in_feat))
        
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=in_feat, num_heads=8, dropout=0.1, batch_first=True
        )
        self.norm1 = nn.LayerNorm(in_feat)
        
        self.self_attention = nn.MultiheadAttention(
            embed_dim=in_feat, num_heads=8, dropout=0.1, batch_first=True
        )
        self.norm2 = nn.LayerNorm(in_feat)
        self.norm3 = nn.LayerNorm(in_feat)
        
        self.ffn = nn.Sequential(
            nn.Linear(in_feat, in_feat * 2),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(in_feat * 2, in_feat),
            nn.Dropout(drop)
        )
        self.norm4 = nn.LayerNorm(in_feat)
        
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_feat, 256),
                nn.LayerNorm(256),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(256, 128),
                nn.LayerNorm(128),
                nn.GELU(),
                nn.Dropout(drop),
                nn.Linear(128, num_classes)
            ) for _ in range(num_regions)
        ])
    
    def forward(self, x):
        B = x.size(0)
        feat = self.features(x)
        feat = feat.flatten(2).transpose(1, 2)
        feat = feat + self.pos_embed
        
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)
        attn_out, _ = self.cross_attention(query=queries, key=feat, value=feat)
        attn_out = self.norm1(attn_out + queries)
        
        self_attn_out, attn_weights = self.self_attention(
            query=attn_out, key=attn_out, value=attn_out
        )
        attn_out = self.norm2(attn_out + self_attn_out)
        ffn_out = self.ffn(attn_out)
        attn_out = self.norm3(attn_out + ffn_out)
        
        outputs = []
        for i in range(len(self.heads)):
            region_feat = attn_out[:, i, :]
            logits = self.heads[i](region_feat)
            outputs.append(logits)
        
        out = torch.stack(outputs, dim=1)
        return out, attn_weights

# ============================================================
# Grad-CAM Íµ¨ÌòÑ
# ============================================================
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Hook Îì±Î°ù
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, region_idx, class_idx=None):
        """
        ÌäπÏ†ï Î∂ÄÏúÑ(region_idx)ÏôÄ ÌÅ¥ÎûòÏä§(class_idx)Ïóê ÎåÄÌïú CAM ÏÉùÏÑ±
        """
        self.model.eval()
        
        # Forward pass
        output, _ = self.model(input_image)  # [1, 6, 4]
        
        # ÌäπÏ†ï Î∂ÄÏúÑÏùò ÏòàÏ∏°
        region_output = output[0, region_idx, :]  # [4]
        
        # class_idxÍ∞Ä ÏóÜÏúºÎ©¥ ÏòàÏ∏°Îêú ÌÅ¥ÎûòÏä§ ÏÇ¨Ïö©
        if class_idx is None:
            class_idx = region_output.argmax().item()
        
        # Backward pass
        self.model.zero_grad()
        target = region_output[class_idx]
        target.backward(retain_graph=True)
        
        # Grad-CAM Í≥ÑÏÇ∞
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Global Average Pooling on gradients
        weights = gradients.mean(dim=(1, 2), keepdim=True)  # [C, 1, 1]
        
        # Weighted combination
        cam = (weights * activations).sum(dim=0)  # [H, W]
        cam = F.relu(cam)
        
        # Normalize
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        
        return cam.cpu().numpy(), class_idx

# ============================================================
# ÏãúÍ∞ÅÌôî Ìï®Ïàò
# ============================================================
def preprocess_image(image_path, img_size=IMG_SIZE):
    """Ïù¥ÎØ∏ÏßÄ Ï†ÑÏ≤òÎ¶¨"""
    img = Image.open(image_path).convert('RGB')
    img_resized = img.resize((img_size, img_size), Image.BILINEAR)
    
    # Tensor Î≥ÄÌôò
    img_tensor = TF.to_tensor(img_resized)
    img_tensor = TF.normalize(img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    return img_tensor.unsqueeze(0), img_resized

def visualize_gradcam(original_img, cam, region_name, pred_class, confidence, 
                      save_path=None, alpha=0.4):
    """Grad-CAM ÌûàÌä∏Îßµ ÏãúÍ∞ÅÌôî"""
    # CAMÏùÑ ÏõêÎ≥∏ Ïù¥ÎØ∏ÏßÄ ÌÅ¨Í∏∞Î°ú Î¶¨ÏÇ¨Ïù¥Ï¶à
    cam_resized = np.array(Image.fromarray(cam).resize(
        original_img.size, Image.BILINEAR
    ))
    
    # ÌûàÌä∏Îßµ ÏÉùÏÑ±
    plt.figure(figsize=(10, 8))
    
    # ÏõêÎ≥∏ Ïù¥ÎØ∏ÏßÄ
    plt.subplot(1, 2, 1)
    plt.imshow(original_img)
    plt.title(f'{region_name}\nÏòàÏ∏°: {CLASS_NAMES[pred_class]} ({confidence:.1f}%)', 
              fontsize=12, fontweight='bold')
    plt.axis('off')
    
    # Grad-CAM Ïò§Î≤ÑÎ†àÏù¥
    plt.subplot(1, 2, 2)
    plt.imshow(original_img)
    plt.imshow(cam_resized, cmap='jet', alpha=alpha)
    plt.title(f'Grad-CAM Heatmap', fontsize=12, fontweight='bold')
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úÖ Ï†ÄÏû•: {save_path}")
    
    plt.close()

def visualize_all_regions(original_img, model, img_tensor, predictions, 
                          confidences, save_path=None):
    """6Í∞ú Î∂ÄÏúÑ Î™®Îëê ÏãúÍ∞ÅÌôî (2x3 grid)"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Grad-CAM ÏÉùÏÑ±
    target_layer = model.features[-1]  # EfficientNet ÎßàÏßÄÎßâ Conv layer
    grad_cam = GradCAM(model, target_layer)
    
    for idx in range(6):
        cam, pred_class = grad_cam.generate_cam(img_tensor, region_idx=idx)
        
        # CAM Î¶¨ÏÇ¨Ïù¥Ï¶à
        cam_resized = np.array(Image.fromarray(cam).resize(
            original_img.size, Image.BILINEAR
        ))
        
        # ÏãúÍ∞ÅÌôî
        axes[idx].imshow(original_img)
        im = axes[idx].imshow(cam_resized, cmap='jet', alpha=0.4)
        
        axes[idx].set_title(
            f'{REGION_NAMES[idx]}\n{CLASS_NAMES[pred_class]} ({confidences[idx]:.1f}%)',
            fontsize=12, fontweight='bold', color=REGION_COLORS[idx]
        )
        axes[idx].axis('off')
        
        # Colorbar
        cbar = plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)
    
    # Ï†ÑÏ≤¥ Brixia Score ÌëúÏãú
    total_score = sum(predictions)
    fig.suptitle(
        f'Brixia COVID-19 Severity Score: {total_score}/18\n'
        f'Í∞Å Î∂ÄÏúÑÎ≥Ñ Grad-CAM ÌûàÌä∏Îßµ',
        fontsize=16, fontweight='bold'
    )
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úÖ Ï†ÑÏ≤¥ ÏãúÍ∞ÅÌôî Ï†ÄÏû•: {save_path}")
    
    plt.close()

def visualize_prediction_summary(predictions, confidences, save_path=None):
    """ÏòàÏ∏° Í≤∞Í≥º ÏöîÏïΩ Ï∞®Ìä∏"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # 1. Î∂ÄÏúÑÎ≥Ñ Ï§ëÏ¶ùÎèÑ ÎßâÎåÄ Í∑∏ÎûòÌîÑ
    x = np.arange(6)
    bars = ax1.bar(x, predictions, color=REGION_COLORS, alpha=0.7, edgecolor='black')
    ax1.set_xticks(x)
    ax1.set_xticklabels(REGION_NAMES, rotation=45, ha='right')
    ax1.set_ylabel('Ï§ëÏ¶ùÎèÑ Ï†êÏàò', fontsize=12, fontweight='bold')
    ax1.set_ylim(0, 3.5)
    ax1.set_title('Î∂ÄÏúÑÎ≥Ñ Ï§ëÏ¶ùÎèÑ Ï†êÏàò', fontsize=14, fontweight='bold')
    ax1.grid(axis='y', alpha=0.3)
    
    # Í∞Å ÎßâÎåÄ ÏúÑÏóê Í∞í ÌëúÏãú
    for i, (bar, conf) in enumerate(zip(bars, confidences)):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{int(height)}\n({conf:.0f}%)',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # 2. Ï¢åÏö∞ ÎπÑÍµê
    left_scores = predictions[:3]  # A, B, C
    right_scores = predictions[3:]  # D, E, F
    
    categories = ['ÏÉÅÎ∂Ä', 'Ï§ëÎ∂Ä', 'ÌïòÎ∂Ä']
    x2 = np.arange(len(categories))
    width = 0.35
    
    ax2.bar(x2 - width/2, left_scores, width, label='Ï¢åÌèê', 
            color='#FF6B6B', alpha=0.7, edgecolor='black')
    ax2.bar(x2 + width/2, right_scores, width, label='Ïö∞Ìèê', 
            color='#4ECDC4', alpha=0.7, edgecolor='black')
    
    ax2.set_xticks(x2)
    ax2.set_xticklabels(categories)
    ax2.set_ylabel('Ï§ëÏ¶ùÎèÑ Ï†êÏàò', fontsize=12, fontweight='bold')
    ax2.set_ylim(0, 3.5)
    ax2.set_title('Ï¢åÏö∞ Ìèê ÎπÑÍµê', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(axis='y', alpha=0.3)
    
    # Ï¥ùÏ†ê ÌëúÏãú
    total_score = sum(predictions)
    fig.suptitle(f'Ï†ÑÏ≤¥ Brixia Score: {total_score}/18', 
                 fontsize=16, fontweight='bold')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úÖ ÏöîÏïΩ Ï∞®Ìä∏ Ï†ÄÏû•: {save_path}")
    
    plt.close()

# ============================================================
# Ï∂îÎ°† Ìï®Ïàò
# ============================================================
# @torch.no_grad()
def predict_single_image(model, image_path, visualize=True):
    """Îã®Ïùº Ïù¥ÎØ∏ÏßÄ Ï∂îÎ°† Î∞è ÏãúÍ∞ÅÌôî"""
    model.eval()

    with torch.no_grad():
        # Ïù¥ÎØ∏ÏßÄ Î°úÎìú Î∞è Ï†ÑÏ≤òÎ¶¨
        img_tensor, original_img = preprocess_image(image_path)
        img_tensor = img_tensor.to(DEVICE)
        
        # ÏòàÏ∏°
        output, attn_weights = model(img_tensor)  # [1, 6, 4]
        probs = F.softmax(output, dim=-1)  # [1, 6, 4]
        
        predictions = output[0].argmax(dim=-1).cpu().numpy()  # [6]
        confidences = probs[0].max(dim=-1).values.detach().cpu().numpy() * 100  # [6]
    
    # Í≤∞Í≥º Ï∂úÎ†•
    print("\n" + "="*60)
    print(f"üìä ÏòàÏ∏° Í≤∞Í≥º: {os.path.basename(image_path)}")
    print("="*60)
    
    for i, (name, pred, conf) in enumerate(zip(REGION_NAMES, predictions, confidences)):
        print(f"{name}: {CLASS_NAMES[pred]:12s} (Ïã†Î¢∞ÎèÑ: {conf:5.1f}%)")
    
    total_score = predictions.sum()
    print(f"\nÏ†ÑÏ≤¥ Brixia Score: {total_score}/18")
    print("="*60)
    
    # ÏãúÍ∞ÅÌôî
    if visualize:
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        
        # 1. Ï†ÑÏ≤¥ 6Í∞ú Î∂ÄÏúÑ Grad-CAM
        all_regions_path = os.path.join(OUTPUT_DIR, f"{base_name}_all_regions.png")
        visualize_all_regions(original_img, model, img_tensor, 
                             predictions, confidences, all_regions_path)
        
        # 2. ÏòàÏ∏° ÏöîÏïΩ Ï∞®Ìä∏
        summary_path = os.path.join(OUTPUT_DIR, f"{base_name}_summary.png")
        visualize_prediction_summary(predictions, confidences, summary_path)
        
        # 3. Í∞úÎ≥Ñ Î∂ÄÏúÑ Grad-CAM (ÏÑ†ÌÉùÏ†Å)
        # for idx in range(6):
        #     single_path = os.path.join(OUTPUT_DIR, f"{base_name}_{REGION_NAMES[idx]}.png")
        #     target_layer = model.features[-1]
        #     grad_cam = GradCAM(model, target_layer)
        #     cam, pred_class = grad_cam.generate_cam(img_tensor, region_idx=idx)
        #     visualize_gradcam(original_img, cam, REGION_NAMES[idx], 
        #                      pred_class, confidences[idx], single_path)
    
    return predictions, confidences, attn_weights

# ============================================================
# Î©îÏù∏ Ïã§Ìñâ
# ============================================================
def main():
    print("\n" + "="*70)
    print("üî¨ Brixia COVID-19 Grad-CAM ÏãúÍ∞ÅÌôî")
    print("="*70)
    
    # Î™®Îç∏ Î°úÎìú
    print(f"\nüì¶ Î™®Îç∏ Î°úÎî©: {MODEL_PATH}")
    model = EfficientNetB0Classification(
        pretrained=False,
        drop=0.3,
        num_regions=NUM_REGIONS,
        num_classes=NUM_CLASSES
    ).to(DEVICE)
    
    checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("‚úÖ Î™®Îç∏ Î°úÎìú ÏôÑÎ£å!")
    
    # ÌÖåÏä§Ìä∏ Ïù¥ÎØ∏ÏßÄ Í≤ΩÎ°ú (ÏòàÏãú)
    # Ïã§Ï†ú ÏÇ¨Ïö© Ïãú Í≤ΩÎ°ú ÏàòÏ†ï ÌïÑÏöî
    test_image_path = "./data/covid19-xray-severity-scoring/segmented_png/5032497707401895.png"
    
    # Ïù¥ÎØ∏ÏßÄÍ∞Ä Ï°¥Ïû¨ÌïòÎäîÏßÄ ÌôïÏù∏
    if not os.path.exists(test_image_path):
        print(f"\n‚ö†Ô∏è ÌÖåÏä§Ìä∏ Ïù¥ÎØ∏ÏßÄÎ•º Ï∞æÏùÑ Ïàò ÏóÜÏäµÎãàÎã§: {test_image_path}")
        print("üìù ÏÇ¨Ïö©Î≤ï:")
        print("   1. test_image_path Î≥ÄÏàòÎ•º Ïã§Ï†ú Ïù¥ÎØ∏ÏßÄ Í≤ΩÎ°úÎ°ú ÏàòÏ†ï")
        print("   2. ÎòêÎäî ÏïÑÎûò Ìï®ÏàòÎ•º ÏßÅÏ†ë Ìò∏Ï∂ú:")
        print("      predict_single_image(model, 'your_image_path.png')")
        return
    
    # Ï∂îÎ°† Î∞è ÏãúÍ∞ÅÌôî
    predictions, confidences, attn_weights = predict_single_image(
        model, test_image_path, visualize=True
    )
    
    print(f"\n‚úÖ Í≤∞Í≥ºÍ∞Ä {OUTPUT_DIR}/ Ìè¥ÎçîÏóê Ï†ÄÏû•ÎêòÏóàÏäµÎãàÎã§!")
    print("="*70)

if __name__ == "__main__":
    main()


üî¨ Brixia COVID-19 Grad-CAM ÏãúÍ∞ÅÌôî

üì¶ Î™®Îç∏ Î°úÎî©: ./runs_severity_classification/best_efficientnet_b0_classification.pth
‚úÖ Î™®Îç∏ Î°úÎìú ÏôÑÎ£å!

üìä ÏòàÏ∏° Í≤∞Í≥º: 5032497707401895.png
A (Ï¢åÏÉÅ): Ï†ïÏÉÅ (0)       (Ïã†Î¢∞ÎèÑ:  65.3%)
B (Ï¢åÏ§ë): Í≤ΩÏ¶ù (1)       (Ïã†Î¢∞ÎèÑ:  51.2%)
C (Ï¢åÌïò): Ï§ëÎì±ÎèÑ (2)      (Ïã†Î¢∞ÎèÑ:  56.5%)
D (Ïö∞ÏÉÅ): Ï†ïÏÉÅ (0)       (Ïã†Î¢∞ÎèÑ:  56.0%)
E (Ïö∞Ï§ë): Ï§ëÎì±ÎèÑ (2)      (Ïã†Î¢∞ÎèÑ:  50.0%)
F (Ïö∞Ìïò): Ï§ëÏ¶ù (3)       (Ïã†Î¢∞ÎèÑ:  66.3%)

Ï†ÑÏ≤¥ Brixia Score: 8/18


  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.tight_layout(rect=[0, 0.03, 1, 0.97])
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150

‚úÖ Ï†ÑÏ≤¥ ÏãúÍ∞ÅÌôî Ï†ÄÏû•: ./gradcam_results/5032497707401895_all_regions.png


  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.tight_layout(rect=[0, 0.03, 1, 0.95])
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
  plt.savefig(save_path, dpi=150

‚úÖ ÏöîÏïΩ Ï∞®Ìä∏ Ï†ÄÏû•: ./gradcam_results/5032497707401895_summary.png

‚úÖ Í≤∞Í≥ºÍ∞Ä ./gradcam_results/ Ìè¥ÎçîÏóê Ï†ÄÏû•ÎêòÏóàÏäµÎãàÎã§!


In [6]:
import os, random, time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

from tqdm import tqdm
from torch.amp import autocast
from torch.cuda.amp import autocast, GradScaler

# --- Í≤ΩÎ°ú ÏÑ§Ï†ï Î∞è ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ---
BASE_DIR     = f"./data/covid19-xray-severity-scoring/"
CSV_PATH     = str(Path(BASE_DIR) / "Brixia.csv")
IMAGE_DIR    = str(Path(BASE_DIR) / "segmented_png")

OUT_DIR      = "./runs_severity_classification"
BEST_PATH    = str(Path(OUT_DIR) / "best_efficientnet_b0_classification.pth")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED         = 42
IMG_SIZE     = 224
BATCH_SIZE   = 32
NUM_CLASSES  = 4   # 0, 1, 2, 3
EPOCHS       = 100  # Single phase training
LR           = 1e-4
WEIGHT_DECAY = 5e-4
AMP          = True
EARLY_STOP_ACC = 0.75  # üîÑ MAE ‚Üí Accuracy
DROP_RATIO   = 0.3
AUG_RATIO    = 0.5
MIXUP_ALPHA  = 0.2
LABEL_SMOOTHING = 0.1  # ‚úÖ NEW: Label smoothing

# --- ÏãúÎìú Í≥†Ï†ï ---
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
set_seed(SEED)

def make_transform_with_label(train: bool, img_size: int = IMG_SIZE, aug_ratio=AUG_RATIO):
    """Brixia ScoreÏùò Ï¢åÏö∞ Íµ¨Ï°∞Î•º Í≥†Î†§Ìïú transform"""
    def _tfm(img: Image.Image, label: torch.Tensor = None):
        img = img.convert('RGB')
        img = TF.resize(
            img, 
            [img_size, img_size], 
            interpolation=TF.InterpolationMode.BILINEAR,
            antialias=True
        )
        
        if train:
            # 1. Horizontal Flip (Ï¢åÏö∞ Î∞òÏ†Ñ: ABC ‚Üî DEF)
            if random.random() < aug_ratio:
                img = TF.hflip(img)
                if label is not None:
                    # [A, B, C, D, E, F] ‚Üí [D, E, F, A, B, C]
                    label = label[[3, 4, 5, 0, 1, 2]]
            
            # 2. ÏïΩÌïú ÌöåÏ†Ñ (¬±5ÎèÑ)
            if random.random() < aug_ratio:
                angle = float(torch.empty(1).uniform_(-5, 5))
                img = TF.rotate(
                    img, 
                    angle, 
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 3. ÏïΩÌïú Translation
            if random.random() < aug_ratio:
                max_dx = 0.05 * img_size
                max_dy = 0.05 * img_size
                translations = (
                    float(torch.empty(1).uniform_(-max_dx, max_dx)),
                    float(torch.empty(1).uniform_(-max_dy, max_dy))
                )
                img = TF.affine(
                    img,
                    angle=0,
                    translate=translations,
                    scale=1.0,
                    shear=0,
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 4. Brightness & Contrast
            if random.random() < aug_ratio:
                brightness_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_brightness(img, brightness_factor)
            
            if random.random() < aug_ratio:
                contrast_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_contrast(img, contrast_factor)
            
            # 5. Gamma Correction
            if random.random() < 0.3:
                gamma = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_gamma(img, gamma)
        
        # Tensor Î≥ÄÌôò
        img = TF.to_tensor(img)
        
        # Gaussian Noise (train only)
        if train and random.random() < 0.2:
            noise = torch.randn_like(img) * 0.01
            img = img + noise
            img = torch.clamp(img, 0, 1)
        
        # Ï†ïÍ∑úÌôî
        img = TF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        if label is not None:
            return img, label
        return img
    
    return _tfm

def load_and_split_brixia(csv_path, val_ratio=0.2, seed=SEED):
    df = pd.read_csv(csv_path, dtype={'BrixiaScore': str})
    df = df.dropna(subset=['BrixiaScore'])
    df = df[df['BrixiaScore'] != 'nan']
    df = df[df['BrixiaScore'].str.len() == 6].copy()
    
    print(f"Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: {len(df)}Í∞ú")
    
    if 'ConsensusTestset' in df.columns:
        test_df = df[df['ConsensusTestset'] == 1].copy()
        train_val_df = df[df['ConsensusTestset'] == 0].copy()
    else:
        test_df = pd.DataFrame()
        train_val_df = df.copy()
    
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    train_idx, val_idx = next(gss.split(
        train_val_df, 
        groups=train_val_df['StudyId']
    ))
    
    tr_df = train_val_df.iloc[train_idx].copy()
    val_df = train_val_df.iloc[val_idx].copy()
    
    print(f"Train: {len(tr_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    validate_split(tr_df, val_df, test_df)
    
    return tr_df, val_df, test_df

def validate_split(tr_df, val_df, tt_df):
    train_studies = set(tr_df['StudyId'])
    val_studies = set(val_df['StudyId'])
    test_studies = set(tt_df['StudyId']) if len(tt_df) > 0 else set()
    
    assert len(train_studies & val_studies) == 0, "Train-Val Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(train_studies & test_studies) == 0, "Train-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(val_studies & test_studies) == 0, "Val-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    
    for name, data in [('Train', tr_df), ('Val', val_df), ('Test', tt_df)]:
        if len(data) > 0:
            scores = data['BrixiaScore'].apply(lambda x: sum(int(c) for c in x))
            print(f"{name} - Mean: {scores.mean():.2f}, Std: {scores.std():.2f}")
    
    return True

# ============================================================
# Dataset
# ============================================================
class BrixiaDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.img_col = "Filename"
        self.label_col = "BrixiaScore"
        self._validate_data()
    
    def _validate_data(self):
        assert self.img_col in self.df.columns
        assert self.label_col in self.df.columns
        
        invalid_scores = self.df[self.df[self.label_col].str.len() != 6]
        if len(invalid_scores) > 0:
            print(f"‚ö†Ô∏è Í≤ΩÍ≥†: {len(invalid_scores)}Í∞úÏùò ÏûòÎ™ªÎêú BrixiaScore Î∞úÍ≤¨")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_name_from_csv = row[self.img_col]
        img_name = img_name_from_csv.replace('.dcm', '.png')
        img_path = os.path.join(self.img_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"‚ùå Ïù¥ÎØ∏ÏßÄ Î°úÎìú Ïò§Î•ò: {img_path}")
            raise
        
        scores_str = row[self.label_col]
        scores_list = [int(c) for c in scores_str]
        labels = torch.tensor(scores_list, dtype=torch.long)  # üîÑ longÏúºÎ°ú Î≥ÄÍ≤Ω
        
        if self.transform:
            image, labels = self.transform(image, labels)
        else:
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, labels

def create_dataloaders(tr_df, val_df, tt_df, img_dir, 
                       batch_size=32, img_size=224, num_workers=4):
    train_transform = make_transform_with_label(train=True, img_size=img_size)
    val_transform = make_transform_with_label(train=False, img_size=img_size)
    
    tr_ds = BrixiaDataset(tr_df, img_dir, transform=train_transform)
    val_ds = BrixiaDataset(val_df, img_dir, transform=val_transform)
    tt_ds = BrixiaDataset(tt_df, img_dir, transform=val_transform)
    
    tr_loader = DataLoader(
        tr_ds, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    tt_loader = DataLoader(
        tt_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    print("‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å")
    print(f"   Train: {len(tr_ds)} samples, {len(tr_loader)} batches")
    print(f"   Val:   {len(val_ds)} samples, {len(val_loader)} batches")
    print(f"   Test:  {len(tt_ds)} samples, {len(tt_loader)} batches")
    
    return tr_loader, val_loader, tt_loader

# ============================================================
# Loss Function - Ordinal Classification
# ============================================================
def calculate_class_weights(labels, num_classes=4, method='sqrt_inverse'):
    """ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ìï¥Í≤∞ÏùÑ ÏúÑÌïú Í∞ÄÏ§ëÏπò Í≥ÑÏÇ∞"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    labels_flat = labels.flatten()
    counts = np.bincount(labels_flat.astype(int), minlength=num_classes)
    
    if method == 'sqrt_inverse':
        weights = 1.0 / (np.sqrt(counts) + 1e-6)
    elif method == 'inverse':
        weights = 1.0 / (counts + 1e-6)
    else:
        total = len(labels_flat)
        weights = total / (num_classes * (counts + 1e-6))
    
    weights = weights / weights.mean()
    return torch.FloatTensor(weights)

def print_class_distribution(labels):
    """ÌÅ¥ÎûòÏä§ Î∂ÑÌè¨ ÏãúÍ∞ÅÌôî"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    print("=" * 60)
    print("Class Distribution Analysis")
    print("=" * 60)
    
    region_names = ['A', 'B', 'C', 'D', 'E', 'F']
    for idx, name in enumerate(region_names):
        region_labels = labels[:, idx]
        counts = np.bincount(region_labels.astype(int), minlength=4)
        total = counts.sum()
        
        print(f"\n{name}:")
        for cls in range(4):
            pct = 100 * counts[cls] / total if total > 0 else 0
            bar = '‚ñà' * int(pct / 2)
            print(f"  Class {cls}: {counts[cls]:4d} ({pct:5.1f}%) {bar}")
    
    print("=" * 60)

class OrdinalRegressionLoss(nn.Module):
    """
    ‚úÖ Ordinal Regression Loss (ÏàúÏÑúÌòï ÌöåÍ∑Ä)
    0 < 1 < 2 < 3Ïùò ÏàúÏÑúÎ•º Î™ÖÏãúÏ†ÅÏúºÎ°ú ÌïôÏäµ
    """
    def __init__(self, num_classes=4):
        super().__init__()
        self.num_classes = num_classes
        
    def forward(self, logits, target):
        """
        logits: [B*6, num_classes-1] - cumulative logits
        target: [B*6] - class labels (0~3)
        """
        # Cumulative labels ÏÉùÏÑ±
        # Class 0: [0, 0, 0]
        # Class 1: [1, 0, 0]
        # Class 2: [1, 1, 0]
        # Class 3: [1, 1, 1]
        batch_size = target.size(0)
        target_expanded = target.unsqueeze(1)  # [B*6, 1]
        
        # [0, 1, 2, ..., num_classes-2]
        thresholds = torch.arange(self.num_classes - 1).to(target.device)
        thresholds = thresholds.unsqueeze(0).expand(batch_size, -1)  # [B*6, 3]
        
        # target > thresholdÏù¥Î©¥ 1, ÏïÑÎãàÎ©¥ 0
        cumulative_target = (target_expanded > thresholds).float()  # [B*6, 3]
        
        # Binary cross entropy for each threshold
        loss = F.binary_cross_entropy_with_logits(
            logits, cumulative_target, reduction='none'
        )
        
        return loss.mean()

class AdaptiveOrdinalLoss(nn.Module):
    """
    ‚úÖ Ordinal Loss + Class Weights + Part Weights
    ÏàúÏÑúÎ•º Í≥†Î†§ÌïòÎ©¥ÏÑú Î∂àÍ∑†ÌòïÎèÑ Ìï¥Í≤∞
    """
    
    def __init__(self, train_labels, num_classes=4, use_class_weights=True, 
                 part_weights=None, ordinal_weight=0.5):
        super().__init__()
        self.num_classes = num_classes
        self.ordinal_weight = ordinal_weight  # Ordinal lossÏôÄ CE lossÏùò ÎπÑÏú®
        
        if isinstance(train_labels, torch.Tensor):
            train_labels_np = train_labels.cpu().numpy()
        else:
            train_labels_np = train_labels
        
        print_class_distribution(train_labels_np)
        
        # ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
        if use_class_weights:
            class_weights = calculate_class_weights(train_labels_np, num_classes=num_classes)
            self.register_buffer('class_weights', class_weights)
            print(f"‚úÖ Class weights: {class_weights.numpy()}")
        else:
            self.class_weights = None
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò
        if part_weights is None:
            self.part_weights = torch.ones(6)
        else:
            self.part_weights = torch.tensor(part_weights, dtype=torch.float32)
        self.register_buffer('part_weights_buf', self.part_weights)
        print(f"‚úÖ Part weights: {self.part_weights.numpy()}")
        print(f"‚úÖ Ordinal weight: {ordinal_weight}")
        
        # Ordinal loss
        self.ordinal_loss = OrdinalRegressionLoss(num_classes)

    def forward(self, pred_dict, target, use_mixup=False):
        """
        pred_dict: {'logits': [B, 6, 4], 'ordinal_logits': [B, 6, 3]}
        target: [B, 6]
        """
        pred_logits = pred_dict['logits']  # [B, 6, 4]
        ordinal_logits = pred_dict['ordinal_logits']  # [B, 6, 3]
        
        B, num_regions, num_classes = pred_logits.shape
        
        # Reshape
        pred_logits_flat = pred_logits.view(B * num_regions, num_classes)
        ordinal_logits_flat = ordinal_logits.view(B * num_regions, num_classes - 1)
        target_flat = target.view(B * num_regions)
        
        # 1. CrossEntropyLoss (Í∏∞Î≥∏ Î∂ÑÎ•ò)
        if self.class_weights is not None and not use_mixup:
            ce_loss = F.cross_entropy(
                pred_logits_flat, 
                target_flat,
                weight=self.class_weights.to(pred_logits.device),
                reduction='none'
            )
        else:
            ce_loss = F.cross_entropy(
                pred_logits_flat, 
                target_flat,
                reduction='none'
            )
        
        # 2. Ordinal Loss (ÏàúÏÑú ÌïôÏäµ)
        ord_loss = self.ordinal_loss(ordinal_logits_flat, target_flat)
        
        # 3. Í≤∞Ìï©
        total_loss_flat = (1 - self.ordinal_weight) * ce_loss + self.ordinal_weight * ord_loss
        total_loss = total_loss_flat.view(B, num_regions)
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
        part_weights = self.part_weights_buf.to(pred_logits.device)
        total_loss = total_loss * part_weights.unsqueeze(0)
        
        part_losses = total_loss.mean(dim=0)
        
        return total_loss.mean(), part_losses

# ============================================================
# Mixup (Î∂ÑÎ•òÏö©)
# ============================================================
def mixup_data_classification(x, y, alpha=MIXUP_ALPHA):
    """
    ‚úÖ NEW: Î∂ÑÎ•òÏö© Mixup
    yÎäî one-hotÏúºÎ°ú Î≥ÄÌôò ÌõÑ mixup
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    
    # One-hot encoding
    y_onehot = F.one_hot(y, num_classes=NUM_CLASSES).float()  # [B, 6, 4]
    y_onehot_shuffled = y_onehot[index]
    
    mixed_y = lam * y_onehot + (1 - lam) * y_onehot_shuffled  # [B, 6, 4]
    
    return mixed_x, mixed_y, lam

def mixup_criterion(criterion, pred_dict, y_mixed, lam):
    """MixupÏùÑ ÏúÑÌïú ÏÜêÏã§ Í≥ÑÏÇ∞ (Ordinal Loss ÏßÄÏõê)"""
    pred_logits = pred_dict['logits']
    ordinal_logits = pred_dict['ordinal_logits']
    
    B, num_regions, num_classes = pred_logits.shape
    
    # Reshape
    pred_flat = pred_logits.view(B * num_regions, num_classes)
    ordinal_flat = ordinal_logits.view(B * num_regions, num_classes - 1)
    target_flat = y_mixed.view(B * num_regions, num_classes)
    
    # Soft target loss (CE part)
    log_probs = F.log_softmax(pred_flat, dim=1)
    ce_loss = -(target_flat * log_probs).sum(dim=1)
    
    # Ordinal partÎäî mixupÏóêÏÑú skip (hard labelÎßå ÏÇ¨Ïö©)
    loss = ce_loss  # Mixup ÏãúÏóêÎäî ordinal loss Ï†úÏô∏
    loss = loss.view(B, num_regions)
    
    # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    if hasattr(criterion, 'part_weights_buf'):
        part_weights = criterion.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)
    
    part_losses = loss.mean(dim=0)
    return loss.mean(), part_losses

# ============================================================
# Metrics (Î∂ÑÎ•òÏö©)
# ============================================================
@torch.no_grad()
def calculate_classification_metrics(pred_dict, labels):
    """
    ‚úÖ Î∂ÑÎ•ò ÏßÄÌëú Í≥ÑÏÇ∞
    pred_dict: {'logits': [B, 6, 4], 'ordinal_logits': [B, 6, 3]}
    labels: [B, 6]
    """
    pred_logits = pred_dict['logits']
    
    # ÏòàÏ∏° ÌÅ¥ÎûòÏä§
    preds = pred_logits.argmax(dim=-1)  # [B, 6]
    
    # Exact match accuracy
    exact_acc = (preds == labels).float().mean().item()
    
    # Off-by-1 accuracy (Ïù∏Ï†ë ÌÅ¥ÎûòÏä§ ÌóàÏö©)
    off_by_1 = (torch.abs(preds - labels) <= 1).float().mean().item()
    
    # Per-region accuracy
    region_acc = (preds == labels).float().mean(dim=0)  # [6]
    
    # MAE (Ï∞∏Í≥†Ïö©)
    mae = torch.abs(preds.float() - labels.float()).mean().item()
    
    return exact_acc, off_by_1, mae, region_acc

# ============================================================
# Model (Î∂ÑÎ•òÏö©ÏúºÎ°ú Î≥ÄÍ≤Ω)
# ============================================================
class EfficientNetB0Classification(nn.Module):
    """
    ‚úÖ Ordinal Classification Î™®Îç∏
    - Í∞Å Î∂ÄÏúÑ(A~F)ÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§(0~3) ÏòàÏ∏°
    - Self-AttentionÏúºÎ°ú Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ
    - Ordinal Regression Head Ï∂îÍ∞Ä (ÏàúÏÑú ÌïôÏäµ)
    """
    def __init__(self, pretrained=True, drop=0.3, num_regions=6, num_classes=4):
        super().__init__()
        
        weights = EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = efficientnet_b0(weights=weights)
        
        self.features = backbone.features
        in_feat = 1280
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, 49, in_feat))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Region queries
        self.region_queries = nn.Parameter(torch.randn(num_regions, in_feat))
        nn.init.xavier_uniform_(self.region_queries)
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ÌäπÏßï ‚Üí Î∂ÄÏúÑÎ≥Ñ ÌäπÏßï)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(in_feat)
        
        # Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ)
        self.self_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm2 = nn.LayerNorm(in_feat)
        self.norm3 = nn.LayerNorm(in_feat)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(in_feat, in_feat * 2),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(in_feat * 2, in_feat),
            nn.Dropout(drop)
        )
        
        self.norm4 = nn.LayerNorm(in_feat)
        
        # Shared feature extractor
        self.shared_fc = nn.Sequential(
            nn.Linear(in_feat, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(drop)
        )
        
        # ‚úÖ Classification heads (ÏùºÎ∞ò Î∂ÑÎ•ò)
        self.classification_heads = nn.ModuleList([
            nn.Linear(128, num_classes) for _ in range(num_regions)
        ])
        
        # ‚úÖ Ordinal Regression heads (ÏàúÏÑúÌòï ÌöåÍ∑Ä)
        # num_classes-1 Í∞úÏùò thresholdÎ•º ÌïôÏäµ
        self.ordinal_heads = nn.ModuleList([
            nn.Linear(128, num_classes - 1) for _ in range(num_regions)
        ])
    
    def forward(self, x):
        B = x.size(0)
        
        # Feature extraction
        feat = self.features(x)  # [B, 1280, 7, 7]
        feat = feat.flatten(2).transpose(1, 2)  # [B, 49, 1280]
        feat = feat + self.pos_embed
        
        # Region queries
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)  # [B, 6, 1280]
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ‚Üí Î∂ÄÏúÑ)
        attn_out, _ = self.cross_attention(
            query=queries,
            key=feat,
            value=feat
        )
        attn_out = self.norm1(attn_out + queries)  # [B, 6, 1280]
        
        # Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ)
        self_attn_out, attn_weights = self.self_attention(
            query=attn_out,
            key=attn_out,
            value=attn_out
        )
        attn_out = self.norm2(attn_out + self_attn_out)  # [B, 6, 1280]
        
        # FFN
        ffn_out = self.ffn(attn_out)
        attn_out = self.norm3(attn_out + ffn_out)  # [B, 6, 1280]
        
        # Shared feature extraction
        shared_features = self.shared_fc(attn_out)  # [B, 6, 128]
        
        # ‚úÖ Í∞Å Î∂ÄÏúÑÎ≥Ñ ÏòàÏ∏° (Classification + Ordinal)
        classification_outputs = []
        ordinal_outputs = []
        
        for i in range(len(self.classification_heads)):
            region_feat = shared_features[:, i, :]  # [B, 128]
            
            # Classification logits
            class_logits = self.classification_heads[i](region_feat)  # [B, 4]
            classification_outputs.append(class_logits)
            
            # Ordinal logits (cumulative)
            ordinal_logits = self.ordinal_heads[i](region_feat)  # [B, 3]
            ordinal_outputs.append(ordinal_logits)
        
        class_out = torch.stack(classification_outputs, dim=1)  # [B, 6, 4]
        ordinal_out = torch.stack(ordinal_outputs, dim=1)  # [B, 6, 3]
        
        return {
            'logits': class_out,
            'ordinal_logits': ordinal_out,
            'attn_weights': attn_weights
        }

# ============================================================
# Training Functions (Î∂ÑÎ•òÏö© ÏàòÏ†ï)
# ============================================================
def train_epoch(model, tr_loader, criterion, optimizer, scaler, device, 
                amp=True, use_mixup=True):
    model.train()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    
    pbar = tqdm(tr_loader, desc="Train", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)  # [B, 6]
        
        # Mixup Ï†ÅÏö©
        is_mixup = use_mixup and (random.random() < 0.5)
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=amp):
            if is_mixup:
                imgs_mixed, labels_mixed, lam = mixup_data_classification(imgs, labels)
                pred_dict = model(imgs_mixed)
                loss, part_losses = mixup_criterion(criterion, pred_dict, labels_mixed, lam)
            else:
                pred_dict = model(imgs)
                loss, part_losses = criterion(pred_dict, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Metrics (ÏõêÎ≥∏ labelsÎ°ú Í≥ÑÏÇ∞)
        exact_acc, off_by_1, mae, _ = calculate_classification_metrics(pred_dict, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}",
            mae=f"{run_mae/n:.4f}"
        )
    
    part_losses_avg = part_losses_sum / n
    
    return run_loss/n, run_acc/n, run_off1/n, run_mae/n, part_losses_avg

@torch.no_grad()
def evaluate(model, val_loader, criterion, device, split='val'):
    model.eval()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    region_acc_sum = torch.zeros(6)
    
    pbar = tqdm(val_loader, desc=f"{split.capitalize()}", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        pred_dict = model(imgs)
        loss, part_losses = criterion(pred_dict, labels)
        
        exact_acc, off_by_1, mae, region_acc = calculate_classification_metrics(pred_dict, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        region_acc_sum += region_acc.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}"
        )
    
    avg_loss = run_loss/n
    avg_acc = run_acc/n
    avg_off1 = run_off1/n
    avg_mae = run_mae/n
    part_losses_avg = part_losses_sum / n
    region_acc_avg = region_acc_sum / n
    
    print(f"[{split}] loss:{avg_loss:.4f} acc:{avg_acc:.4f} "
          f"off1:{avg_off1:.4f} mae:{avg_mae:.4f}")
    print(f"  Region Acc: {region_acc_avg.numpy().round(3)}")
    
    return avg_loss, avg_acc, avg_off1, avg_mae, part_losses_avg, region_acc_avg

def get_lrs(optimizer):
    return [pg['lr'] for pg in optimizer.param_groups]

# ============================================================
# Main Function
# ============================================================
def main():
    print("\n" + "="*70)
    print("üöÄ Brixia COVID-19 Ordinal Classification Training")
    print("   üí° 80% Î™©ÌëúÎ•º ÏúÑÌïú ÌïµÏã¨ Í∏∞Ïà†:")
    print("   ‚úÖ Ordinal Loss - ÏàúÏÑúÌòï ÌöåÍ∑Ä (0<1<2<3)")
    print("   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)")
    print("   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)")
    print("   ‚úÖ Mixed Precision Training (AMP)")
    print("   ‚úÖ Mixup Augmentation")
    print("="*70)
    
    # Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
    print("\nüìÇ Loading data...")
    tr_df, val_df, tt_df = load_and_split_brixia(CSV_PATH)
    
    print("\nüì¶ Creating DataLoaders...")
    tr_loader, val_loader, tt_loader = create_dataloaders(
        tr_df, val_df, tt_df, img_dir=IMAGE_DIR, 
        batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=4
    )
    
    # Train labels Ï∂îÏ∂ú
    train_labels = torch.cat([labels for _, labels in tr_loader], dim=0)
    
    # ========================================
    # ÌïôÏäµ Ï§ÄÎπÑ
    # ========================================
    print("\n" + "="*70)
    print("üìç Training Setup")
    print("="*70)
    
    # ‚úÖ Ordinal Loss + ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
    criterion = AdaptiveOrdinalLoss(
        train_labels, 
        num_classes=NUM_CLASSES,
        use_class_weights=True,
        part_weights=None,
        ordinal_weight=0.5  # CE:Ordinal = 50:50
    )
    
    # ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥ÑÎ•º Í≥†Î†§ÌïòÎäî Î™®Îç∏
    model = EfficientNetB0Classification(
        pretrained=True, 
        drop=DROP_RATIO,
        num_regions=6,
        num_classes=NUM_CLASSES
    ).to(DEVICE)
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6
    )
    scaler = GradScaler(enabled=AMP)
    
    best_acc = 0.0
    best_mae = float('inf')
    patience_counter = 0
    max_patience = 15
    
    # ========================================
    # ÌïôÏäµ Î£®ÌîÑ
    # ========================================
    print("\n" + "="*70)
    print("üèãÔ∏è Training Start")
    print("="*70)
    
    for ep in range(1, EPOCHS + 1):
        t0 = time.time()
        
        tr_loss, tr_acc, tr_off1, tr_mae, tr_part_losses = train_epoch(
            model, tr_loader, criterion, optimizer, scaler, DEVICE, AMP
        )
        
        val_loss, val_acc, val_off1, val_mae, val_part_losses, val_region_acc = evaluate(
            model, val_loader, criterion, DEVICE, split='val'
        )
        
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_mae = val_mae
            patience_counter = 0
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': best_acc,
                'val_mae': best_mae,
                'val_off1': val_off1,
                'region_acc': val_region_acc,
            }, BEST_PATH)
            print(f"‚úÖ New Best! (Acc={best_acc:.4f}, MAE={best_mae:.4f})")
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= max_patience:
            print(f"\n‚èπÔ∏è Early stopping at epoch {ep}")
            break
        
        elapsed = time.time() - t0
        print(f"\n[Epoch {ep:03d}/{EPOCHS}]")
        print(f"  Train - loss:{tr_loss:.4f} acc:{tr_acc:.4f} off1:{tr_off1:.4f} mae:{tr_mae:.4f}")
        print(f"  Val   - loss:{val_loss:.4f} acc:{val_acc:.4f} off1:{val_off1:.4f} mae:{val_mae:.4f}")
        print(f"  Part losses (Val): {val_part_losses.numpy().round(3)}")
        print(f"  Region Acc (Val): {val_region_acc.numpy().round(3)}")
        print(f"  LR:{get_lrs(optimizer)[0]:.2e} | {elapsed:.1f}s | Patience:{patience_counter}/{max_patience}")
        print("-" * 70)
    
    # ========================================
    # Test Evaluation
    # ========================================
    print("\n" + "="*70)
    print("üéâ Training Finished!")
    print("="*70)
    
    if len(tt_loader) > 0:
        print("\nüìä Test evaluation with best model...")
        checkpoint = torch.load(BEST_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        tt_loss, tt_acc, tt_off1, tt_mae, tt_part_losses, tt_region_acc = evaluate(
            model, tt_loader, criterion, DEVICE, split='test'
        )
        
        print(f"\nüèÜ Test Results:")
        print(f"   Accuracy: {tt_acc:.4f}")
        print(f"   Off-by-1: {tt_off1:.4f}")
        print(f"   MAE: {tt_mae:.4f}")
        print(f"   Region Acc: {tt_region_acc.numpy().round(3)}")
        print(f"   Part losses: {tt_part_losses.numpy().round(3)}")
    
    print(f"\nüíæ Best model saved: {BEST_PATH}")
    print(f"üìà Best Validation Accuracy: {best_acc:.4f}")
    print(f"üìâ Best Validation MAE: {best_mae:.4f}")
    
    # ÏÑ±Îä• Ìñ•ÏÉÅ Ï†úÏïà
    print("\n" + "="*70)
    print("üí° Ï∂îÍ∞Ä ÏÑ±Îä• Ìñ•ÏÉÅÏùÑ ÏúÑÌïú Ï†úÏïà:")
    print("="*70)
    print("1. üéØ ÏïôÏÉÅÎ∏î: Îã§Î•∏ ÏãúÎìúÎ°ú 3~5Í∞ú Î™®Îç∏ ÌïôÏäµ ÌõÑ Ìà¨Ìëú")
    print("2. üî¨ ÏùòÎ£å Ï†ÑÏö© pretrained model ÏÇ¨Ïö©:")
    print("   - CheXpert, MIMIC-CXR Îì±ÏúºÎ°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú Î™®Îç∏")
    print("3. üìä Îç∞Ïù¥ÌÑ∞ Ï∂îÍ∞Ä:")
    print("   - Ïô∏Î∂Ä COVID-19 Îç∞Ïù¥ÌÑ∞ÏÖã ÌôúÏö©")
    print("   - Pseudo-labelingÏúºÎ°ú unlabeled Îç∞Ïù¥ÌÑ∞ ÌôúÏö©")
    print("4. üé® Í≥†Í∏â Ï¶ùÍ∞ï:")
    print("   - CutMix, AugMix, RandAugment")
    print("   - Test-Time Augmentation (TTA)")
    print("5. üß™ ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ÌäúÎãù:")
    print("   - ordinal_weight Ï°∞Ï†ï (0.3~0.7)")
    print("   - Learning rate, batch size Ïã§Ìóò")
    print("="*70)
    
    print("\n‚úÖ Ïù¥Ï†ú gradcam_inference.pyÎ•º Ïã§ÌñâÌïòÏó¨ Í≤∞Í≥ºÎ•º ÏãúÍ∞ÅÌôîÌïòÏÑ∏Ïöî!")
    print("="*70)

if __name__ == "__main__":
    main()


üöÄ Brixia COVID-19 Ordinal Classification Training
   üí° 80% Î™©ÌëúÎ•º ÏúÑÌïú ÌïµÏã¨ Í∏∞Ïà†:
   ‚úÖ Ordinal Loss - ÏàúÏÑúÌòï ÌöåÍ∑Ä (0<1<2<3)
   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)
   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)
   ‚úÖ Mixed Precision Training (AMP)
   ‚úÖ Mixup Augmentation

üìÇ Loading data...
Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: 4695Í∞ú
Train: 3637, Val: 912, Test: 146
Train - Mean: 8.31, Std: 4.26
Val - Mean: 8.35, Std: 4.15
Test - Mean: 7.78, Std: 4.20

üì¶ Creating DataLoaders...
‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å
   Train: 3637 samples, 113 batches
   Val:   912 samples, 29 batches
   Test:  146 samples, 5 batches

üìç Training Setup
Class Distribution Analysis

A:
  Class 0: 1810 ( 50.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1: 1126 ( 31.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 2:  446 ( 12.3%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 3:  234 (  6.5%) ‚ñà‚ñà‚ñà

B:
  Class 0:  721 ( 19.9%) ‚ñà‚ñà‚ñà‚ñà‚ñà

                                                                                             

[val] loss:0.8177 acc:0.5000 off1:0.8849 mae:0.6274
  Region Acc: [0.535 0.478 0.442 0.604 0.477 0.464]
‚úÖ New Best! (Acc=0.5000, MAE=0.6274)

[Epoch 001/100]
  Train - loss:1.0706 acc:0.4070 off1:0.8209 mae:0.8038
  Val   - loss:0.8177 acc:0.5000 off1:0.8849 mae:0.6274
  Part losses (Val): [0.805 0.849 0.851 0.726 0.832 0.843]
  Region Acc (Val): [0.535 0.478 0.442 0.604 0.477 0.464]
  LR:1.00e-04 | 17.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7633 acc:0.5055 off1:0.8977 mae:0.6080
  Region Acc: [0.553 0.469 0.433 0.623 0.48  0.475]
‚úÖ New Best! (Acc=0.5055, MAE=0.6080)

[Epoch 002/100]
  Train - loss:0.9917 acc:0.4487 off1:0.8573 mae:0.7199
  Val   - loss:0.7633 acc:0.5055 off1:0.8977 mae:0.6080
  Part losses (Val): [0.741 0.79  0.816 0.671 0.777 0.785]
  Region Acc (Val): [0.553 0.469 0.433 0.623 0.48  0.475]
  LR:1.00e-04 | 13.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7017 acc:0.5481 off1:0.9262 mae:0.5305
  Region Acc: [0.562 0.536 0.515 0.657 0.523 0.495]
‚úÖ New Best! (Acc=0.5481, MAE=0.5305)

[Epoch 003/100]
  Train - loss:0.9166 acc:0.4787 off1:0.8809 mae:0.6574
  Val   - loss:0.7017 acc:0.5481 off1:0.9262 mae:0.5305
  Part losses (Val): [0.684 0.709 0.732 0.608 0.727 0.75 ]
  Region Acc (Val): [0.562 0.536 0.515 0.657 0.523 0.495]
  LR:1.00e-04 | 14.2s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6826 acc:0.5590 off1:0.9315 mae:0.5148
  Region Acc: [0.586 0.539 0.524 0.649 0.548 0.508]
‚úÖ New Best! (Acc=0.5590, MAE=0.5148)

[Epoch 004/100]
  Train - loss:0.8872 acc:0.4768 off1:0.8747 mae:0.6703
  Val   - loss:0.6826 acc:0.5590 off1:0.9315 mae:0.5148
  Part losses (Val): [0.677 0.703 0.711 0.587 0.697 0.72 ]
  Region Acc (Val): [0.586 0.539 0.524 0.649 0.548 0.508]
  LR:1.00e-04 | 14.2s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6711 acc:0.5607 off1:0.9412 mae:0.5031
  Region Acc: [0.579 0.549 0.541 0.64  0.556 0.499]
‚úÖ New Best! (Acc=0.5607, MAE=0.5031)

[Epoch 005/100]
  Train - loss:0.8796 acc:0.4918 off1:0.8787 mae:0.6533
  Val   - loss:0.6711 acc:0.5607 off1:0.9412 mae:0.5031
  Part losses (Val): [0.658 0.681 0.697 0.589 0.684 0.718]
  Region Acc (Val): [0.579 0.549 0.541 0.64  0.556 0.499]
  LR:1.00e-04 | 13.6s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6551 acc:0.5673 off1:0.9494 mae:0.4861
  Region Acc: [0.547 0.553 0.538 0.658 0.583 0.524]
‚úÖ New Best! (Acc=0.5673, MAE=0.4861)

[Epoch 006/100]
  Train - loss:0.8611 acc:0.4769 off1:0.8703 mae:0.6791
  Val   - loss:0.6551 acc:0.5673 off1:0.9494 mae:0.4861
  Part losses (Val): [0.653 0.665 0.682 0.573 0.652 0.706]
  Region Acc (Val): [0.547 0.553 0.538 0.658 0.583 0.524]
  LR:1.00e-04 | 14.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6416 acc:0.5768 off1:0.9485 mae:0.4783
  Region Acc: [0.588 0.569 0.556 0.673 0.569 0.505]
‚úÖ New Best! (Acc=0.5768, MAE=0.4783)

[Epoch 007/100]
  Train - loss:0.8326 acc:0.4980 off1:0.8850 mae:0.6393
  Val   - loss:0.6416 acc:0.5768 off1:0.9485 mae:0.4783
  Part losses (Val): [0.625 0.654 0.667 0.555 0.651 0.697]
  Region Acc (Val): [0.588 0.569 0.556 0.673 0.569 0.505]
  LR:1.00e-04 | 13.8s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6365 acc:0.5853 off1:0.9488 mae:0.4698
  Region Acc: [0.601 0.558 0.555 0.689 0.605 0.504]
‚úÖ New Best! (Acc=0.5853, MAE=0.4698)

[Epoch 008/100]
  Train - loss:0.8044 acc:0.5110 off1:0.8947 mae:0.6146
  Val   - loss:0.6365 acc:0.5853 off1:0.9488 mae:0.4698
  Part losses (Val): [0.62  0.662 0.662 0.54  0.641 0.695]
  Region Acc (Val): [0.601 0.558 0.555 0.689 0.605 0.504]
  LR:1.00e-04 | 14.8s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6281 acc:0.5848 off1:0.9598 mae:0.4580
  Region Acc: [0.613 0.577 0.544 0.657 0.587 0.532]

[Epoch 009/100]
  Train - loss:0.7970 acc:0.5205 off1:0.9004 mae:0.5962
  Val   - loss:0.6281 acc:0.5848 off1:0.9598 mae:0.4580
  Part losses (Val): [0.613 0.635 0.654 0.549 0.628 0.69 ]
  Region Acc (Val): [0.613 0.577 0.544 0.657 0.587 0.532]
  LR:1.00e-04 | 13.6s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6311 acc:0.5879 off1:0.9510 mae:0.4638
  Region Acc: [0.6   0.578 0.554 0.673 0.594 0.529]
‚úÖ New Best! (Acc=0.5879, MAE=0.4638)

[Epoch 010/100]
  Train - loss:0.8100 acc:0.5379 off1:0.9179 mae:0.5582
  Val   - loss:0.6311 acc:0.5879 off1:0.9510 mae:0.4638
  Part losses (Val): [0.631 0.642 0.655 0.539 0.64  0.679]
  Region Acc (Val): [0.6   0.578 0.554 0.673 0.594 0.529]
  LR:1.00e-04 | 14.3s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6420 acc:0.5744 off1:0.9538 mae:0.4755
  Region Acc: [0.602 0.541 0.537 0.669 0.577 0.521]

[Epoch 011/100]
  Train - loss:0.8088 acc:0.5138 off1:0.8934 mae:0.6139
  Val   - loss:0.6420 acc:0.5744 off1:0.9538 mae:0.4755
  Part losses (Val): [0.622 0.662 0.675 0.546 0.653 0.694]
  Region Acc (Val): [0.602 0.541 0.537 0.669 0.577 0.521]
  LR:1.00e-04 | 13.3s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6272 acc:0.5824 off1:0.9560 mae:0.4644
  Region Acc: [0.596 0.577 0.549 0.639 0.595 0.537]

[Epoch 012/100]
  Train - loss:0.7947 acc:0.5263 off1:0.9006 mae:0.5932
  Val   - loss:0.6272 acc:0.5824 off1:0.9560 mae:0.4644
  Part losses (Val): [0.616 0.64  0.652 0.546 0.623 0.687]
  Region Acc (Val): [0.596 0.577 0.549 0.639 0.595 0.537]
  LR:1.00e-04 | 13.3s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6221 acc:0.5963 off1:0.9556 mae:0.4527
  Region Acc: [0.626 0.579 0.564 0.68  0.592 0.537]
‚úÖ New Best! (Acc=0.5963, MAE=0.4527)

[Epoch 013/100]
  Train - loss:0.8079 acc:0.5246 off1:0.8993 mae:0.5944
  Val   - loss:0.6221 acc:0.5963 off1:0.9556 mae:0.4527
  Part losses (Val): [0.599 0.635 0.654 0.528 0.637 0.68 ]
  Region Acc (Val): [0.626 0.579 0.564 0.68  0.592 0.537]
  LR:1.00e-04 | 14.0s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6289 acc:0.5894 off1:0.9538 mae:0.4611
  Region Acc: [0.609 0.576 0.559 0.682 0.587 0.524]

[Epoch 014/100]
  Train - loss:0.7494 acc:0.5568 off1:0.9229 mae:0.5325
  Val   - loss:0.6289 acc:0.5894 off1:0.9538 mae:0.4611
  Part losses (Val): [0.621 0.639 0.654 0.532 0.631 0.696]
  Region Acc (Val): [0.609 0.576 0.559 0.682 0.587 0.524]
  LR:1.00e-04 | 13.9s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6283 acc:0.5892 off1:0.9558 mae:0.4583
  Region Acc: [0.611 0.583 0.558 0.683 0.577 0.523]

[Epoch 015/100]
  Train - loss:0.7585 acc:0.5350 off1:0.8992 mae:0.5878
  Val   - loss:0.6283 acc:0.5892 off1:0.9558 mae:0.4583
  Part losses (Val): [0.607 0.636 0.655 0.534 0.638 0.7  ]
  Region Acc (Val): [0.611 0.583 0.558 0.683 0.577 0.523]
  LR:1.00e-04 | 13.0s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6338 acc:0.5822 off1:0.9510 mae:0.4700
  Region Acc: [0.607 0.545 0.555 0.682 0.582 0.522]

[Epoch 016/100]
  Train - loss:0.7548 acc:0.5459 off1:0.9080 mae:0.5644
  Val   - loss:0.6338 acc:0.5822 off1:0.9510 mae:0.4700
  Part losses (Val): [0.62  0.655 0.658 0.539 0.641 0.689]
  Region Acc (Val): [0.607 0.545 0.555 0.682 0.582 0.522]
  LR:1.00e-04 | 13.4s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6293 acc:0.5808 off1:0.9560 mae:0.4675
  Region Acc: [0.615 0.561 0.546 0.673 0.569 0.52 ]

[Epoch 017/100]
  Train - loss:0.7459 acc:0.5668 off1:0.9176 mae:0.5302
  Val   - loss:0.6293 acc:0.5808 off1:0.9560 mae:0.4675
  Part losses (Val): [0.612 0.643 0.655 0.542 0.632 0.691]
  Region Acc (Val): [0.615 0.561 0.546 0.673 0.569 0.52 ]
  LR:1.00e-04 | 13.2s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6285 acc:0.5928 off1:0.9545 mae:0.4558
  Region Acc: [0.617 0.577 0.561 0.683 0.572 0.546]

[Epoch 018/100]
  Train - loss:0.7330 acc:0.5635 off1:0.9139 mae:0.5399
  Val   - loss:0.6285 acc:0.5928 off1:0.9545 mae:0.4558
  Part losses (Val): [0.615 0.645 0.663 0.535 0.629 0.683]
  Region Acc (Val): [0.617 0.577 0.561 0.683 0.572 0.546]
  LR:1.00e-04 | 13.1s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6385 acc:0.5828 off1:0.9549 mae:0.4655
  Region Acc: [0.603 0.562 0.537 0.679 0.59  0.525]

[Epoch 019/100]
  Train - loss:0.7070 acc:0.5570 off1:0.9095 mae:0.5519
  Val   - loss:0.6385 acc:0.5828 off1:0.9549 mae:0.4655
  Part losses (Val): [0.621 0.655 0.674 0.549 0.636 0.696]
  Region Acc (Val): [0.603 0.562 0.537 0.679 0.59  0.525]
  LR:5.00e-05 | 13.1s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6363 acc:0.5906 off1:0.9554 mae:0.4572
  Region Acc: [0.616 0.56  0.566 0.684 0.588 0.53 ]

[Epoch 020/100]
  Train - loss:0.7481 acc:0.5572 off1:0.9085 mae:0.5517
  Val   - loss:0.6363 acc:0.5906 off1:0.9554 mae:0.4572
  Part losses (Val): [0.631 0.651 0.662 0.55  0.632 0.692]
  Region Acc (Val): [0.616 0.56  0.566 0.684 0.588 0.53 ]
  LR:5.00e-05 | 12.8s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6478 acc:0.5808 off1:0.9545 mae:0.4691
  Region Acc: [0.622 0.57  0.538 0.668 0.575 0.512]

[Epoch 021/100]
  Train - loss:0.6738 acc:0.5989 off1:0.9288 mae:0.4856
  Val   - loss:0.6478 acc:0.5808 off1:0.9545 mae:0.4691
  Part losses (Val): [0.646 0.654 0.677 0.558 0.639 0.712]
  Region Acc (Val): [0.622 0.57  0.538 0.668 0.575 0.512]
  LR:5.00e-05 | 12.2s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6346 acc:0.5883 off1:0.9585 mae:0.4560
  Region Acc: [0.607 0.575 0.558 0.681 0.587 0.522]

[Epoch 022/100]
  Train - loss:0.7160 acc:0.5860 off1:0.9175 mae:0.5125
  Val   - loss:0.6346 acc:0.5883 off1:0.9585 mae:0.4560
  Part losses (Val): [0.615 0.645 0.669 0.534 0.646 0.698]
  Region Acc (Val): [0.607 0.575 0.558 0.681 0.587 0.522]
  LR:5.00e-05 | 12.9s | Patience:9/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6528 acc:0.5824 off1:0.9539 mae:0.4682
  Region Acc: [0.602 0.573 0.554 0.668 0.594 0.503]

[Epoch 023/100]
  Train - loss:0.6885 acc:0.5806 off1:0.9145 mae:0.5240
  Val   - loss:0.6528 acc:0.5824 off1:0.9539 mae:0.4682
  Part losses (Val): [0.65  0.652 0.683 0.559 0.658 0.716]
  Region Acc (Val): [0.602 0.573 0.554 0.668 0.594 0.503]
  LR:5.00e-05 | 13.3s | Patience:10/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6476 acc:0.5906 off1:0.9541 mae:0.4583
  Region Acc: [0.61  0.578 0.569 0.672 0.584 0.531]

[Epoch 024/100]
  Train - loss:0.6519 acc:0.6068 off1:0.9256 mae:0.4819
  Val   - loss:0.6476 acc:0.5906 off1:0.9541 mae:0.4583
  Part losses (Val): [0.644 0.66  0.67  0.553 0.647 0.711]
  Region Acc (Val): [0.61  0.578 0.569 0.672 0.584 0.531]
  LR:5.00e-05 | 13.8s | Patience:11/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6555 acc:0.5828 off1:0.9536 mae:0.4682
  Region Acc: [0.598 0.577 0.557 0.669 0.578 0.519]

[Epoch 025/100]
  Train - loss:0.6583 acc:0.5978 off1:0.9216 mae:0.4983
  Val   - loss:0.6555 acc:0.5828 off1:0.9536 mae:0.4682
  Part losses (Val): [0.659 0.658 0.684 0.569 0.647 0.717]
  Region Acc (Val): [0.598 0.577 0.557 0.669 0.578 0.519]
  LR:2.50e-05 | 13.2s | Patience:12/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6572 acc:0.5828 off1:0.9583 mae:0.4620
  Region Acc: [0.606 0.576 0.541 0.664 0.582 0.527]

[Epoch 026/100]
  Train - loss:0.6394 acc:0.6116 off1:0.9248 mae:0.4810
  Val   - loss:0.6572 acc:0.5828 off1:0.9583 mae:0.4620
  Part losses (Val): [0.647 0.659 0.692 0.571 0.659 0.715]
  Region Acc (Val): [0.606 0.576 0.541 0.664 0.582 0.527]
  LR:2.50e-05 | 13.1s | Patience:13/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6470 acc:0.5846 off1:0.9572 mae:0.4614
  Region Acc: [0.613 0.567 0.567 0.662 0.57  0.529]

[Epoch 027/100]
  Train - loss:0.6684 acc:0.5970 off1:0.9136 mae:0.5092
  Val   - loss:0.6470 acc:0.5846 off1:0.9572 mae:0.4614
  Part losses (Val): [0.633 0.653 0.68  0.56  0.646 0.709]
  Region Acc (Val): [0.613 0.567 0.567 0.662 0.57  0.529]
  LR:2.50e-05 | 13.8s | Patience:14/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6601 acc:0.5793 off1:0.9556 mae:0.4682
  Region Acc: [0.599 0.561 0.559 0.669 0.57  0.518]

‚èπÔ∏è Early stopping at epoch 28

üéâ Training Finished!

üìä Test evaluation with best model...


                                                                            

[test] loss:0.5895 acc:0.5982 off1:0.9703 mae:0.4326
  Region Acc: [0.664 0.555 0.603 0.719 0.521 0.527]

üèÜ Test Results:
   Accuracy: 0.5982
   Off-by-1: 0.9703
   MAE: 0.4326
   Region Acc: [0.664 0.555 0.603 0.719 0.521 0.527]
   Part losses: [0.53  0.62  0.607 0.503 0.614 0.663]

üíæ Best model saved: runs_severity_classification/best_efficientnet_b0_classification.pth
üìà Best Validation Accuracy: 0.5963
üìâ Best Validation MAE: 0.4527

üí° Ï∂îÍ∞Ä ÏÑ±Îä• Ìñ•ÏÉÅÏùÑ ÏúÑÌïú Ï†úÏïà:
1. üéØ ÏïôÏÉÅÎ∏î: Îã§Î•∏ ÏãúÎìúÎ°ú 3~5Í∞ú Î™®Îç∏ ÌïôÏäµ ÌõÑ Ìà¨Ìëú
2. üî¨ ÏùòÎ£å Ï†ÑÏö© pretrained model ÏÇ¨Ïö©:
   - CheXpert, MIMIC-CXR Îì±ÏúºÎ°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú Î™®Îç∏
3. üìä Îç∞Ïù¥ÌÑ∞ Ï∂îÍ∞Ä:
   - Ïô∏Î∂Ä COVID-19 Îç∞Ïù¥ÌÑ∞ÏÖã ÌôúÏö©
   - Pseudo-labelingÏúºÎ°ú unlabeled Îç∞Ïù¥ÌÑ∞ ÌôúÏö©
4. üé® Í≥†Í∏â Ï¶ùÍ∞ï:
   - CutMix, AugMix, RandAugment
   - Test-Time Augmentation (TTA)
5. üß™ ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ÌäúÎãù:
   - ordinal_weight Ï°∞Ï†ï (0.3~0.7)
   - Learning rate, batch size Ïã§Ìóò

‚úÖ Ïù



In [13]:
import os, random, time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import functional as TF
import torchxrayvision as xrv  # ‚úÖ ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö© ÎùºÏù¥Î∏åÎü¨Î¶¨

from tqdm import tqdm
from torch.amp import autocast
from torch.cuda.amp import autocast, GradScaler

# --- Í≤ΩÎ°ú ÏÑ§Ï†ï Î∞è ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ---
BASE_DIR     = f"./data/covid19-xray-severity-scoring/"
CSV_PATH     = str(Path(BASE_DIR) / "Brixia.csv")
IMAGE_DIR    = str(Path(BASE_DIR) / "segmented_png")

OUT_DIR      = "./runs_severity_classification"
BEST_PATH    = str(Path(OUT_DIR) / "best_densenet121_mimic_classification.pth")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED         = 42
IMG_SIZE     = 224
BATCH_SIZE   = 256
NUM_CLASSES  = 4   # 0, 1, 2, 3
EPOCHS       = 100  # Single phase training
LR           = 1e-4
WEIGHT_DECAY = 5e-4
AMP          = True
EARLY_STOP_ACC = 0.75  # üîÑ MAE ‚Üí Accuracy
DROP_RATIO   = 0.3
AUG_RATIO    = 0.5
MIXUP_ALPHA  = 0.2
LABEL_SMOOTHING = 0.1  # ‚úÖ NEW: Label smoothing

# --- ÏãúÎìú Í≥†Ï†ï ---
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
set_seed(SEED)

def make_transform_with_label(train: bool, img_size: int = IMG_SIZE, aug_ratio=AUG_RATIO):
    """Brixia ScoreÏùò Ï¢åÏö∞ Íµ¨Ï°∞Î•º Í≥†Î†§Ìïú transform"""
    def _tfm(img: Image.Image, label: torch.Tensor = None):
        img = img.convert('RGB')
        img = TF.resize(
            img, 
            [img_size, img_size], 
            interpolation=TF.InterpolationMode.BILINEAR,
            antialias=True
        )
        
        if train:
            # 1. Horizontal Flip (Ï¢åÏö∞ Î∞òÏ†Ñ: ABC ‚Üî DEF)
            if random.random() < aug_ratio:
                img = TF.hflip(img)
                if label is not None:
                    # [A, B, C, D, E, F] ‚Üí [D, E, F, A, B, C]
                    label = label[[3, 4, 5, 0, 1, 2]]
            
            # 2. ÏïΩÌïú ÌöåÏ†Ñ (¬±5ÎèÑ)
            if random.random() < aug_ratio:
                angle = float(torch.empty(1).uniform_(-5, 5))
                img = TF.rotate(
                    img, 
                    angle, 
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 3. ÏïΩÌïú Translation
            if random.random() < aug_ratio:
                max_dx = 0.05 * img_size
                max_dy = 0.05 * img_size
                translations = (
                    float(torch.empty(1).uniform_(-max_dx, max_dx)),
                    float(torch.empty(1).uniform_(-max_dy, max_dy))
                )
                img = TF.affine(
                    img,
                    angle=0,
                    translate=translations,
                    scale=1.0,
                    shear=0,
                    interpolation=TF.InterpolationMode.BILINEAR,
                    fill=0
                )
            
            # 4. Brightness & Contrast
            if random.random() < aug_ratio:
                brightness_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_brightness(img, brightness_factor)
            
            if random.random() < aug_ratio:
                contrast_factor = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_contrast(img, contrast_factor)
            
            # 5. Gamma Correction
            if random.random() < 0.3:
                gamma = float(torch.empty(1).uniform_(0.9, 1.1))
                img = TF.adjust_gamma(img, gamma)
        
        # Tensor Î≥ÄÌôò
        img = TF.to_tensor(img)
        
        # Gaussian Noise (train only)
        if train and random.random() < 0.2:
            noise = torch.randn_like(img) * 0.01
            img = img + noise
            img = torch.clamp(img, 0, 1)
        
        # Ï†ïÍ∑úÌôî (torchxrayvision ÌëúÏ§Ä)
        # MIMIC-CXR Îç∞Ïù¥ÌÑ∞Î°ú ÌïôÏäµÎêú Î™®Îç∏ÏùÄ [-1024, 1024] Î≤îÏúÑÎ•º ÏÇ¨Ïö©
        # ÌïòÏßÄÎßå ÏùºÎ∞ò RGB Ïù¥ÎØ∏ÏßÄÎäî ImageNet Ï†ïÍ∑úÌôî Ïú†ÏßÄ
        img = TF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        if label is not None:
            return img, label
        return img
    
    return _tfm

def load_and_split_brixia(csv_path, val_ratio=0.2, seed=SEED):
    df = pd.read_csv(csv_path, dtype={'BrixiaScore': str})
    df = df.dropna(subset=['BrixiaScore'])
    df = df[df['BrixiaScore'] != 'nan']
    df = df[df['BrixiaScore'].str.len() == 6].copy()
    
    print(f"Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: {len(df)}Í∞ú")
    
    if 'ConsensusTestset' in df.columns:
        test_df = df[df['ConsensusTestset'] == 1].copy()
        train_val_df = df[df['ConsensusTestset'] == 0].copy()
    else:
        test_df = pd.DataFrame()
        train_val_df = df.copy()
    
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    train_idx, val_idx = next(gss.split(
        train_val_df, 
        groups=train_val_df['StudyId']
    ))
    
    tr_df = train_val_df.iloc[train_idx].copy()
    val_df = train_val_df.iloc[val_idx].copy()
    
    print(f"Train: {len(tr_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    validate_split(tr_df, val_df, test_df)
    
    return tr_df, val_df, test_df

def validate_split(tr_df, val_df, tt_df):
    train_studies = set(tr_df['StudyId'])
    val_studies = set(val_df['StudyId'])
    test_studies = set(tt_df['StudyId']) if len(tt_df) > 0 else set()
    
    assert len(train_studies & val_studies) == 0, "Train-Val Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(train_studies & test_studies) == 0, "Train-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    assert len(val_studies & test_studies) == 0, "Val-Test Í∞Ñ StudyId Ï§ëÎ≥µ!"
    
    for name, data in [('Train', tr_df), ('Val', val_df), ('Test', tt_df)]:
        if len(data) > 0:
            scores = data['BrixiaScore'].apply(lambda x: sum(int(c) for c in x))
            print(f"{name} - Mean: {scores.mean():.2f}, Std: {scores.std():.2f}")
    
    return True

# ============================================================
# Dataset
# ============================================================
class BrixiaDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.img_col = "Filename"
        self.label_col = "BrixiaScore"
        self._validate_data()
    
    def _validate_data(self):
        assert self.img_col in self.df.columns
        assert self.label_col in self.df.columns
        
        invalid_scores = self.df[self.df[self.label_col].str.len() != 6]
        if len(invalid_scores) > 0:
            print(f"‚ö†Ô∏è Í≤ΩÍ≥†: {len(invalid_scores)}Í∞úÏùò ÏûòÎ™ªÎêú BrixiaScore Î∞úÍ≤¨")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_name_from_csv = row[self.img_col]
        img_name = img_name_from_csv.replace('.dcm', '.png')
        img_path = os.path.join(self.img_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"‚ùå Ïù¥ÎØ∏ÏßÄ Î°úÎìú Ïò§Î•ò: {img_path}")
            raise
        
        scores_str = row[self.label_col]
        scores_list = [int(c) for c in scores_str]
        labels = torch.tensor(scores_list, dtype=torch.long)  # üîÑ longÏúºÎ°ú Î≥ÄÍ≤Ω
        
        if self.transform:
            image, labels = self.transform(image, labels)
        else:
            image = TF.to_tensor(image)
            image = TF.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, labels

def create_dataloaders(tr_df, val_df, tt_df, img_dir, 
                       batch_size=32, img_size=224, num_workers=4):
    train_transform = make_transform_with_label(train=True, img_size=img_size)
    val_transform = make_transform_with_label(train=False, img_size=img_size)
    
    tr_ds = BrixiaDataset(tr_df, img_dir, transform=train_transform)
    val_ds = BrixiaDataset(val_df, img_dir, transform=val_transform)
    tt_ds = BrixiaDataset(tt_df, img_dir, transform=val_transform)
    
    tr_loader = DataLoader(
        tr_ds, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available(),
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    tt_loader = DataLoader(
        tt_ds, 
        batch_size=batch_size, 
        shuffle=False,
        num_workers=num_workers, 
        pin_memory=torch.cuda.is_available()
    )
    
    print("‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å")
    print(f"   Train: {len(tr_ds)} samples, {len(tr_loader)} batches")
    print(f"   Val:   {len(val_ds)} samples, {len(val_loader)} batches")
    print(f"   Test:  {len(tt_ds)} samples, {len(tt_loader)} batches")
    
    return tr_loader, val_loader, tt_loader

# ============================================================
# Loss Function - Ordinal Classification
# ============================================================
def calculate_class_weights(labels, num_classes=4, method='sqrt_inverse'):
    """ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ìï¥Í≤∞ÏùÑ ÏúÑÌïú Í∞ÄÏ§ëÏπò Í≥ÑÏÇ∞"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    labels_flat = labels.flatten()
    counts = np.bincount(labels_flat.astype(int), minlength=num_classes)
    
    if method == 'sqrt_inverse':
        weights = 1.0 / (np.sqrt(counts) + 1e-6)
    elif method == 'inverse':
        weights = 1.0 / (counts + 1e-6)
    else:
        total = len(labels_flat)
        weights = total / (num_classes * (counts + 1e-6))
    
    weights = weights / weights.mean()
    return torch.FloatTensor(weights)

def print_class_distribution(labels):
    """ÌÅ¥ÎûòÏä§ Î∂ÑÌè¨ ÏãúÍ∞ÅÌôî"""
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    print("=" * 60)
    print("Class Distribution Analysis")
    print("=" * 60)
    
    region_names = ['A', 'B', 'C', 'D', 'E', 'F']
    for idx, name in enumerate(region_names):
        region_labels = labels[:, idx]
        counts = np.bincount(region_labels.astype(int), minlength=4)
        total = counts.sum()
        
        print(f"\n{name}:")
        for cls in range(4):
            pct = 100 * counts[cls] / total if total > 0 else 0
            bar = '‚ñà' * int(pct / 2)
            print(f"  Class {cls}: {counts[cls]:4d} ({pct:5.1f}%) {bar}")
    
    print("=" * 60)

class OrdinalRegressionLoss(nn.Module):
    """
    ‚úÖ Ordinal Regression Loss (ÏàúÏÑúÌòï ÌöåÍ∑Ä)
    0 < 1 < 2 < 3Ïùò ÏàúÏÑúÎ•º Î™ÖÏãúÏ†ÅÏúºÎ°ú ÌïôÏäµ
    """
    def __init__(self, num_classes=4):
        super().__init__()
        self.num_classes = num_classes
        
    def forward(self, logits, target):
        """
        logits: [B*6, num_classes-1] - cumulative logits
        target: [B*6] - class labels (0~3)
        """
        # Cumulative labels ÏÉùÏÑ±
        # Class 0: [0, 0, 0]
        # Class 1: [1, 0, 0]
        # Class 2: [1, 1, 0]
        # Class 3: [1, 1, 1]
        batch_size = target.size(0)
        target_expanded = target.unsqueeze(1)  # [B*6, 1]
        
        # [0, 1, 2, ..., num_classes-2]
        thresholds = torch.arange(self.num_classes - 1).to(target.device)
        thresholds = thresholds.unsqueeze(0).expand(batch_size, -1)  # [B*6, 3]
        
        # target > thresholdÏù¥Î©¥ 1, ÏïÑÎãàÎ©¥ 0
        cumulative_target = (target_expanded > thresholds).float()  # [B*6, 3]
        
        # Binary cross entropy for each threshold
        loss = F.binary_cross_entropy_with_logits(
            logits, cumulative_target, reduction='none'
        )
        
        return loss.mean()

class AdaptiveOrdinalLoss(nn.Module):
    """
    ‚úÖ Ordinal Loss + Class Weights + Part Weights
    ÏàúÏÑúÎ•º Í≥†Î†§ÌïòÎ©¥ÏÑú Î∂àÍ∑†ÌòïÎèÑ Ìï¥Í≤∞
    """
    
    def __init__(self, train_labels, num_classes=4, use_class_weights=True, 
                 part_weights=None, ordinal_weight=0.5):
        super().__init__()
        self.num_classes = num_classes
        self.ordinal_weight = ordinal_weight  # Ordinal lossÏôÄ CE lossÏùò ÎπÑÏú®
        
        if isinstance(train_labels, torch.Tensor):
            train_labels_np = train_labels.cpu().numpy()
        else:
            train_labels_np = train_labels
        
        print_class_distribution(train_labels_np)
        
        # ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
        if use_class_weights:
            class_weights = calculate_class_weights(train_labels_np, num_classes=num_classes)
            self.register_buffer('class_weights', class_weights)
            print(f"‚úÖ Class weights: {class_weights.numpy()}")
        else:
            self.class_weights = None
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò
        if part_weights is None:
            self.part_weights = torch.ones(6)
        else:
            self.part_weights = torch.tensor(part_weights, dtype=torch.float32)
        self.register_buffer('part_weights_buf', self.part_weights)
        print(f"‚úÖ Part weights: {self.part_weights.numpy()}")
        print(f"‚úÖ Ordinal weight: {ordinal_weight}")
        
        # Ordinal loss
        self.ordinal_loss = OrdinalRegressionLoss(num_classes)

    def forward(self, pred_dict, target, use_mixup=False):
        """
        pred_dict: {'logits': [B, 6, 4], 'ordinal_logits': [B, 6, 3]}
        target: [B, 6]
        """
        pred_logits = pred_dict['logits']  # [B, 6, 4]
        ordinal_logits = pred_dict['ordinal_logits']  # [B, 6, 3]
        
        B, num_regions, num_classes = pred_logits.shape
        
        # Reshape
        pred_logits_flat = pred_logits.view(B * num_regions, num_classes)
        ordinal_logits_flat = ordinal_logits.view(B * num_regions, num_classes - 1)
        target_flat = target.view(B * num_regions)
        
        # 1. CrossEntropyLoss (Í∏∞Î≥∏ Î∂ÑÎ•ò)
        if self.class_weights is not None and not use_mixup:
            ce_loss = F.cross_entropy(
                pred_logits_flat, 
                target_flat,
                weight=self.class_weights.to(pred_logits.device),
                reduction='none'
            )
        else:
            ce_loss = F.cross_entropy(
                pred_logits_flat, 
                target_flat,
                reduction='none'
            )
        
        # 2. Ordinal Loss (ÏàúÏÑú ÌïôÏäµ)
        ord_loss = self.ordinal_loss(ordinal_logits_flat, target_flat)
        
        # 3. Í≤∞Ìï©
        total_loss_flat = (1 - self.ordinal_weight) * ce_loss + self.ordinal_weight * ord_loss
        total_loss = total_loss_flat.view(B, num_regions)
        
        # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
        part_weights = self.part_weights_buf.to(pred_logits.device)
        total_loss = total_loss * part_weights.unsqueeze(0)
        
        part_losses = total_loss.mean(dim=0)
        
        return total_loss.mean(), part_losses

# ============================================================
# Mixup (Î∂ÑÎ•òÏö©)
# ============================================================
def mixup_data_classification(x, y, alpha=MIXUP_ALPHA):
    """
    ‚úÖ NEW: Î∂ÑÎ•òÏö© Mixup
    yÎäî one-hotÏúºÎ°ú Î≥ÄÌôò ÌõÑ mixup
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    
    # One-hot encoding
    y_onehot = F.one_hot(y, num_classes=NUM_CLASSES).float()  # [B, 6, 4]
    y_onehot_shuffled = y_onehot[index]
    
    mixed_y = lam * y_onehot + (1 - lam) * y_onehot_shuffled  # [B, 6, 4]
    
    return mixed_x, mixed_y, lam

def mixup_criterion(criterion, pred_dict, y_mixed, lam):
    """MixupÏùÑ ÏúÑÌïú ÏÜêÏã§ Í≥ÑÏÇ∞ (Ordinal Loss ÏßÄÏõê)"""
    pred_logits = pred_dict['logits']
    ordinal_logits = pred_dict['ordinal_logits']
    
    B, num_regions, num_classes = pred_logits.shape
    
    # Reshape
    pred_flat = pred_logits.view(B * num_regions, num_classes)
    ordinal_flat = ordinal_logits.view(B * num_regions, num_classes - 1)
    target_flat = y_mixed.view(B * num_regions, num_classes)
    
    # Soft target loss (CE part)
    log_probs = F.log_softmax(pred_flat, dim=1)
    ce_loss = -(target_flat * log_probs).sum(dim=1)
    
    # Ordinal partÎäî mixupÏóêÏÑú skip (hard labelÎßå ÏÇ¨Ïö©)
    loss = ce_loss  # Mixup ÏãúÏóêÎäî ordinal loss Ï†úÏô∏
    loss = loss.view(B, num_regions)
    
    # Î∂ÄÏúÑÎ≥Ñ Í∞ÄÏ§ëÏπò Ï†ÅÏö©
    if hasattr(criterion, 'part_weights_buf'):
        part_weights = criterion.part_weights_buf.to(pred_logits.device)
        loss = loss * part_weights.unsqueeze(0)
    
    part_losses = loss.mean(dim=0)
    return loss.mean(), part_losses

# ============================================================
# Metrics (Î∂ÑÎ•òÏö©)
# ============================================================
@torch.no_grad()
def calculate_classification_metrics(pred_dict, labels):
    """
    ‚úÖ Î∂ÑÎ•ò ÏßÄÌëú Í≥ÑÏÇ∞
    pred_dict: {'logits': [B, 6, 4], 'ordinal_logits': [B, 6, 3]}
    labels: [B, 6]
    """
    pred_logits = pred_dict['logits']
    
    # ÏòàÏ∏° ÌÅ¥ÎûòÏä§
    preds = pred_logits.argmax(dim=-1)  # [B, 6]
    
    # Exact match accuracy
    exact_acc = (preds == labels).float().mean().item()
    
    # Off-by-1 accuracy (Ïù∏Ï†ë ÌÅ¥ÎûòÏä§ ÌóàÏö©)
    off_by_1 = (torch.abs(preds - labels) <= 1).float().mean().item()
    
    # Per-region accuracy
    region_acc = (preds == labels).float().mean(dim=0)  # [6]
    
    # MAE (Ï∞∏Í≥†Ïö©)
    mae = torch.abs(preds.float() - labels.float()).mean().item()
    
    return exact_acc, off_by_1, mae, region_acc

# ============================================================
# Model - DenseNet121-MIMIC Í∏∞Î∞ò (ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö©)
# ============================================================
class DenseNet121MIMICClassification(nn.Module):
    """
    ‚úÖ DenseNet121-MIMIC Í∏∞Î∞ò Ordinal Classification Î™®Îç∏
    - MIMIC-CXR Îç∞Ïù¥ÌÑ∞Î°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö© Î™®Îç∏
    - Í∞Å Î∂ÄÏúÑ(A~F)ÎßàÎã§ 4Í∞ú ÌÅ¥ÎûòÏä§(0~3) ÏòàÏ∏°
    - Self-AttentionÏúºÎ°ú Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ
    - Ordinal Regression Head Ï∂îÍ∞Ä (ÏàúÏÑú ÌïôÏäµ)
    """
    def __init__(self, pretrained=True, drop=0.3, num_regions=6, num_classes=4):
        super().__init__()
        
        # ‚úÖ MIMIC-CXRÎ°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú DenseNet121
        if pretrained:
            # self.backbone = xrv.models.DenseNet(weights="densenet121-res224-mimic_ch")
            self.backbone = xrv.models.DenseNet(weights="densenet121-res224-chex")
            print("‚úÖ Loaded DenseNet121 pretrained on CXR!")
        else:
            self.backbone = xrv.models.DenseNet(weights=None)
        
        # DenseNet121Ïùò feature extractorÎßå ÏÇ¨Ïö©
        self.features = self.backbone.features
        in_feat = 1024  # DenseNet121Ïùò Ï∂úÎ†• Ï±ÑÎÑê
        
        # Adaptive poolingÏúºÎ°ú Í≥†Ï†ïÎêú ÌÅ¨Í∏∞ Ï∂úÎ†•
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))  # [B, 1024, 7, 7]
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, 49, in_feat))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        
        # Region queries
        self.region_queries = nn.Parameter(torch.randn(num_regions, in_feat))
        nn.init.xavier_uniform_(self.region_queries)
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ÌäπÏßï ‚Üí Î∂ÄÏúÑÎ≥Ñ ÌäπÏßï)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(in_feat)
        
        # Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ)
        self.self_attention = nn.MultiheadAttention(
            embed_dim=in_feat,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.norm2 = nn.LayerNorm(in_feat)
        self.norm3 = nn.LayerNorm(in_feat)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(in_feat, in_feat * 2),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(in_feat * 2, in_feat),
            nn.Dropout(drop)
        )
        
        self.norm4 = nn.LayerNorm(in_feat)
        
        # Shared feature extractor
        self.shared_fc = nn.Sequential(
            nn.Linear(in_feat, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(drop)
        )
        
        # ‚úÖ Classification heads (ÏùºÎ∞ò Î∂ÑÎ•ò)
        self.classification_heads = nn.ModuleList([
            nn.Linear(256, num_classes) for _ in range(num_regions)
        ])
        
        # ‚úÖ Ordinal Regression heads (ÏàúÏÑúÌòï ÌöåÍ∑Ä)
        self.ordinal_heads = nn.ModuleList([
            nn.Linear(256, num_classes - 1) for _ in range(num_regions)
        ])
    
    def forward(self, x):
        B = x.size(0)
        
        # ‚úÖ DenseNet121 feature extraction
        # torchxrayvision Î™®Îç∏ÏùÄ grayscaleÏùÑ Í∏∞ÎåÄÌïòÎØÄÎ°ú RGBÏùò Í≤ΩÏö∞ ÌèâÍ∑†
        if x.size(1) == 3:
            # RGBÎ•º grayscaleÎ°ú Î≥ÄÌôò (weighted average)
            x = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
        
        feat = self.features(x)  # [B, 1024, H, W]
        feat = self.adaptive_pool(feat)  # [B, 1024, 7, 7]
        feat = feat.flatten(2).transpose(1, 2)  # [B, 49, 1024]
        feat = feat + self.pos_embed
        
        # Region queries
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)  # [B, 6, 1024]
        
        # Cross attention (Ïù¥ÎØ∏ÏßÄ ‚Üí Î∂ÄÏúÑ)
        attn_out, _ = self.cross_attention(
            query=queries,
            key=feat,
            value=feat
        )
        attn_out = self.norm1(attn_out + queries)  # [B, 6, 1024]
        
        # Self-Attention (Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ)
        self_attn_out, attn_weights = self.self_attention(
            query=attn_out,
            key=attn_out,
            value=attn_out
        )
        attn_out = self.norm2(attn_out + self_attn_out)  # [B, 6, 1024]
        
        # FFN
        ffn_out = self.ffn(attn_out)
        attn_out = self.norm3(attn_out + ffn_out)  # [B, 6, 1024]
        
        # Shared feature extraction
        shared_features = self.shared_fc(attn_out)  # [B, 6, 256]
        
        # ‚úÖ Í∞Å Î∂ÄÏúÑÎ≥Ñ ÏòàÏ∏° (Classification + Ordinal)
        classification_outputs = []
        ordinal_outputs = []
        
        for i in range(len(self.classification_heads)):
            region_feat = shared_features[:, i, :]  # [B, 256]
            
            # Classification logits
            class_logits = self.classification_heads[i](region_feat)  # [B, 4]
            classification_outputs.append(class_logits)
            
            # Ordinal logits (cumulative)
            ordinal_logits = self.ordinal_heads[i](region_feat)  # [B, 3]
            ordinal_outputs.append(ordinal_logits)
        
        class_out = torch.stack(classification_outputs, dim=1)  # [B, 6, 4]
        ordinal_out = torch.stack(ordinal_outputs, dim=1)  # [B, 6, 3]
        
        return {
            'logits': class_out,
            'ordinal_logits': ordinal_out,
            'attn_weights': attn_weights
        }

# ============================================================
# Training Functions (Î∂ÑÎ•òÏö© ÏàòÏ†ï)
# ============================================================
def train_epoch(model, tr_loader, criterion, optimizer, scaler, device, 
                amp=True, use_mixup=True):
    model.train()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    
    pbar = tqdm(tr_loader, desc="Train", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)  # [B, 6]
        
        # Mixup Ï†ÅÏö©
        is_mixup = use_mixup and (random.random() < 0.5)
        
        optimizer.zero_grad(set_to_none=True)
        
        with autocast(enabled=amp):
            if is_mixup:
                imgs_mixed, labels_mixed, lam = mixup_data_classification(imgs, labels)
                pred_dict = model(imgs_mixed)
                loss, part_losses = mixup_criterion(criterion, pred_dict, labels_mixed, lam)
            else:
                pred_dict = model(imgs)
                loss, part_losses = criterion(pred_dict, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Metrics (ÏõêÎ≥∏ labelsÎ°ú Í≥ÑÏÇ∞)
        exact_acc, off_by_1, mae, _ = calculate_classification_metrics(pred_dict, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}",
            mae=f"{run_mae/n:.4f}"
        )
    
    part_losses_avg = part_losses_sum / n
    
    return run_loss/n, run_acc/n, run_off1/n, run_mae/n, part_losses_avg

@torch.no_grad()
def evaluate(model, val_loader, criterion, device, split='val'):
    model.eval()
    run_loss = run_acc = run_off1 = run_mae = n = 0
    part_losses_sum = torch.zeros(6)
    region_acc_sum = torch.zeros(6)
    
    pbar = tqdm(val_loader, desc=f"{split.capitalize()}", leave=False)
    
    for imgs, labels in pbar:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        pred_dict = model(imgs)
        loss, part_losses = criterion(pred_dict, labels)
        
        exact_acc, off_by_1, mae, region_acc = calculate_classification_metrics(pred_dict, labels)
        
        bs = imgs.size(0)
        run_loss += loss.item() * bs
        run_acc += exact_acc * bs
        run_off1 += off_by_1 * bs
        run_mae += mae * bs
        part_losses_sum += part_losses.cpu() * bs
        region_acc_sum += region_acc.cpu() * bs
        n += bs
        
        pbar.set_postfix(
            loss=f"{run_loss/n:.4f}",
            acc=f"{run_acc/n:.4f}"
        )
    
    avg_loss = run_loss/n
    avg_acc = run_acc/n
    avg_off1 = run_off1/n
    avg_mae = run_mae/n
    part_losses_avg = part_losses_sum / n
    region_acc_avg = region_acc_sum / n
    
    print(f"[{split}] loss:{avg_loss:.4f} acc:{avg_acc:.4f} "
          f"off1:{avg_off1:.4f} mae:{avg_mae:.4f}")
    print(f"  Region Acc: {region_acc_avg.numpy().round(3)}")
    
    return avg_loss, avg_acc, avg_off1, avg_mae, part_losses_avg, region_acc_avg

def get_lrs(optimizer):
    return [pg['lr'] for pg in optimizer.param_groups]

# ============================================================
# Main Function
# ============================================================
def main():
    print("\n" + "="*70)
    print("üöÄ Brixia COVID-19 Ordinal Classification Training")
    print("   üè• DenseNet121-MIMIC (ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö© ÏÇ¨Ï†ÑÌïôÏäµ!)")
    print("   üí° 80% Î™©ÌëúÎ•º ÏúÑÌïú ÌïµÏã¨ Í∏∞Ïà†:")
    print("   ‚úÖ MIMIC-CXR ÏÇ¨Ï†ÑÌïôÏäµ Î™®Îç∏ (Í∞ÄÏû• Í∞ïÎ†•!)")
    print("   ‚úÖ Ordinal Loss - ÏàúÏÑúÌòï ÌöåÍ∑Ä (0<1<2<3)")
    print("   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)")
    print("   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)")
    print("   ‚úÖ Mixed Precision Training (AMP)")
    print("   ‚úÖ Mixup Augmentation")
    print("="*70)
    
    # Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
    print("\nüìÇ Loading data...")
    tr_df, val_df, tt_df = load_and_split_brixia(CSV_PATH)
    
    print("\nüì¶ Creating DataLoaders...")
    tr_loader, val_loader, tt_loader = create_dataloaders(
        tr_df, val_df, tt_df, img_dir=IMAGE_DIR, 
        batch_size=BATCH_SIZE, img_size=IMG_SIZE, num_workers=4
    )
    
    # Train labels Ï∂îÏ∂ú
    train_labels = torch.cat([labels for _, labels in tr_loader], dim=0)
    
    # ========================================
    # ÌïôÏäµ Ï§ÄÎπÑ
    # ========================================
    print("\n" + "="*70)
    print("üìç Training Setup")
    print("="*70)
    
    # ‚úÖ Ordinal Loss + ÌÅ¥ÎûòÏä§ Í∞ÄÏ§ëÏπò
    criterion = AdaptiveOrdinalLoss(
        train_labels, 
        num_classes=NUM_CLASSES,
        use_class_weights=True,
        part_weights=None,
        ordinal_weight=0.5  # CE:Ordinal = 50:50
    )
    
    # ‚úÖ ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö© ÏÇ¨Ï†ÑÌïôÏäµ Î™®Îç∏
    model = DenseNet121MIMICClassification(
        pretrained=True, 
        drop=DROP_RATIO,
        num_regions=6,
        num_classes=NUM_CLASSES
    ).to(DEVICE)
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6
    )
    scaler = GradScaler(enabled=AMP)
    
    best_acc = 0.0
    best_mae = float('inf')
    patience_counter = 0
    max_patience = 15
    
    # ========================================
    # ÌïôÏäµ Î£®ÌîÑ
    # ========================================
    print("\n" + "="*70)
    print("üèãÔ∏è Training Start")
    print("="*70)
    
    for ep in range(1, EPOCHS + 1):
        t0 = time.time()
        
        tr_loss, tr_acc, tr_off1, tr_mae, tr_part_losses = train_epoch(
            model, tr_loader, criterion, optimizer, scaler, DEVICE, AMP
        )
        
        val_loss, val_acc, val_off1, val_mae, val_part_losses, val_region_acc = evaluate(
            model, val_loader, criterion, DEVICE, split='val'
        )
        
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_mae = val_mae
            patience_counter = 0
            torch.save({
                'epoch': ep,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': best_acc,
                'val_mae': best_mae,
                'val_off1': val_off1,
                'region_acc': val_region_acc,
            }, BEST_PATH)
            print(f"‚úÖ New Best! (Acc={best_acc:.4f}, MAE={best_mae:.4f})")
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= max_patience:
            print(f"\n‚èπÔ∏è Early stopping at epoch {ep}")
            break
        
        elapsed = time.time() - t0
        print(f"\n[Epoch {ep:03d}/{EPOCHS}]")
        print(f"  Train - loss:{tr_loss:.4f} acc:{tr_acc:.4f} off1:{tr_off1:.4f} mae:{tr_mae:.4f}")
        print(f"  Val   - loss:{val_loss:.4f} acc:{val_acc:.4f} off1:{val_off1:.4f} mae:{val_mae:.4f}")
        print(f"  Part losses (Val): {val_part_losses.numpy().round(3)}")
        print(f"  Region Acc (Val): {val_region_acc.numpy().round(3)}")
        print(f"  LR:{get_lrs(optimizer)[0]:.2e} | {elapsed:.1f}s | Patience:{patience_counter}/{max_patience}")
        print("-" * 70)
    
    # ========================================
    # Test Evaluation
    # ========================================
    print("\n" + "="*70)
    print("üéâ Training Finished!")
    print("="*70)
    
    if len(tt_loader) > 0:
        print("\nüìä Test evaluation with best model...")
        checkpoint = torch.load(BEST_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        tt_loss, tt_acc, tt_off1, tt_mae, tt_part_losses, tt_region_acc = evaluate(
            model, tt_loader, criterion, DEVICE, split='test'
        )
        
        print(f"\nüèÜ Test Results:")
        print(f"   Accuracy: {tt_acc:.4f}")
        print(f"   Off-by-1: {tt_off1:.4f}")
        print(f"   MAE: {tt_mae:.4f}")
        print(f"   Region Acc: {tt_region_acc.numpy().round(3)}")
        print(f"   Part losses: {tt_part_losses.numpy().round(3)}")
    
    print(f"\nüíæ Best model saved: {BEST_PATH}")
    print(f"üìà Best Validation Accuracy: {best_acc:.4f}")
    print(f"üìâ Best Validation MAE: {best_mae:.4f}")
    
    # ÏÑ±Îä• Ìñ•ÏÉÅ Ï†úÏïà
    print("\n" + "="*70)
    print("üí° Ï∂îÍ∞Ä ÏÑ±Îä• Ìñ•ÏÉÅÏùÑ ÏúÑÌïú Ï†úÏïà:")
    print("="*70)
    print("1. üéØ ÏïôÏÉÅÎ∏î: Îã§Î•∏ ÏãúÎìúÎ°ú 3~5Í∞ú Î™®Îç∏ ÌïôÏäµ ÌõÑ Ìà¨Ìëú")
    print("2. üî¨ ÏùòÎ£å Ï†ÑÏö© pretrained model ÏÇ¨Ïö©:")
    print("   - CheXpert, MIMIC-CXR Îì±ÏúºÎ°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú Î™®Îç∏")
    print("3. üìä Îç∞Ïù¥ÌÑ∞ Ï∂îÍ∞Ä:")
    print("   - Ïô∏Î∂Ä COVID-19 Îç∞Ïù¥ÌÑ∞ÏÖã ÌôúÏö©")
    print("   - Pseudo-labelingÏúºÎ°ú unlabeled Îç∞Ïù¥ÌÑ∞ ÌôúÏö©")
    print("4. üé® Í≥†Í∏â Ï¶ùÍ∞ï:")
    print("   - CutMix, AugMix, RandAugment")
    print("   - Test-Time Augmentation (TTA)")
    print("5. üß™ ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ÌäúÎãù:")
    print("   - ordinal_weight Ï°∞Ï†ï (0.3~0.7)")
    print("   - Learning rate, batch size Ïã§Ìóò")
    print("="*70)
    
    print("\n‚úÖ Ïù¥Ï†ú gradcam_inference.pyÎ•º Ïã§ÌñâÌïòÏó¨ Í≤∞Í≥ºÎ•º ÏãúÍ∞ÅÌôîÌïòÏÑ∏Ïöî!")
    print("="*70)

if __name__ == "__main__":
    main()


üöÄ Brixia COVID-19 Ordinal Classification Training
   üè• DenseNet121-MIMIC (ÏùòÎ£å ÏòÅÏÉÅ Ï†ÑÏö© ÏÇ¨Ï†ÑÌïôÏäµ!)
   üí° 80% Î™©ÌëúÎ•º ÏúÑÌïú ÌïµÏã¨ Í∏∞Ïà†:
   ‚úÖ MIMIC-CXR ÏÇ¨Ï†ÑÌïôÏäµ Î™®Îç∏ (Í∞ÄÏû• Í∞ïÎ†•!)
   ‚úÖ Ordinal Loss - ÏàúÏÑúÌòï ÌöåÍ∑Ä (0<1<2<3)
   ‚úÖ ÌÅ¥ÎûòÏä§ Î∂àÍ∑†Ìòï Ï≤òÎ¶¨ (ÏûêÎèô Í∞ÄÏ§ëÏπò)
   ‚úÖ Î∂ÄÏúÑ Í∞Ñ ÏÉÅÍ¥ÄÍ¥ÄÍ≥Ñ ÌïôÏäµ (Self-Attention)
   ‚úÖ Mixed Precision Training (AMP)
   ‚úÖ Mixup Augmentation

üìÇ Loading data...
Ïú†Ìö®Ìïú Îç∞Ïù¥ÌÑ∞: 4695Í∞ú
Train: 3637, Val: 912, Test: 146
Train - Mean: 8.31, Std: 4.26
Val - Mean: 8.35, Std: 4.15
Test - Mean: 7.78, Std: 4.20

üì¶ Creating DataLoaders...
‚úÖ DataLoader Ï§ÄÎπÑ ÏôÑÎ£å
   Train: 3637 samples, 113 batches
   Val:   912 samples, 29 batches
   Test:  146 samples, 5 batches

üìç Training Setup
Class Distribution Analysis

A:
  Class 0: 1810 ( 50.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Class 1: 1126 ( 31.1%) ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
  Clas

                                                                                             

[val] loss:0.8672 acc:0.4346 off1:0.8299 mae:0.7577
  Region Acc: [0.471 0.373 0.404 0.578 0.374 0.408]
‚úÖ New Best! (Acc=0.4346, MAE=0.7577)

[Epoch 001/100]
  Train - loss:1.0916 acc:0.3780 off1:0.8085 mae:0.8496
  Val   - loss:0.8672 acc:0.4346 off1:0.8299 mae:0.7577
  Part losses (Val): [0.847 0.924 0.884 0.753 0.911 0.884]
  Region Acc (Val): [0.471 0.373 0.404 0.578 0.374 0.408]
  LR:1.00e-04 | 20.3s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7997 acc:0.4713 off1:0.8628 mae:0.6838
  Region Acc: [0.525 0.407 0.416 0.607 0.414 0.458]
‚úÖ New Best! (Acc=0.4713, MAE=0.6838)

[Epoch 002/100]
  Train - loss:1.0302 acc:0.4215 off1:0.8358 mae:0.7681
  Val   - loss:0.7997 acc:0.4713 off1:0.8628 mae:0.6838
  Part losses (Val): [0.774 0.837 0.839 0.695 0.829 0.824]
  Region Acc (Val): [0.525 0.407 0.416 0.607 0.414 0.458]
  LR:1.00e-04 | 18.3s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7686 acc:0.4839 off1:0.8792 mae:0.6488
  Region Acc: [0.532 0.445 0.438 0.614 0.419 0.456]
‚úÖ New Best! (Acc=0.4839, MAE=0.6488)

[Epoch 003/100]
  Train - loss:0.9762 acc:0.4330 off1:0.8454 mae:0.7468
  Val   - loss:0.7686 acc:0.4839 off1:0.8792 mae:0.6488
  Part losses (Val): [0.749 0.791 0.806 0.675 0.793 0.798]
  Region Acc (Val): [0.532 0.445 0.438 0.614 0.419 0.456]
  LR:1.00e-04 | 17.7s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7622 acc:0.4978 off1:0.8867 mae:0.6232
  Region Acc: [0.544 0.467 0.452 0.612 0.452 0.461]
‚úÖ New Best! (Acc=0.4978, MAE=0.6232)

[Epoch 004/100]
  Train - loss:0.9641 acc:0.4335 off1:0.8455 mae:0.7491
  Val   - loss:0.7622 acc:0.4978 off1:0.8867 mae:0.6232
  Part losses (Val): [0.739 0.775 0.798 0.682 0.787 0.793]
  Region Acc (Val): [0.544 0.467 0.452 0.612 0.452 0.461]
  LR:1.00e-04 | 19.1s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7375 acc:0.5042 off1:0.9115 mae:0.5903
  Region Acc: [0.558 0.479 0.463 0.598 0.455 0.473]
‚úÖ New Best! (Acc=0.5042, MAE=0.5903)

[Epoch 005/100]
  Train - loss:0.9621 acc:0.4500 off1:0.8544 mae:0.7192
  Val   - loss:0.7375 acc:0.5042 off1:0.9115 mae:0.5903
  Part losses (Val): [0.716 0.756 0.767 0.656 0.759 0.77 ]
  Region Acc (Val): [0.558 0.479 0.463 0.598 0.455 0.473]
  LR:1.00e-04 | 19.6s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7404 acc:0.5031 off1:0.9110 mae:0.5932
  Region Acc: [0.534 0.507 0.461 0.612 0.477 0.429]

[Epoch 006/100]
  Train - loss:0.9336 acc:0.4463 off1:0.8519 mae:0.7278
  Val   - loss:0.7404 acc:0.5031 off1:0.9110 mae:0.5932
  Part losses (Val): [0.728 0.75  0.763 0.646 0.769 0.785]
  Region Acc (Val): [0.534 0.507 0.461 0.612 0.477 0.429]
  LR:1.00e-04 | 17.6s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7318 acc:0.5106 off1:0.9097 mae:0.5844
  Region Acc: [0.552 0.505 0.443 0.626 0.485 0.453]
‚úÖ New Best! (Acc=0.5106, MAE=0.5844)

[Epoch 007/100]
  Train - loss:0.9174 acc:0.4612 off1:0.8642 mae:0.6981
  Val   - loss:0.7318 acc:0.5106 off1:0.9097 mae:0.5844
  Part losses (Val): [0.705 0.743 0.766 0.649 0.753 0.775]
  Region Acc (Val): [0.552 0.505 0.443 0.626 0.485 0.453]
  LR:1.00e-04 | 19.9s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7085 acc:0.5302 off1:0.9225 mae:0.5545
  Region Acc: [0.576 0.509 0.473 0.627 0.523 0.474]
‚úÖ New Best! (Acc=0.5302, MAE=0.5545)

[Epoch 008/100]
  Train - loss:0.8942 acc:0.4646 off1:0.8681 mae:0.6881
  Val   - loss:0.7085 acc:0.5302 off1:0.9225 mae:0.5545
  Part losses (Val): [0.688 0.725 0.731 0.621 0.729 0.757]
  Region Acc (Val): [0.576 0.509 0.473 0.627 0.523 0.474]
  LR:1.00e-04 | 17.6s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6987 acc:0.5369 off1:0.9229 mae:0.5473
  Region Acc: [0.555 0.522 0.511 0.615 0.513 0.505]
‚úÖ New Best! (Acc=0.5369, MAE=0.5473)

[Epoch 009/100]
  Train - loss:0.8811 acc:0.4742 off1:0.8773 mae:0.6678
  Val   - loss:0.6987 acc:0.5369 off1:0.9229 mae:0.5473
  Part losses (Val): [0.693 0.713 0.716 0.614 0.708 0.748]
  Region Acc (Val): [0.555 0.522 0.511 0.615 0.513 0.505]
  LR:1.00e-04 | 19.0s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7003 acc:0.5292 off1:0.9154 mae:0.5621
  Region Acc: [0.539 0.522 0.495 0.623 0.508 0.489]

[Epoch 010/100]
  Train - loss:0.9018 acc:0.4827 off1:0.8844 mae:0.6508
  Val   - loss:0.7003 acc:0.5292 off1:0.9154 mae:0.5621
  Part losses (Val): [0.688 0.714 0.725 0.613 0.718 0.744]
  Region Acc (Val): [0.539 0.522 0.495 0.623 0.508 0.489]
  LR:1.00e-04 | 17.9s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7023 acc:0.5289 off1:0.9165 mae:0.5610
  Region Acc: [0.525 0.482 0.516 0.627 0.526 0.496]

[Epoch 011/100]
  Train - loss:0.8968 acc:0.4798 off1:0.8806 mae:0.6574
  Val   - loss:0.7023 acc:0.5289 off1:0.9165 mae:0.5610
  Part losses (Val): [0.711 0.726 0.712 0.61  0.713 0.742]
  Region Acc (Val): [0.525 0.482 0.516 0.627 0.526 0.496]
  LR:1.00e-04 | 17.9s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6857 acc:0.5453 off1:0.9282 mae:0.5322
  Region Acc: [0.573 0.518 0.522 0.632 0.527 0.5  ]
‚úÖ New Best! (Acc=0.5453, MAE=0.5322)

[Epoch 012/100]
  Train - loss:0.8887 acc:0.4794 off1:0.8793 mae:0.6607
  Val   - loss:0.6857 acc:0.5453 off1:0.9282 mae:0.5322
  Part losses (Val): [0.676 0.704 0.699 0.598 0.704 0.734]
  Region Acc (Val): [0.573 0.518 0.522 0.632 0.527 0.5  ]
  LR:1.00e-04 | 17.8s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6973 acc:0.5307 off1:0.9346 mae:0.5389
  Region Acc: [0.541 0.512 0.503 0.629 0.501 0.498]

[Epoch 013/100]
  Train - loss:0.8933 acc:0.4809 off1:0.8812 mae:0.6563
  Val   - loss:0.6973 acc:0.5307 off1:0.9346 mae:0.5389
  Part losses (Val): [0.686 0.717 0.727 0.604 0.714 0.737]
  Region Acc (Val): [0.541 0.512 0.503 0.629 0.501 0.498]
  LR:1.00e-04 | 18.8s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6844 acc:0.5406 off1:0.9360 mae:0.5272
  Region Acc: [0.572 0.526 0.52  0.605 0.532 0.488]

[Epoch 014/100]
  Train - loss:0.8470 acc:0.5058 off1:0.8990 mae:0.6087
  Val   - loss:0.6844 acc:0.5406 off1:0.9360 mae:0.5272
  Part losses (Val): [0.665 0.698 0.704 0.607 0.699 0.734]
  Region Acc (Val): [0.572 0.526 0.52  0.605 0.532 0.488]
  LR:1.00e-04 | 17.4s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6767 acc:0.5450 off1:0.9304 mae:0.5291
  Region Acc: [0.576 0.513 0.514 0.635 0.533 0.499]

[Epoch 015/100]
  Train - loss:0.8611 acc:0.4864 off1:0.8817 mae:0.6524
  Val   - loss:0.6767 acc:0.5450 off1:0.9304 mae:0.5291
  Part losses (Val): [0.652 0.696 0.699 0.593 0.687 0.732]
  Region Acc (Val): [0.576 0.513 0.514 0.635 0.533 0.499]
  LR:1.00e-04 | 18.8s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.7055 acc:0.5320 off1:0.9167 mae:0.5585
  Region Acc: [0.555 0.484 0.515 0.623 0.526 0.489]

[Epoch 016/100]
  Train - loss:0.8567 acc:0.4936 off1:0.8904 mae:0.6340
  Val   - loss:0.7055 acc:0.5320 off1:0.9167 mae:0.5585
  Part losses (Val): [0.713 0.743 0.705 0.611 0.717 0.744]
  Region Acc (Val): [0.555 0.484 0.515 0.623 0.526 0.489]
  LR:1.00e-04 | 17.4s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6831 acc:0.5376 off1:0.9346 mae:0.5323
  Region Acc: [0.577 0.499 0.515 0.618 0.527 0.489]

[Epoch 017/100]
  Train - loss:0.8588 acc:0.4971 off1:0.8942 mae:0.6244
  Val   - loss:0.6831 acc:0.5376 off1:0.9346 mae:0.5323
  Part losses (Val): [0.664 0.712 0.7   0.594 0.7   0.728]
  Region Acc (Val): [0.577 0.499 0.515 0.618 0.527 0.489]
  LR:1.00e-04 | 18.6s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6910 acc:0.5444 off1:0.9287 mae:0.5325
  Region Acc: [0.576 0.523 0.523 0.624 0.527 0.493]

[Epoch 018/100]
  Train - loss:0.8501 acc:0.4913 off1:0.8891 mae:0.6382
  Val   - loss:0.6910 acc:0.5444 off1:0.9287 mae:0.5325
  Part losses (Val): [0.696 0.711 0.694 0.599 0.708 0.737]
  Region Acc (Val): [0.576 0.523 0.523 0.624 0.527 0.493]
  LR:5.00e-05 | 17.5s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6776 acc:0.5457 off1:0.9351 mae:0.5236
  Region Acc: [0.577 0.518 0.537 0.627 0.521 0.495]
‚úÖ New Best! (Acc=0.5457, MAE=0.5236)

[Epoch 019/100]
  Train - loss:0.8274 acc:0.4940 off1:0.8858 mae:0.6399
  Val   - loss:0.6776 acc:0.5457 off1:0.9351 mae:0.5236
  Part losses (Val): [0.671 0.7   0.686 0.588 0.695 0.726]
  Region Acc (Val): [0.577 0.518 0.537 0.627 0.521 0.495]
  LR:5.00e-05 | 17.7s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6783 acc:0.5431 off1:0.9366 mae:0.5239
  Region Acc: [0.564 0.522 0.51  0.622 0.543 0.499]

[Epoch 020/100]
  Train - loss:0.8877 acc:0.4857 off1:0.8825 mae:0.6522
  Val   - loss:0.6783 acc:0.5431 off1:0.9366 mae:0.5239
  Part losses (Val): [0.673 0.698 0.693 0.589 0.689 0.728]
  Region Acc (Val): [0.564 0.522 0.51  0.622 0.543 0.499]
  LR:5.00e-05 | 18.0s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6673 acc:0.5563 off1:0.9335 mae:0.5146
  Region Acc: [0.58  0.533 0.527 0.65  0.543 0.504]
‚úÖ New Best! (Acc=0.5563, MAE=0.5146)

[Epoch 021/100]
  Train - loss:0.8301 acc:0.5150 off1:0.9030 mae:0.5958
  Val   - loss:0.6673 acc:0.5563 off1:0.9335 mae:0.5146
  Part losses (Val): [0.666 0.681 0.68  0.58  0.677 0.721]
  Region Acc (Val): [0.58  0.533 0.527 0.65  0.543 0.504]
  LR:5.00e-05 | 18.1s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6715 acc:0.5523 off1:0.9346 mae:0.5174
  Region Acc: [0.591 0.529 0.527 0.633 0.526 0.508]

[Epoch 022/100]
  Train - loss:0.8610 acc:0.4992 off1:0.8936 mae:0.6237
  Val   - loss:0.6715 acc:0.5523 off1:0.9346 mae:0.5174
  Part losses (Val): [0.658 0.688 0.69  0.58  0.691 0.722]
  Region Acc (Val): [0.591 0.529 0.527 0.633 0.526 0.508]
  LR:5.00e-05 | 18.8s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6623 acc:0.5557 off1:0.9428 mae:0.5049
  Region Acc: [0.583 0.546 0.53  0.637 0.534 0.504]

[Epoch 023/100]
  Train - loss:0.8472 acc:0.4939 off1:0.8897 mae:0.6346
  Val   - loss:0.6623 acc:0.5557 off1:0.9428 mae:0.5049
  Part losses (Val): [0.646 0.671 0.682 0.58  0.671 0.724]
  Region Acc (Val): [0.583 0.546 0.53  0.637 0.534 0.504]
  LR:5.00e-05 | 19.4s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6671 acc:0.5535 off1:0.9368 mae:0.5143
  Region Acc: [0.573 0.536 0.529 0.626 0.552 0.505]

[Epoch 024/100]
  Train - loss:0.8177 acc:0.5126 off1:0.9019 mae:0.5998
  Val   - loss:0.6671 acc:0.5535 off1:0.9368 mae:0.5143
  Part losses (Val): [0.664 0.674 0.68  0.583 0.68  0.722]
  Region Acc (Val): [0.573 0.536 0.529 0.626 0.552 0.505]
  LR:5.00e-05 | 18.4s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6751 acc:0.5451 off1:0.9306 mae:0.5287
  Region Acc: [0.557 0.53  0.515 0.621 0.547 0.501]

[Epoch 025/100]
  Train - loss:0.8402 acc:0.5044 off1:0.8991 mae:0.6127
  Val   - loss:0.6751 acc:0.5451 off1:0.9306 mae:0.5287
  Part losses (Val): [0.686 0.686 0.689 0.59  0.676 0.723]
  Region Acc (Val): [0.557 0.53  0.515 0.621 0.547 0.501]
  LR:5.00e-05 | 18.0s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6654 acc:0.5590 off1:0.9371 mae:0.5077
  Region Acc: [0.588 0.55  0.527 0.638 0.542 0.509]
‚úÖ New Best! (Acc=0.5590, MAE=0.5077)

[Epoch 026/100]
  Train - loss:0.8228 acc:0.5098 off1:0.9005 mae:0.6065
  Val   - loss:0.6654 acc:0.5590 off1:0.9371 mae:0.5077
  Part losses (Val): [0.651 0.675 0.687 0.582 0.676 0.721]
  Region Acc (Val): [0.588 0.55  0.527 0.638 0.542 0.509]
  LR:5.00e-05 | 19.6s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6688 acc:0.5524 off1:0.9340 mae:0.5190
  Region Acc: [0.569 0.546 0.522 0.632 0.542 0.504]

[Epoch 027/100]
  Train - loss:0.8548 acc:0.4988 off1:0.8900 mae:0.6298
  Val   - loss:0.6688 acc:0.5524 off1:0.9340 mae:0.5190
  Part losses (Val): [0.666 0.676 0.688 0.585 0.68  0.719]
  Region Acc (Val): [0.569 0.546 0.522 0.632 0.542 0.504]
  LR:5.00e-05 | 18.7s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6654 acc:0.5504 off1:0.9370 mae:0.5174
  Region Acc: [0.579 0.543 0.525 0.627 0.541 0.488]

[Epoch 028/100]
  Train - loss:0.8235 acc:0.5106 off1:0.9012 mae:0.6044
  Val   - loss:0.6654 acc:0.5504 off1:0.9370 mae:0.5174
  Part losses (Val): [0.653 0.675 0.68  0.578 0.682 0.725]
  Region Acc (Val): [0.579 0.543 0.525 0.627 0.541 0.488]
  LR:5.00e-05 | 17.3s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6635 acc:0.5504 off1:0.9399 mae:0.5135
  Region Acc: [0.587 0.544 0.501 0.627 0.539 0.504]

[Epoch 029/100]
  Train - loss:0.8174 acc:0.5270 off1:0.9137 mae:0.5713
  Val   - loss:0.6635 acc:0.5504 off1:0.9399 mae:0.5135
  Part losses (Val): [0.646 0.678 0.682 0.578 0.677 0.72 ]
  Region Acc (Val): [0.587 0.544 0.501 0.627 0.539 0.504]
  LR:5.00e-05 | 17.7s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6637 acc:0.5526 off1:0.9320 mae:0.5208
  Region Acc: [0.587 0.532 0.529 0.632 0.549 0.488]

[Epoch 030/100]
  Train - loss:0.8478 acc:0.4937 off1:0.8897 mae:0.6381
  Val   - loss:0.6637 acc:0.5526 off1:0.9320 mae:0.5208
  Part losses (Val): [0.647 0.68  0.682 0.57  0.682 0.722]
  Region Acc (Val): [0.587 0.532 0.529 0.632 0.549 0.488]
  LR:5.00e-05 | 17.4s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6684 acc:0.5495 off1:0.9393 mae:0.5148
  Region Acc: [0.567 0.539 0.536 0.627 0.529 0.499]

[Epoch 031/100]
  Train - loss:0.8314 acc:0.4912 off1:0.8807 mae:0.6519
  Val   - loss:0.6684 acc:0.5495 off1:0.9393 mae:0.5148
  Part losses (Val): [0.658 0.677 0.681 0.589 0.686 0.719]
  Region Acc (Val): [0.567 0.539 0.536 0.627 0.529 0.499]
  LR:5.00e-05 | 17.9s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6595 acc:0.5636 off1:0.9419 mae:0.4976
  Region Acc: [0.602 0.558 0.534 0.647 0.542 0.499]
‚úÖ New Best! (Acc=0.5636, MAE=0.4976)

[Epoch 032/100]
  Train - loss:0.8033 acc:0.5241 off1:0.9147 mae:0.5741
  Val   - loss:0.6595 acc:0.5636 off1:0.9419 mae:0.4976
  Part losses (Val): [0.635 0.676 0.681 0.576 0.67  0.72 ]
  Region Acc (Val): [0.602 0.558 0.534 0.647 0.542 0.499]
  LR:5.00e-05 | 19.4s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6704 acc:0.5508 off1:0.9331 mae:0.5216
  Region Acc: [0.591 0.529 0.522 0.641 0.539 0.482]

[Epoch 033/100]
  Train - loss:0.8003 acc:0.5166 off1:0.9006 mae:0.5985
  Val   - loss:0.6704 acc:0.5508 off1:0.9331 mae:0.5216
  Part losses (Val): [0.66  0.68  0.682 0.578 0.692 0.731]
  Region Acc (Val): [0.591 0.529 0.522 0.641 0.539 0.482]
  LR:5.00e-05 | 17.4s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6662 acc:0.5504 off1:0.9417 mae:0.5115
  Region Acc: [0.581 0.539 0.526 0.639 0.522 0.495]

[Epoch 034/100]
  Train - loss:0.7914 acc:0.5291 off1:0.9135 mae:0.5706
  Val   - loss:0.6662 acc:0.5504 off1:0.9417 mae:0.5115
  Part losses (Val): [0.656 0.676 0.682 0.575 0.684 0.724]
  Region Acc (Val): [0.581 0.539 0.526 0.639 0.522 0.495]
  LR:5.00e-05 | 17.1s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6835 acc:0.5450 off1:0.9335 mae:0.5263
  Region Acc: [0.567 0.515 0.534 0.629 0.526 0.498]

[Epoch 035/100]
  Train - loss:0.8367 acc:0.5101 off1:0.8977 mae:0.6082
  Val   - loss:0.6835 acc:0.5450 off1:0.9335 mae:0.5263
  Part losses (Val): [0.68  0.694 0.692 0.602 0.705 0.729]
  Region Acc (Val): [0.567 0.515 0.534 0.629 0.526 0.498]
  LR:5.00e-05 | 17.2s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6598 acc:0.5568 off1:0.9393 mae:0.5073
  Region Acc: [0.575 0.545 0.537 0.643 0.536 0.505]

[Epoch 036/100]
  Train - loss:0.8250 acc:0.5210 off1:0.9086 mae:0.5852
  Val   - loss:0.6598 acc:0.5568 off1:0.9393 mae:0.5073
  Part losses (Val): [0.656 0.667 0.674 0.57  0.672 0.718]
  Region Acc (Val): [0.575 0.545 0.537 0.643 0.536 0.505]
  LR:5.00e-05 | 17.3s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6565 acc:0.5616 off1:0.9402 mae:0.5022
  Region Acc: [0.596 0.538 0.552 0.645 0.544 0.495]

[Epoch 037/100]
  Train - loss:0.8154 acc:0.5037 off1:0.8937 mae:0.6219
  Val   - loss:0.6565 acc:0.5616 off1:0.9402 mae:0.5022
  Part losses (Val): [0.636 0.668 0.672 0.568 0.675 0.718]
  Region Acc (Val): [0.596 0.538 0.552 0.645 0.544 0.495]
  LR:5.00e-05 | 19.3s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6624 acc:0.5532 off1:0.9388 mae:0.5122
  Region Acc: [0.587 0.524 0.529 0.645 0.541 0.495]

[Epoch 038/100]
  Train - loss:0.8366 acc:0.5041 off1:0.8913 mae:0.6245
  Val   - loss:0.6624 acc:0.5532 off1:0.9388 mae:0.5122
  Part losses (Val): [0.647 0.677 0.681 0.57  0.678 0.721]
  Region Acc (Val): [0.587 0.524 0.529 0.645 0.541 0.495]
  LR:2.50e-05 | 17.6s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6591 acc:0.5601 off1:0.9348 mae:0.5099
  Region Acc: [0.581 0.549 0.546 0.639 0.545 0.5  ]

[Epoch 039/100]
  Train - loss:0.8497 acc:0.4973 off1:0.8867 mae:0.6362
  Val   - loss:0.6591 acc:0.5601 off1:0.9348 mae:0.5099
  Part losses (Val): [0.652 0.668 0.675 0.569 0.67  0.72 ]
  Region Acc (Val): [0.581 0.549 0.546 0.639 0.545 0.5  ]
  LR:2.50e-05 | 18.5s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6635 acc:0.5539 off1:0.9412 mae:0.5086
  Region Acc: [0.581 0.53  0.533 0.643 0.529 0.509]

[Epoch 040/100]
  Train - loss:0.8224 acc:0.5152 off1:0.8948 mae:0.6085
  Val   - loss:0.6635 acc:0.5539 off1:0.9412 mae:0.5086
  Part losses (Val): [0.648 0.672 0.679 0.572 0.69  0.721]
  Region Acc (Val): [0.581 0.53  0.533 0.643 0.529 0.509]
  LR:2.50e-05 | 17.4s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6573 acc:0.5694 off1:0.9404 mae:0.4940
  Region Acc: [0.611 0.545 0.554 0.658 0.556 0.493]
‚úÖ New Best! (Acc=0.5694, MAE=0.4940)

[Epoch 041/100]
  Train - loss:0.7996 acc:0.5173 off1:0.9003 mae:0.6003
  Val   - loss:0.6573 acc:0.5694 off1:0.9404 mae:0.4940
  Part losses (Val): [0.642 0.665 0.676 0.563 0.679 0.72 ]
  Region Acc (Val): [0.611 0.545 0.554 0.658 0.556 0.493]
  LR:2.50e-05 | 19.2s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6507 acc:0.5671 off1:0.9397 mae:0.4973
  Region Acc: [0.596 0.552 0.541 0.66  0.547 0.507]

[Epoch 042/100]
  Train - loss:0.8156 acc:0.5126 off1:0.8940 mae:0.6141
  Val   - loss:0.6507 acc:0.5671 off1:0.9397 mae:0.4973
  Part losses (Val): [0.63  0.659 0.671 0.561 0.666 0.718]
  Region Acc (Val): [0.596 0.552 0.541 0.66  0.547 0.507]
  LR:2.50e-05 | 17.5s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6637 acc:0.5627 off1:0.9371 mae:0.5051
  Region Acc: [0.588 0.559 0.544 0.641 0.548 0.496]

[Epoch 043/100]
  Train - loss:0.8040 acc:0.5170 off1:0.8972 mae:0.6056
  Val   - loss:0.6637 acc:0.5627 off1:0.9371 mae:0.5051
  Part losses (Val): [0.654 0.673 0.675 0.573 0.688 0.72 ]
  Region Acc (Val): [0.588 0.559 0.544 0.641 0.548 0.496]
  LR:2.50e-05 | 18.6s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6545 acc:0.5643 off1:0.9395 mae:0.5005
  Region Acc: [0.594 0.543 0.547 0.651 0.552 0.499]

[Epoch 044/100]
  Train - loss:0.7883 acc:0.5403 off1:0.9158 mae:0.5579
  Val   - loss:0.6545 acc:0.5643 off1:0.9395 mae:0.5005
  Part losses (Val): [0.641 0.664 0.669 0.564 0.675 0.715]
  Region Acc (Val): [0.594 0.543 0.547 0.651 0.552 0.499]
  LR:2.50e-05 | 18.6s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6536 acc:0.5674 off1:0.9377 mae:0.4987
  Region Acc: [0.59  0.55  0.553 0.658 0.555 0.499]

[Epoch 045/100]
  Train - loss:0.7794 acc:0.5219 off1:0.9006 mae:0.5956
  Val   - loss:0.6536 acc:0.5674 off1:0.9377 mae:0.4987
  Part losses (Val): [0.637 0.659 0.675 0.566 0.668 0.716]
  Region Acc (Val): [0.59  0.55  0.553 0.658 0.555 0.499]
  LR:2.50e-05 | 18.5s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6545 acc:0.5614 off1:0.9435 mae:0.4995
  Region Acc: [0.594 0.535 0.542 0.647 0.557 0.493]

[Epoch 046/100]
  Train - loss:0.8012 acc:0.5180 off1:0.8922 mae:0.6080
  Val   - loss:0.6545 acc:0.5614 off1:0.9435 mae:0.4995
  Part losses (Val): [0.638 0.671 0.672 0.563 0.665 0.718]
  Region Acc (Val): [0.594 0.535 0.542 0.647 0.557 0.493]
  LR:2.50e-05 | 17.7s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6541 acc:0.5649 off1:0.9413 mae:0.4980
  Region Acc: [0.594 0.555 0.55  0.649 0.542 0.499]

[Epoch 047/100]
  Train - loss:0.7861 acc:0.5156 off1:0.8971 mae:0.6075
  Val   - loss:0.6541 acc:0.5649 off1:0.9413 mae:0.4980
  Part losses (Val): [0.643 0.661 0.666 0.566 0.671 0.717]
  Region Acc (Val): [0.594 0.555 0.55  0.649 0.542 0.499]
  LR:1.25e-05 | 17.4s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6571 acc:0.5656 off1:0.9443 mae:0.4942
  Region Acc: [0.594 0.538 0.547 0.657 0.555 0.502]

[Epoch 048/100]
  Train - loss:0.8056 acc:0.5219 off1:0.8993 mae:0.5959
  Val   - loss:0.6571 acc:0.5656 off1:0.9443 mae:0.4942
  Part losses (Val): [0.642 0.667 0.669 0.565 0.678 0.722]
  Region Acc (Val): [0.594 0.538 0.547 0.657 0.555 0.502]
  LR:1.25e-05 | 17.2s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6525 acc:0.5674 off1:0.9410 mae:0.4958
  Region Acc: [0.6   0.554 0.547 0.649 0.548 0.507]

[Epoch 049/100]
  Train - loss:0.7890 acc:0.5274 off1:0.9059 mae:0.5826
  Val   - loss:0.6525 acc:0.5674 off1:0.9410 mae:0.4958
  Part losses (Val): [0.638 0.658 0.667 0.561 0.674 0.717]
  Region Acc (Val): [0.6   0.554 0.547 0.649 0.548 0.507]
  LR:1.25e-05 | 19.3s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6512 acc:0.5643 off1:0.9430 mae:0.4962
  Region Acc: [0.602 0.543 0.541 0.651 0.546 0.503]

[Epoch 050/100]
  Train - loss:0.7866 acc:0.5281 off1:0.9087 mae:0.5780
  Val   - loss:0.6512 acc:0.5643 off1:0.9430 mae:0.4962
  Part losses (Val): [0.636 0.657 0.669 0.561 0.669 0.716]
  Region Acc (Val): [0.602 0.543 0.541 0.651 0.546 0.503]
  LR:1.25e-05 | 18.2s | Patience:9/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6507 acc:0.5694 off1:0.9437 mae:0.4912
  Region Acc: [0.598 0.558 0.555 0.652 0.547 0.507]
‚úÖ New Best! (Acc=0.5694, MAE=0.4912)

[Epoch 051/100]
  Train - loss:0.7964 acc:0.5183 off1:0.9005 mae:0.5982
  Val   - loss:0.6507 acc:0.5694 off1:0.9437 mae:0.4912
  Part losses (Val): [0.632 0.655 0.668 0.56  0.673 0.717]
  Region Acc (Val): [0.598 0.558 0.555 0.652 0.547 0.507]
  LR:1.25e-05 | 19.1s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6515 acc:0.5678 off1:0.9433 mae:0.4931
  Region Acc: [0.602 0.543 0.544 0.658 0.553 0.508]

[Epoch 052/100]
  Train - loss:0.8136 acc:0.5113 off1:0.8946 mae:0.6136
  Val   - loss:0.6515 acc:0.5678 off1:0.9433 mae:0.4931
  Part losses (Val): [0.634 0.663 0.67  0.559 0.667 0.715]
  Region Acc (Val): [0.602 0.543 0.544 0.658 0.553 0.508]
  LR:1.25e-05 | 18.9s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6539 acc:0.5640 off1:0.9417 mae:0.4989
  Region Acc: [0.59  0.543 0.548 0.656 0.554 0.493]

[Epoch 053/100]
  Train - loss:0.7919 acc:0.5083 off1:0.8883 mae:0.6238
  Val   - loss:0.6539 acc:0.5640 off1:0.9417 mae:0.4989
  Part losses (Val): [0.64  0.662 0.671 0.564 0.671 0.716]
  Region Acc (Val): [0.59  0.543 0.548 0.656 0.554 0.493]
  LR:6.25e-06 | 18.3s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6524 acc:0.5669 off1:0.9439 mae:0.4936
  Region Acc: [0.596 0.549 0.544 0.654 0.553 0.505]

[Epoch 054/100]
  Train - loss:0.7871 acc:0.5313 off1:0.9054 mae:0.5788
  Val   - loss:0.6524 acc:0.5669 off1:0.9439 mae:0.4936
  Part losses (Val): [0.638 0.661 0.669 0.563 0.669 0.714]
  Region Acc (Val): [0.596 0.549 0.544 0.654 0.553 0.505]
  LR:6.25e-06 | 17.7s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6613 acc:0.5598 off1:0.9393 mae:0.5053
  Region Acc: [0.586 0.53  0.539 0.646 0.555 0.503]

[Epoch 055/100]
  Train - loss:0.8262 acc:0.5085 off1:0.8881 mae:0.6250
  Val   - loss:0.6613 acc:0.5598 off1:0.9393 mae:0.5053
  Part losses (Val): [0.653 0.675 0.675 0.572 0.676 0.717]
  Region Acc (Val): [0.586 0.53  0.539 0.646 0.555 0.503]
  LR:6.25e-06 | 17.7s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6523 acc:0.5658 off1:0.9433 mae:0.4952
  Region Acc: [0.596 0.548 0.548 0.647 0.555 0.5  ]

[Epoch 056/100]
  Train - loss:0.7855 acc:0.5261 off1:0.9063 mae:0.5839
  Val   - loss:0.6523 acc:0.5658 off1:0.9433 mae:0.4952
  Part losses (Val): [0.636 0.659 0.669 0.56  0.671 0.718]
  Region Acc (Val): [0.596 0.548 0.548 0.647 0.555 0.5  ]
  LR:6.25e-06 | 18.8s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6504 acc:0.5652 off1:0.9448 mae:0.4942
  Region Acc: [0.598 0.544 0.555 0.651 0.545 0.499]

[Epoch 057/100]
  Train - loss:0.7870 acc:0.5304 off1:0.9004 mae:0.5857
  Val   - loss:0.6504 acc:0.5652 off1:0.9448 mae:0.4942
  Part losses (Val): [0.632 0.657 0.669 0.561 0.669 0.714]
  Region Acc (Val): [0.598 0.544 0.555 0.651 0.545 0.499]
  LR:6.25e-06 | 17.7s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6516 acc:0.5693 off1:0.9430 mae:0.4921
  Region Acc: [0.602 0.556 0.553 0.65  0.552 0.503]

[Epoch 058/100]
  Train - loss:0.7643 acc:0.5468 off1:0.9177 mae:0.5496
  Val   - loss:0.6516 acc:0.5693 off1:0.9430 mae:0.4921
  Part losses (Val): [0.636 0.658 0.668 0.56  0.672 0.716]
  Region Acc (Val): [0.602 0.556 0.553 0.65  0.552 0.503]
  LR:6.25e-06 | 18.6s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6532 acc:0.5658 off1:0.9432 mae:0.4956
  Region Acc: [0.589 0.552 0.552 0.648 0.553 0.502]

[Epoch 059/100]
  Train - loss:0.7643 acc:0.5393 off1:0.9133 mae:0.5619
  Val   - loss:0.6532 acc:0.5658 off1:0.9432 mae:0.4956
  Part losses (Val): [0.638 0.659 0.668 0.563 0.674 0.716]
  Region Acc (Val): [0.589 0.552 0.552 0.648 0.553 0.502]
  LR:3.13e-06 | 18.5s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6527 acc:0.5640 off1:0.9435 mae:0.4969
  Region Acc: [0.593 0.547 0.547 0.65  0.547 0.499]

[Epoch 060/100]
  Train - loss:0.8112 acc:0.5295 off1:0.9054 mae:0.5804
  Val   - loss:0.6527 acc:0.5640 off1:0.9435 mae:0.4969
  Part losses (Val): [0.637 0.659 0.669 0.563 0.673 0.716]
  Region Acc (Val): [0.593 0.547 0.547 0.65  0.547 0.499]
  LR:3.13e-06 | 18.5s | Patience:9/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6534 acc:0.5702 off1:0.9424 mae:0.4920
  Region Acc: [0.598 0.559 0.55  0.654 0.557 0.503]
‚úÖ New Best! (Acc=0.5702, MAE=0.4920)

[Epoch 061/100]
  Train - loss:0.7857 acc:0.5311 off1:0.9057 mae:0.5790
  Val   - loss:0.6534 acc:0.5702 off1:0.9424 mae:0.4920
  Part losses (Val): [0.641 0.66  0.668 0.562 0.672 0.716]
  Region Acc (Val): [0.598 0.559 0.55  0.654 0.557 0.503]
  LR:3.13e-06 | 18.1s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6538 acc:0.5671 off1:0.9428 mae:0.4943
  Region Acc: [0.601 0.556 0.536 0.654 0.554 0.502]

[Epoch 062/100]
  Train - loss:0.7663 acc:0.5316 off1:0.9076 mae:0.5775
  Val   - loss:0.6538 acc:0.5671 off1:0.9428 mae:0.4943
  Part losses (Val): [0.638 0.658 0.67  0.561 0.677 0.718]
  Region Acc (Val): [0.601 0.556 0.536 0.654 0.554 0.502]
  LR:3.13e-06 | 18.3s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6509 acc:0.5707 off1:0.9419 mae:0.4918
  Region Acc: [0.599 0.55  0.559 0.66  0.547 0.509]
‚úÖ New Best! (Acc=0.5707, MAE=0.4918)

[Epoch 063/100]
  Train - loss:0.7755 acc:0.5223 off1:0.8970 mae:0.5977
  Val   - loss:0.6509 acc:0.5707 off1:0.9419 mae:0.4918
  Part losses (Val): [0.637 0.656 0.668 0.559 0.67  0.716]
  Region Acc (Val): [0.599 0.55  0.559 0.66  0.547 0.509]
  LR:3.13e-06 | 17.9s | Patience:0/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6557 acc:0.5651 off1:0.9433 mae:0.4960
  Region Acc: [0.595 0.543 0.535 0.654 0.561 0.502]

[Epoch 064/100]
  Train - loss:0.8399 acc:0.5060 off1:0.8857 mae:0.6321
  Val   - loss:0.6557 acc:0.5651 off1:0.9433 mae:0.4960
  Part losses (Val): [0.644 0.663 0.671 0.563 0.676 0.717]
  Region Acc (Val): [0.595 0.543 0.535 0.654 0.561 0.502]
  LR:3.13e-06 | 16.2s | Patience:1/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6507 acc:0.5704 off1:0.9430 mae:0.4910
  Region Acc: [0.603 0.556 0.555 0.656 0.55  0.502]

[Epoch 065/100]
  Train - loss:0.8018 acc:0.5219 off1:0.8960 mae:0.6027
  Val   - loss:0.6507 acc:0.5704 off1:0.9430 mae:0.4910
  Part losses (Val): [0.634 0.657 0.669 0.559 0.67  0.716]
  Region Acc (Val): [0.603 0.556 0.555 0.656 0.55  0.502]
  LR:3.13e-06 | 17.2s | Patience:2/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6542 acc:0.5676 off1:0.9424 mae:0.4943
  Region Acc: [0.6   0.556 0.549 0.649 0.557 0.495]

[Epoch 066/100]
  Train - loss:0.8034 acc:0.5085 off1:0.8901 mae:0.6239
  Val   - loss:0.6542 acc:0.5676 off1:0.9424 mae:0.4943
  Part losses (Val): [0.641 0.661 0.67  0.562 0.673 0.718]
  Region Acc (Val): [0.6   0.556 0.549 0.649 0.557 0.495]
  LR:3.13e-06 | 17.3s | Patience:3/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6539 acc:0.5658 off1:0.9437 mae:0.4949
  Region Acc: [0.594 0.545 0.545 0.652 0.557 0.501]

[Epoch 067/100]
  Train - loss:0.7759 acc:0.5279 off1:0.9004 mae:0.5894
  Val   - loss:0.6539 acc:0.5658 off1:0.9437 mae:0.4949
  Part losses (Val): [0.64  0.662 0.671 0.562 0.672 0.716]
  Region Acc (Val): [0.594 0.545 0.545 0.652 0.557 0.501]
  LR:3.13e-06 | 17.2s | Patience:4/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6510 acc:0.5682 off1:0.9441 mae:0.4921
  Region Acc: [0.603 0.55  0.547 0.657 0.55  0.501]

[Epoch 068/100]
  Train - loss:0.7806 acc:0.5362 off1:0.9084 mae:0.5731
  Val   - loss:0.6510 acc:0.5682 off1:0.9441 mae:0.4921
  Part losses (Val): [0.636 0.659 0.668 0.559 0.669 0.715]
  Region Acc (Val): [0.603 0.55  0.547 0.657 0.55  0.501]
  LR:3.13e-06 | 18.3s | Patience:5/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6507 acc:0.5702 off1:0.9430 mae:0.4912
  Region Acc: [0.605 0.557 0.552 0.655 0.55  0.502]

[Epoch 069/100]
  Train - loss:0.8111 acc:0.5303 off1:0.9033 mae:0.5846
  Val   - loss:0.6507 acc:0.5702 off1:0.9430 mae:0.4912
  Part losses (Val): [0.635 0.657 0.669 0.558 0.67  0.715]
  Region Acc (Val): [0.605 0.557 0.552 0.655 0.55  0.502]
  LR:1.56e-06 | 17.6s | Patience:6/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6539 acc:0.5649 off1:0.9430 mae:0.4962
  Region Acc: [0.594 0.543 0.545 0.651 0.555 0.501]

[Epoch 070/100]
  Train - loss:0.7732 acc:0.5215 off1:0.8985 mae:0.6003
  Val   - loss:0.6539 acc:0.5649 off1:0.9430 mae:0.4962
  Part losses (Val): [0.641 0.661 0.671 0.563 0.673 0.715]
  Region Acc (Val): [0.594 0.543 0.545 0.651 0.555 0.501]
  LR:1.56e-06 | 17.6s | Patience:7/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6554 acc:0.5636 off1:0.9435 mae:0.4973
  Region Acc: [0.596 0.548 0.533 0.647 0.555 0.502]

[Epoch 071/100]
  Train - loss:0.7650 acc:0.5304 off1:0.9022 mae:0.5847
  Val   - loss:0.6554 acc:0.5636 off1:0.9435 mae:0.4973
  Part losses (Val): [0.641 0.661 0.672 0.562 0.676 0.719]
  Region Acc (Val): [0.596 0.548 0.533 0.647 0.555 0.502]
  LR:1.56e-06 | 16.2s | Patience:8/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6513 acc:0.5676 off1:0.9439 mae:0.4927
  Region Acc: [0.604 0.555 0.544 0.652 0.547 0.503]

[Epoch 072/100]
  Train - loss:0.7572 acc:0.5496 off1:0.9161 mae:0.5473
  Val   - loss:0.6513 acc:0.5676 off1:0.9439 mae:0.4927
  Part losses (Val): [0.635 0.657 0.669 0.559 0.671 0.716]
  Region Acc (Val): [0.604 0.555 0.544 0.652 0.547 0.503]
  LR:1.56e-06 | 17.5s | Patience:9/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6542 acc:0.5645 off1:0.9430 mae:0.4967
  Region Acc: [0.6   0.545 0.538 0.652 0.554 0.498]

[Epoch 073/100]
  Train - loss:0.7572 acc:0.5263 off1:0.9057 mae:0.5849
  Val   - loss:0.6542 acc:0.5645 off1:0.9430 mae:0.4967
  Part losses (Val): [0.641 0.662 0.671 0.563 0.673 0.716]
  Region Acc (Val): [0.6   0.545 0.538 0.652 0.554 0.498]
  LR:1.56e-06 | 17.5s | Patience:10/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6500 acc:0.5691 off1:0.9459 mae:0.4889
  Region Acc: [0.595 0.56  0.554 0.656 0.545 0.504]

[Epoch 074/100]
  Train - loss:0.7784 acc:0.5338 off1:0.9080 mae:0.5738
  Val   - loss:0.6500 acc:0.5691 off1:0.9459 mae:0.4889
  Part losses (Val): [0.633 0.656 0.668 0.559 0.669 0.716]
  Region Acc (Val): [0.595 0.56  0.554 0.656 0.545 0.504]
  LR:1.56e-06 | 17.4s | Patience:11/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6524 acc:0.5671 off1:0.9432 mae:0.4942
  Region Acc: [0.6   0.554 0.55  0.649 0.55  0.499]

[Epoch 075/100]
  Train - loss:0.7535 acc:0.5369 off1:0.9142 mae:0.5637
  Val   - loss:0.6524 acc:0.5671 off1:0.9432 mae:0.4942
  Part losses (Val): [0.639 0.659 0.669 0.561 0.671 0.716]
  Region Acc (Val): [0.6   0.554 0.55  0.649 0.55  0.499]
  LR:1.00e-06 | 19.0s | Patience:12/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6509 acc:0.5660 off1:0.9433 mae:0.4947
  Region Acc: [0.591 0.545 0.554 0.652 0.555 0.499]

[Epoch 076/100]
  Train - loss:0.8149 acc:0.5184 off1:0.8945 mae:0.6076
  Val   - loss:0.6509 acc:0.5660 off1:0.9433 mae:0.4947
  Part losses (Val): [0.636 0.658 0.669 0.561 0.668 0.714]
  Region Acc (Val): [0.591 0.545 0.554 0.652 0.555 0.499]
  LR:1.00e-06 | 15.9s | Patience:13/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6527 acc:0.5665 off1:0.9439 mae:0.4940
  Region Acc: [0.6   0.55  0.547 0.65  0.552 0.5  ]

[Epoch 077/100]
  Train - loss:0.7989 acc:0.5044 off1:0.8874 mae:0.6282
  Val   - loss:0.6527 acc:0.5665 off1:0.9439 mae:0.4940
  Part losses (Val): [0.638 0.658 0.669 0.561 0.674 0.717]
  Region Acc (Val): [0.6   0.55  0.547 0.65  0.552 0.5  ]
  LR:1.00e-06 | 18.3s | Patience:14/15
----------------------------------------------------------------------


                                                                                             

[val] loss:0.6558 acc:0.5627 off1:0.9421 mae:0.4996
  Region Acc: [0.594 0.545 0.532 0.65  0.549 0.505]

‚èπÔ∏è Early stopping at epoch 78

üéâ Training Finished!

üìä Test evaluation with best model...


                                                                            

[test] loss:0.6052 acc:0.5765 off1:0.9623 mae:0.4646
  Region Acc: [0.568 0.582 0.534 0.692 0.603 0.479]

üèÜ Test Results:
   Accuracy: 0.5765
   Off-by-1: 0.9623
   MAE: 0.4646
   Region Acc: [0.568 0.582 0.534 0.692 0.603 0.479]
   Part losses: [0.593 0.63  0.62  0.507 0.599 0.682]

üíæ Best model saved: runs_severity_classification/best_densenet121_mimic_classification.pth
üìà Best Validation Accuracy: 0.5707
üìâ Best Validation MAE: 0.4918

üí° Ï∂îÍ∞Ä ÏÑ±Îä• Ìñ•ÏÉÅÏùÑ ÏúÑÌïú Ï†úÏïà:
1. üéØ ÏïôÏÉÅÎ∏î: Îã§Î•∏ ÏãúÎìúÎ°ú 3~5Í∞ú Î™®Îç∏ ÌïôÏäµ ÌõÑ Ìà¨Ìëú
2. üî¨ ÏùòÎ£å Ï†ÑÏö© pretrained model ÏÇ¨Ïö©:
   - CheXpert, MIMIC-CXR Îì±ÏúºÎ°ú ÏÇ¨Ï†ÑÌïôÏäµÎêú Î™®Îç∏
3. üìä Îç∞Ïù¥ÌÑ∞ Ï∂îÍ∞Ä:
   - Ïô∏Î∂Ä COVID-19 Îç∞Ïù¥ÌÑ∞ÏÖã ÌôúÏö©
   - Pseudo-labelingÏúºÎ°ú unlabeled Îç∞Ïù¥ÌÑ∞ ÌôúÏö©
4. üé® Í≥†Í∏â Ï¶ùÍ∞ï:
   - CutMix, AugMix, RandAugment
   - Test-Time Augmentation (TTA)
5. üß™ ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ÌäúÎãù:
   - ordinal_weight Ï°∞Ï†ï (0.3~0.7)
   - Learning rate, batch size Ïã§Ìóò

‚úÖ 

In [8]:
pip install torchxrayvision

Looking in indexes: https://mirror.kakao.com/pypi/simple
Collecting torchxrayvision
  Downloading https://mirror.kakao.com/pypi/packages/8d/fe/0cad4be210168c00cb0bce2870f57b92d43c2f5b5e604a46d97f9d6c3957/torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m29.0/29.0 MB[0m [31m2.7 MB/s[0m  [33m0:00:10[0mm0:00:01[0m00:01[0m
Collecting scikit-image>=0.16 (from torchxrayvision)
  Downloading https://mirror.kakao.com/pypi/packages/96/08/916e7d9ee4721031b2f625db54b11d8379bd51707afaa3e5a29aecf10bc4/scikit_image-0.25.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.8 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m14.8/14.8 MB[0m [31m2.1 MB/s[0m  [33m0:00:07[0mm0:00:01[0m00:01[0m
Collecting imageio (from torchxrayvision)
  Downl