# Training pipeline


### 0. import

In [None]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import glob
from tqdm import tqdm
from collections import defaultdict
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.models as models

### 1. Configuration

##### path & format


In [None]:
# 1. Âü∫Êú¨ÂèÉÊï∏Ë®≠ÂÆö
NUM_CLASSES_DETAILED = 12 # Detailed (12È°û): 0:BG, 1:SA, 2:RF, 3:VL, 4:VI, 5:VM, 6:AM, 7:GR, 8:BFL, 9:ST, 10:SM, 11:BFS
NUM_CLASSES_ROUGH = 5 # Rough (5È°û): 0:BG, 1:SA, 2:QF, 3:GR, 4:HS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_WORKERS = 0
PIN_MEMORY = True

# 2. Ë≥áÊñôË∑ØÂæë
PROJECT_ROOT = "/home/n26141826/114-1_TAICA_cv_Final_Project"
DATA_DIR = os.path.join(PROJECT_ROOT, "data", "npy_2D_dataset_with_Embedding")
TRAIN_DATA_DIR = os.path.join(DATA_DIR, "train") 
TEST_DATA_DIR = os.path.join(DATA_DIR, "test")

CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints_ResNet34_pretrain")
RESULTS_DIR = os.path.join(PROJECT_ROOT, "results_ResNet34_pretrain")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)


# 3. Ê®°ÂûãÂ≠òÊ™îË∑ØÂæë (Model Checkpoints)
ROUGH_MODEL_PATH  = os.path.join(CHECKPOINT_DIR, "rough_model_best.pth")
WARMUP_MODEL_PATH = os.path.join(CHECKPOINT_DIR, "warmup_model.pth")
DETAIL_MODEL_PATH = os.path.join(CHECKPOINT_DIR, "detail_model_best.pth")
EVAL_MODEL_PATH   = DETAIL_MODEL_PATH

# 4. train loss / val loss Á¥ÄÈåÑË∑ØÂæë
ROUGH_LOSS_LOG  = os.path.join(RESULTS_DIR, "rough_loss.csv")
ROUGH_LOSS_IMG   = os.path.join(RESULTS_DIR, "rough_loss.png")
WARMUP_LOSS_LOG = os.path.join(RESULTS_DIR, "warmup_loss.csv")
WARMUP_LOSS_IMG  = os.path.join(RESULTS_DIR, "warmup_loss.png")
DETAIL_LOSS_LOG = os.path.join(RESULTS_DIR, "detail_loss.csv")
DETAIL_LOSS_IMG  = os.path.join(RESULTS_DIR, "detail_loss.png")


# 5. Ë©ï‰º∞Ëº∏Âá∫Ë∑ØÂæë
EVAL_CSV_OUTPUT = os.path.join(RESULTS_DIR, "evaluation_metrics_per_sequence.csv")
FLATTENED_METRICS_OUTPUT = os.path.join(RESULTS_DIR, "flattened_metrics.csv")

# Ê®°ÂûãÊû∂ÊßãÂèÉÊï∏ (Model Architecture)
IN_CHANNELS = 1        # Ëº∏ÂÖ•ÂΩ±ÂÉèÈÄöÈÅìÊï∏ (ÁÅ∞Èöé)
EMBEDDING_DIM = 64     # Ê¢ù‰ª∂ÂêëÈáè (Type/Pos) ÁöÑÁ∂≠Â∫¶
NUM_MRI_TYPES = 5      # MRI Â∫èÂàóÁ∏ΩÊï∏ (Water, Fat, T1, T2, STIR)

# Ë¶ñË¶∫ÂåñË®≠ÂÆö
VIZ_INTERVAL = 50

print(f"‚úÖ Configuration Loaded!")
print(f"   - Device     : {DEVICE}")
print(f"   - Data       : {DATA_DIR}")
print(f"   - Checkpoints: {CHECKPOINT_DIR}")
print(f"   - Results    : {RESULTS_DIR}")

#### Training Hyper-parameter

In [None]:
TARGET_SIZE = 512
# Âπæ‰ΩïËÆäÊèõ (Joint Transform: Image + Label)
AUG_P_FLIP = 0.5          # Â∑¶Âè≥ÁøªËΩâÁöÑÊ©üÁéá
AUG_P_SCALE = 0.5         # Èö®Ê©üÁ∏ÆÊîæÁöÑÊ©üÁéá
AUG_LIMIT_SCALE = 0.1     # Á∏ÆÊîæÂπÖÂ∫¶ (0.1 ‰ª£Ë°® 0.9x ~ 1.1x)

# ÂÉèÁ¥†ËÆäÊèõ (Independent Transform: Image Only)
AUG_P_BRIGHTNESS = 0.5    # ‰∫ÆÂ∫¶/Â∞çÊØîÂ∫¶Ë™øÊï¥Ê©üÁéá
AUG_LIMIT_BRIGHT = 0.2    # ‰∫ÆÂ∫¶Ë™øÊï¥ÂπÖÂ∫¶ (+-20%)
AUG_LIMIT_CONTRAST = 0.2  # Â∞çÊØîÂ∫¶Ë™øÊï¥ÂπÖÂ∫¶ (+-20%)

# Pre-train (Rough)
ROUGH_BATCH_SIZE = 32
ROUGH_LR = 1e-3
ROUGH_EPOCHS = 30

# Hierarchical Warm-up (Detail Head Warm-up)
WARMUP_BATCH_SIZE = 32    # ‰ΩøÁî® Detail DataÔºåÈÄöÂ∏∏ÈáèËºÉÂ∞ëÊàñÈúÄËºÉÂ∞è Batch
WARMUP_LR = 1e-3          # Âè™Ë®ìÁ∑¥ HeadÔºåÂèØ‰ª•Áî®Â§ß‰∏ÄÈªûÁöÑ LR
WARMUP_EPOCHS = 10        # Áü≠Êö´ÁÜ±Ë∫´Âç≥ÂèØ
# ÊéßÂà∂ Student(Detail) Ê®°‰ªø Teacher(Rough) ÁöÑÂº∑Â∫¶
# Âª∫Ë≠∞ÁØÑÂúç: 0.1 ~ 1.0
CONSISTENCY_WEIGHT = 0.5

# Fine-tune (Detail)
DETAIL_BATCH_SIZE = 24
DETAIL_EPOCHS = 50
DETAIL_LR_ENCODER = 1e-5 # Encoder ÊÖ¢ÊÖ¢‰øÆ (ÂæÆÊï¥ÂΩ¢)
DETAIL_LR_DECODER = 1e-4 # Decoder Ê≠£Â∏∏Â≠∏

# Mappings & Definitions
ROUGH_MAP = [0, 1, 2, 2, 2, 2, 0, 3, 4, 4, 4, 4]  # 0:BG, 1:SA, 2:QF, 3:GR, 4:HS
# MRI Â∫èÂàóÊò†Â∞ÑË°® (Modality Mapping)
TYPE_MAP = {
    'Water': 0,
    'FATFRACTION': 1, # ÈÄöÂ∏∏Â∞á Fat Fraction Ë¶ñÁÇ∫ Fat È°ûÂà•ÔºåÊàñ‰æù‰Ω†ÈúÄÊ±ÇÊîπÁÇ∫Áç®Á´ã ID
    'Fat': 1,
    'T1': 2,
    'T2': 3,
    'STIR': 4
}
ID_TO_TYPE = {v: k for k, v in TYPE_MAP.items()}

