<a href="https://colab.research.google.com/github/jongwoonalee/jongwoonalee.github.io/blob/main/Bladder_TRIAL1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
FlexAttention-based Multi-Instance Learning for Bladder Cancer Classification
FINAL CORRECT VERSION:
- 1024x1024 megapatch → 16개 256x256 patches
- 각 patch별로: LR(64x64) + HR(256x256) + Global(64x64)
- FlexAttention으로 중요한 HR patches만 선택
"""

import os
import re
import zipfile
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold
from skimage.filters import threshold_otsu
import time
import random
import math
import pickle
import hashlib
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

# RTX 6000 Ada x2 설정
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    num_gpus = torch.cuda.device_count()
    print(f"Found {num_gpus} GPUs")
    for i in range(num_gpus):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed()

# =============================================================================
# 1. 기존 데이터 로딩 함수들 (그대로 유지)
# =============================================================================

def extract_identifier(filename):
    """Extract patient ID from filename"""
    name, ext = os.path.splitext(filename)
    if '[' in name:
        name = name.split('[')[0].strip()

    m1 = re.match(r'^S(\d+)-(\d+)(?:_\d{4}-\d{2}-\d{2})?', name)
    if m1:
        slide = m1.group(1)
        patch = m1.group(2)
        if len(patch) == 3:
            patch_padded = "000" + patch
        elif len(patch) == 4:
            patch_padded = "00" + patch
        elif len(patch) == 5:
            patch_padded = "0" + patch
        else:
            patch_padded = patch
        return f"S{slide}{patch_padded}"

    m2 = re.match(r'^S(\d+)[,;]', name)
    if m2:
        slide_id = m2.group(1)
        return f"S{slide_id}", ext

    m3 = re.match(r'^S(\d{8}|\d{7}|\d{6})', name)
    if m3:
        slide_id = m3.group(1)
        return f"S{slide_id}", ext

    return None, ext

def convert_file_id_to_excel_format(file_id):
    """Convert file ID to Excel format"""
    if file_id is None:
        return None

    file_id = str(file_id).strip()
    if "-" in file_id:
        parts = file_id.split("-")
        if len(parts) == 2 and parts[1].isdigit():
            patch = parts[1]
            if len(patch) == 3:
                padded_number = "000" + patch
            elif len(patch) == 4:
                padded_number = "00" + patch
            elif len(patch) == 5:
                padded_number = "0" + patch
            else:
                padded_number = patch
            return f"{parts[0]}{padded_number}"
    elif len(file_id) > 3 and file_id.startswith("S"):
        return file_id

    return None

# =============================================================================
# 2. 핵심! 메가패치 처리 함수 (완전히 새로운 방식)
# =============================================================================

def split_megapatch_to_patches(megapatch_path, grid_size=4):
    """
    STEP 1: 1024x1024 메가패치를 4x4=16개의 256x256 패치로 분할

    Args:
        megapatch_path: 1024x1024 메가패치 경로
        grid_size: 그리드 크기 (4x4 = 16개 패치)

    Returns:
        list: 16개의 256x256 패치들
        list: 각 패치의 위치 정보 (i, j)
    """
    # 1024x1024 메가패치 읽기
    megapatch = cv2.imread(megapatch_path)
    if megapatch is None:
        raise ValueError(f"Cannot read megapatch: {megapatch_path}")

    megapatch = cv2.cvtColor(megapatch, cv2.COLOR_BGR2RGB)
    h, w = megapatch.shape[:2]

    # 각 패치 크기 계산: 1024/4 = 256
    patch_size = h // grid_size  # 256x256

    patches = []
    positions = []

    # 4x4 그리드로 분할
    for i in range(grid_size):
        for j in range(grid_size):
            y_start = i * patch_size
            x_start = j * patch_size
            y_end = y_start + patch_size
            x_end = x_start + patch_size

            # 256x256 패치 추출
            patch = megapatch[y_start:y_end, x_start:x_end]
            patches.append(patch)
            positions.append((i, j))

    return patches, positions

def create_three_streams_from_patch(patch_256, megapatch_1024):
    """
    STEP 2: 각 256x256 패치로부터 3-stream 생성

    Args:
        patch_256: 256x256 패치 (numpy array)
        megapatch_1024: 전체 1024x1024 메가패치 (Global용)

    Returns:
        dict: {
            'lr': 64x64 LR 패치,
            'hr': 256x256 HR 패치 (원본),
            'global': 64x64 Global 컨텍스트
        }
    """
    # 1. LR: 256x256 → 64x64 다운샘플링
    lr_patch = cv2.resize(patch_256, (64, 64), interpolation=cv2.INTER_AREA)

    # 2. HR: 256x256 원본 그대로
    hr_patch = patch_256.copy()

    # 3. Global: 전체 1024x1024 → 64x64 (매우 작은 overview)
    global_context = cv2.resize(megapatch_1024, (64, 64), interpolation=cv2.INTER_AREA)

    return {
        'lr': lr_patch,         # 64x64 LR
        'hr': hr_patch,         # 256x256 HR
        'global': global_context # 64x64 Global
    }

def process_megapatch_complete(megapatch_path):
    """
    STEP 3: 메가패치 전체 처리 - 1024x1024 → 16개 패치 → 각각 3-stream

    Args:
        megapatch_path: 1024x1024 메가패치 경로

    Returns:
        dict: {
            'lr_patches': 16개의 64x64 LR 패치들,
            'hr_patches': 16개의 256x256 HR 패치들,
            'global_tokens': 16개의 64x64 Global 토큰들 (모두 동일),
            'positions': 패치 위치 정보
        }
    """
    # 원본 메가패치 읽기
    megapatch = cv2.imread(megapatch_path)
    if megapatch is None:
        raise ValueError(f"Cannot read megapatch: {megapatch_path}")
    megapatch = cv2.cvtColor(megapatch, cv2.COLOR_BGR2RGB)

    # STEP 1: 1024x1024 → 16개 256x256 패치로 분할
    patches_256, positions = split_megapatch_to_patches(megapatch_path)

    # STEP 2: 각 패치별로 3-stream 생성
    lr_patches = []
    hr_patches = []
    global_tokens = []

    for patch_256 in patches_256:
        streams = create_three_streams_from_patch(patch_256, megapatch)

        lr_patches.append(streams['lr'])        # 64x64
        hr_patches.append(streams['hr'])        # 256x256
        global_tokens.append(streams['global']) # 64x64 (전부 동일한 global context)

    return {
        'lr_patches': lr_patches,     # 16개 × 64x64
        'hr_patches': hr_patches,     # 16개 × 256x256
        'global_tokens': global_tokens, # 16개 × 64x64 (모두 동일)
        'positions': positions        # 16개 위치 정보
    }

# =============================================================================
# 3. ResNet 기반 Feature Extractor (64x64와 256x256용)
# =============================================================================

class ResNetFeatureExtractor(nn.Module):
    """
    ResNet18 기반 feature extractor
    - 64x64 이미지용 (LR, Global)
    - 256x256 이미지용 (HR)
    """

    def __init__(self, feature_dim=384, pretrained=True):
        super(ResNetFeatureExtractor, self).__init__()

        # ResNet18 backbone
        resnet = models.resnet18(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # avgpool, fc 제거

        # Global Average Pooling
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        # Feature projection
        self.projection = nn.Sequential(
            nn.Linear(512, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        """
        Args:
            x: [batch_size, 3, H, W] - 64x64 또는 256x256
        Returns:
            [batch_size, feature_dim] - feature vectors
        """
        # Feature extraction
        features = self.backbone(x)      # [B, 512, H', W']
        pooled = self.avgpool(features)  # [B, 512, 1, 1]
        flattened = pooled.view(pooled.size(0), -1)  # [B, 512]
        projected = self.projection(flattened)       # [B, feature_dim]

        return projected

# =============================================================================
# 4. 핵심! High-Resolution Feature Selection Module
# =============================================================================

class HRFeatureSelector(nn.Module):
    """
    FlexAttention의 핵심: LR attention에 기반해서 중요한 HR features만 선택
    """

    def __init__(self, threshold=0.1, max_selection_ratio=0.5):
        super(HRFeatureSelector, self).__init__()
        self.threshold = threshold
        self.max_selection_ratio = max_selection_ratio

    def forward(self, lr_attention_scores, hr_features):
        """
        LR의 attention scores에 기반해서 중요한 HR features 선택

        Args:
            lr_attention_scores: [batch_size, 16] - LR 패치들의 attention scores
            hr_features: [batch_size, 16, feature_dim] - HR 패치들의 features

        Returns:
            selected_hr_features: [batch_size, num_selected, feature_dim] - 선택된 HR features
            selection_mask: [batch_size, 16] - 선택 마스크 (시각화용)
        """
        batch_size, num_patches, feature_dim = hr_features.shape
        max_selections = int(num_patches * self.max_selection_ratio)  # 최대 8개 선택

        selected_hr_features = []
        selection_masks = []

        for b in range(batch_size):
            # 이 샘플의 attention scores
            att_scores = lr_attention_scores[b]  # [16]

            # Dynamic threshold 계산 (Otsu 또는 percentile)
            try:
                threshold_val = threshold_otsu(att_scores.detach().cpu().numpy())
            except:
                threshold_val = torch.quantile(att_scores, 0.6)  # 상위 40%

            # Threshold 이상인 패치들 선택
            mask = att_scores > threshold_val
            selected_indices = torch.where(mask)[0]

            # 선택된 패치 수 제한
            if len(selected_indices) > max_selections:
                # Top-K만 선택
                _, top_indices = torch.topk(att_scores[selected_indices], max_selections)
                selected_indices = selected_indices[top_indices]
            elif len(selected_indices) < 2:  # 최소 2개는 선택
                _, top_indices = torch.topk(att_scores, 2)
                selected_indices = top_indices

            # HR features 선택
            selected_features = hr_features[b, selected_indices]  # [num_selected, feature_dim]

            # 고정 크기로 패딩 (max_selections 크기)
            if len(selected_indices) < max_selections:
                padding_size = max_selections - len(selected_indices)
                padding = torch.zeros(padding_size, feature_dim, device=hr_features.device)
                selected_features = torch.cat([selected_features, padding], dim=0)

            selected_hr_features.append(selected_features)

            # Selection mask 생성 (시각화용)
            binary_mask = torch.zeros_like(att_scores)
            binary_mask[selected_indices] = 1
            selection_masks.append(binary_mask)

        selected_hr_features = torch.stack(selected_hr_features)  # [batch_size, max_selections, feature_dim]
        selection_masks = torch.stack(selection_masks)           # [batch_size, 16]

        return selected_hr_features, selection_masks

# =============================================================================
# 5. 핵심! Hierarchical Self-Attention Module (FlexAttention 구현)
# =============================================================================

class HierarchicalSelfAttention(nn.Module):
    """
    FlexAttention 논문의 Hierarchical Self-Attention (Equations 3-7)
    """

    def __init__(self, feature_dim=384, num_heads=6, dropout=0.1):
        super(HierarchicalSelfAttention, self).__init__()

        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads

        assert feature_dim % num_heads == 0

        # Standard projections for hidden states (Eq. 3)
        self.q_proj = nn.Linear(feature_dim, feature_dim)
        self.k_proj = nn.Linear(feature_dim, feature_dim)
        self.v_proj = nn.Linear(feature_dim, feature_dim)

        # Separate projections for HR features (Eq. 4-5)
        self.k_proj_hr = nn.Linear(feature_dim, feature_dim)  # W'_K
        self.v_proj_hr = nn.Linear(feature_dim, feature_dim)  # W'_V

        self.out_proj = nn.Linear(feature_dim, feature_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)

    def forward(self, hidden_states, hr_features):
        """
        FlexAttention의 핵심 계산 (Equations 3-7)

        Args:
            hidden_states: [batch_size, seq_len, feature_dim] - LR + Global + CLS tokens
            hr_features: [batch_size, num_hr_selected, feature_dim] - 선택된 HR features

        Returns:
            output: [batch_size, seq_len, feature_dim] - 업데이트된 hidden states
            attention_map: [batch_size, seq_len] - 다음 layer용 attention map
        """
        batch_size, seq_len, _ = hidden_states.shape
        _, num_hr, _ = hr_features.shape

        # Equation 3: Q = H * W_Q
        Q = self.q_proj(hidden_states)
        K_h = self.k_proj(hidden_states)
        V_h = self.v_proj(hidden_states)

        # Equation 4-5: K_all = Concat(H*W_K, f_SHR*W'_K)
        K_hr = self.k_proj_hr(hr_features)  # W'_K (separate projection)
        V_hr = self.v_proj_hr(hr_features)  # W'_V (separate projection)

        # Concatenate keys and values
        K_all = torch.cat([K_h, K_hr], dim=1)  # [batch_size, seq_len + num_hr, feature_dim]
        V_all = torch.cat([V_h, V_hr], dim=1)  # [batch_size, seq_len + num_hr, feature_dim]

        # Multi-head attention으로 reshape
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K_all = K_all.view(batch_size, seq_len + num_hr, self.num_heads, self.head_dim).transpose(1, 2)
        V_all = V_all.view(batch_size, seq_len + num_hr, self.num_heads, self.head_dim).transpose(1, 2)

        # Equation 6: Hierarchical Self-attention 계산
        scores = torch.matmul(Q, K_all.transpose(-2, -1)) / self.scale
        attention_weights = F.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len + num_hr]
        attention_weights = self.dropout(attention_weights)

        # Attention 적용
        attended = torch.matmul(attention_weights, V_all)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)

        # Output projection
        output = self.out_proj(attended)

        # Equation 7: 다음 layer용 attention map 추출
        # CLS token의 attention을 LR tokens에 대해서만 추출
        cls_attention = attention_weights[:, :, -1, :seq_len-1]  # [batch_size, num_heads, seq_len-1]
        attention_map = cls_attention.mean(dim=1)  # [batch_size, seq_len-1] - head들 평균

        return output, attention_map

# =============================================================================
# 6. 메인 FlexAttention MIL 모델
# =============================================================================

class FlexAttentionPatientMIL(nn.Module):
    """
    Patient-Level FlexAttention MIL 모델

    구조:
    1. 환자별 20개 메가패치 → 각각 16개 패치 → 3-stream
    2. LR + Global tokens → Self-Attention layers
    3. LR attention → HR selection → FlexAttention layers
    4. CLS token → Patient-level classification
    """

    def __init__(self, feature_dim=384, num_classes=2, num_heads=6,
                 num_sa_layers=1, num_fa_layers=2, dropout=0.1):
        super(FlexAttentionPatientMIL, self).__init__()

        self.feature_dim = feature_dim
        self.num_sa_layers = num_sa_layers
        self.num_fa_layers = num_fa_layers

        # Feature extractors
        self.lr_extractor = ResNetFeatureExtractor(feature_dim=feature_dim)     # 64x64용
        self.global_extractor = ResNetFeatureExtractor(feature_dim=feature_dim) # 64x64용
        self.hr_extractor = ResNetFeatureExtractor(feature_dim=feature_dim)     # 256x256용

        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim))

        # Positional encoding (최대 320+20+1 = 341 tokens)
        max_tokens = 400  # 충분한 여유
        self.pos_encoding = nn.Parameter(torch.randn(1, max_tokens, feature_dim))

        # Standard Self-Attention layers (LR + Global + CLS)
        self.sa_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=feature_dim,
                nhead=num_heads,
                dim_feedforward=feature_dim * 4,
                dropout=dropout,
                batch_first=True
            ) for _ in range(num_sa_layers)
        ])

        # FlexAttention layers
        self.hr_selectors = nn.ModuleList([
            HRFeatureSelector() for _ in range(num_fa_layers)
        ])

        self.hierarchical_attentions = nn.ModuleList([
            HierarchicalSelfAttention(feature_dim, num_heads, dropout)
            for _ in range(num_fa_layers)
        ])

        self.fa_ffns = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim * 4),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(feature_dim * 4, feature_dim),
                nn.Dropout(dropout)
            ) for _ in range(num_fa_layers)
        ])

        self.fa_layer_norms = nn.ModuleList([
            nn.LayerNorm(feature_dim) for _ in range(num_fa_layers)
        ])

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(feature_dim // 2, num_classes)
        )

    def forward(self, lr_features, global_features, hr_features):
        """
        FlexAttention MIL Forward Pass

        Args:
            lr_features: [batch_size, total_lr_patches, feature_dim] - 모든 LR features
            global_features: [batch_size, num_megapatches, feature_dim] - Global features
            hr_features: [batch_size, total_hr_patches, feature_dim] - 모든 HR features

        Returns:
            logits: [batch_size, num_classes] - Patient-level predictions
            attention_maps: List of attention maps for visualization
        """
        batch_size = lr_features.shape[0]

        # Step 1: Token sequence 구성 (LR + Global + CLS)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)

        # LR과 Global features 결합
        # 메모리 효율성을 위해 일부만 사용 (큰 경우)
        max_lr_tokens = min(lr_features.shape[1], 256)  # 최대 256개 LR tokens
        max_global_tokens = min(global_features.shape[1], 20)  # 최대 20개 Global tokens

        lr_subset = lr_features[:, :max_lr_tokens]
        global_subset = global_features[:, :max_global_tokens]

        # Initial token sequence: [LR tokens + Global tokens + CLS]
        hidden_states = torch.cat([lr_subset, global_subset, cls_tokens], dim=1)

        # Positional encoding 추가
        seq_len = hidden_states.shape[1]
        if seq_len <= self.pos_encoding.shape[1]:
            hidden_states = hidden_states + self.pos_encoding[:, :seq_len, :]

        attention_maps = []

        # Step 2: Standard Self-Attention layers (Algorithm 1, lines 8-12)
        for i in range(self.num_sa_layers):
            hidden_states = self.sa_layers[i](hidden_states)

        # Step 3: FlexAttention layers (Algorithm 1, lines 14-19)
        for i in range(self.num_fa_layers):
            # Step 3a: LR attention으로 HR selection (Algorithm 1, line 15)
            if i == 0:
                # 첫 번째 layer: uniform attention
                num_lr_tokens = lr_subset.shape[1]
                lr_attention_map = torch.ones(batch_size, num_lr_tokens, device=lr_features.device)
                lr_attention_map = lr_attention_map / lr_attention_map.sum(dim=1, keepdim=True)
            else:
                # 이전 layer의 attention 사용
                lr_attention_map = attention_maps[-1][:, :lr_subset.shape[1]]  # LR 부분만

            # HR features를 LR과 동일한 크기로 맞춤 (패치 단위 대응)
            hr_subset_size = min(hr_features.shape[1], lr_subset.shape[1])
            hr_subset = hr_features[:, :hr_subset_size]
            lr_attention_subset = lr_attention_map[:, :hr_subset_size]

            # Step 3b: 중요한 HR features 선택
            selected_hr_features, selection_mask = self.hr_selectors[i](
                lr_attention_subset, hr_subset
            )

            # Step 3c: Hierarchical Self-Attention (Algorithm 1, line 16)
            attended_output, new_attention_map = self.hierarchical_attentions[i](
                hidden_states, selected_hr_features
            )

            # Step 3d: Skip connection (Algorithm 1, line 17)
            hidden_states = hidden_states + attended_output

            # Step 3e: Layer normalization
            hidden_states = self.fa_layer_norms[i](hidden_states)

            # Step 3f: FFN + skip connection (Algorithm 1, line 18)
            ffn_output = self.fa_ffns[i](hidden_states)
            hidden_states = hidden_states + ffn_output

            attention_maps.append(new_attention_map)

        # Step 4: Patient-level classification (Algorithm 1, line 20)
        cls_output = hidden_states[:, -1]  # CLS token
        logits = self.classifier(cls_output)

        return logits, attention_maps

# =============================================================================
# 7. Dataset (환자별 데이터 처리)
# =============================================================================

class FlexAttentionBladderDataset(Dataset):
    """
    FlexAttention용 환자별 Dataset
    각 환자의 메가패치들을 처리해서 3-stream features 생성
    """

    def __init__(self, patient_data, target_type='t_label',
                 max_megapatches_per_patient=20, cache_dir=None):
        self.patient_data = patient_data
        self.patient_ids = list(patient_data.keys())
        self.target_type = target_type
        self.max_megapatches_per_patient = max_megapatches_per_patient
        self.cache_dir = cache_dir

        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)

        # Transforms
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        patient_info = self.patient_data[patient_id]

        # Label 가져오기
        label = patient_info.get(self.target_type, 0)
        if label is None:
            label = 0

        # 이 환자의 모든 메가패치 처리
        all_lr_features = []
        all_global_features = []
        all_hr_features = []

        # 메가패치 수 제한 (메모리 효율성)
        megapatch_paths = patient_info['images'][:self.max_megapatches_per_patient]

        for megapatch_path in megapatch_paths:
            try:
                # 캐싱 확인
                cache_key = None
                if self.cache_dir:
                    cache_key = hashlib.md5(megapatch_path.encode()).hexdigest()
                    cache_path = os.path.join(self.cache_dir, f"{cache_key}.pkl")

                    if os.path.exists(cache_path):
                        with open(cache_path, 'rb') as f:
                            processed = pickle.load(f)
                    else:
                        # 메가패치 처리: 1024x1024 → 16개 패치 → 3-stream
                        processed = process_megapatch_complete(megapatch_path)

                        # 캐싱 저장
                        with open(cache_path, 'wb') as f:
                            pickle.dump(processed, f)
                else:
                    # 캐싱 없이 처리
                    processed = process_megapatch_complete(megapatch_path)

                # 각 stream별로 tensor 변환
                # LR patches: 16개 × 64x64
                lr_tensors = []
                for lr_patch in processed['lr_patches']:
                    lr_pil = Image.fromarray(lr_patch)
                    lr_tensor = self.transform(lr_pil)
                    lr_tensors.append(lr_tensor)

                # Global tokens: 16개 × 64x64 (모두 동일하므로 1개만 사용)
                global_pil = Image.fromarray(processed['global_tokens'][0])  # 첫 번째 (모두 동일)
                global_tensor = self.transform(global_pil)

                # HR patches: 16개 × 256x256
                hr_tensors = []
                for hr_patch in processed['hr_patches']:
                    hr_pil = Image.fromarray(hr_patch)
                    hr_tensor = self.transform(hr_pil)
                    hr_tensors.append(hr_tensor)

                # 리스트에 추가
                all_lr_features.extend(lr_tensors)    # 메가패치별 16개씩 누적
                all_global_features.append(global_tensor)  # 메가패치별 1개씩
                all_hr_features.extend(hr_tensors)    # 메가패치별 16개씩 누적

            except Exception as e:
                print(f"Error processing {megapatch_path}: {e}")
                continue

        # Dummy features if empty
        if not all_lr_features:
            dummy_lr = torch.zeros(3, 64, 64)
            dummy_global = torch.zeros(3, 64, 64)
            dummy_hr = torch.zeros(3, 256, 256)

            all_lr_features = [dummy_lr] * 16
            all_global_features = [dummy_global]
            all_hr_features = [dummy_hr] * 16

        # Tensor로 변환
        lr_tensor = torch.stack(all_lr_features)      # [total_lr_patches, 3, 64, 64]
        global_tensor = torch.stack(all_global_features)  # [num_megapatches, 3, 64, 64]
        hr_tensor = torch.stack(all_hr_features)      # [total_hr_patches, 3, 256, 256]

        return {
            'patient_id': patient_id,
            'lr_patches': lr_tensor,
            'global_patches': global_tensor,
            'hr_patches': hr_tensor,
            'label': torch.tensor(label, dtype=torch.long)
        }

# =============================================================================
# 8. Training Function
# =============================================================================

def train_flexattention_model(patient_data, target_type='t_label', num_folds=3, num_epochs=15,
                            batch_size=2, learning_rate=3e-4, device=device, save_dir='./checkpoints'):
    """FlexAttention MIL 모델 훈련"""

    os.makedirs(save_dir, exist_ok=True)
    cache_dir = os.path.join(save_dir, 'cache')

    # 데이터 준비
    patient_ids = list(patient_data.keys())
    patient_labels = [patient_data[pid].get(target_type, 0) for pid in patient_ids]
    patient_labels = [0 if label is None else label for label in patient_labels]

    # Stratified K-Fold
    kf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

    results = {
        'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'auc': []
    }

    for fold, (train_idx, test_idx) in enumerate(kf.split(patient_ids, patient_labels)):
        print(f"\n{'='*60}")
        print(f"Fold {fold+1}/{num_folds}")
        print(f"{'='*60}")

        # 데이터 분할
        train_patients = {patient_ids[i]: patient_data[patient_ids[i]] for i in train_idx}
        test_patients = {patient_ids[i]: patient_data[patient_ids[i]] for i in test_idx}

        # Dataset 생성
        train_dataset = FlexAttentionBladderDataset(
            train_patients, target_type=target_type, cache_dir=cache_dir
        )
        test_dataset = FlexAttentionBladderDataset(
            test_patients, target_type=target_type, cache_dir=cache_dir
        )

        # DataLoader 생성
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True,
            num_workers=2, pin_memory=True, persistent_workers=True
        )
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False,
            num_workers=2, pin_memory=True, persistent_workers=True
        )

        # 모델 초기화
        model = FlexAttentionPatientMIL(
            feature_dim=384, num_classes=2, num_heads=6,
            num_sa_layers=1, num_fa_layers=2, dropout=0.1
        )

        # DataParallel for RTX 6000 Ada x2
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
            model = nn.DataParallel(model)

        model = model.to(device)

        # Optimizer & Scheduler
        optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
        total_steps = len(train_loader) * num_epochs
        scheduler = OneCycleLR(optimizer, max_lr=learning_rate, total_steps=total_steps)

        # Loss & Scaler
        criterion = nn.CrossEntropyLoss()
        scaler = GradScaler()

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            num_batches = 0

            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                lr_patches = batch['lr_patches'].to(device)       # [batch_size, num_lr, 3, 64, 64]
                global_patches = batch['global_patches'].to(device)  # [batch_size, num_global, 3, 64, 64]
                hr_patches = batch['hr_patches'].to(device)       # [batch_size, num_hr, 3, 256, 256]
                labels = batch['label'].to(device)

                optimizer.zero_grad()

                with autocast():
                    # Feature extraction
                    batch_size, num_lr, C, H_lr, W_lr = lr_patches.shape
                    _, num_global, _, H_global, W_global = global_patches.shape
                    _, num_hr, _, H_hr, W_hr = hr_patches.shape

                    # Flatten for feature extraction
                    lr_flat = lr_patches.view(-1, C, H_lr, W_lr)
                    global_flat = global_patches.view(-1, C, H_global, W_global)
                    hr_flat = hr_patches.view(-1, C, H_hr, W_hr)

                    # Feature extractors
                    if hasattr(model, 'module'):
                        lr_extractor = model.module.lr_extractor
                        global_extractor = model.module.global_extractor
                        hr_extractor = model.module.hr_extractor
                    else:
                        lr_extractor = model.lr_extractor
                        global_extractor = model.global_extractor
                        hr_extractor = model.hr_extractor

                    # Extract features
                    lr_features = lr_extractor(lr_flat)      # [batch_size * num_lr, feature_dim]
                    global_features = global_extractor(global_flat)  # [batch_size * num_global, feature_dim]
                    hr_features = hr_extractor(hr_flat)      # [batch_size * num_hr, feature_dim]

                    # Reshape back
                    lr_features = lr_features.view(batch_size, num_lr, -1)
                    global_features = global_features.view(batch_size, num_global, -1)
                    hr_features = hr_features.view(batch_size, num_hr, -1)

                    # Forward through FlexAttention MIL
                    if hasattr(model, 'module'):
                        logits, attention_maps = model.module(lr_features, global_features, hr_features)
                    else:
                        logits, attention_maps = model(lr_features, global_features, hr_features)

                    loss = criterion(logits, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()

                total_loss += loss.item()
                num_batches += 1

            avg_loss = total_loss / num_batches
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

            # 체크포인트 저장
            if (epoch + 1) % 5 == 0:
                checkpoint_path = os.path.join(save_dir, f'fold_{fold+1}_epoch_{epoch+1}.pt')
                model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model_state,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                }, checkpoint_path)
                print(f"Checkpoint saved: {checkpoint_path}")
                torch.cuda.empty_cache()

        # Evaluation
        model.eval()
        all_preds = []
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                lr_patches = batch['lr_patches'].to(device)
                global_patches = batch['global_patches'].to(device)
                hr_patches = batch['hr_patches'].to(device)
                labels = batch['label'].to(device)

                # Feature extraction (동일한 과정)
                batch_size, num_lr, C, H_lr, W_lr = lr_patches.shape
                _, num_global, _, H_global, W_global = global_patches.shape
                _, num_hr, _, H_hr, W_hr = hr_patches.shape

                lr_flat = lr_patches.view(-1, C, H_lr, W_lr)
                global_flat = global_patches.view(-1, C, H_global, W_global)
                hr_flat = hr_patches.view(-1, C, H_hr, W_hr)

                if hasattr(model, 'module'):
                    lr_features = model.module.lr_extractor(lr_flat)
                    global_features = model.module.global_extractor(global_flat)
                    hr_features = model.module.hr_extractor(hr_flat)
                    logits, _ = model.module(
                        lr_features.view(batch_size, num_lr, -1),
                        global_features.view(batch_size, num_global, -1),
                        hr_features.view(batch_size, num_hr, -1)
                    )
                else:
                    lr_features = model.lr_extractor(lr_flat)
                    global_features = model.global_extractor(global_flat)
                    hr_features = model.hr_extractor(hr_flat)
                    logits, _ = model(
                        lr_features.view(batch_size, num_lr, -1),
                        global_features.view(batch_size, num_global, -1),
                        hr_features.view(batch_size, num_hr, -1)
                    )

                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)

                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
                all_probs.extend(probs[:, 1].cpu().tolist())

        # 메트릭 계산
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, zero_division=0)
        recall = recall_score(all_labels, all_preds, zero_division=0)
        f1 = f1_score(all_labels, all_preds, zero_division=0)

        try:
            auc = roc_auc_score(all_labels, all_probs)
        except:
            auc = 0.0

        print(f"\nFold {fold+1} Results:")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1: {f1:.4f}")
        print(f"AUC: {auc:.4f}")

        results['accuracy'].append(accuracy)
        results['precision'].append(precision)
        results['recall'].append(recall)
        results['f1'].append(f1)
        results['auc'].append(auc)

        # 최종 모델 저장
        final_model_path = os.path.join(save_dir, f'final_model_fold_{fold+1}.pt')
        model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save(model_state, final_model_path)

        # 메모리 정리
        del model
        torch.cuda.empty_cache()

    # 최종 결과
    print(f"\n{'='*60}")
    print("FINAL RESULTS")
    print(f"{'='*60}")
    print(f"Average Accuracy: {np.mean(results['accuracy']):.4f} ± {np.std(results['accuracy']):.4f}")
    print(f"Average Precision: {np.mean(results['precision']):.4f} ± {np.std(results['precision']):.4f}")
    print(f"Average Recall: {np.mean(results['recall']):.4f} ± {np.std(results['recall']):.4f}")
    print(f"Average F1: {np.mean(results['f1']):.4f} ± {np.std(results['f1']):.4f}")
    print(f"Average AUC: {np.mean(results['auc']):.4f} ± {np.std(results['auc']):.4f}")

    return results

# =============================================================================
# 9. Main Execution
# =============================================================================

if __name__ == "__main__":
    # 경로 설정
    zip_path = "/home/ubuntu/ExternalUSB_Bladder_240710.zip"
    base_dir = "/home/ubuntu/ExternalUSB_Bladder_240710"
    excel_path = "/home/ubuntu/MIL_TURB_240918_Modified_LambdaLabs.xlsx"

    print("Loading and matching data...")
    # [기존 데이터 로딩 코드 사용]

    # T-stage 분류 훈련
    print("\nTraining FlexAttention MIL for T-stage classification...")
    t_results = train_flexattention_model(
        patient_data=patient_data,  # 로딩된 데이터
        target_type='t_label',
        num_folds=3,
        num_epochs=15,
        batch_size=2,
        learning_rate=3e-4,
        device=device,
        save_dir='./checkpoints_t_stage'
    )

    # Recurrence 예측 훈련
    print("\nTraining FlexAttention MIL for recurrence prediction...")
    recur_results = train_flexattention_model(
        patient_data=patient_data,
        target_type='recur_label',
        num_folds=3,
        num_epochs=15,
        batch_size=2,
        learning_rate=3e-4,
        device=device,
        save_dir='./checkpoints_recurrence'
    )

    print("\nTraining completed!")