# ÂÆöÁæ©Êò†Â∞ÑÁü©Èô£ (12È°û -> 5È°û)
# 0:BG, 1:SA, 2:RF, 3:VL, 4:VI, 5:VM, 6:AM, 7:GR, 8:BFL, 9:ST, 10:SM, 11:BFS
# Map to: 0:BG, 1:SA, 2:QF, 3:GR, 4:HS
MAPPING_MATRIX = torch.tensor([
    [1, 0, 0, 0, 0], # 0->0
    [0, 1, 0, 0, 0], # 1->1
    [0, 0, 1, 0, 0], # 2->2
    [0, 0, 1, 0, 0], # 3->2
    [0, 0, 1, 0, 0], # 4->2
    [0, 0, 1, 0, 0], # 5->2
    [1, 0, 0, 0, 0], # 6->0 (AM -> BG)
    [0, 0, 0, 1, 0], # 7->3
    [0, 0, 0, 0, 1], # 8->4
    [0, 0, 0, 0, 1], # 9->4
    [0, 0, 0, 0, 1], # 10->4
    [0, 0, 0, 0, 1]  # 11->4
], dtype=torch.float32).to(DEVICE)

# Evaluation Configuration
EVAL_BATCH_SIZE = 1
# 2. ËÇåËÇâÂêçÁ®±Â∞çÁÖßË°® (1-11)
MUSCLE_NAMES = {
    1: 'Sartorius',
    2: 'Rectus Femoris',
    3: 'Vastus Lateralis',
    4: 'Vastus Intermedius',
    5: 'Vastus Medialis',
    6: 'Adductor Magnus',
    7: 'Gracilis',
    8: 'Biceps Femoris LH',
    9: 'Semitendinosus',
    10: 'Semimembranosus',
    11: 'Biceps Femoris SH'
}
# 3. Ë¶ñË¶∫ÂåñÈ°èËâ≤Ë®≠ÂÆö (Visualization Colors)
VIZ_COLORS = [
    '#000000', '#e6194b', '#006400', '#228B22', '#32CD32', '#7CFC00', 
    '#911eb4', '#46f0f0', '#00008B', '#0000CD', '#4169E1', '#87CEEB'
]
VIZ_CMAP = mcolors.ListedColormap(VIZ_COLORS)
VIZ_NORM = mcolors.BoundaryNorm(boundaries=np.arange(13)-0.5, ncolors=12)

### 2. Utilities

In [None]:
# --- Metrics (ÁÇ∫‰∫ÜË∑üÁµÑÂì°ÊØîÂ∞çÔºå‰ΩøÁî®Ê®ôÊ∫ñ Dice) ---
def dice_score(preds, targets, num_classes):
    # Á∞°ÂñÆÁöÑ Dice Ë®àÁÆóÁØÑ‰æãÔºåÁµÑÂì°ÂèØËÉΩÊúâÊõ¥Ë§áÈõúÁöÑÁâàÊú¨ÔºåÂèØÊõøÊèõ
    dice_list = []
    preds = torch.argmax(preds, dim=1)
    
    for cls in range(1, num_classes): # Skip background
        pred_cls = (preds == cls).float()
        target_cls = (targets == cls).float()
        
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        
        score = (2. * intersection + 1e-6) / (union + 1e-6)
        dice_list.append(score.item())
        
    return np.mean(dice_list)

def plot_training_curves(log_path, image_path=None):
    """
    ËÆÄÂèñ CSV Log ‰∏¶Áï´Âá∫ Loss Âíå Dice Êõ≤Á∑öÔºå
    ‰ΩøÁî® clear_output ÈÅîÊàêÂç≥ÊôÇÂà∑Êñ∞ÊïàÊûú„ÄÇ
    """
    if not os.path.exists(log_path):
        return

    # ËÆÄÂèñÊï∏Êìö
    try:
        df = pd.read_csv(log_path)
    except pd.errors.EmptyDataError:
        return

    # Ë®≠ÂÆöÁï´Â∏É (Â∑¶ÈÇä Loss, Âè≥ÈÇä Dice)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # --- 1. Loss Curve ---
    ax1.plot(df['epoch'], df['train_loss'], label='Train Loss', marker='o', color='tab:blue')
    ax1.plot(df['epoch'], df['val_loss'], label='Val Loss', marker='o', color='tab:orange')
    ax1.set_title("Loss Curve")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.grid(True, linestyle='--', alpha=0.6)
    ax1.legend()

    # --- 2. Dice Curve ---
    if 'val_dice' in df.columns:
        ax2.plot(df['epoch'], df['val_dice'], label='Val Dice', marker='o', color='tab:green')
        ax2.set_title("Validation Dice Score")
        ax2.set_xlabel("Epoch")
        ax2.set_ylabel("Dice")
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.legend()
        
        # Ê®ôÂá∫ÁõÆÂâçÁöÑÊúÄÈ´òÂàÜ
        max_dice = df['val_dice'].max()
        max_epoch = df.loc[df['val_dice'].idxmax(), 'epoch']
        ax2.annotate(f'Max: {max_dice:.4f}', xy=(max_epoch, max_dice), 
                     xytext=(max_epoch, max_dice - 0.05),
                     arrowprops=dict(facecolor='black', shrink=0.05))

    plt.tight_layout()
    if image_path:
        plt.savefig(image_path)
    
    # Ê∏ÖÈô§‰∏ä‰∏ÄÊ¨°ÁöÑÂúñ‰∏¶È°ØÁ§∫Êñ∞ÁöÑ
    clear_output(wait=True)
    plt.show()

### 2. Train - validation split

In [None]:
# 1. ÊêúÂ∞ãÊâÄÊúâ .npy Ê™îÊ°à
all_files = glob.glob(os.path.join(TRAIN_DATA_DIR, "*.npy"))
print(f"Total .npy files found in train data: {len(all_files)}")

# 2. ‰æùÁÖß Subject ÂàÜÈ°û
train_files = []
val_files = []
train_detail = []
train_without_detail = []
for i in all_files:
    data = np.load(i, allow_pickle=True).item()
    if data.get('has_detail'):
        train_detail.append(i)
    else:
        train_without_detail.append(i)
# --- split rough only training data & val ---
train_roughonly_subs, val_roughonly_subs = train_test_split(
    train_without_detail, test_size=0.1, random_state=42
)
# --- split detail training data & val ---
train_detail_subs, val_detail_subs = train_test_split(
    train_detail, test_size=0.1, random_state=42
)
# --- combine ---
train_files = train_roughonly_subs + train_detail_subs
val_files = val_roughonly_subs + val_detail_subs
# --- Ë®àÁÆó Train ÂàÜ‰Ωà ---
train_rough_seq_distribution = {}
for i in train_files:
    data = np.load(i, allow_pickle=True).item()
    key = data.get('type_idx')
    if key not in train_rough_seq_distribution:
        train_rough_seq_distribution[key] = 0
    train_rough_seq_distribution[key] += 1
# --- Ë®àÁÆó Val ÂàÜ‰Ωà ---
val_rough_seq_distribution = {}
for i in val_files:
    data = np.load(i, allow_pickle=True).item()
    key = data.get('type_idx')
    if key not in val_rough_seq_distribution:
        val_rough_seq_distribution[key] = 0
    val_rough_seq_distribution[key] += 1
# --- log ---
print("-"*30)
print(f"Train Files : {len(train_files)}")
print(f"Val Files   : {len(val_files)}")
print("-"*30)

# --- measure type distribution ---
# 1. ÊâæÂá∫ÊâÄÊúâÂá∫ÁèæÈÅéÁöÑÈ°ûÂà• ID (ÂèñËÅØÈõÜ‰∏¶ÊéíÂ∫è)
all_keys = sorted(set(train_rough_seq_distribution.keys()) | set(val_rough_seq_distribution.keys()))

# 2. Ê∫ñÂÇôË≥áÊñôÂ≠óÂÖ∏
data = {}
for k in all_keys:
    # ÂèñÂæóÈ°ûÂà•ÂêçÁ®±
    seq_name = ID_TO_TYPE.get(k, f"Unknown({k})")
    
    # ÂèñÂæó Train Âíå Val ÁöÑÊï∏Èáè
    train_count = train_rough_seq_distribution.get(k, 0)
    val_count = val_rough_seq_distribution.get(k, 0)
    
    # Ë®àÁÆóÊØîÁéá (Val / Train)
    ratio = val_count / train_count if train_count > 0 else 0.0
    
    # Â≠òÂÖ•Â≠óÂÖ∏
    data[seq_name] = [train_count, val_count, ratio]

# 3. ËΩâÊàê DataFrame
df_dist = pd.DataFrame.from_dict(data, orient='index', columns=['Train', 'Val', 'Val : Train'])

# 4. Âä†ÂÖ• Total Á∏ΩÂíåÂàó
total_train = df_dist['Train'].sum()
total_val = df_dist['Val'].sum()
total_ratio = total_val / total_train if total_train > 0 else 0.0

df_dist.loc['Total'] = [total_train, total_val, total_ratio]

# 5. Ê†ºÂºèÂåñÈ°ØÁ§∫ (È°ØÁ§∫ÁôæÂàÜÊØîÂèØËÉΩÊúÉÊõ¥Áõ¥ËßÄ)
styled_df = df_dist.style.format({
    'Train': '{:,.0f}',       # Êï¥Êï∏
    'Val': '{:,.0f}',         # Êï¥Êï∏
    'Val : Train': '{:.1%}'   # ËΩâÊàêÁôæÂàÜÊØî (‰æãÂ¶Ç 0.11 -> 11.0%)
})
display(styled_df)

### 3. Dataset

In [None]:
class SliceMasterDataset(Dataset):
    def __init__(self, file_list, mode='rough', transform=None):
        """
        Args:
            file_list (list): List of .npy file paths
            mode (str): 'rough' or 'detail'
            transform: Albumentations transform
        """
        self.file_list = file_list
        self.mode = mode
        self.transform = transform
        # Rough Map (12 -> 5)
        self.rough_map = np.array(ROUGH_MAP, dtype=np.uint8)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # 1. Load Dictionary
        data = np.load(self.file_list[idx], allow_pickle=True).item()
        
        image = data['image'] # (H, W) float32
        z_pos = data['z_pos']
        type_idx = data['type_idx']

        # 2. Select Label
        if self.mode == 'rough':
            label = data['rough_label']
        elif self.mode == 'detail':
            label = data['detail_label']
            # Â¶ÇÊûúÊòØ Detail Ê®°Âºè‰ΩÜÊ≤íÊúâÊ®ôË®ªÔºåÈÄôÂºµÂúñÊáâË©≤Ë¢´ÈÅéÊøæÊàñ Loss Masking
            # ÈÄôË£°Á∞°ÂñÆËôïÁêÜÔºöÂ¶ÇÊûúÂÖ®ÈªëÂâáË¶ñÁÇ∫ËÉåÊôØ
        
        # 3. Augmentation
        if self.transform:
            augmented = self.transform(image=image, mask=label)
            image = augmented['image']
            label = augmented['mask']
        
        # 4. To Tensor
        if isinstance(image, np.ndarray):
            if image.ndim == 2: 
                image = image[np.newaxis, ...] # (1, H, W)
            image = image.copy()
            
        if isinstance(label, np.ndarray):
            label = label.copy()

        return (
            torch.from_numpy(image).float(),
            torch.from_numpy(label).long(),
            torch.tensor(z_pos).float(),
            torch.tensor(type_idx).long()
        )

### 3. Transform

In [None]:
train_transform = A.Compose([
    # 1. Independent Transform (Âè™Êîπ Image ‰∏çÂΩ±ÈüøÂ∫ßÊ®ô)
    A.RandomBrightnessContrast(
        brightness_limit=AUG_LIMIT_BRIGHT,  # ‰∫ÆÂ∫¶
        contrast_limit=AUG_LIMIT_CONTRAST,  # Â∞çÊØîÂ∫¶
        p=AUG_P_BRIGHTNESS
    ),
    # (ÂèØÈÅ∏) È´òÊñØÈõúË®ä
    # A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),

    # 2. Joint Transform (Image Ëàá Label ÂêåÊ≠•)
    A.HorizontalFlip(p=AUG_P_FLIP),   # Â∑¶Âè≥ÁøªËΩâ (Horizontal Flip)
    A.RandomScale(scale_limit=AUG_LIMIT_SCALE, p=AUG_P_SCALE), # Èö®Ê©üÁ∏ÆÊîæ (Zoom In/Out)
    
    # 3. Resize
    A.Resize(height=TARGET_SIZE, width=TARGET_SIZE, interpolation=1),
])

# È©óË≠âÈõÜÔºöÂè™ÂÅö ResizeÔºå‰∏çÂÅö‰ªª‰ΩïÈö®Ê©üÂ¢ûÂº∑
val_transform = A.Compose([
    A.Resize(height=TARGET_SIZE, width=TARGET_SIZE, interpolation=1)
])

### 4. Model Architecture

#### main model (Conditioned U-Net)

In [None]:


class ConditionedResNetUNet(nn.Module):
    def __init__(self, n_channels=1, n_classes_rough=5, n_classes_detail=12, embed_dim=64, num_mri_types=5):
        super().__init__()
        
        # ==========================================
        # 1. Encoder (Pre-trained ResNet34)
        # ==========================================
        # ËºâÂÖ• ImageNet È†êË®ìÁ∑¥Ê¨äÈáç
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        
        # [‰øÆÊîπÁ¨¨‰∏ÄÂ±§] ResNet È†êË®≠ÂêÉ RGB (3 channel)ÔºåÊàëÂÄëÊòØÁÅ∞Èöé (1 channel)
        # Á≠ñÁï•ÔºöÂ∞áÂéüÊú¨ 3 channel ÁöÑÊ¨äÈáçÂèñÂπ≥ÂùáÔºåÊøÉÁ∏ÆÊàê 1 channel
        original_conv1 = resnet.conv1
        self.enc_conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            self.enc_conv1.weight[:] = original_conv1.weight.sum(dim=1, keepdim=True)
            
        self.enc_bn1 = resnet.bn1
        self.enc_relu = resnet.relu
        self.enc_maxpool = resnet.maxpool
        
        self.enc_layer1 = resnet.layer1 # 64
        self.enc_layer2 = resnet.layer2 # 128
        self.enc_layer3 = resnet.layer3 # 256
        self.enc_layer4 = resnet.layer4 # 512 (Bottleneck)
        
        # ==========================================
        # 2. Embedding Layers (‰∏çËÆä)
        # ==========================================
        self.type_emb = nn.Embedding(num_embeddings=num_mri_types, embedding_dim=embed_dim)
        self.pos_emb = nn.Sequential(
            nn.Linear(1, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim)
        )
        
        # ==========================================
        # 3. Fusion Layer
        # ==========================================
        # ResNet34 Bottleneck Ëº∏Âá∫ÊòØ 512 channels
        # ÊàëÂÄëË¶ÅËûçÂêà: 512 (Feature) + 64 (Type) + 64 (Pos) = 640 -> ËΩâÂõû 512
        self.fusion = nn.Sequential(
            nn.Conv2d(512 + embed_dim*2, 512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )

        # ==========================================
        # 4. Decoder (ÈÖçÂêà ResNet ÈÄöÈÅìÊï∏Ë™øÊï¥)
        # ==========================================
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # ResNet34 Skip Connections:
        # Layer4: 512
        # Layer3: 256
        # Layer2: 128
        # Layer1: 64
        # Conv1 : 64
        
        # Up1: Fusion(512) + Layer3(256) -> 256
        self.up1 = self._block(512 + 256, 256)
        # Up2: Up1(256) + Layer2(128) -> 128
        self.up2 = self._block(256 + 128, 128)
        # Up3: Up2(128) + Layer1(64) -> 64
        self.up3 = self._block(128 + 64, 64)
        # Up4: Up3(64) + Conv1(64) -> 64 (Original ResNet Conv1 output size is same as Layer1, but MaxPool happened)
        # Note: ResNet structure is: Conv1(1/2) -> MaxPool(1/4) -> Layer1(1/4). 
        # ÊâÄ‰ª• Layer1 Âíå MaxPool ÂæåÁöÑ Feature Map Â§ßÂ∞è‰∏ÄÊ®£„ÄÇ
        # ÁÇ∫‰∫ÜÁ∞°ÂåñÔºåÊúÄÂæå‰∏ÄÂ±§ÊàëÂÄëÁõ¥Êé•Âæû 64 -> 32ÔºåÂÜç Upsample ÂõûÂéüÂúñ
        self.up4 = self._block(64, 32) 
        
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # ÂõûÂà∞ÂéüÂúñÂ§ßÂ∞è (Â¶ÇÊûúÈúÄË¶Å)
        # Ë®ª: ResNet Conv1 ÊòØ stride 2ÔºåÊâÄ‰ª•ÊúÄÂæå Feature map ÊòØÂéüÂúñ 1/2„ÄÇ
        # ‰∏äÈù¢ up4 Âá∫‰æÜÂæåÈÇÑÊòØ 1/2 Â§ßÂ∞èÔºåÈúÄË¶ÅÂÜç up ‰∏ÄÊ¨°„ÄÇ
        
        # ==========================================
        # 5. Dual Heads
        # ==========================================
        self.head_rough = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), # Ë£úÂõûÊúÄÂæåÁöÑ 2x
            nn.Conv2d(32, n_classes_rough, kernel_size=1)
        )
        self.head_detail = nn.Sequential(
             nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), # Ë£úÂõûÊúÄÂæåÁöÑ 2x
             nn.Conv2d(32, n_classes_detail, kernel_size=1)
        )

    def _block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True)
        )

    def forward(self, x, type_idx, z_pos, return_mode='both'):
        # --- Encoder (ResNet) ---
        # x: (B, 1, 512, 512)
        x0 = self.enc_conv1(x)      # (B, 64, 256, 256)
        x0 = self.enc_bn1(x0)
        x0 = self.enc_relu(x0)
        x1 = self.enc_maxpool(x0)   # (B, 64, 128, 128) -> Skip 1
        
        x2 = self.enc_layer1(x1)    # (B, 64, 128, 128) -> ResNet layer1 ‰∏çÊîπËÆäÂ§ßÂ∞è
        x3 = self.enc_layer2(x2)    # (B, 128, 64, 64)  -> Skip 2
        x4 = self.enc_layer3(x3)    # (B, 256, 32, 32)  -> Skip 3
        x5 = self.enc_layer4(x4)    # (B, 512, 16, 16)  -> Bottleneck
        
        # --- Injection ---
        t_vec = self.type_emb(type_idx) 
        p_vec = self.pos_emb(z_pos.unsqueeze(1))
        cond = torch.cat([t_vec, p_vec], dim=1)
        cond = cond.unsqueeze(2).unsqueeze(3).expand(-1, -1, x5.shape[2], x5.shape[3])
        
        x5 = torch.cat([x5, cond], dim=1)
        x5 = self.fusion(x5) # (B, 512, 16, 16)
        
        # --- Decoder ---
        # x5 (16x16) + x4 (32x32)
        d1 = self.up1(torch.cat([self.up(x5), x4], dim=1)) # (B, 256, 32, 32)
        
        # d1 (32x32) + x3 (64x64)
        d2 = self.up2(torch.cat([self.up(d1), x3], dim=1)) # (B, 128, 64, 64)
        
        # d2 (64x64) + x2 (128x128) [Ê≥®ÊÑè: x2 Âíå x1 ÂÖ∂ÂØ¶ÊòØÂêå‰∏ÄÂ±§Á¥öÔºåÈÄôË£°ÈÅ∏ x2 Êé•]
        d3 = self.up3(torch.cat([self.up(d2), x2], dim=1)) # (B, 64, 128, 128)
        
        # d3 (128x128) -> (256x256)
        d4 = self.up4(self.up(d3)) # (B, 32, 256, 256)
        
        # Heads ÂÖßÂê´ÊúÄÂæåÁöÑ Upsample (256 -> 512)
        
        if return_mode == 'rough':
            return self.head_rough(d4)
        elif return_mode == 'detail':
            return self.head_detail(d4)
        elif return_mode == 'both':
            return self.head_rough(d4), self.head_detail(d4)
            
    # --- Helper Methods ---
    def freeze_encoder_and_rough(self):
        print("üîí Freezing ResNet Encoder & Rough Head...")
        # ÂáçÁµê ResNet ÈÉ®ÂàÜ
        for param in self.enc_conv1.parameters(): param.requires_grad = False
        for param in self.enc_bn1.parameters(): param.requires_grad = False
        for param in self.enc_layer1.parameters(): param.requires_grad = False
        for param in self.enc_layer2.parameters(): param.requires_grad = False
        for param in self.enc_layer3.parameters(): param.requires_grad = False
        for param in self.enc_layer4.parameters(): param.requires_grad = False
        
        # ÂáçÁµê Embedding & Fusion
        for param in self.type_emb.parameters(): param.requires_grad = False
        for param in self.pos_emb.parameters(): param.requires_grad = False
        for param in self.fusion.parameters(): param.requires_grad = False
        
        # ÂáçÁµê Rough Head
        for param in self.head_rough.parameters(): param.requires_grad = False
        
        # Á¢∫‰øù Detail Head ÊòØÈñãÂïüÁöÑ
        for param in self.head_detail.parameters(): param.requires_grad = True
        print("‚úÖ Done.")

    def unfreeze_all(self):
        for param in self.parameters(): 
            param.requires_grad = True
        print("üîì All layers unfrozen.")
    
    # Ê≥®ÊÑè: ÈÄôÂÄãÊñ∞Ê®°ÂûãÂ∑≤Á∂ìÂÖßÂª∫ ImageNet Ê¨äÈáçÔºå‰∏çÈúÄË¶Å load_pretrained_encoder
    # ‰ΩÜÁÇ∫‰∫ÜÁõ∏ÂÆπ‰Ω†ÁöÑ codeÔºåÂèØ‰ª•Áïô‰∏ÄÂÄãÁ©∫ function ÊàñÁî®‰æÜËºâÂÖ•‰Ω†Ëá™Â∑±ÁöÑ checkpoint
    def load_pretrained_encoder(self, path):
        print("‚ÑπÔ∏è Note: This model uses ImageNet weights by default.")
        if os.path.exists(path):
            print(f"üîÑ Loading checkpoint: {path}")
            state_dict = torch.load(path)
            # ÈÅéÊøæ‰∏¶ËºâÂÖ•... (Âêå‰πãÂâçÁöÑÈÇèËºØ)

#### Fine-tune Optimizer Design

##### Differential Learning Rates

In [None]:
# [Fine-tune Strategy Design] (Differential Learning Rates)
def get_fine_tuning_optimizer(model, encoder_lr=1e-5, decoder_lr=1e-4):
    """
    ÁÇ∫ Fine-tuning Ë®≠Ë®àÁöÑÂÑ™ÂåñÂô®Ë®≠ÂÆöÔºö
    1. Encoder ‰ΩøÁî®Ê•µÂ∞èÁöÑ LR (‰øùÁïô Pre-train Áü•Ë≠òÔºå‰ΩÜÂÖÅË®±ÂæÆË™ø) -> Post-train ÊïàÊûú
    2. Decoder/Head ‰ΩøÁî®ËºÉÂ§ßÁöÑ LR (Âø´ÈÄüÂ≠∏Áøí 12 È°ûÁâπÂæµ)
    """
    
    # ÂçÄÂàÜÂèÉÊï∏Áæ§ÁµÑ
    encoder_params = []
    decoder_params = []
    
    # ÂÆöÁæ©Âì™‰∫õÂ±§Â±¨Êñº Encoder (Ê†πÊìö ConditionedResNetUNet ÁµêÊßã)
    encoder_layers = ['inc', 'down1', 'down2', 'down3', 'down4', 'type_emb', 'pos_emb', 'fusion']
    
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
            
        # Âà§Êñ∑ÂèÉÊï∏Â±¨Êñº Encoder ÈÇÑÊòØ Decoder
        is_encoder = any(layer in name for layer in encoder_layers)
        
        if is_encoder:
            encoder_params.append(param)
        else:
            decoder_params.append(param) # up1, up2, up3, up4, outc
            
    print(f"üîß Optimizer Groups Setup:")
    print(f"   - Encoder Params: {len(encoder_params)} tensors (LR={encoder_lr}) -> Slow updates")
    print(f"   - Decoder Params: {len(decoder_params)} tensors (LR={decoder_lr}) -> Fast learning")

    # Âª∫Á´ãÂÑ™ÂåñÂô®
    optimizer = torch.optim.Adam([
        {'params': encoder_params, 'lr': encoder_lr},
        {'params': decoder_params, 'lr': decoder_lr}
    ])
    
    return optimizer

# --- Ê∏¨Ë©¶‰∏Ä‰∏ã (Optional) ---
# dummy_model = ConditionedResNetUNet(1, 12).to(DEVICE)
# opt = get_fine_tuning_optimizer(dummy_model)

#### mapping matrix

In [None]:
def hierarchical_consistency_loss(logits_detail, logits_rough):
    """
    Ë®àÁÆó‰∏ÄËá¥ÊÄß Loss:
    Â∞á Detailed È†êÊ∏¨ÈÄèÈÅé Mapping ËÅöÂêàÂæåÔºåÊáâË©≤Ë¶ÅË∑ü Rough È†êÊ∏¨ÂæàÂÉè„ÄÇ
    """
    # 1. Â∞á Logits ËΩâÁÇ∫Ê©üÁéáÂàÜ‰Ωà (Softmax)
    probs_detail = F.softmax(logits_detail, dim=1) # (B, 12, H, W)
    probs_rough = F.softmax(logits_rough, dim=1)   # (B, 5, H, W)
    
    # 2. Âü∑Ë°å Mapping (Áü©Èô£‰πòÊ≥ï)
    # Ë™øÊï¥Á∂≠Â∫¶‰ª•ÈÄ≤Ë°åÁü©Èô£‰πòÊ≥ï: (B, H, W, 12) x (12, 5) -> (B, H, W, 5)
    probs_detail_permuted = probs_detail.permute(0, 2, 3, 1) 
    probs_projected = torch.matmul(probs_detail_permuted, MAPPING_MATRIX)
    
    # ËΩâÂõû (B, 5, H, W)
    probs_projected = probs_projected.permute(0, 3, 1, 2)
    
    # 3. Ë®àÁÆóÂÖ©ÂàÜ‰Ωà‰πãÈñìÁöÑÂ∑ÆÁï∞ (KL Divergence Êàñ MSE)
    # ÈÄôË£°Áî® MSE Á∞°ÂñÆÁõ¥ËßÄÔºöÂ∏åÊúõÊäïÂΩ±ÂæåÁöÑÊ©üÁéáÂàÜÂ∏É = ÂéüÂßã Rough Ê©üÁéáÂàÜÂ∏É
    loss = F.mse_loss(probs_projected, probs_rough)
    
    return loss

### 5. Phase 1: Pre-train

In [None]:
import time

print("\n" + "="*40)
print("   üöÄ START PHASE 1: PRE-TRAINING (ROUGH)")
print("   Target: 5 Classes (Rough Labels)")
print("="*40)
print(f"   üìù log will be saved to: {ROUGH_LOSS_LOG}")


# 1. Setup Data (ÊåáÂÆö mode='rough')
train_ds = SliceMasterDataset(train_files, mode='rough', transform=train_transform)
val_ds = SliceMasterDataset(val_files, mode='rough', transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=ROUGH_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_ds, batch_size=ROUGH_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print(f"Phase 1 Data: {len(train_ds)} train slices, {len(val_ds)} val slices")

# 2. Setup Model (ÊåáÂÆö n_classes=5)
model = ConditionedResNetUNet(
    n_channels=IN_CHANNELS, 
    n_classes_rough=NUM_CLASSES_ROUGH,   # 5
    n_classes_detail=NUM_CLASSES_DETAILED, # 12
    embed_dim=EMBEDDING_DIM,
    num_mri_types=NUM_MRI_TYPES
).to(DEVICE)

if torch.cuda.device_count() > 1:
    print(f"üî• Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=ROUGH_LR)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler("cuda") # Ê∑∑ÂêàÁ≤æÂ∫¶Ë®ìÁ∑¥

# 3. Training Loop
best_dice = 0.0
history_rough = []

for epoch in range(ROUGH_EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    
    # [Training]
    # ÈÄôË£°Ëß£ÂåÖ 4 ÂÄãËÆäÊï∏: image, label, z_pos, type_idx
    for images, masks, z_pos, type_idx in tqdm(train_loader, desc=f"Epoch {epoch+1}/{ROUGH_EPOCHS}"):
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
        
        with torch.amp.autocast("cuda"):
            # Forward
            preds = model(images, type_idx, z_pos, return_mode='rough')
            loss = criterion(preds, masks)
            
        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
    
    # [Validation]
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for images, masks, z_pos, type_idx in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
            
            with torch.amp.autocast("cuda"):
                preds = model(images, type_idx, z_pos, return_mode='rough')
                loss = criterion(preds, masks)
            
            val_loss += loss.item()
            # Ë®àÁÆó Dice (‰ΩøÁî® Cell 2 ÂÆöÁæ©ÁöÑ function)
            val_dice += dice_score(preds, masks, NUM_CLASSES_ROUGH)
            
    # Ë®àÁÆóÂπ≥ÂùáÊåáÊ®ô
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)
    
    duration = time.time() - start_time
    
    # Á¥ÄÈåÑÊ≠∑Âè≤    
    history_rough.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_dice': avg_val_dice,
        'time': duration
    })
    pd.DataFrame(history_rough).to_csv(ROUGH_LOSS_LOG, index=False)
    
    # [Save Best Model] ÈóúÈçµÊ≠•È©üÔºÅ
    if avg_val_dice > best_dice:
        best_dice = avg_val_dice
        state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        torch.save(state_dict, ROUGH_MODEL_PATH)
        print(f"  üèÜ Saved Best Model! (Dice: {best_dice:.4f}) -> {ROUGH_MODEL_PATH}")
    
    plot_training_curves(ROUGH_LOSS_LOG, ROUGH_LOSS_IMG)
    print(f"Epoch {epoch+1} | T-Loss: {avg_train_loss:.4f} | V-Loss: {avg_val_loss:.4f} | V-Dice: {avg_val_dice:.4f} (Best: {best_dice:.4f})")

print("‚úÖ Phase 1 Training Complete!")

### 6. Phase 1.5: Hierarchical Warm-up

In [None]:
print("\n" + "="*40)
print("   üöÄ START PHASE 1.5: HIERARCHICAL WARM-UP")
print("   Goal: Train Detail Head to align with Rough Head")
print("="*40)
print(f"   üìù log will be saved to: {WARMUP_LOSS_LOG}")

# 1. Ë≥áÊñôÈÅéÊøæ (Âè™Áî®Êúâ Detail ÁöÑË≥áÊñô)
print("Filtering detailed data...")
train_files_detail = [f for f in train_files if np.load(f, allow_pickle=True).item().get('has_detail', False)]
val_files_detail = [f for f in val_files if np.load(f, allow_pickle=True).item().get('has_detail', False)]
print(f"Detail Train: {len(train_files_detail)} | Detail Val: {len(val_files_detail)}")

# Setup Data (Mode='detail')
train_ds_warmup = SliceMasterDataset(train_files_detail, mode='detail', transform=train_transform)
train_loader_warmup = DataLoader(train_ds_warmup, batch_size=WARMUP_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# È©óË≠âÈõÜ: ‰∏çÂÅö Augmentation (transform=None)
val_ds_warmup = SliceMasterDataset(val_files_detail, mode='detail', transform=val_transform)
val_loader_warmup = DataLoader(val_ds_warmup, batch_size=WARMUP_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

model = ConditionedResNetUNet(
    n_channels=IN_CHANNELS, 
    n_classes_rough=NUM_CLASSES_ROUGH, 
    n_classes_detail=NUM_CLASSES_DETAILED, 
    embed_dim=EMBEDDING_DIM, 
    num_mri_types=NUM_MRI_TYPES
).to(DEVICE)

# 3. ËºâÂÖ• Phase 1 Ê¨äÈáç
model.load_pretrained_encoder(ROUGH_MODEL_PATH)

# 4. ÂáçÁµê Encoder Âíå Rough Head (ËÄÅÂ∏´‰∏çÂáÜÂãï)
model.freeze_encoder_and_rough()

if torch.cuda.device_count() > 1:
    print(f"üî• Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)
    
# 5. ÂÆöÁæ©Â∏∂Ê¨äÈáçÁöÑ Loss (Crucial for AM Muscle!)
# 0:BG, 6:AM -> Âä†Âº∑ AM Ê¨äÈáç
warmup_weights = torch.ones(NUM_CLASSES_DETAILED).to(DEVICE)
warmup_weights[6] = 2.0 
criterion_warmup = nn.CrossEntropyLoss(weight=warmup_weights)

# 6. Optimizer (Âè™Êõ¥Êñ∞ head_detail)
optimizer_warmup = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=WARMUP_LR)
scaler = torch.amp.GradScaler("cuda") # Ê∑∑ÂêàÁ≤æÂ∫¶Ë®ìÁ∑¥
history_warmup = []

# 7. Training Loop (Warm-up)
print(f"üî• Starting Warm-up for {WARMUP_EPOCHS} epochs...")
for epoch in range(WARMUP_EPOCHS): # Áü≠Êö´Ë®ìÁ∑¥Âç≥ÂèØ (e.g., 5-10 epochs)
    model.train()
    train_loss = 0
    train_loss_ce = 0
    train_loss_consist = 0
    
    pbar = tqdm(train_loader_warmup, desc=f"Warm-up Epoch {epoch+1}/{WARMUP_EPOCHS}")
    
    for images, mask_detail, z_pos, type_idx in pbar:
        images, mask_detail = images.to(DEVICE), mask_detail.to(DEVICE)
        z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
        
        with torch.amp.autocast("cuda"):
            # Forward (ÂèñÂæóÂÖ©ÁµÑ Logits)
            logits_rough, logits_detail = model(images, type_idx, z_pos, return_mode='both')
            
            # Loss 1: Detailed Head Ë¶ÅÈ†êÊ∏¨Ê≠£Á¢∫ (Cross Entropy)
            loss_ce = criterion_warmup(logits_detail, mask_detail)
            
            with torch.no_grad():
                # Loss 2: Consistency Constraint (Student -> Teacher)
                # ÈÄôË£°ÊàëÂÄë‰ø°‰ªª Rough Head ÁöÑÂà§Êñ∑ (Âõ†ÁÇ∫ÂÆÉÊòØ Phase 1 Ë®ìÁ∑¥Â•ΩÁöÑËÄÅÂ∏´)
                teacher_logits = logits_rough 
            loss_consist = hierarchical_consistency_loss(logits_detail, teacher_logits)
            
            # Total Loss
            loss = loss_ce + 0.5 * loss_consist # 0.5 ÊòØÊ¨äÈáçÔºåÂèØË™øÊï¥
            
        optimizer_warmup.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer_warmup)
        scaler.update()
        
        # Ë®òÈåÑÊï∏Êìö
        train_loss += loss.item()
        train_loss_ce += loss_ce.item()
        train_loss_consist += loss_consist.item()
        
        pbar.set_postfix({'L': loss.item(), 'CE': loss_ce.item(), 'C': loss_consist.item()})

    avg_train_loss = train_loss / len(train_loader_warmup)

    # --- [Validation] ---
    model.eval()
    val_loss = 0
    val_loss_ce = 0
    val_loss_consist = 0
    val_dice = 0
    
    with torch.no_grad():
        for images, masks, z_pos, type_idx in val_loader_warmup:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
            
            with torch.amp.autocast("cuda"):
                logits_rough, logits_detail = model(images, type_idx, z_pos, return_mode='both')
                
                l_ce = criterion_warmup(logits_detail, masks)
                l_consist = hierarchical_consistency_loss(logits_detail, logits_rough)
                l_total = l_ce + CONSISTENCY_WEIGHT * l_consist
            
            val_loss += l_total.item()
            val_loss_ce += l_ce.item()
            val_loss_consist += l_consist.item()
            preds = model(images, type_idx, z_pos, return_mode='detail')
            val_dice += dice_score(preds, masks, NUM_CLASSES_DETAILED)
            
    avg_val_loss = val_loss / len(val_loader_warmup)
    avg_val_ce = val_loss_ce / len(val_loader_warmup)
    avg_val_consist = val_loss_consist / len(val_loader_warmup)
    avg_val_dice = val_dice / len(val_loader_warmup)
    
    # Á¥ÄÈåÑÊ≠∑Âè≤
    history_warmup.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_ce': avg_val_ce,
        'val_consist': avg_val_consist,
        'val_dice': avg_val_dice
    })
    pd.DataFrame(history_warmup).to_csv(WARMUP_LOSS_LOG, index=False)
    plot_training_curves(WARMUP_LOSS_LOG, WARMUP_LOSS_IMG)
        
    print(f"Warmup {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} (CE: {avg_val_ce:.4f}, Consist: {avg_val_consist:.4f})")

# Â≠ò‰∏ã Warm-up ÂæåÁöÑÊ¨äÈáçÔºåÁµ¶ Step 2 Áî®
state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
torch.save(state_dict, WARMUP_MODEL_PATH)
print("Phase 1.5 Complete. Detail Head Initialized.")

### 7. Phase 2: Fine-tune

In [None]:
print("\n" + "="*40)
print("   üöÄ START PHASE 2: FINE-TUNING (DETAIL)")
print(f"   Target: {NUM_CLASSES_DETAILED} Classes (Detailed Labels)")
print("="*40)
print(f"   üìù log will be saved to: {DETAIL_LOSS_LOG}")

# ==========================================
# 0. Ë≥áÊñôÈÅéÊøæ (Crucial Step!)
# ==========================================
# ÊàëÂÄëÂè™Ë®ìÁ∑¥ÈÇ£‰∫õÁúüÊ≠£ÊìÅÊúâ Detailed Label ÁöÑÂàáÁâá
# ÈÄô‰∏ÄÊ≠•ÊúÉËÆÄÂèñÊâÄÊúâ .npy headerÔºåÂèØËÉΩÊúÉËä±‰∏ÄÈªûÊôÇÈñìÔºå‰ΩÜÈùûÂ∏∏ÊúâÂøÖË¶Å
print("üîç Filtering dataset for detailed labels... (This may take a moment)")

train_files_detail = [f for f in train_files if np.load(f, allow_pickle=True).item().get('has_detail', False)]
val_files_detail = [f for f in val_files if np.load(f, allow_pickle=True).item().get('has_detail', False)]
print(f"Detailed Train: {len(train_files_detail)} | Detailed Val: {len(val_files_detail)}")

# ==========================================
# 1. Setup Data (Detail Mode)
# ==========================================
train_ds = SliceMasterDataset(train_files_detail, mode='detail', transform=train_transform)
val_ds = SliceMasterDataset(val_files_detail, mode='detail', transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=DETAIL_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_ds, batch_size=DETAIL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# ==========================================
# 2. Setup Model (12 Classes)
# ==========================================
model = ConditionedResNetUNet(
    n_channels=IN_CHANNELS, 
    n_classes_rough=NUM_CLASSES_ROUGH, 
    n_classes_detail=NUM_CLASSES_DETAILED, 
    embed_dim=EMBEDDING_DIM, 
    num_mri_types=NUM_MRI_TYPES
).to(DEVICE)

# ==========================================
# 3. Load Pretrained Weights (Surgery)
# ==========================================

if os.path.exists(WARMUP_MODEL_PATH):
    print(f"üîÑ Loading pretrained weights from: {WARMUP_MODEL_PATH}")
    model.load_state_dict(torch.load(WARMUP_MODEL_PATH))
    print("‚úÖ Weights loaded from Phase 1.5")
else:
    raise FileNotFoundError(f"‚ùå Phase 1.5 model not found at {WARMUP_MODEL_PATH}.")

# [ÈóúÈçµ] ÂÖ®Èù¢Ëß£Âáç (Âõ†ÁÇ∫ Phase 1.5 ÁµêÊùüÊôÇÊòØÂáçÁµêÁöÑ)
model.unfreeze_all()

if torch.cuda.device_count() > 1:
    print(f"üî• Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)
    
# ==========================================
# 4. Optimizer & Weighted Loss
# ==========================================
optimizer = get_fine_tuning_optimizer(model, encoder_lr=DETAIL_LR_ENCODER, decoder_lr=DETAIL_LR_DECODER)
# [ÈóúÈçµÁ≠ñÁï•] Class Weighting
# Âõ†ÁÇ∫ AM (Class 6) ‰πãÂâçË¢´Áï∂‰ΩúËÉåÊôØÔºåÊàëÂÄëË¶ÅÁµ¶ÂÆÉÊõ¥È´òÁöÑÊ¨äÈáçÔºåÂº∑Ëø´Ê®°ÂûãÈóúÊ≥®ÂÆÉ
# 0:BG, 1:SA, ..., 6:AM, ...
class_weights = torch.ones(NUM_CLASSES_DETAILED).to(DEVICE)
class_weights[6] = 2.0  # Â∞ç AM Âä†Ê¨ä (ÂèØÊ†πÊìöÈ©óË≠âÁµêÊûúË™øÊï¥Ôºå‰æãÂ¶Ç 1.5 ~ 3.0)
class_weights[0] = 0.5  # Èôç‰ΩéËÉåÊôØÊ¨äÈáç (ÈÅ∏ÊìáÊÄß)

criterion = nn.CrossEntropyLoss(weight=class_weights)
scaler = torch.amp.GradScaler("cuda") # Ê∑∑ÂêàÁ≤æÂ∫¶Ë®ìÁ∑¥

# ==========================================
# 5. Training Loop
# ==========================================
best_dice = 0.0
history_detail = []
print(f"Start Fine-tuning for {DETAIL_EPOCHS} epochs...")

for epoch in range(DETAIL_EPOCHS):
    start_time = time.time()
    model.train()
    train_loss = 0
    
    # [Training]
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{DETAIL_EPOCHS}")
    
     # ÈÄôË£°Ëß£ÂåÖ 4 ÂÄãËÆäÊï∏: image, label, z_pos, type_idx
    for images, masks, z_pos, type_idx in pbar:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
        
        with torch.amp.autocast("cuda"):
            preds = model(images, type_idx, z_pos, return_mode='detail')
            loss = criterion(preds, masks)
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        
    # [Validation]
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for images, masks, z_pos, type_idx in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            z_pos, type_idx = z_pos.to(DEVICE), type_idx.to(DEVICE)
                        
            with torch.amp.autocast("cuda"):
                preds = model(images, type_idx, z_pos, return_mode='detail')
                loss = criterion(preds, masks)
            
            val_loss += loss.item()
            # ÈÄôË£°Ë®àÁÆó 12 È°ûÁöÑ Dice
            val_dice += dice_score(preds, masks, NUM_CLASSES_DETAILED)
            
    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)
    
    duration = time.time() - start_time
    # Á¥ÄÈåÑÊ≠∑Âè≤
    history_detail.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_dice': avg_val_dice,
        'time': duration
    })
    pd.DataFrame(history_detail).to_csv(DETAIL_LOSS_LOG, index=False)
    plot_training_curves(DETAIL_LOSS_LOG, DETAIL_LOSS_IMG)
    print(f"Epoch {epoch+1} | T-Loss: {avg_train_loss:.4f} | V-Loss: {avg_val_loss:.4f} | V-Dice: {avg_val_dice:.4f} | Time: {duration:.1f}s")
    
    # [Save Best]
    if avg_val_dice > best_dice:
        best_dice = avg_val_dice
        state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
        torch.save(state_dict, DETAIL_MODEL_PATH)
        print(f"  üèÜ Saved Best Detail Model! (Dice: {best_dice:.4f})")

print("Phase 2 Complete.")

### 8. Test

In [None]:
test_files = glob.glob(os.path.join(TEST_DATA_DIR, "*.npy"))
if len(test_files) == 0:
    raise FileNotFoundError(f"‚ùå No .npy files found in {TEST_DATA_DIR}. Please check the path!")

print(f"Found {len(test_files)} slices in Test Set.")
test_detail = []
test_without_detail = []
for i in test_files:
    data = np.load(i, allow_pickle=True).item()
    if data.get('has_detail'):
        test_detail.append(i)
    else:
        test_without_detail.append(i)
print(f"  - With Detailed Labels: {len(test_detail)} slices")
print(f"  - Without Detailed Labels: {len(test_without_detail)} slices")
        
# Âª∫Á´ã Dataset Ëàá Loader
# Ê≥®ÊÑè: mode='detail', transform=None (Ê∏¨Ë©¶ÈõÜ‰∏çÂÅöÂ¢ûÂº∑)
test_ds = SliceMasterDataset(test_detail, mode='detail', transform=val_transform)
test_loader = DataLoader(test_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ==========================================
# 2. ËºâÂÖ•Ê®°Âûã
# ==========================================
model = ConditionedResNetUNet(
    n_channels=IN_CHANNELS, 
    n_classes_rough=NUM_CLASSES_ROUGH, 
    n_classes_detail=NUM_CLASSES_DETAILED, 
    embed_dim=EMBEDDING_DIM,
    num_mri_types=NUM_MRI_TYPES
).to(DEVICE)

if os.path.exists(EVAL_MODEL_PATH):
    print(f"üîÑ Loading weights from: {EVAL_MODEL_PATH}")
    state_dict = torch.load(EVAL_MODEL_PATH, map_location=DEVICE, weights_only=False)
    model.load_state_dict(state_dict)
else:
    raise FileNotFoundError(f"‚ùå Model weight not found at {EVAL_MODEL_PATH}")
model.eval()

# ================= 3. Êé®Ë´ñËàáÊï∏ÊìöÊî∂ÈõÜ =================
# ÂÑ≤Â≠òÁµêÊßã: metrics[Sequence][Muscle_ID] = [dice1, dice2, ...]
metrics_data = defaultdict(lambda: defaultdict(list))
viz_results = []  # ÂÑ≤Â≠òË¶ÅÁï´ÂúñÁöÑË≥áÊñô
print("üöÄ Starting Inference on Test Set...")

with torch.no_grad():
    for idx, (images, labels, z_pos, type_idx) in enumerate(tqdm(test_loader)):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        z_pos = z_pos.to(DEVICE)
        type_idx = type_idx.to(DEVICE)
        
        # ÂèñÂæóÁï∂ÂâçÁöÑ MRI Type ID
        current_type_id = type_idx.item()
        
        # Êé®Ë´ñ (return_mode='detail')
        with torch.amp.autocast("cuda"):
            out_det = model(images, type_idx, z_pos, return_mode='detail')
        
        pred_det = torch.argmax(out_det, dim=1)
        
        # --- Ë®àÁÆóÊØèÂÄãËÇåËÇâÁöÑ Dice ‰∏¶Â≠òËµ∑‰æÜ ---
        slice_dices = []
        for c in range(1, NUM_CLASSES_DETAILED): # 1~11 (Skip BG)
            pred_mask = (pred_det == c)
            true_mask = (labels == c)
            
            inter = (pred_mask & true_mask).sum().item()
            union = (pred_mask.sum() + true_mask.sum()).item()
            
            # Âè™ÊúâÁï∂ GT ÊúâË©≤ËÇåËÇâÊôÇÊâçÁ¥çÂÖ•Áµ±Ë®à
            if true_mask.sum() > 0:
                dice_val = 2 * inter / (union + 1e-6)
                metrics_data[current_type_id][c].append(dice_val)
                slice_dices.append(dice_val)
            
        # --- Êî∂ÈõÜË¶ñË¶∫ÂåñË≥áÊñô (Èö®Ê©üÊäΩÊ®£) ---
        if idx % VIZ_INTERVAL == 0:
            avg_slice_dice = np.mean(slice_dices) if slice_dices else 0.0
            type_name = ID_TO_TYPE.get(current_type_id, str(current_type_id))
            
            viz_results.append({
                'type_name': type_name,
                'z': z_pos.item(),
                'img': images[0, 0].cpu().numpy(),
                'gt': labels[0].cpu().numpy(),
                'pred': pred_det[0].cpu().numpy(),
                'dice': avg_slice_dice
            })

# ================= 4. Áî¢ÁîüÂ†±Ë°® =================
# Rows: Muscle Names, Cols: Sequences

print("\n" + "="*40)
print("   Test Set Evaluation Report (Dice Score)")
print("="*40)

final_table = {}
all_types_in_data = sorted(metrics_data.keys())

for c in range(1, NUM_CLASSES_DETAILED):
    muscle_name = MUSCLE_NAMES.get(c, f"Muscle_{c}")
    row_data = {}
    for t_id in all_types_in_data:
        dices = metrics_data[t_id][c]
        mean_dice = np.mean(dices) if dices else 0.0
        
        col_name = ID_TO_TYPE.get(t_id, str(t_id))
        row_data[col_name] = mean_dice
    final_table[muscle_name] = row_data

df_metrics = pd.DataFrame(final_table).T 
df_metrics = df_metrics.sort_index()

# ÈáçÊñ∞ÊéíÂ∫èÊ¨Ñ‰Ωç
target_order = ['Fat', 'STIR', 'T1', 'T2', 'Water']
df_metrics = df_metrics.reindex(columns=target_order)  

# Âä†ÂÖ•Âπ≥Âùá
df_metrics.loc['AVERAGE'] = df_metrics.mean()

# È°ØÁ§∫ÊºÇ‰∫ÆÁöÑË°®Ê†º
pd.options.display.float_format = '{:.4f}'.format
print("Êï∏ÂÄºÁÇ∫ Mean Dice Score:")
display(df_metrics)

# Â≠òÊ™î
df_metrics.to_csv(EVAL_CSV_OUTPUT)
print(f"\nË©ï‰º∞Ë°®Â∑≤ÂÑ≤Â≠òËá≥: {EVAL_CSV_OUTPUT}")

# ================= 5. Ë¶ñË¶∫ÂåñÂ±ïÁ§∫ =================
if viz_results:
    print("\n" + "="*40)
    # print(f"   ÂÄãÊ°àÂàáÁâáË¶ñË¶∫Âåñ (First / Middle / Last)")
    print(f"Sampled every {VIZ_INTERVAL} slices from Test Set")
    print("="*40)

    for item in viz_results:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        
        title_fs = 11
        
        # Raw Image
        axs[0].imshow(item['img'], cmap='gray')
        axs[0].set_title(f"{item['type_name']} | Z={item['z']:.2f}", fontsize=title_fs)
        axs[0].axis('off')
        
        # Ground Truth
        axs[1].imshow(item['img'], cmap='gray')
        axs[1].imshow(item['gt'], cmap=VIZ_CMAP, norm=VIZ_NORM, alpha=0.6, interpolation='nearest')
        axs[1].set_title("Ground Truth", fontsize=title_fs)
        axs[1].axis('off')
        
        # Prediction
        axs[2].imshow(item['img'], cmap='gray')
        axs[2].imshow(item['pred'], cmap=VIZ_CMAP, norm=VIZ_NORM, alpha=0.6, interpolation='nearest')
        axs[2].set_title(f"Prediction (Dice: {item['dice']:.2f})", fontsize=title_fs)
        axs[2].axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print("No visualization samples generated. Check VIZ_INTERVAL or data size.")

In [None]:
import pandas as pd

df_metrics = pd.read_csv(EVAL_CSV_OUTPUT, index_col=0)

print(df_metrics.index)

avg_row = df_metrics.loc[['AVERAGE']]

other_rows = df_metrics.drop('AVERAGE')

df_metrics = pd.concat([avg_row, other_rows])

flat = df_metrics.stack()  # ËÆäÊàê MultiIndex Series (row, col)
flat.index = [f"{row}_{col}" for row, col in flat.index]  # Âêà‰ΩµÂêçÁ®±
flat_df = flat.to_frame().T  # ËΩâÊàê‰∏ÄÂàó DataFrame

flat_df.to_csv(FLATTENED_METRICS_OUTPUT, index=False)
print(f"Flattened metrics saved to '{FLATTENED_METRICS_OUTPUT}'")