In [2]:
!pip install openvino --break-system-packages

Collecting openvino
  Downloading openvino-2025.3.0-19807-cp310-cp310-manylinux2014_x86_64.whl.metadata (12 kB)
Collecting openvino-telemetry>=2023.2.1 (from openvino)
  Downloading openvino_telemetry-2025.2.0-py3-none-any.whl.metadata (2.3 kB)
Downloading openvino-2025.3.0-19807-cp310-cp310-manylinux2014_x86_64.whl (49.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.0/49.0 MB[0m [31m74.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading openvino_telemetry-2025.2.0-py3-none-any.whl (25 kB)
Installing collected packages: openvino-telemetry, openvino
Successfully installed openvino-2025.3.0 openvino-telemetry-2025.2.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
!pip install nvidia-tensorrt --break-system-packages
!pip install torch-tensorrt --break-system-packages

Collecting nvidia-tensorrt
  Downloading nvidia_tensorrt-99.0.0-py3-none-manylinux_2_17_x86_64.whl.metadata (596 bytes)
Collecting tensorrt (from nvidia-tensorrt)
  Downloading tensorrt-10.13.3.9.tar.gz (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting tensorrt_cu13==10.13.3.9 (from tensorrt->nvidia-tensorrt)
  Downloading tensorrt_cu13-10.13.3.9.tar.gz (18 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting tensorrt_cu13_libs==10.13.3.9 (from tensorrt_cu13==10.13.3.9->tensorrt->nvidia-tensorrt)
  Downloading tensorrt_cu13_libs-10.13.3.9.tar.gz (706 bytes)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting tensorrt_cu13_bindings==10.13.3.9 (from tensorrt_cu13==10.13.3.9->tensorrt->nvidia-tensorrt

In [5]:
!pip install onnx==1.15.0 --break-system-packages

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [22]:
#!/usr/bin/env python3
"""
🏥 醫學影像分割系統 - 專業修正版
修正：spacing計算、標註處理、向量化縮放、解剖學標註
"""

import os
import json
import warnings
warnings.filterwarnings('ignore')

print("=" * 80)
print("🏥 醫學影像分割系統 - 專業版")
print("=" * 80)

import numpy as np
import nibabel as nib
from pathlib import Path
from datetime import datetime

try:
    from PIL import Image, ImageDraw, ImageFont
    USE_PIL = True
    print("✓ 使用 PIL 繪製文字")
except:
    USE_PIL = False
    print("✓ 不使用文字標註")

print("✓ 所有套件已就緒")

# ==================== 資料載入 ====================

class DataLoader:
    def __init__(self, data_root):
        self.data_root = Path(data_root)
        dataset_json = self.data_root / 'dataset.json'
        
        if not dataset_json.exists():
            raise FileNotFoundError(f"找不到 dataset.json: {dataset_json}")
        
        with open(dataset_json, 'r') as f:
            self.metadata = json.load(f)
        
        print(f"\n📊 資料集: {self.metadata['name']}")
        print(f"   訓練樣本: {self.metadata['numTraining']}")
    
    def load_nifti(self, filepath):
        """載入 NIfTI 檔案"""
        if not filepath.exists():
            raise FileNotFoundError(f"檔案不存在: {filepath}")
        
        img = nib.load(str(filepath))
        data = img.get_fdata()
        spacing = np.array(img.header.get_zooms()[:3])
        
        return data, {'spacing': spacing, 'shape': np.array(data.shape)}
    
    def get_cases(self, num=5):
        """安全地獲取案例"""
        cases = []
        training_data = self.metadata.get('training', [])
        
        for item in training_data[:num]:
            # 保守處理路徑
            img_rel = item['image'].lstrip('./')
            lbl_rel = item['label'].lstrip('./')
            
            img_path = self.data_root / img_rel
            lbl_path = self.data_root / lbl_rel
            
            # 檢查存在性
            if img_path.exists() and lbl_path.exists():
                case_id = img_path.stem.replace('hippocampus_', '').replace('.nii', '')
                cases.append({
                    'image': img_path,
                    'label': lbl_path,
                    'id': case_id
                })
            else:
                print(f"⚠️ 跳過缺失檔案: {img_path.name}")
        
        if not cases:
            raise ValueError("沒有找到有效的案例檔案")
        
        return cases

# ==================== 向量化縮放 ====================

def resize_nearest_vectorized(image, target_shape):
    """向量化的最近鄰縮放 - O(target_size)"""
    src_d, src_h, src_w = image.shape
    dst_d, dst_h, dst_w = target_shape
    
    # 計算源索引 (向量化)
    scale_d = src_d / dst_d
    scale_h = src_h / dst_h
    scale_w = src_w / dst_w
    
    idx_d = np.floor(np.arange(dst_d) * scale_d).astype(int)
    idx_h = np.floor(np.arange(dst_h) * scale_h).astype(int)
    idx_w = np.floor(np.arange(dst_w) * scale_w).astype(int)
    
    # 邊界檢查
    idx_d = np.clip(idx_d, 0, src_d - 1)
    idx_h = np.clip(idx_h, 0, src_h - 1)
    idx_w = np.clip(idx_w, 0, src_w - 1)
    
    # 一次性索引 (NumPy 高級索引)
    result = image[idx_d[:, None, None], idx_h[None, :, None], idx_w[None, None, :]]
    
    return result

# ==================== 預處理（分流版本）====================

def preprocess_image(image, target_size=64):
    """影像預處理 - z-score 標準化"""
    # 百分位裁切
    p1, p99 = np.percentile(image, [1, 99])
    image = np.clip(image, p1, p99)
    
    # Z-score 標準化
    mean = np.mean(image)
    std = np.std(image)
    if std > 1e-8:
        image = (image - mean) / std
    
    # 縮放
    target_shape = (target_size, target_size, target_size)
    result = resize_nearest_vectorized(image, target_shape)
    
    return result.astype(np.float32)

def preprocess_label(label, target_size=64):
    """標註預處理 - 僅重採樣（保持整數）"""
    # 直接最近鄰縮放（不做標準化）
    target_shape = (target_size, target_size, target_size)
    result = resize_nearest_vectorized(label, target_shape)
    
    # 四捨五入並轉為整數
    result = np.round(result).astype(np.uint8)
    
    return result

def calculate_new_spacing(orig_spacing, orig_shape, target_shape):
    """計算重採樣後的新 spacing"""
    orig_spacing = np.array(orig_spacing)
    orig_shape = np.array(orig_shape)
    target_shape = np.array(target_shape)
    
    new_spacing = (orig_spacing * orig_shape) / target_shape
    
    return new_spacing

# ==================== 分割 ====================

def segment(image):
    """簡單閾值分割"""
    threshold = np.percentile(image, 70)
    binary = (image > threshold).astype(np.uint8)
    
    mid = binary.shape[0] // 2
    center = binary.shape[2] // 2  # 寬度軸中點
    prediction = np.zeros_like(binary)
    
    # 注意：這裡用左右分區，非前後
    # 在中間層處理
    for i in range(max(0, mid-10), min(binary.shape[0], mid+10)):
        prediction[i, :, :center] = binary[i, :, :center] * 1  # 左側
        prediction[i, :, center:] = binary[i, :, center:] * 2  # 右側
    
    return prediction

# ==================== 特徵提取（修正 spacing）====================

def extract_features(pred, gt, new_spacing):
    """
    使用重採樣後的 spacing 計算體積
    new_spacing: 已根據縮放比例調整過的 spacing
    """
    voxel_volume = float(np.prod(new_spacing))
    
    left_vol = float(np.sum(pred == 1)) * voxel_volume
    right_vol = float(np.sum(pred == 2)) * voxel_volume
    total = left_vol + right_vol
    
    # Dice 係數
    dice_scores = []
    for i in [1, 2]:
        p = (pred == i).astype(np.float32)
        g = (gt == i).astype(np.float32)
        inter = float(np.sum(p * g))
        union = float(np.sum(p) + np.sum(g))
        dice = 2 * inter / (union + 1e-8) if union > 0 else 0.0
        dice_scores.append(dice)
    
    return {
        'total_volume': total,
        'left_vol': left_vol,
        'right_vol': right_vol,
        'dice': float(np.mean(dice_scores)),
        'dice_left': dice_scores[0],
        'dice_right': dice_scores[1]
    }

# ==================== 可視化 ====================

def add_text_pil(image_array, features):
    """使用PIL添加文字"""
    img = Image.fromarray(image_array)
    draw = ImageDraw.Draw(img)
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
    except:
        font = ImageFont.load_default()
    
    texts = [
        "HIPPOCAMPUS SEGMENTATION",
        "(Left-Right Split Demo)",
        "",
        f"Total: {features['total_volume']:.1f} mm3",
        f"Left:  {features['left_vol']:.1f} mm3",
        f"Right: {features['right_vol']:.1f} mm3",
        "",
        f"Dice: {features['dice']:.3f}",
        f"  Left:  {features['dice_left']:.3f}",
        f"  Right: {features['dice_right']:.3f}",
    ]
    
    y = 10
    for text in texts:
        draw.text((10, y), text, fill=(0, 0, 0), font=font)
        y += 18
    
    return np.array(img)

def visualize(image, pred, gt, features):
    """創建可視化"""
    mid = pred.shape[0] // 2
    views = []
    
    for offset in [-8, 0, 8]:
        idx = mid + offset
        if 0 <= idx < pred.shape[0]:
            # 標準化到0-255
            img_slice = image[idx]
            img_min, img_max = float(img_slice.min()), float(img_slice.max())
            if img_max > img_min:
                img_norm = ((img_slice - img_min) / (img_max - img_min) * 255).astype(np.uint8)
            else:
                img_norm = np.zeros_like(img_slice, dtype=np.uint8)
            
            # 轉為RGB
            img_rgb = np.stack([img_norm, img_norm, img_norm], axis=-1)
            
            # 預測疊加（左=紅，右=藍）
            pred_overlay = img_rgb.copy()
            pred_overlay[pred[idx] == 1] = [255, 0, 0]  # 左側-紅色
            pred_overlay[pred[idx] == 2] = [0, 0, 255]  # 右側-藍色
            pred_result = (img_rgb * 0.7 + pred_overlay * 0.3).astype(np.uint8)
            
            # Ground Truth疊加（左=綠，右=黃）
            gt_overlay = img_rgb.copy()
            gt_overlay[gt[idx] == 1] = [0, 255, 0]      # 左側-綠色
            gt_overlay[gt[idx] == 2] = [255, 255, 0]    # 右側-黃色
            gt_result = (img_rgb * 0.7 + gt_overlay * 0.3).astype(np.uint8)
            
            combined = np.hstack([pred_result, gt_result])
            views.append(combined)
    
    result = np.vstack(views)
    
    # 添加文字
    if USE_PIL:
        text_h = result.shape[0]
        text_w = 300
        text_area = np.ones((text_h, text_w, 3), dtype=np.uint8) * 255
        text_area = add_text_pil(text_area, features)
        result = np.hstack([result, text_area])
    
    return result

def save_image(image_array, filepath):
    """使用PIL保存"""
    img = Image.fromarray(image_array.astype(np.uint8))
    img.save(filepath)

# ==================== 主程序 ====================

def main():
    DATA_ROOT = Path('/workspace/Task04_Hippocampus')
    OUTPUT_DIR = Path('/workspace/outputs')
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    print("\n" + "=" * 80)
    print("開始處理")
    print("=" * 80)
    
    # 載入資料
    print("\n📊 步驟 1: 載入資料")
    loader = DataLoader(DATA_ROOT)
    cases = loader.get_cases(num=3)
    print(f"✓ 載入 {len(cases)} 個案例")
    
    # 處理第一個案例
    case = cases[0]
    print(f"\n📊 步驟 2: 處理案例 {case['id']}")
    
    image, img_meta = loader.load_nifti(case['image'])
    label, lbl_meta = loader.load_nifti(case['label'])
    print(f"   原始尺寸: {image.shape}")
    print(f"   原始 spacing: {img_meta['spacing']}")
    
    # 預處理（分流）
    print("\n📊 步驟 3: 預處理")
    target_size = 64
    image_prep = preprocess_image(image, target_size)
    label_prep = preprocess_label(label, target_size)
    
    # 計算新 spacing
    new_spacing = calculate_new_spacing(
        img_meta['spacing'], 
        img_meta['shape'], 
        (target_size, target_size, target_size)
    )
    
    print(f"   處理後尺寸: {image_prep.shape}")
    print(f"   新 spacing: {new_spacing}")
    print(f"   體素體積: {np.prod(new_spacing):.2f} mm³/voxel")
    
    # 分割
    print("\n📊 步驟 4: 執行分割")
    prediction = segment(image_prep)
    print(f"   左側: {np.sum(prediction == 1)} 體素")
    print(f"   右側: {np.sum(prediction == 2)} 體素")
    
    # 特徵提取（使用新 spacing）
    print("\n📊 步驟 5: 特徵提取")
    features = extract_features(prediction, label_prep, new_spacing)
    print(f"   總體積: {features['total_volume']:.1f} mm³")
    print(f"   左側體積: {features['left_vol']:.1f} mm³")
    print(f"   右側體積: {features['right_vol']:.1f} mm³")
    print(f"   Dice係數: {features['dice']:.3f}")
    
    # 可視化
    print("\n📊 步驟 6: 創建可視化")
    viz = visualize(image_prep, prediction, label_prep, features)
    
    output_path = OUTPUT_DIR / f"result_{case['id']}.png"
    save_image(viz, output_path)
    print(f"   ✓ 已保存: {output_path}")
    
    # JSON報告
    report = {
        'case_id': case['id'],
        'original_shape': img_meta['shape'].tolist(),
        'original_spacing': img_meta['spacing'].tolist(),
        'resampled_shape': [target_size] * 3,
        'resampled_spacing': new_spacing.tolist(),
        'features': features,
        'timestamp': datetime.now().isoformat()
    }
    
    json_path = OUTPUT_DIR / f"report_{case['id']}.json"
    with open(json_path, 'w') as f:
        json.dump(report, f, indent=2)
    print(f"   ✓ 已保存: {json_path}")
    
    print("\n" + "=" * 80)
    print("✓ 處理完成!")
    print("=" * 80)
    print(f"\n📊 最終結果:")
    print(f"  總體積: {features['total_volume']:.1f} mm³")
    print(f"  Dice分數: {features['dice']:.3f}")
    print(f"\n輸出位置: {OUTPUT_DIR}/")
    print("\n✅ 系統運行成功!")
    
    return viz, features

if __name__ == "__main__":
    try:
        result_viz, result_features = main()
    except Exception as e:
        print(f"\n❌ 錯誤: {e}")
        import traceback
        traceback.print_exc()


🏥 醫學影像分割系統 - 專業版
✓ 使用 PIL 繪製文字
✓ 所有套件已就緒

開始處理

📊 步驟 1: 載入資料

📊 資料集: Hippocampus
   訓練樣本: 260
✓ 載入 3 個案例

📊 步驟 2: 處理案例 367
   原始尺寸: (36, 57, 37)
   原始 spacing: [1. 1. 1.]

📊 步驟 3: 預處理
   處理後尺寸: (64, 64, 64)
   新 spacing: [0.5625   0.890625 0.578125]
   體素體積: 0.29 mm³/voxel

📊 步驟 4: 執行分割
   左側: 4639 體素
   右側: 13376 體素

📊 步驟 5: 特徵提取
   總體積: 5217.6 mm³
   左側體積: 1343.6 mm³
   右側體積: 3874.1 mm³
   Dice係數: 0.014

📊 步驟 6: 創建可視化
   ✓ 已保存: /workspace/outputs/result_367.png
   ✓ 已保存: /workspace/outputs/report_367.json

✓ 處理完成!

📊 最終結果:
  總體積: 5217.6 mm³
  Dice分數: 0.014

輸出位置: /workspace/outputs/

✅ 系統運行成功!


In [30]:
#!/usr/bin/env python3
"""
🧠 醫學影像分割訓練系統 - 專業完整版
整合：隨機切分、類別平衡、AMP、梯度裁剪、優化的評估指標
"""

import os
import json
import numpy as np
import nibabel as nib
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

print("=" * 80)
print("🧠 醫學影像分割訓練系統 - 專業版")
print("=" * 80)

os.environ['TORCH_DISABLE_ONNX'] = '1'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

print(f"✓ PyTorch {torch.__version__}")
print(f"✓ CUDA: {torch.cuda.is_available()}")

# cuDNN 優化
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print("✓ cuDNN benchmark enabled")

# ==================== 資料集 ====================

class HippocampusDataset(Dataset):
    def __init__(self, data_root, indices, target_size=64):
        self.data_root = Path(data_root)
        self.target_size = target_size
        
        with open(self.data_root / 'dataset.json', 'r') as f:
            metadata = json.load(f)
        
        training_data = metadata['training']
        self.cases = [training_data[i] for i in indices]
    
    def __len__(self):
        return len(self.cases)
    
    def resize_3d(self, image, target_shape):
        """向量化最近鄰縮放"""
        src_d, src_h, src_w = image.shape
        dst_d, dst_h, dst_w = target_shape
        
        idx_d = np.clip(np.floor(np.arange(dst_d) * src_d / dst_d).astype(int), 0, src_d - 1)
        idx_h = np.clip(np.floor(np.arange(dst_h) * src_h / dst_h).astype(int), 0, src_h - 1)
        idx_w = np.clip(np.floor(np.arange(dst_w) * src_w / dst_w).astype(int), 0, src_w - 1)
        
        return image[idx_d[:, None, None], idx_h[None, :, None], idx_w[None, None, :]]
    
    def __getitem__(self, idx):
        case = self.cases[idx]
        img_path = self.data_root / case['image'].lstrip('./')
        lbl_path = self.data_root / case['label'].lstrip('./')
        
        # 載入影像 - 直接轉 float32 節省記憶體
        image = nib.load(str(img_path)).get_fdata(dtype=np.float32)
        p1, p99 = np.percentile(image, [1, 99])
        image = np.clip(image, p1, p99)
        mean, std = image.mean(), image.std()
        if std > 1e-8:
            image = (image - mean) / std
        
        # 載入標註 - float32
        label = nib.load(str(lbl_path)).get_fdata(dtype=np.float32)
        
        # 縮放
        image = self.resize_3d(image, (self.target_size,) * 3)
        label = self.resize_3d(label, (self.target_size,) * 3)
        label = np.round(label).astype(np.int64)
        
        return torch.FloatTensor(image).unsqueeze(0), torch.LongTensor(label)

# ==================== 資料切分（隨機） ====================

def get_train_val_split(num_samples, train_ratio=0.8, seed=42):
    """固定隨機種子的資料切分"""
    np.random.seed(seed)
    indices = np.random.permutation(num_samples)
    split_idx = int(num_samples * train_ratio)
    return indices[:split_idx], indices[split_idx:]

# ==================== 模型 ====================

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_ch)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))))

class UNet3D(nn.Module):
    def __init__(self, in_ch=1, num_classes=3, base=16):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base * 2)
        self.enc3 = ConvBlock(base * 2, base * 4)
        self.bottleneck = ConvBlock(base * 4, base * 8)
        
        self.up3 = nn.ConvTranspose3d(base * 8, base * 4, 2, stride=2)
        self.dec3 = ConvBlock(base * 8, base * 4)
        self.up2 = nn.ConvTranspose3d(base * 4, base * 2, 2, stride=2)
        self.dec2 = ConvBlock(base * 4, base * 2)
        self.up1 = nn.ConvTranspose3d(base * 2, base, 2, stride=2)
        self.dec1 = ConvBlock(base * 2, base)
        
        self.out = nn.Conv3d(base, num_classes, 1)
        self.pool = nn.MaxPool3d(2, 2)
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        
        return self.out(d1)

# ==================== 損失函數（類別平衡）====================

class FocalLoss(nn.Module):
    """Focal Loss 處理類別不平衡"""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        if alpha is not None:
            self.register_buffer('alpha', alpha)
        else:
            self.alpha = None
        self.gamma = gamma
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[target]
            focal_loss = alpha_t * focal_loss
        
        return focal_loss.mean()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 4, 1, 2, 3).float()
        
        inter = (pred * target_one_hot).sum(dim=(2, 3, 4))
        union = pred.sum(dim=(2, 3, 4)) + target_one_hot.sum(dim=(2, 3, 4))
        dice = (2. * inter + self.smooth) / (union + self.smooth)
        
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self, use_focal=True):
        super().__init__()
        if use_focal:
            # 背景權重較低
            alpha = torch.tensor([0.2, 1.0, 1.0])
            self.ce = FocalLoss(alpha=alpha, gamma=2.0)
        else:
            # 傳統加權 CE
            weight = torch.tensor([0.2, 1.0, 1.0])
            self.ce = nn.CrossEntropyLoss(weight=weight)
        
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        return 0.3 * self.ce(pred, target) + 0.7 * self.dice(pred, target)

# ==================== Adam with Weight Decay ====================

class AdamW:
    """手動實現 AdamW，支援 weight decay"""
    def __init__(self, params, lr=1e-3, weight_decay=1e-4):
        self.params = list(params)
        self.lr = lr
        self.weight_decay = weight_decay
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.eps = 1e-8
        self.t = 0
        
        # 為了相容 GradScaler，添加 param_groups
        self.param_groups = [{'params': self.params, 'lr': lr}]
        
        self.m = [torch.zeros_like(p.data) for p in self.params]
        self.v = [torch.zeros_like(p.data) for p in self.params]
    
    def zero_grad(self):
        for p in self.params:
            if p.grad is not None:
                p.grad.zero_()
    
    def step(self, closure=None):
        """添加 closure 參數以相容 optimizer 介面"""
        self.t += 1
        for i, p in enumerate(self.params):
            if p.grad is None:
                continue
            
            grad = p.grad.data
            
            # Weight decay
            if self.weight_decay > 0:
                p.data.mul_(1 - self.lr * self.weight_decay)
            
            # Adam 更新
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad ** 2
            
            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)
            
            p.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
    
    def set_lr(self, lr):
        """動態調整學習率"""
        self.lr = lr
        self.param_groups[0]['lr'] = lr

# ==================== 訓練器（優化版）====================

class Trainer:
    def __init__(self, model, train_loader, val_loader, device, output_dir, use_amp=True):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
        self.criterion = CombinedLoss(use_focal=True).to(device)
        
        # AMP
        self.use_amp = use_amp and torch.cuda.is_available()
        self.scaler = GradScaler() if self.use_amp else None
        
        self.history = {'train_loss': [], 'val_loss': [], 'val_dice': []}
        self.best_dice = 0.0
        self.lr = 1e-3
        
        print(f"  AMP enabled: {self.use_amp}")
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            
            # AMP forward
            if self.use_amp:
                with autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                
                self.scaler.scale(loss).backward()
                
                # 梯度裁剪
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"    Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
        
        return total_loss / len(self.train_loader)
    
    def validate(self):
        """累積式 Dice 計算"""
        self.model.eval()
        total_loss = 0
        
        # 累積 intersection 和 union
        total_inter = {1: 0.0, 2: 0.0}
        total_union = {1: 0.0, 2: 0.0}
        
        with torch.no_grad():
            for images, labels in self.val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                if self.use_amp:
                    with autocast():
                        outputs = self.model(images)
                        loss = self.criterion(outputs, labels)
                else:
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                
                preds = torch.argmax(outputs, dim=1)
                
                # 累積每個類別的 intersection 和 union
                for cls in [1, 2]:
                    pred_mask = (preds == cls).float()
                    target_mask = (labels == cls).float()
                    
                    total_inter[cls] += (pred_mask * target_mask).sum().item()
                    total_union[cls] += (pred_mask.sum() + target_mask.sum()).item()
                
                total_loss += loss.item()
        
        # 計算整體 Dice
        dice_scores = []
        for cls in [1, 2]:
            dice = (2.0 * total_inter[cls]) / (total_union[cls] + 1e-8)
            dice_scores.append(dice)
        
        avg_dice = np.mean(dice_scores)
        
        return total_loss / len(self.val_loader), avg_dice
    
    def train(self, num_epochs):
        print(f"\n開始訓練 {num_epochs} epochs")
        print("=" * 80)
        
        for epoch in range(1, num_epochs + 1):
            print(f"\nEpoch {epoch}/{num_epochs}")
            print("-" * 40)
            
            train_loss = self.train_epoch()
            val_loss, val_dice = self.validate()
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_dice'].append(val_dice)
            
            # 學習率衰減
            if epoch % 10 == 0:
                self.lr *= 0.5
                self.optimizer.set_lr(self.lr)
            
            print(f"\n  訓練損失: {train_loss:.4f}")
            print(f"  驗證損失: {val_loss:.4f}")
            print(f"  驗證 Dice: {val_dice:.4f}")
            print(f"  學習率: {self.lr:.6f}")
            
            # 保存最佳
            if val_dice > self.best_dice:
                self.best_dice = val_dice
                torch.save({
                    'epoch': epoch,
                    'model': self.model.state_dict(),
                    'dice': val_dice,
                    'history': self.history
                }, self.output_dir / 'best_model.pth')
                print(f"  ✓ 保存最佳模型 (Dice: {val_dice:.4f})")
        
        # 保存歷史
        with open(self.output_dir / 'history.json', 'w') as f:
            json.dump(self.history, f, indent=2)
        
        print("\n" + "=" * 80)
        print(f"✓ 訓練完成! 最佳 Dice: {self.best_dice:.4f}")

# ==================== 主程序 ====================

def main():
    DATA_ROOT = '/workspace/Task04_Hippocampus'
    OUTPUT_DIR = '/workspace/outputs/training_pro'
    BATCH_SIZE = 2
    NUM_EPOCHS = 10
    TARGET_SIZE = 64
    USE_AMP = True
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n使用設備: {device}")
    
    # 隨機切分資料
    print("\n準備資料...")
    with open(Path(DATA_ROOT) / 'dataset.json') as f:
        num_samples = len(json.load(f)['training'])
    
    train_indices, val_indices = get_train_val_split(num_samples, train_ratio=0.8, seed=42)
    print(f"  訓練集: {len(train_indices)} 個案例")
    print(f"  驗證集: {len(val_indices)} 個案例")
    
    train_dataset = HippocampusDataset(DATA_ROOT, train_indices, TARGET_SIZE)
    val_dataset = HippocampusDataset(DATA_ROOT, val_indices, TARGET_SIZE)
    
    # 優化的 DataLoader 參數
    num_workers = min(4, os.cpu_count() or 1) if torch.cuda.is_available() else 0
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=num_workers > 0,
        prefetch_factor=2 if num_workers > 0 else None
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available()
    )
    
    print("\n建立模型...")
    model = UNet3D(1, 3, 16)
    print(f"  參數: {sum(p.numel() for p in model.parameters()):,}")
    
    trainer = Trainer(model, train_loader, val_loader, device, OUTPUT_DIR, use_amp=USE_AMP)
    trainer.train(NUM_EPOCHS)

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n訓練中斷")
    except Exception as e:
        print(f"\n錯誤: {e}")
        import traceback
        traceback.print_exc()


🧠 醫學影像分割訓練系統 - 專業版
✓ PyTorch 2.1.0+cu118
✓ CUDA: True
✓ cuDNN benchmark enabled

使用設備: cuda

準備資料...
  訓練集: 208 個案例
  驗證集: 52 個案例

建立模型...
  參數: 1,402,003
  AMP enabled: True

開始訓練 10 epochs

Epoch 1/10
----------------------------------------
    Batch 0/104, Loss: 0.5706
    Batch 10/104, Loss: 0.4696
    Batch 20/104, Loss: 0.4542
    Batch 30/104, Loss: 0.4323
    Batch 40/104, Loss: 0.4178
    Batch 50/104, Loss: 0.4056
    Batch 60/104, Loss: 0.4053
    Batch 70/104, Loss: 0.3969
    Batch 80/104, Loss: 0.3498
    Batch 90/104, Loss: 0.3299
    Batch 100/104, Loss: 0.3091

  訓練損失: 0.4097
  驗證損失: 0.3119
  驗證 Dice: 0.7447
  學習率: 0.001000
  ✓ 保存最佳模型 (Dice: 0.7447)

Epoch 2/10
----------------------------------------
    Batch 0/104, Loss: 0.3355
    Batch 10/104, Loss: 0.2731
    Batch 20/104, Loss: 0.2825
    Batch 30/104, Loss: 0.2375
    Batch 40/104, Loss: 0.2094
    Batch 50/104, Loss: 0.2021
    Batch 60/104, Loss: 0.1952
    Batch 70/104, Loss: 0.1852
    Batch 80/104, Loss: 

In [33]:
#!/usr/bin/env python3
"""
🧠 醫學影像分割訓練系統 - 生產級版本
精準短路 torch._compile/onnx、trilinear 插值、完全可重現、Cosine LR
"""

import os
import sys
import warnings
warnings.filterwarnings('ignore')

print("=" * 80)
print("🧠 醫學影像分割訓練系統 - 生產級")
print("=" * 80)

# ==================== 精準短路：避免載入 transformers ====================
print("\n🔧 設置環境...")

# 創建 stub 模組來短路導入鏈
class _StubModule:
    """Stub 模組，阻止實際載入但不破壞導入鏈"""
    def __getattr__(self, name):
        return _StubModule()
    def __call__(self, *args, **kwargs):
        return _StubModule()

# 短路 torch._compile 和 torch.onnx，避免觸發 transformers
sys.modules['torch._compile'] = _StubModule()
sys.modules['torch.onnx'] = _StubModule()
print("  ✓ 已短路 torch._compile 和 torch.onnx")

# ==================== 導入套件 ====================

import json
import math
import numpy as np
import nibabel as nib
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print(f"\n✓ PyTorch {torch.__version__}")
print(f"✓ CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print("✓ cuDNN benchmark enabled")

# ==================== 完全可重現設置 ====================

def set_seed(seed=42):
    """設置所有隨機種子以確保可重現"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # 為了完全可重現（會稍微降低性能）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ==================== 資料集（優化插值）====================

class HippocampusDataset(Dataset):
    def __init__(self, data_root, indices, target_size=64):
        self.data_root = Path(data_root)
        self.target_size = target_size
        
        with open(self.data_root / 'dataset.json', 'r') as f:
            metadata = json.load(f)
        
        training_data = metadata['training']
        self.cases = [training_data[i] for i in indices]
    
    def __len__(self):
        return len(self.cases)
    
    def __getitem__(self, idx):
        case = self.cases[idx]
        img_path = self.data_root / case['image'].lstrip('./')
        lbl_path = self.data_root / case['label'].lstrip('./')
        
        # 載入影像 - float32
        image = nib.load(str(img_path)).get_fdata(dtype=np.float32)
        
        # Z-score 正規化
        p1, p99 = np.percentile(image, [1, 99])
        image = np.clip(image, p1, p99)
        mean, std = image.mean(), image.std()
        if std > 1e-8:
            image = (image - mean) / std
        
        # 載入標註 - float32
        label = nib.load(str(lbl_path)).get_fdata(dtype=np.float32)
        
        # 轉為 torch tensor（添加 batch 和 channel 維度）
        image_t = torch.from_numpy(image).unsqueeze(0).unsqueeze(0)  # [1, 1, D, H, W]
        label_t = torch.from_numpy(label).unsqueeze(0).unsqueeze(0)  # [1, 1, D, H, W]
        
        # 使用 PyTorch 的 F.interpolate
        # 影像：trilinear（更平滑）
        image_resized = F.interpolate(
            image_t,
            size=(self.target_size, self.target_size, self.target_size),
            mode='trilinear',
            align_corners=False
        ).squeeze(0)  # [1, D, H, W]
        
        # 標註：nearest（保持整數值）
        label_resized = F.interpolate(
            label_t,
            size=(self.target_size, self.target_size, self.target_size),
            mode='nearest'
        ).squeeze(0).squeeze(0)  # [D, H, W]
        
        # 清洗標註
        label_resized = torch.clamp(label_resized.long(), 0, 2)
        
        return image_resized, label_resized

def get_train_val_split(num_samples, train_ratio=0.8, seed=42):
    """固定隨機種子的資料切分"""
    np.random.seed(seed)
    indices = np.random.permutation(num_samples)
    split_idx = int(num_samples * train_ratio)
    return indices[:split_idx], indices[split_idx:]

# ==================== 模型 ====================

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=8):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.gn1 = nn.GroupNorm(min(num_groups, out_ch), out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1)
        self.gn2 = nn.GroupNorm(min(num_groups, out_ch), out_ch)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.relu(self.gn1(self.conv1(x)))
        x = self.relu(self.gn2(self.conv2(x)))
        return x

class UNet3D(nn.Module):
    def __init__(self, in_ch=1, num_classes=3, base=16):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base * 2)
        self.enc3 = ConvBlock(base * 2, base * 4)
        self.bottleneck = ConvBlock(base * 4, base * 8)
        
        self.up3 = nn.ConvTranspose3d(base * 8, base * 4, 2, stride=2)
        self.dec3 = ConvBlock(base * 8, base * 4)
        self.up2 = nn.ConvTranspose3d(base * 4, base * 2, 2, stride=2)
        self.dec2 = ConvBlock(base * 4, base * 2)
        self.up1 = nn.ConvTranspose3d(base * 2, base, 2, stride=2)
        self.dec1 = ConvBlock(base * 2, base)
        
        self.out = nn.Conv3d(base, num_classes, 1)
        self.pool = nn.MaxPool3d(2, 2)
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        
        return self.out(d1)

# ==================== 損失函數 ====================

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        if alpha is not None:
            self.register_buffer('alpha', alpha)
        else:
            self.alpha = None
        self.gamma = gamma
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[target]
            focal_loss = alpha_t * focal_loss
        
        return focal_loss.mean()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_one_hot = F.one_hot(target, pred.shape[1]).permute(0, 4, 1, 2, 3).float()
        
        inter = (pred * target_one_hot).sum(dim=(2, 3, 4))
        union = pred.sum(dim=(2, 3, 4)) + target_one_hot.sum(dim=(2, 3, 4))
        dice = (2. * inter + self.smooth) / (union + self.smooth)
        
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self, device, use_focal=True):
        super().__init__()
        if use_focal:
            alpha = torch.tensor([0.2, 1.0, 1.0], device=device)
            self.ce = FocalLoss(alpha=alpha, gamma=2.0)
        else:
            weight = torch.tensor([0.2, 1.0, 1.0], device=device)
            self.ce = nn.CrossEntropyLoss(weight=weight)
        
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        return 0.3 * self.ce(pred, target) + 0.7 * self.dice(pred, target)

# ==================== Cosine Annealing 學習率 ====================

class CosineAnnealingLR:
    """手動實現 Cosine Annealing LR Scheduler"""
    def __init__(self, optimizer, T_max, eta_min=0):
        self.optimizer = optimizer
        self.T_max = T_max
        self.eta_min = eta_min
        self.base_lr = optimizer.param_groups[0]['lr']
        self.current_epoch = 0
    
    def step(self):
        """更新學習率"""
        self.current_epoch += 1
        lr = self.eta_min + (self.base_lr - self.eta_min) * \
             (1 + math.cos(math.pi * self.current_epoch / self.T_max)) / 2
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_lr(self):
        """獲取當前學習率"""
        return self.optimizer.param_groups[0]['lr']

# ==================== 訓練器 ====================

class Trainer:
    def __init__(self, model, train_loader, val_loader, device, output_dir, num_epochs):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 使用 SGD with momentum
        self.optimizer = torch.optim.SGD(
            model.parameters(),
            lr=1e-2,
            momentum=0.9,
            weight_decay=1e-4,
            nesterov=True
        )
        
        # Cosine Annealing LR
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=num_epochs,
            eta_min=1e-5
        )
        
        self.criterion = CombinedLoss(device=device, use_focal=True)
        self.history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'lr': []}
        self.best_dice = 0.0
        
        print(f"  使用 SGD + Nesterov momentum")
        print(f"  使用 Cosine Annealing LR (T_max={num_epochs})")
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images, labels = images.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"    Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
        
        return total_loss / len(self.train_loader)
    
    def validate(self):
        """累積式 Dice 計算"""
        self.model.eval()
        total_loss = 0
        total_inter = {1: 0.0, 2: 0.0}
        total_union = {1: 0.0, 2: 0.0}
        
        with torch.no_grad():
            for images, labels in self.val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                preds = torch.argmax(outputs, dim=1)
                
                for cls in [1, 2]:
                    pred_mask = (preds == cls).float()
                    target_mask = (labels == cls).float()
                    total_inter[cls] += (pred_mask * target_mask).sum().item()
                    total_union[cls] += (pred_mask.sum() + target_mask.sum()).item()
                
                total_loss += loss.item()
        
        dice_scores = []
        for cls in [1, 2]:
            dice = (2.0 * total_inter[cls]) / (total_union[cls] + 1e-8)
            dice_scores.append(dice)
        
        return total_loss / len(self.val_loader), np.mean(dice_scores)
    
    def train(self, num_epochs):
        print(f"\n開始訓練 {num_epochs} epochs")
        print("=" * 80)
        
        for epoch in range(1, num_epochs + 1):
            print(f"\nEpoch {epoch}/{num_epochs}")
            print("-" * 40)
            
            train_loss = self.train_epoch()
            val_loss, val_dice = self.validate()
            
            # 更新學習率
            self.scheduler.step()
            current_lr = self.scheduler.get_lr()
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_dice'].append(val_dice)
            self.history['lr'].append(current_lr)
            
            print(f"\n  訓練損失: {train_loss:.4f}")
            print(f"  驗證損失: {val_loss:.4f}")
            print(f"  驗證 Dice: {val_dice:.4f}")
            print(f"  學習率: {current_lr:.6f}")
            
            if val_dice > self.best_dice:
                self.best_dice = val_dice
                torch.save({
                    'epoch': epoch,
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'dice': val_dice,
                    'history': self.history
                }, self.output_dir / 'best_model.pth')
                print(f"  ✓ 保存最佳模型 (Dice: {val_dice:.4f})")
        
        with open(self.output_dir / 'history.json', 'w') as f:
            json.dump(self.history, f, indent=2)
        
        print("\n" + "=" * 80)
        print(f"✓ 訓練完成! 最佳 Dice: {self.best_dice:.4f}")

# ==================== 主程序 ====================

def main():
    # 設置隨機種子（確保可重現）
    SEED = 42
    set_seed(SEED)
    print(f"\n🎲 設置隨機種子: {SEED}")
    
    DATA_ROOT = '/workspace/Task04_Hippocampus'
    OUTPUT_DIR = '/workspace/outputs/training_production'
    BATCH_SIZE = 2
    NUM_EPOCHS = 30  # SGD 需要更多 epochs 才能達到好的收斂
    TARGET_SIZE = 64
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用設備: {device}")
    
    print("\n準備資料...")
    with open(Path(DATA_ROOT) / 'dataset.json') as f:
        num_samples = len(json.load(f)['training'])
    
    train_indices, val_indices = get_train_val_split(num_samples, train_ratio=0.8, seed=SEED)
    print(f"  訓練集: {len(train_indices)} 個案例")
    print(f"  驗證集: {len(val_indices)} 個案例")
    
    train_dataset = HippocampusDataset(DATA_ROOT, train_indices, TARGET_SIZE)
    val_dataset = HippocampusDataset(DATA_ROOT, val_indices, TARGET_SIZE)
    
    # DataLoader 配置（可選優化）
    use_multiprocessing = torch.cuda.is_available()
    num_workers = min(4, os.cpu_count() or 1) if use_multiprocessing else 0
    
    train_loader_kwargs = {
        'batch_size': BATCH_SIZE,
        'shuffle': True,
        'num_workers': num_workers,
    }
    if num_workers > 0:
        train_loader_kwargs.update({
            'pin_memory': True,
            'persistent_workers': True,
            'prefetch_factor': 2
        })
    
    val_loader_kwargs = {
        'batch_size': BATCH_SIZE,
        'shuffle': False,
        'num_workers': num_workers,
    }
    if num_workers > 0:
        val_loader_kwargs['pin_memory'] = True
    
    train_loader = DataLoader(train_dataset, **train_loader_kwargs)
    val_loader = DataLoader(val_dataset, **val_loader_kwargs)
    
    print(f"  DataLoader workers: {num_workers}")
    print(f"  影像插值: trilinear")
    print(f"  標註插值: nearest")
    
    print("\n建立模型...")
    model = UNet3D(1, 3, 16)
    print(f"  參數: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  使用 GroupNorm (穩定於小 batch)")
    
    trainer = Trainer(model, train_loader, val_loader, device, OUTPUT_DIR, NUM_EPOCHS)
    trainer.train(NUM_EPOCHS)

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n訓練中斷")
    except Exception as e:
        print(f"\n錯誤: {e}")
        import traceback
        traceback.print_exc()


🧠 醫學影像分割訓練系統 - 生產級

🔧 設置環境...
  ✓ 已短路 torch._compile 和 torch.onnx

✓ PyTorch 2.1.0+cu118
✓ CUDA: True
✓ cuDNN benchmark enabled

🎲 設置隨機種子: 42
使用設備: cuda

準備資料...
  訓練集: 208 個案例
  驗證集: 52 個案例
  DataLoader workers: 4
  影像插值: trilinear
  標註插值: nearest

建立模型...
  參數: 1,402,003
  使用 GroupNorm (穩定於小 batch)
  使用 SGD + Nesterov momentum
  使用 Cosine Annealing LR (T_max=30)

開始訓練 30 epochs

Epoch 1/30
----------------------------------------
    Batch 0/104, Loss: 0.6117
    Batch 10/104, Loss: 0.5528
    Batch 20/104, Loss: 0.5192
    Batch 30/104, Loss: 0.4926
    Batch 40/104, Loss: 0.4747
    Batch 50/104, Loss: 0.4702
    Batch 60/104, Loss: 0.4432
    Batch 70/104, Loss: 0.4337
    Batch 80/104, Loss: 0.4114
    Batch 90/104, Loss: 0.4004
    Batch 100/104, Loss: 0.3695

  訓練損失: 0.4646
  驗證損失: 0.3680
  驗證 Dice: 0.0298
  學習率: 0.009973
  ✓ 保存最佳模型 (Dice: 0.0298)

Epoch 2/30
----------------------------------------
    Batch 0/104, Loss: 0.3700
    Batch 10/104, Loss: 0.3577
    Batch 20/104, 

In [44]:
#!/usr/bin/env python3
"""
🚀 提升 Dice 的改進版訓練系統
策略：更大模型 + 更強增強 + LR warmup + 更長訓練
"""

import os
import sys
import warnings
warnings.filterwarnings('ignore')

print("=" * 80)
print("提升 Dice 的改進版訓練系統")
print("=" * 80)

import json
import math
import copy
import random
from types import ModuleType
import numpy as np
import nibabel as nib
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

print(f"\nPyTorch {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

# ==================== Stub ====================

def create_stub_module(name, attributes):
    mod = ModuleType(name)
    for attr_name, attr_value in attributes.items():
        setattr(mod, attr_name, attr_value)
    return mod

sys.modules['torch._compile'] = create_stub_module('torch._compile', {'inner': lambda func: func})
sys.modules['torch.onnx'] = create_stub_module('torch.onnx', {'is_in_onnx_export': False})

print("\n設置環境完成")

# ==================== 改進的 Config ====================

class ImprovedConfig:
    def __init__(self):
        self.seed = 42
        self.data_root = '/workspace/Task04_Hippocampus'
        self.output_dir = '/workspace/outputs/training_improved'
        
        # 更長訓練
        self.batch_size = 2
        self.num_epochs = 50  # 從 30 增加到 50
        self.target_size = 64
        
        # 改進的學習率策略
        self.lr = 5e-4  # 從 1e-3 降到 5e-4（更穩定）
        self.weight_decay = 1e-4
        self.lr_min = 1e-6  # 從 1e-5 降到 1e-6
        self.warmup_epochs = 5  # 新增：前 5 epochs warmup
        
        # 更大的模型
        self.model_base = 24  # 從 16 增加到 24
        self.num_classes = 3
        self.deep_supervision = True
        self.ds_weights = [1.0, 0.5, 0.25]
        self.ds_warmup_epochs = 10  # 延長深度監督暖身
        self.ds_warmup_weights = [1.0, 0.75, 0.5]
        
        self.use_amp = True
        self.use_ema = True
        self.ema_decay = 0.9999
        self.grad_clip = 1.0
        
        # 調整損失權重（更重視 Dice）
        self.use_focal = True
        self.focal_alpha = [0.2, 1.0, 1.0]
        self.focal_gamma = 2.0
        self.loss_weights = [0.2, 0.8]  # CE:Dice = 0.2:0.8（從 0.3:0.7）
        self.dice_ignore_background = True
        
        # 更強的數據增強
        self.aug_flip_prob = 0.5
        self.aug_rotate_prob = 0.5
        self.aug_gamma_prob = 0.5
        self.aug_gamma_range = [0.7, 1.3]  # 從 [0.8, 1.2] 擴大
        self.aug_intensity_shift_prob = 0.5
        self.aug_intensity_shift_range = [-0.15, 0.15]  # 從 [-0.1, 0.1] 擴大
        self.aug_intensity_scale_prob = 0.5
        self.aug_intensity_scale_range = [0.85, 1.15]  # 從 [0.9, 1.1] 擴大
        self.aug_noise_prob = 0.3  # 新增：高斯噪聲
        self.aug_noise_std = 0.1
        self.aug_blur_prob = 0.2  # 新增：模糊
        
        self.num_workers = 4
        self.pin_memory = True
        self.persistent_workers = True
        self.prefetch_factor = 2
        
        self.use_tf32 = True
        self.use_channels_last = False
        self.use_deterministic_algorithms = False
    
    def save(self, path):
        with open(path, 'w') as f:
            json.dump(self.__dict__, f, indent=2)
    
    @classmethod
    def load(cls, path):
        config = cls()
        with open(path, 'r') as f:
            config.__dict__.update(json.load(f))
        return config

# ==================== 設置 ====================

def set_seed(seed=42, use_tf32=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    if use_tf32 and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("  TF32 enabled")

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def get_generator(seed):
    g = torch.Generator()
    g.manual_seed(seed)
    return g

def gcd(a, b):
    while b:
        a, b = b, a % b
    return a

MEMORY_FORMAT = getattr(torch, "channels_last_3d", torch.contiguous_format)

# ==================== 改進的數據增強 ====================

class ImprovedAugmentation3D:
    def __init__(self, config):
        self.config = config
    
    def random_flip(self, image, label):
        if np.random.rand() < self.config.aug_flip_prob:
            axis = np.random.choice([0, 1, 2])
            image = torch.flip(image, dims=[axis + 1])
            label = torch.flip(label, dims=[axis])
        return image, label
    
    def random_rotate_90(self, image, label):
        if np.random.rand() < self.config.aug_rotate_prob:
            k = np.random.randint(1, 4)
            image = torch.rot90(image, k, dims=[2, 3])
            label = torch.rot90(label, k, dims=[1, 2])
        return image, label
    
    def random_gamma(self, image):
        if np.random.rand() < self.config.aug_gamma_prob:
            gamma = np.random.uniform(*self.config.aug_gamma_range)
            image_min = image.min()
            image_max = image.max()
            if image_max > image_min:
                image_norm = (image - image_min) / (image_max - image_min)
                image = torch.pow(image_norm, gamma)
                image = image * (image_max - image_min) + image_min
        return image
    
    def random_intensity_shift(self, image):
        if np.random.rand() < self.config.aug_intensity_shift_prob:
            shift = np.random.uniform(*self.config.aug_intensity_shift_range)
            image = image + shift
        return image
    
    def random_intensity_scale(self, image):
        if np.random.rand() < self.config.aug_intensity_scale_prob:
            scale = np.random.uniform(*self.config.aug_intensity_scale_range)
            image = image * scale
        return image
    
    def add_gaussian_noise(self, image):
        """新增：高斯噪聲"""
        if np.random.rand() < self.config.aug_noise_prob:
            noise = torch.randn_like(image) * self.config.aug_noise_std
            image = image + noise
        return image
    
    def gaussian_blur(self, image):
        """新增：高斯模糊"""
        if np.random.rand() < self.config.aug_blur_prob:
            # 使用 avg pooling 模擬模糊
            image = F.avg_pool3d(image.unsqueeze(0), 3, stride=1, padding=1).squeeze(0)
        return image
    
    def apply(self, image, label, is_training=True):
        if not is_training:
            return image, label
        
        image, label = self.random_flip(image, label)
        image, label = self.random_rotate_90(image, label)
        image = self.random_gamma(image)
        image = self.random_intensity_shift(image)
        image = self.random_intensity_scale(image)
        image = self.add_gaussian_noise(image)
        image = self.gaussian_blur(image)
        
        return image, label

# ==================== 數據集 ====================

class HippocampusDataset(Dataset):
    def __init__(self, data_root, indices, config, is_training=True):
        self.data_root = Path(data_root)
        self.config = config
        self.is_training = is_training
        self.augmentation = ImprovedAugmentation3D(config)
        
        with open(self.data_root / 'dataset.json', 'r') as f:
            metadata = json.load(f)
        
        training_data = metadata['training']
        self.cases = [training_data[i] for i in indices]
    
    def __len__(self):
        return len(self.cases)
    
    def __getitem__(self, idx):
        case = self.cases[idx]
        img_path = self.data_root / case['image'].lstrip('./')
        lbl_path = self.data_root / case['label'].lstrip('./')
        
        image = nib.load(str(img_path)).get_fdata(dtype=np.float32)
        label = nib.load(str(lbl_path)).get_fdata(dtype=np.float32)
        
        p1, p99 = np.percentile(image, [1, 99])
        image = np.clip(image, p1, p99)
        mean, std = image.mean(), image.std()
        if std > 1e-8:
            image = (image - mean) / std
        
        image_t = torch.from_numpy(image).unsqueeze(0).unsqueeze(0)
        label_t = torch.from_numpy(label).unsqueeze(0).unsqueeze(0)
        
        image = F.interpolate(
            image_t, size=(self.config.target_size,) * 3,
            mode='trilinear', align_corners=False
        ).squeeze(0)
        
        label = F.interpolate(
            label_t, size=(self.config.target_size,) * 3,
            mode='nearest'
        ).squeeze(0).squeeze(0)
        
        label = torch.round(label).long()
        label = torch.clamp(label, 0, 2)
        
        image, label = self.augmentation.apply(image, label, self.is_training)
        
        return image, label

def get_train_val_split(num_samples, train_ratio=0.8, seed=42):
    np.random.seed(seed)
    indices = np.random.permutation(num_samples)
    split_idx = int(num_samples * train_ratio)
    return indices[:split_idx], indices[split_idx:]

# ==================== 模型 ====================

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=8):
        super().__init__()
        actual_groups = gcd(num_groups, out_ch)
        
        self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1)
        self.gn1 = nn.GroupNorm(actual_groups, out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1)
        self.gn2 = nn.GroupNorm(actual_groups, out_ch)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.relu(self.gn1(self.conv1(x)))
        x = self.relu(self.gn2(self.conv2(x)))
        return x

class UNet3DDeepSupervision(nn.Module):
    def __init__(self, in_ch=1, num_classes=3, base=16, deep_supervision=True):
        super().__init__()
        self.deep_supervision = deep_supervision
        
        self.enc1 = ConvBlock(in_ch, base)
        self.enc2 = ConvBlock(base, base * 2)
        self.enc3 = ConvBlock(base * 2, base * 4)
        self.bottleneck = ConvBlock(base * 4, base * 8)
        
        self.up3 = nn.ConvTranspose3d(base * 8, base * 4, 2, stride=2)
        self.dec3 = ConvBlock(base * 8, base * 4)
        self.up2 = nn.ConvTranspose3d(base * 4, base * 2, 2, stride=2)
        self.dec2 = ConvBlock(base * 4, base * 2)
        self.up1 = nn.ConvTranspose3d(base * 2, base, 2, stride=2)
        self.dec1 = ConvBlock(base * 2, base)
        
        self.out = nn.Conv3d(base, num_classes, 1)
        self.pool = nn.MaxPool3d(2, 2)
        
        if deep_supervision:
            self.ds_out3 = nn.Conv3d(base * 4, num_classes, 1)
            self.ds_out2 = nn.Conv3d(base * 2, num_classes, 1)
    
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        
        d3 = self.dec3(torch.cat([self.up3(b), e3], 1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        
        out = self.out(d1)
        
        if self.deep_supervision and self.training:
            ds3 = F.interpolate(self.ds_out3(d3), size=out.shape[2:], mode='trilinear', align_corners=False)
            ds2 = F.interpolate(self.ds_out2(d2), size=out.shape[2:], mode='trilinear', align_corners=False)
            return out, ds3, ds2
        
        return out

# ==================== EMA ====================

class ModelEMA:
    def __init__(self, model, decay=0.9999):
        self.model = copy.deepcopy(model).eval()
        self.decay = decay
        for param in self.model.parameters():
            param.requires_grad = False
    
    def update(self, model):
        with torch.no_grad():
            model_state = model.state_dict()
            ema_state = self.model.state_dict()
            assert model_state.keys() == ema_state.keys()
            for key in ema_state.keys():
                if ema_state[key].dtype.is_floating_point:
                    ema_state[key].mul_(self.decay).add_(model_state[key], alpha=1 - self.decay)
                else:
                    ema_state[key].copy_(model_state[key])

# ==================== 損失函數 ====================

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        if alpha is not None:
            self.register_buffer('alpha', alpha)
        else:
            self.alpha = None
        self.gamma = gamma
    
    def forward(self, pred, target):
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        if self.alpha is not None:
            alpha_t = self.alpha[target]
            focal_loss = alpha_t * focal_loss
        return focal_loss.mean()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0, ignore_index=None):
        super().__init__()
        self.smooth = smooth
        self.ignore_index = ignore_index
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        target_1h = F.one_hot(target, pred.shape[1]).permute(0, 4, 1, 2, 3).float()
        
        if self.ignore_index is not None:
            keep = [i for i in range(pred.shape[1]) if i != self.ignore_index]
            pred = pred[:, keep]
            target_1h = target_1h[:, keep]
        
        inter = (pred * target_1h).sum(dim=(2, 3, 4))
        union = pred.sum(dim=(2, 3, 4)) + target_1h.sum(dim=(2, 3, 4))
        dice = (2. * inter + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

class CombinedLossDeepSupervision(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.config = config
        self.loss_weights = config.loss_weights
        
        if config.use_focal:
            alpha = torch.tensor(config.focal_alpha, device=device)
            self.ce = FocalLoss(alpha=alpha, gamma=config.focal_gamma)
        else:
            weight = torch.tensor(config.focal_alpha, device=device)
            self.ce = nn.CrossEntropyLoss(weight=weight)
        
        ignore_idx = 0 if config.dice_ignore_background else None
        self.dice = DiceLoss(ignore_index=ignore_idx)
    
    def forward(self, outputs, target, epoch=None):
        ce_weight, dice_weight = self.loss_weights
        
        if epoch is not None and epoch <= self.config.ds_warmup_epochs:
            ds_weights = self.config.ds_warmup_weights
        else:
            ds_weights = self.config.ds_weights
        
        if isinstance(outputs, tuple):
            main_out, ds3, ds2 = outputs
            loss = ds_weights[0] * (ce_weight * self.ce(main_out, target) + dice_weight * self.dice(main_out, target))
            loss += ds_weights[1] * (ce_weight * self.ce(ds3, target) + dice_weight * self.dice(ds3, target))
            loss += ds_weights[2] * (ce_weight * self.ce(ds2, target) + dice_weight * self.dice(ds2, target))
            return loss
        else:
            return ce_weight * self.ce(outputs, target) + dice_weight * self.dice(outputs, target)

# ==================== 指標 ====================

class DatasetMetrics:
    def __init__(self, num_classes, device):
        self.num_classes = num_classes
        self.device = device
        self.reset()
    
    def reset(self):
        self.tp = torch.zeros(self.num_classes, dtype=torch.float32, device=self.device)
        self.fp = torch.zeros(self.num_classes, dtype=torch.float32, device=self.device)
        self.fn = torch.zeros(self.num_classes, dtype=torch.float32, device=self.device)
    
    def update(self, pred, target):
        for cls in range(1, self.num_classes):
            pred_mask = (pred == cls)
            target_mask = (target == cls)
            self.tp[cls] += (pred_mask & target_mask).sum().to(self.tp.dtype)
            self.fp[cls] += (pred_mask & ~target_mask).sum().to(self.fp.dtype)
            self.fn[cls] += (~pred_mask & target_mask).sum().to(self.fn.dtype)
    
    def compute(self):
        smooth = 1.0
        tp = self.tp.detach().cpu()
        fp = self.fp.detach().cpu()
        fn = self.fn.detach().cpu()
        
        dice_scores = {}
        iou_scores = {}
        
        for cls in range(1, self.num_classes):
            dice = (2 * tp[cls] + smooth) / (2 * tp[cls] + fp[cls] + fn[cls] + smooth)
            iou = (tp[cls] + smooth) / (tp[cls] + fp[cls] + fn[cls] + smooth)
            dice_scores[cls] = dice.item()
            iou_scores[cls] = iou.item()
        
        return dice_scores, iou_scores

# ==================== Warmup Cosine Scheduler ====================

class WarmupCosineAnnealingLR:
    def __init__(self, optimizer, T_max, warmup_epochs, eta_min=0):
        self.optimizer = optimizer
        self.T_max = T_max
        self.warmup_epochs = warmup_epochs
        self.eta_min = eta_min
        self.base_lr = optimizer.param_groups[0]['lr']
        self.current_epoch = 0
    
    def step(self):
        self.current_epoch += 1
        
        if self.current_epoch <= self.warmup_epochs:
            # Warmup: 線性增長
            lr = self.base_lr * (self.current_epoch / self.warmup_epochs)
        else:
            # Cosine annealing
            progress = (self.current_epoch - self.warmup_epochs) / (self.T_max - self.warmup_epochs)
            lr = self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(math.pi * progress)) / 2
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
    
    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']

# ==================== 訓練器 ====================

class Trainer:
    def __init__(self, model, train_loader, val_loader, device, config):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.config = config
        self.output_dir = Path(config.output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.lr,
            weight_decay=config.weight_decay,
            betas=(0.9, 0.999)
        )
        
        self.scheduler = WarmupCosineAnnealingLR(
            self.optimizer, 
            T_max=config.num_epochs, 
            warmup_epochs=config.warmup_epochs,
            eta_min=config.lr_min
        )
        self.ema = ModelEMA(model, decay=config.ema_decay) if config.use_ema else None
        self.criterion = CombinedLossDeepSupervision(config, device)
        
        self.use_amp = config.use_amp and torch.cuda.is_available()
        self.scaler = GradScaler() if self.use_amp else None
        
        self.history = {
            'train_loss': [], 'val_loss': [], 'val_dice': [],
            'val_dice_class1': [], 'val_dice_class2': [],
            'val_iou_class1': [], 'val_iou_class2': [], 'lr': []
        }
        self.best_dice = 0.0
        self.current_epoch = 0
        
        print(f"  AdamW (lr={config.lr}) + Warmup({config.warmup_epochs}) + Cosine")
        print(f"  更大模型 (base={config.model_base})")
        print(f"  更強增強 (noise + blur)")
        print(f"  更長訓練 ({config.num_epochs} epochs)")
        print(f"  Dice 權重: {config.loss_weights[1]:.1f}")
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for batch_idx, (images, labels) in enumerate(self.train_loader):
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            
            self.optimizer.zero_grad()
            
            if self.use_amp:
                with autocast():
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels, epoch=self.current_epoch)
                
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config.grad_clip)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(images)
                loss = self.criterion(outputs, labels, epoch=self.current_epoch)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.config.grad_clip)
                self.optimizer.step()
            
            if self.ema:
                self.ema.update(self.model)
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"    Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
        
        return total_loss / len(self.train_loader)
    
    def validate(self, use_ema=True):
        model = self.ema.model if (use_ema and self.ema) else self.model
        model.eval()
        
        total_loss = 0
        metrics_tracker = DatasetMetrics(self.config.num_classes, self.device)
        
        with torch.inference_mode():
            for images, labels in self.val_loader:
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                
                if self.use_amp:
                    with autocast():
                        outputs = model(images)
                        if isinstance(outputs, tuple):
                            outputs = outputs[0]
                        loss = self.criterion(outputs, labels)
                else:
                    outputs = model(images)
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    loss = self.criterion(outputs, labels)
                
                preds = torch.argmax(outputs, dim=1)
                
                for i in range(preds.shape[0]):
                    metrics_tracker.update(preds[i], labels[i])
                
                total_loss += loss.item()
        
        dice_scores, iou_scores = metrics_tracker.compute()
        avg_dice = np.mean([dice_scores[1], dice_scores[2]])
        
        avg_metrics = {
            1: {'dice': dice_scores[1], 'iou': iou_scores[1]},
            2: {'dice': dice_scores[2], 'iou': iou_scores[2]}
        }
        
        return total_loss / len(self.val_loader), avg_dice, avg_metrics
    
    def train(self):
        print(f"\n開始訓練 {self.config.num_epochs} epochs")
        print("=" * 80)
        
        for epoch in range(1, self.config.num_epochs + 1):
            self.current_epoch = epoch
            
            print(f"\nEpoch {epoch}/{self.config.num_epochs}")
            print("-" * 40)
            
            train_loss = self.train_epoch()
            val_loss, val_dice, metrics = self.validate(use_ema=True)
            
            self.scheduler.step()
            current_lr = self.scheduler.get_lr()
            
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_dice'].append(val_dice)
            self.history['val_dice_class1'].append(metrics[1]['dice'])
            self.history['val_dice_class2'].append(metrics[2]['dice'])
            self.history['val_iou_class1'].append(metrics[1]['iou'])
            self.history['val_iou_class2'].append(metrics[2]['iou'])
            self.history['lr'].append(current_lr)
            
            print(f"\n  訓練損失: {train_loss:.4f}")
            print(f"  驗證損失: {val_loss:.4f}")
            print(f"  驗證 Dice: {val_dice:.4f} {'🎯 新紀錄!' if val_dice > self.best_dice else ''}")
            print(f"    Class 1 - Dice: {metrics[1]['dice']:.4f}, IoU: {metrics[1]['iou']:.4f}")
            print(f"    Class 2 - Dice: {metrics[2]['dice']:.4f}, IoU: {metrics[2]['iou']:.4f}")
            print(f"  學習率: {current_lr:.6f}")
            
            if val_dice > self.best_dice:
                self.best_dice = val_dice
                checkpoint = {
                    'epoch': epoch,
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'dice': val_dice,
                    'metrics': metrics,
                    'history': self.history,
                    'config': self.config.__dict__
                }
                if self.ema:
                    checkpoint['ema'] = self.ema.model.state_dict()
                
                torch.save(checkpoint, self.output_dir / 'best_model.pth')
                print(f"  ✓ 已保存最佳模型")
        
        with open(self.output_dir / 'history.json', 'w') as f:
            json.dump(self.history, f, indent=2)
        
        self.config.save(self.output_dir / 'config.json')
        
        print("\n" + "=" * 80)
        print(f"訓練完成! 最佳 Dice: {self.best_dice:.4f}")

# ==================== 主程序 ====================

def main():
    config = ImprovedConfig()
    
    set_seed(config.seed, config.use_tf32)
    
    print(f"隨機種子: {config.seed}")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用設備: {device}")
    
    print("\n準備資料...")
    with open(Path(config.data_root) / 'dataset.json') as f:
        num_samples = len(json.load(f)['training'])
    
    train_indices, val_indices = get_train_val_split(num_samples, seed=config.seed)
    print(f"  訓練集: {len(train_indices)}")
    print(f"  驗證集: {len(val_indices)}")
    
    train_dataset = HippocampusDataset(config.data_root, train_indices, config, is_training=True)
    val_dataset = HippocampusDataset(config.data_root, val_indices, config, is_training=False)
    
    num_workers = min(config.num_workers, os.cpu_count() or 1) if torch.cuda.is_available() else 0
    
    train_generator = get_generator(config.seed)
    val_generator = get_generator(config.seed)
    
    train_loader_kwargs = {
        'batch_size': config.batch_size,
        'shuffle': True,
        'num_workers': num_workers,
        'pin_memory': config.pin_memory,
        'worker_init_fn': worker_init_fn,
        'generator': train_generator
    }
    if num_workers > 0:
        train_loader_kwargs.update({
            'persistent_workers': config.persistent_workers,
            'prefetch_factor': config.prefetch_factor
        })
    
    val_loader_kwargs = {
        'batch_size': config.batch_size,
        'shuffle': False,
        'num_workers': num_workers,
        'pin_memory': config.pin_memory,
        'worker_init_fn': worker_init_fn,
        'generator': val_generator
    }
    
    train_loader = DataLoader(train_dataset, **train_loader_kwargs)
    val_loader = DataLoader(val_dataset, **val_loader_kwargs)
    
    print("\n建立模型...")
    model = UNet3DDeepSupervision(1, config.num_classes, config.model_base, config.deep_supervision)
    print(f"  參數: {sum(p.numel() for p in model.parameters()):,}")
    
    trainer = Trainer(model, train_loader, val_loader, device, config)
    trainer.train()

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n訓練中斷")
    except Exception as e:
        print(f"\n錯誤: {e}")
        import traceback
        traceback.print_exc()


提升 Dice 的改進版訓練系統

PyTorch 2.1.0+cu118
CUDA: True

設置環境完成
  TF32 enabled
隨機種子: 42
使用設備: cuda

準備資料...
  訓練集: 208
  驗證集: 52

建立模型...
  參數: 3,152,913
  AdamW (lr=0.0005) + Warmup(5) + Cosine
  更大模型 (base=24)
  更強增強 (noise + blur)
  更長訓練 (50 epochs)
  Dice 權重: 0.8

開始訓練 50 epochs

Epoch 1/50
----------------------------------------
    Batch 0/104, Loss: 1.7866
    Batch 10/104, Loss: 1.6767
    Batch 20/104, Loss: 1.5896
    Batch 30/104, Loss: 1.4519
    Batch 40/104, Loss: 1.2524
    Batch 50/104, Loss: 1.1830
    Batch 60/104, Loss: 1.1920
    Batch 70/104, Loss: 1.1376
    Batch 80/104, Loss: 1.1990
    Batch 90/104, Loss: 1.1095
    Batch 100/104, Loss: 1.1447

  訓練損失: 1.3372
  驗證損失: 0.7977
  驗證 Dice: 0.0584 🎯 新紀錄!
    Class 1 - Dice: 0.0770, IoU: 0.0400
    Class 2 - Dice: 0.0399, IoU: 0.0203
  學習率: 0.000100
  ✓ 已保存最佳模型

Epoch 2/50
----------------------------------------
    Batch 0/104, Loss: 1.0202
    Batch 10/104, Loss: 0.9235
    Batch 20/104, Loss: 0.8873
    Batch 30/104, Lo

In [47]:
#!/usr/bin/env python3
"""
純粹的 3D U-Net 實現（修正版）
- 自動檢測數據路徑
- 移除深度監督
- 移除 Focal Loss
- 移除 EMA
- 使用最基礎的配置
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from pathlib import Path
import random
from tqdm import tqdm
import os

# ==================== 設置隨機種子 ====================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ==================== 純粹的 3D U-Net 模型 ====================

class ConvBlock(nn.Module):
    """基礎卷積塊"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x


class PureUNet3D(nn.Module):
    """純粹的 3D U-Net（沒有深度監督）"""
    def __init__(self, in_channels=1, num_classes=3, base_channels=16):
        super().__init__()
        
        # 編碼器
        self.enc1 = ConvBlock(in_channels, base_channels)
        self.pool1 = nn.MaxPool3d(2)
        
        self.enc2 = ConvBlock(base_channels, base_channels * 2)
        self.pool2 = nn.MaxPool3d(2)
        
        self.enc3 = ConvBlock(base_channels * 2, base_channels * 4)
        self.pool3 = nn.MaxPool3d(2)
        
        # 瓶頸層
        self.bottleneck = ConvBlock(base_channels * 4, base_channels * 8)
        
        # 解碼器
        self.upconv3 = nn.ConvTranspose3d(base_channels * 8, base_channels * 4, 
                                          kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_channels * 8, base_channels * 4)
        
        self.upconv2 = nn.ConvTranspose3d(base_channels * 4, base_channels * 2, 
                                          kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_channels * 4, base_channels * 2)
        
        self.upconv1 = nn.ConvTranspose3d(base_channels * 2, base_channels, 
                                          kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_channels * 2, base_channels)
        
        # 輸出層
        self.out = nn.Conv3d(base_channels, num_classes, kernel_size=1)
    
    def forward(self, x):
        # 編碼路徑
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        
        # 瓶頸
        b = self.bottleneck(self.pool3(e3))
        
        # 解碼路徑
        d3 = self.upconv3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # 輸出
        out = self.out(d1)
        return out


# ==================== Dice Loss（簡化版）====================

class DiceLoss(nn.Module):
    """簡單的 Dice Loss"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        
        # 計算每個類別的 Dice
        dice_scores = []
        for c in range(pred.shape[1]):
            pred_c = pred[:, c]
            target_c = (target == c).float()
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        return 1.0 - torch.stack(dice_scores).mean()


# ==================== 數據集 ====================

class HippocampusDataset(Dataset):
    """海馬體數據集"""
    def __init__(self, data_dir, target_size=64, is_train=True):
        self.data_dir = Path(data_dir)
        self.target_size = target_size
        self.is_train = is_train
        
        # 獲取所有圖像文件（過濾掉 macOS 隱藏文件）
        all_image_files = list((self.data_dir / 'imagesTr').glob('*.nii.gz'))
        self.image_files = sorted([f for f in all_image_files if not f.name.startswith('._')])
        
        all_label_files = list((self.data_dir / 'labelsTr').glob('*.nii.gz'))
        self.label_files = sorted([f for f in all_label_files if not f.name.startswith('._')])
        
        if len(self.image_files) == 0:
            raise ValueError(f"在 {self.data_dir / 'imagesTr'} 找不到任何 .nii.gz 文件！")
        
        print(f"找到 {len(self.image_files)} 個訓練樣本")
    
    def __len__(self):
        return len(self.image_files)
    
    def preprocess(self, image):
        """預處理"""
        # Percentile clipping
        p1, p99 = np.percentile(image, [1, 99])
        image = np.clip(image, p1, p99)
        
        # Z-score normalization
        mean, std = image.mean(), image.std()
        image = (image - mean) / (std + 1e-8)
        return image
    
    def augment(self, image, label):
        """簡單的數據增強"""
        # 隨機翻轉
        if random.random() > 0.5:
            axis = random.choice([0, 1, 2])
            image = np.flip(image, axis).copy()
            label = np.flip(label, axis).copy()
        
        # 隨機旋轉 90 度
        if random.random() > 0.5:
            k = random.randint(1, 3)
            axes = random.choice([(0, 1), (0, 2), (1, 2)])
            image = np.rot90(image, k, axes).copy()
            label = np.rot90(label, k, axes).copy()
        
        return image, label
    
    def __getitem__(self, idx):
        # 讀取文件
        image = nib.load(self.image_files[idx]).get_fdata(dtype=np.float32)
        label = nib.load(self.label_files[idx]).get_fdata(dtype=np.float32)
        
        # 預處理
        image = self.preprocess(image)
        
        # 數據增強（僅訓練時）
        if self.is_train:
            image, label = self.augment(image, label)
        
        # 轉為 tensor 並調整大小
        image = torch.from_numpy(image).unsqueeze(0)  # [1, H, W, D]
        label = torch.from_numpy(label).unsqueeze(0)  # [1, H, W, D]
        
        image = F.interpolate(
            image.unsqueeze(0),
            size=(self.target_size, self.target_size, self.target_size),
            mode='trilinear',
            align_corners=False
        ).squeeze(0)
        
        label = F.interpolate(
            label.unsqueeze(0),
            size=(self.target_size, self.target_size, self.target_size),
            mode='nearest'
        ).squeeze(0)
        
        label = label.squeeze(0).long()
        label = torch.clamp(label, 0, 2)
        
        return image, label


# ==================== 評估指標 ====================

def compute_dice(pred, target, num_classes=3):
    """計算 Dice 分數"""
    dice_scores = []
    
    for c in range(1, num_classes):  # 跳過背景
        pred_c = (pred == c)
        target_c = (target == c)
        
        intersection = (pred_c & target_c).sum().float()
        union = pred_c.sum().float() + target_c.sum().float()
        
        if union == 0:
            dice = 1.0 if intersection == 0 else 0.0
        else:
            dice = (2.0 * intersection) / union
        
        dice_scores.append(dice.item())
    
    return dice_scores


# ==================== 訓練函數 ====================

def train_epoch(model, loader, criterion_ce, criterion_dice, optimizer, device):
    """訓練一個 epoch"""
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向傳播
        outputs = model(images)
        
        # 計算損失（簡單的 CE + Dice）
        loss_ce = criterion_ce(outputs, labels)
        loss_dice = criterion_dice(outputs, labels)
        loss = 0.5 * loss_ce + 0.5 * loss_dice
        
        # 反向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)


def validate(model, loader, device):
    """驗證"""
    model.eval()
    all_dice_scores = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            
            # 計算 Dice
            for pred, label in zip(preds, labels):
                dice_scores = compute_dice(pred.cpu(), label.cpu())
                all_dice_scores.append(dice_scores)
    
    # 平均 Dice
    all_dice_scores = np.array(all_dice_scores)
    mean_dice = all_dice_scores.mean(axis=0)
    
    return mean_dice


# ==================== 自動檢測數據路徑 ====================

def find_data_directory():
    """自動查找數據目錄"""
    possible_paths = [
        '/home/claude/Task04_Hippocampus',
        '/workspace/data/Task04_Hippocampus',
        './Task04_Hippocampus',
        './data/Task04_Hippocampus',
        '../data/Task04_Hippocampus',
        '/data/Task04_Hippocampus',
    ]
    
    for path in possible_paths:
        if Path(path).exists():
            images_dir = Path(path) / 'imagesTr'
            if images_dir.exists():
                # 過濾掉 macOS 隱藏文件
                image_files = [f for f in images_dir.glob('*.nii.gz') if not f.name.startswith('._')]
                if len(image_files) > 0:
                    print(f"✓ 找到數據目錄: {path}")
                    print(f"  ({len(image_files)} 個有效文件)")
                    return path
    
    # 如果都找不到，列出當前目錄
    print("\n❌ 找不到數據目錄！")
    print("\n當前目錄內容:")
    for item in Path('.').iterdir():
        print(f"  - {item}")
    
    print("\n請確保數據在以下位置之一:")
    for path in possible_paths:
        print(f"  - {path}")
    
    return None


# ==================== 主訓練流程 ====================

def main():
    # 自動查找數據目錄
    data_dir = find_data_directory()
    
    if data_dir is None:
        print("\n" + "=" * 60)
        print("錯誤：找不到數據集！")
        print("=" * 60)
        print("\n解決方案:")
        print("1. 下載 Hippocampus 數據集")
        print("2. 解壓到以下任一位置:")
        print("   - /workspace/data/Task04_Hippocampus")
        print("   - ./data/Task04_Hippocampus")
        print("   - ./Task04_Hippocampus")
        print("\n數據集結構應該是:")
        print("Task04_Hippocampus/")
        print("├── imagesTr/")
        print("│   ├── hippocampus_001.nii.gz")
        print("│   ├── hippocampus_002.nii.gz")
        print("│   └── ...")
        print("└── labelsTr/")
        print("    ├── hippocampus_001.nii.gz")
        print("    ├── hippocampus_002.nii.gz")
        print("    └── ...")
        print("=" * 60)
        return
    
    # 配置
    config = {
        'data_dir': data_dir,
        'batch_size': 2,
        'num_epochs': 30,
        'lr': 1e-3,
        'base_channels': 16,
        'target_size': 64,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    
    print("\n" + "=" * 60)
    print("純粹的 3D U-Net 訓練")
    print("=" * 60)
    print(f"配置:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print()
    
    device = torch.device(config['device'])
    
    # 創建數據集
    print("準備數據集...")
    try:
        train_dataset = HippocampusDataset(
            config['data_dir'],
            target_size=config['target_size'],
            is_train=True
        )
    except ValueError as e:
        print(f"\n錯誤: {e}")
        return
    
    # 簡單的訓練/驗證劃分 (80/20)
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"訓練集: {len(train_dataset)} 樣本")
    print(f"驗證集: {len(val_dataset)} 樣本")
    print()
    
    # 創建模型
    print("創建模型...")
    model = PureUNet3D(
        in_channels=1,
        num_classes=3,
        base_channels=config['base_channels']
    ).to(device)
    
    # 計算參數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型參數量: {total_params:,} ({total_params/1e6:.2f}M)")
    print()
    
    # 損失函數和優化器
    criterion_ce = nn.CrossEntropyLoss()
    criterion_dice = DiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    
    # 簡單的學習率衰減
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=10, 
        gamma=0.5
    )
    
    # 訓練循環
    print("開始訓練...")
    print()
    
    best_dice = 0.0
    
    for epoch in range(config['num_epochs']):
        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print("-" * 60)
        
        # 訓練
        train_loss = train_epoch(
            model, train_loader, criterion_ce, criterion_dice, 
            optimizer, device
        )
        
        # 驗證
        val_dice = validate(model, val_loader, device)
        
        # 學習率衰減
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # 打印結果
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Dice - Class 1: {val_dice[0]:.4f}, Class 2: {val_dice[1]:.4f}, "
              f"Avg: {val_dice.mean():.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # 保存最佳模型
        if val_dice.mean() > best_dice:
            best_dice = val_dice.mean()
            
            # 確保輸出目錄存在
            output_dir = Path('/workspace/outputs')
            if not output_dir.exists():
                output_dir = Path('./outputs')
                output_dir.mkdir(exist_ok=True)
            
            save_path = output_dir / 'pure_unet_best.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'dice': val_dice,
                'config': config
            }, save_path)
            print(f"✓ 保存最佳模型 (Dice: {best_dice:.4f}) -> {save_path}")
        
        print()
    
    print("=" * 60)
    print("訓練完成！")
    print(f"最佳驗證 Dice: {best_dice:.4f}")
    print(f"模型已保存至: {save_path}")
    print("=" * 60)


if __name__ == '__main__':
    main()


✓ 找到數據目錄: ./Task04_Hippocampus
  (260 個有效文件)

純粹的 3D U-Net 訓練
配置:
  data_dir: ./Task04_Hippocampus
  batch_size: 2
  num_epochs: 30
  lr: 0.001
  base_channels: 16
  target_size: 64
  device: cuda

準備數據集...
找到 260 個訓練樣本
訓練集: 208 樣本
驗證集: 52 樣本

創建模型...
模型參數量: 1,402,003 (1.40M)

開始訓練...

Epoch 1/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 26.90it/s, loss=0.4429]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.89it/s]


Train Loss: 0.6237
Val Dice - Class 1: 0.6355, Class 2: 0.6403, Avg: 0.6379
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.6379) -> /workspace/outputs/pure_unet_best.pth

Epoch 2/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 29.54it/s, loss=0.2678]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.73it/s]


Train Loss: 0.3207
Val Dice - Class 1: 0.6834, Class 2: 0.7307, Avg: 0.7070
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.7070) -> /workspace/outputs/pure_unet_best.pth

Epoch 3/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:04<00:00, 25.00it/s, loss=0.1526]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.05it/s]


Train Loss: 0.1983
Val Dice - Class 1: 0.7137, Class 2: 0.7626, Avg: 0.7381
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.7381) -> /workspace/outputs/pure_unet_best.pth

Epoch 4/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 29.30it/s, loss=0.1304]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.31it/s]


Train Loss: 0.1533
Val Dice - Class 1: 0.6823, Class 2: 0.7495, Avg: 0.7159
Learning Rate: 0.001000

Epoch 5/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.09it/s, loss=0.1059]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.80it/s]


Train Loss: 0.1299
Val Dice - Class 1: 0.7504, Class 2: 0.7680, Avg: 0.7592
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.7592) -> /workspace/outputs/pure_unet_best.pth

Epoch 6/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.12it/s, loss=0.1206]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.55it/s]


Train Loss: 0.1147
Val Dice - Class 1: 0.7780, Class 2: 0.7877, Avg: 0.7829
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.7829) -> /workspace/outputs/pure_unet_best.pth

Epoch 7/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.75it/s, loss=0.0998]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.82it/s]


Train Loss: 0.1109
Val Dice - Class 1: 0.7887, Class 2: 0.8045, Avg: 0.7966
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.7966) -> /workspace/outputs/pure_unet_best.pth

Epoch 8/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.78it/s, loss=0.0930]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.37it/s]


Train Loss: 0.1031
Val Dice - Class 1: 0.7821, Class 2: 0.7951, Avg: 0.7886
Learning Rate: 0.001000

Epoch 9/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.49it/s, loss=0.1116]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.30it/s]


Train Loss: 0.1015
Val Dice - Class 1: 0.7980, Class 2: 0.8020, Avg: 0.8000
Learning Rate: 0.001000
✓ 保存最佳模型 (Dice: 0.8000) -> /workspace/outputs/pure_unet_best.pth

Epoch 10/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.72it/s, loss=0.0916]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.37it/s]


Train Loss: 0.1018
Val Dice - Class 1: 0.8194, Class 2: 0.8148, Avg: 0.8171
Learning Rate: 0.000500
✓ 保存最佳模型 (Dice: 0.8171) -> /workspace/outputs/pure_unet_best.pth

Epoch 11/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.60it/s, loss=0.0827]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.68it/s]


Train Loss: 0.0937
Val Dice - Class 1: 0.8369, Class 2: 0.8204, Avg: 0.8287
Learning Rate: 0.000500
✓ 保存最佳模型 (Dice: 0.8287) -> /workspace/outputs/pure_unet_best.pth

Epoch 12/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.12it/s, loss=0.0698]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.87it/s]


Train Loss: 0.0871
Val Dice - Class 1: 0.8187, Class 2: 0.8157, Avg: 0.8172
Learning Rate: 0.000500

Epoch 13/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.93it/s, loss=0.0780]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.87it/s]


Train Loss: 0.0887
Val Dice - Class 1: 0.8303, Class 2: 0.8223, Avg: 0.8263
Learning Rate: 0.000500

Epoch 14/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.62it/s, loss=0.0742]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.60it/s]


Train Loss: 0.0866
Val Dice - Class 1: 0.8296, Class 2: 0.8007, Avg: 0.8151
Learning Rate: 0.000500

Epoch 15/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.71it/s, loss=0.0687]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.58it/s]


Train Loss: 0.0854
Val Dice - Class 1: 0.8364, Class 2: 0.8212, Avg: 0.8288
Learning Rate: 0.000500
✓ 保存最佳模型 (Dice: 0.8288) -> /workspace/outputs/pure_unet_best.pth

Epoch 16/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.87it/s, loss=0.0768]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.56it/s]


Train Loss: 0.0855
Val Dice - Class 1: 0.8348, Class 2: 0.8211, Avg: 0.8280
Learning Rate: 0.000500

Epoch 17/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.78it/s, loss=0.0676]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.58it/s]


Train Loss: 0.0851
Val Dice - Class 1: 0.8375, Class 2: 0.8297, Avg: 0.8336
Learning Rate: 0.000500
✓ 保存最佳模型 (Dice: 0.8336) -> /workspace/outputs/pure_unet_best.pth

Epoch 18/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.15it/s, loss=0.0954]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.47it/s]


Train Loss: 0.0846
Val Dice - Class 1: 0.8441, Class 2: 0.8247, Avg: 0.8344
Learning Rate: 0.000500
✓ 保存最佳模型 (Dice: 0.8344) -> /workspace/outputs/pure_unet_best.pth

Epoch 19/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.88it/s, loss=0.0842]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.23it/s]


Train Loss: 0.0840
Val Dice - Class 1: 0.8414, Class 2: 0.8183, Avg: 0.8299
Learning Rate: 0.000500

Epoch 20/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 33.07it/s, loss=0.0814]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.42it/s]


Train Loss: 0.0829
Val Dice - Class 1: 0.8421, Class 2: 0.8304, Avg: 0.8363
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8363) -> /workspace/outputs/pure_unet_best.pth

Epoch 21/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.46it/s, loss=0.0545]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.39it/s]


Train Loss: 0.0822
Val Dice - Class 1: 0.8485, Class 2: 0.8302, Avg: 0.8394
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8394) -> /workspace/outputs/pure_unet_best.pth

Epoch 22/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.95it/s, loss=0.0890]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.95it/s]


Train Loss: 0.0802
Val Dice - Class 1: 0.8500, Class 2: 0.8312, Avg: 0.8406
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8406) -> /workspace/outputs/pure_unet_best.pth

Epoch 23/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.10it/s, loss=0.0727]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.69it/s]


Train Loss: 0.0788
Val Dice - Class 1: 0.8522, Class 2: 0.8319, Avg: 0.8421
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8421) -> /workspace/outputs/pure_unet_best.pth

Epoch 24/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.80it/s, loss=0.0785]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.40it/s]


Train Loss: 0.0794
Val Dice - Class 1: 0.8461, Class 2: 0.8333, Avg: 0.8397
Learning Rate: 0.000250

Epoch 25/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 32.27it/s, loss=0.0814]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.36it/s]


Train Loss: 0.0788
Val Dice - Class 1: 0.8473, Class 2: 0.8378, Avg: 0.8425
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8425) -> /workspace/outputs/pure_unet_best.pth

Epoch 26/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.75it/s, loss=0.0944]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.73it/s]


Train Loss: 0.0800
Val Dice - Class 1: 0.8395, Class 2: 0.8173, Avg: 0.8284
Learning Rate: 0.000250

Epoch 27/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.59it/s, loss=0.0681]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.96it/s]


Train Loss: 0.0790
Val Dice - Class 1: 0.8515, Class 2: 0.8339, Avg: 0.8427
Learning Rate: 0.000250
✓ 保存最佳模型 (Dice: 0.8427) -> /workspace/outputs/pure_unet_best.pth

Epoch 28/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.63it/s, loss=0.0790]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.43it/s]


Train Loss: 0.0781
Val Dice - Class 1: 0.8521, Class 2: 0.8295, Avg: 0.8408
Learning Rate: 0.000250

Epoch 29/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.60it/s, loss=0.0619]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.63it/s]


Train Loss: 0.0787
Val Dice - Class 1: 0.8498, Class 2: 0.8351, Avg: 0.8424
Learning Rate: 0.000250

Epoch 30/30
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:03<00:00, 31.48it/s, loss=0.0616]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.64it/s]

Train Loss: 0.0781
Val Dice - Class 1: 0.8457, Class 2: 0.8304, Avg: 0.8380
Learning Rate: 0.000125

訓練完成！
最佳驗證 Dice: 0.8427
模型已保存至: /workspace/outputs/pure_unet_best.pth





In [52]:
#!/usr/bin/env python3
"""
完整的 nnU-Net 實現
基於論文: nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation

核心特性:
- nnU-Net 架構（Leaky ReLU + Instance Norm）
- 深度監督（Deep Supervision）
- 數據增強（旋轉、縮放、彈性變形等）
- Dice + CE 組合損失
- Poly 學習率調度
- 5-fold 交叉驗證（簡化版）
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from pathlib import Path
import random
from tqdm import tqdm

# 檢查 scipy 是否可用（不在這裡導入以避免版本衝突）
SCIPY_AVAILABLE = False
try:
    import scipy
    SCIPY_AVAILABLE = True
except:
    pass

if not SCIPY_AVAILABLE:
    print("警告: scipy 不可用，將使用簡化的數據增強")

# ==================== 設置隨機種子 ====================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# ==================== nnU-Net 架構組件 ====================

class nnUNetConvBlock(nn.Module):
    """nnU-Net 卷積塊: Conv -> InstanceNorm -> LeakyReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.InstanceNorm3d(out_channels, affine=True)
        self.activation = nn.LeakyReLU(negative_slope=0.01, inplace=True)
    
    def forward(self, x):
        return self.activation(self.norm(self.conv(x)))


class nnUNetResidualBlock(nn.Module):
    """nnU-Net 的雙卷積塊（類殘差結構）"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nnUNetConvBlock(in_channels, out_channels)
        self.conv2 = nnUNetConvBlock(out_channels, out_channels)
        
        # 如果通道數改變，需要 1x1 卷積調整
        self.skip = None
        if in_channels != out_channels:
            self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        
        if self.skip is not None:
            residual = self.skip(residual)
        
        return out + residual


class nnUNetDownsample(nn.Module):
    """nnU-Net 下採樣: Strided Convolution"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nnUNetConvBlock(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
    
    def forward(self, x):
        return self.conv(x)


class nnUNetUpsample(nn.Module):
    """nnU-Net 上採樣: Transposed Convolution"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x):
        return self.upconv(x)


class nnUNet(nn.Module):
    """
    完整的 nnU-Net 架構（3D）
    - 使用 Instance Normalization
    - 使用 Leaky ReLU
    - 使用 Strided Convolution 下採樣
    - 支持深度監督
    """
    def __init__(self, in_channels=1, num_classes=3, base_channels=32, num_pool=3, deep_supervision=True):
        super().__init__()
        
        self.num_pool = num_pool
        self.deep_supervision = deep_supervision
        
        # 編碼器
        self.encoders = nn.ModuleList()
        self.downsamplers = nn.ModuleList()
        
        current_channels = in_channels
        for i in range(num_pool + 1):
            out_channels = base_channels * (2 ** i)
            self.encoders.append(nnUNetResidualBlock(current_channels, out_channels))
            
            if i < num_pool:
                self.downsamplers.append(nnUNetDownsample(out_channels, out_channels))
            
            current_channels = out_channels
        
        # 解碼器
        self.upsamplers = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        for i in range(num_pool):
            in_ch = base_channels * (2 ** (num_pool - i))
            out_ch = base_channels * (2 ** (num_pool - i - 1))
            
            self.upsamplers.append(nnUNetUpsample(in_ch, out_ch))
            self.decoders.append(nnUNetResidualBlock(in_ch, out_ch))  # in_ch 因為有 skip connection
        
        # 輸出頭（多個用於深度監督）
        self.seg_outputs = nn.ModuleList()
        for i in range(num_pool + 1):
            out_ch = base_channels * (2 ** i) if i == num_pool else base_channels * (2 ** i)
            self.seg_outputs.append(nn.Conv3d(out_ch if i == 0 else base_channels * (2 ** i), num_classes, kernel_size=1))
    
    def forward(self, x):
        # 編碼路徑
        encoder_outputs = []
        current = x
        
        for i, encoder in enumerate(self.encoders):
            current = encoder(current)
            encoder_outputs.append(current)
            
            if i < self.num_pool:
                current = self.downsamplers[i](current)
        
        # 解碼路徑
        seg_outputs = []
        
        # 最深層的輸出（用於深度監督）
        if self.deep_supervision:
            seg_outputs.append(self.seg_outputs[-1](encoder_outputs[-1]))
        
        current = encoder_outputs[-1]
        
        for i in range(self.num_pool):
            # 上採樣
            current = self.upsamplers[i](current)
            
            # Skip connection
            skip = encoder_outputs[-(i + 2)]
            current = torch.cat([current, skip], dim=1)
            
            # 解碼塊
            current = self.decoders[i](current)
            
            # 深度監督輸出
            if self.deep_supervision:
                seg_outputs.append(self.seg_outputs[-(i + 2)](current))
        
        # 最終輸出
        final_output = self.seg_outputs[0](current) if not self.deep_supervision else seg_outputs[-1]
        
        if self.deep_supervision and self.training:
            # 反轉順序，從淺到深
            return list(reversed(seg_outputs))
        else:
            return final_output


# ==================== nnU-Net 損失函數 ====================

class nnUNetLoss(nn.Module):
    """
    nnU-Net 損失: Dice + CE
    支持深度監督
    """
    def __init__(self, deep_supervision_weights=None, dice_weight=1.0, ce_weight=1.0):
        super().__init__()
        self.deep_supervision_weights = deep_supervision_weights
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.ce_loss = nn.CrossEntropyLoss()
    
    def dice_loss(self, pred, target, smooth=1.0):
        """Soft Dice Loss"""
        pred = F.softmax(pred, dim=1)
        
        # 計算每個類別的 Dice
        dice_scores = []
        for c in range(pred.shape[1]):
            pred_c = pred[:, c]
            target_c = (target == c).float()
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
        
        return 1.0 - torch.stack(dice_scores).mean()
    
    def forward(self, outputs, target):
        if isinstance(outputs, (list, tuple)):
            # 深度監督
            if self.deep_supervision_weights is None:
                # 默認權重：越深層權重越小
                weights = [1.0 / (2 ** i) for i in range(len(outputs))]
                weights = [w / sum(weights) for w in weights]
            else:
                weights = self.deep_supervision_weights
            
            total_loss = 0
            for i, output in enumerate(outputs):
                # 需要調整 target 大小以匹配輸出
                if output.shape[2:] != target.shape[1:]:
                    target_resized = F.interpolate(
                        target.unsqueeze(1).float(),
                        size=output.shape[2:],
                        mode='nearest'
                    ).squeeze(1).long()
                else:
                    target_resized = target
                
                ce = self.ce_loss(output, target_resized)
                dice = self.dice_loss(output, target_resized)
                total_loss += weights[i] * (self.ce_weight * ce + self.dice_weight * dice)
            
            return total_loss
        else:
            # 單一輸出
            ce = self.ce_loss(outputs, target)
            dice = self.dice_loss(outputs, target)
            return self.ce_weight * ce + self.dice_weight * dice


# ==================== nnU-Net 數據增強 ====================

class nnUNetAugmentation:
    """nnU-Net 風格的數據增強"""
    
    @staticmethod
    def random_rotation(image, label, angle_range=(-15, 15)):
        """隨機旋轉"""
        if not SCIPY_AVAILABLE:
            # 簡化版：只做 90 度旋轉
            if random.random() > 0.5:
                k = random.randint(1, 3)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = np.rot90(image, k, axes).copy()
                label = np.rot90(label, k, axes).copy()
            return image, label
        
        if random.random() > 0.5:
            try:
                from scipy.ndimage import rotate
                angle = random.uniform(*angle_range)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = rotate(image, angle, axes=axes, reshape=False, order=3, mode='constant')
                label = rotate(label, angle, axes=axes, reshape=False, order=0, mode='constant')
            except:
                # 如果導入失敗，使用 90 度旋轉
                k = random.randint(1, 3)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = np.rot90(image, k, axes).copy()
                label = np.rot90(label, k, axes).copy()
        return image, label
    
    @staticmethod
    def random_scaling(image, label, scale_range=(0.85, 1.25)):
        """隨機縮放"""
        if not SCIPY_AVAILABLE:
            return image, label
        
        if random.random() > 0.5:
            try:
                from scipy.ndimage import zoom
                scale = random.uniform(*scale_range)
                scales = [scale] * 3
                image = zoom(image, scales, order=3, mode='constant')
                label = zoom(label, scales, order=0, mode='constant')
            except:
                pass  # 如果失敗就跳過縮放
        return image, label
    
    @staticmethod
    def random_elastic_deformation(image, label, alpha=100, sigma=10):
        """彈性變形"""
        if not SCIPY_AVAILABLE:
            return image, label
        
        if random.random() > 0.3:  # 30% 機率
            try:
                from scipy.ndimage import gaussian_filter, map_coordinates
                shape = image.shape
                
                # 生成隨機位移場
                dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                
                x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
                indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)), np.reshape(z + dz, (-1, 1))
                
                image = map_coordinates(image, indices, order=3, mode='reflect').reshape(shape)
                label = map_coordinates(label, indices, order=0, mode='reflect').reshape(shape)
            except:
                pass  # 如果失敗就跳過彈性變形
        
        return image, label
    
    @staticmethod
    def random_gamma(image, gamma_range=(0.7, 1.5)):
        """隨機 Gamma 校正"""
        if random.random() > 0.5:
            gamma = random.uniform(*gamma_range)
            image_min = image.min()
            image_range = image.max() - image_min
            if image_range > 0:
                image = ((image - image_min) / image_range) ** gamma * image_range + image_min
        return image
    
    @staticmethod
    def random_brightness(image, brightness_range=(-0.2, 0.2)):
        """隨機亮度調整"""
        if random.random() > 0.5:
            brightness = random.uniform(*brightness_range)
            image = image + brightness * image.std()
        return image
    
    @staticmethod
    def random_contrast(image, contrast_range=(0.75, 1.25)):
        """隨機對比度調整"""
        if random.random() > 0.5:
            contrast = random.uniform(*contrast_range)
            mean = image.mean()
            image = (image - mean) * contrast + mean
        return image
    
    @staticmethod
    def random_flip(image, label):
        """隨機翻轉"""
        for axis in range(3):
            if random.random() > 0.5:
                image = np.flip(image, axis=axis).copy()
                label = np.flip(label, axis=axis).copy()
        return image, label


# ==================== nnU-Net 數據集 ====================

class nnUNetDataset(Dataset):
    """nnU-Net 風格的數據集"""
    def __init__(self, data_dir, target_size=64, is_train=True, use_augmentation=True):
        self.data_dir = Path(data_dir)
        self.target_size = target_size
        self.is_train = is_train
        self.use_augmentation = use_augmentation and is_train
        
        # 獲取所有圖像文件（過濾掉 macOS 隱藏文件）
        all_image_files = list((self.data_dir / 'imagesTr').glob('*.nii.gz'))
        self.image_files = sorted([f for f in all_image_files if not f.name.startswith('._')])
        
        all_label_files = list((self.data_dir / 'labelsTr').glob('*.nii.gz'))
        self.label_files = sorted([f for f in all_label_files if not f.name.startswith('._')])
        
        if len(self.image_files) == 0:
            raise ValueError(f"在 {self.data_dir / 'imagesTr'} 找不到任何 .nii.gz 文件！")
        
        print(f"找到 {len(self.image_files)} 個訓練樣本")
        
        self.aug = nnUNetAugmentation()
    
    def __len__(self):
        return len(self.image_files)
    
    def preprocess(self, image):
        """nnU-Net 風格預處理"""
        # Clip to percentiles
        p1, p99 = np.percentile(image[image > 0], [0.5, 99.5]) if (image > 0).any() else (0, 1)
        image = np.clip(image, p1, p99)
        
        # Z-score normalization (per image)
        mean = image[image > 0].mean() if (image > 0).any() else 0
        std = image[image > 0].std() if (image > 0).any() else 1
        image = (image - mean) / (std + 1e-8)
        
        return image
    
    def apply_augmentation(self, image, label):
        """應用 nnU-Net 數據增強"""
        # 幾何變換
        image, label = self.aug.random_rotation(image, label)
        image, label = self.aug.random_scaling(image, label)
        image, label = self.aug.random_flip(image, label)
        
        # 強度變換（僅對圖像）
        image = self.aug.random_gamma(image)
        image = self.aug.random_brightness(image)
        image = self.aug.random_contrast(image)
        
        return image, label
    
    def __getitem__(self, idx):
        # 讀取文件
        image = nib.load(self.image_files[idx]).get_fdata(dtype=np.float32)
        label = nib.load(self.label_files[idx]).get_fdata(dtype=np.float32)
        
        # 預處理
        image = self.preprocess(image)
        
        # 數據增強
        if self.use_augmentation:
            image, label = self.apply_augmentation(image, label)
        
        # 轉為 tensor 並調整大小
        image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0)  # [1, 1, H, W, D]
        label = torch.from_numpy(label).unsqueeze(0).unsqueeze(0)  # [1, 1, H, W, D]
        
        image = F.interpolate(
            image,
            size=(self.target_size, self.target_size, self.target_size),
            mode='trilinear',
            align_corners=False
        ).squeeze(0)
        
        label = F.interpolate(
            label,
            size=(self.target_size, self.target_size, self.target_size),
            mode='nearest'
        ).squeeze(0)
        
        label = label.squeeze(0).long()
        label = torch.clamp(label, 0, 2)
        
        return image, label


# ==================== 評估指標 ====================

def compute_dice(pred, target, num_classes=3):
    """計算 Dice 分數"""
    dice_scores = []
    
    for c in range(1, num_classes):  # 跳過背景
        pred_c = (pred == c)
        target_c = (target == c)
        
        intersection = (pred_c & target_c).sum().float()
        union = pred_c.sum().float() + target_c.sum().float()
        
        if union == 0:
            dice = 1.0 if intersection == 0 else 0.0
        else:
            dice = (2.0 * intersection) / union
        
        dice_scores.append(dice.item())
    
    return dice_scores


# ==================== Poly 學習率調度器 ====================

class PolynomialLRScheduler:
    """nnU-Net 使用的 Polynomial 學習率調度"""
    def __init__(self, optimizer, initial_lr, max_epochs, power=0.9):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.max_epochs = max_epochs
        self.power = power
        self.current_epoch = 0
    
    def step(self):
        self.current_epoch += 1
        lr = self.initial_lr * (1 - self.current_epoch / self.max_epochs) ** self.power
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr


# ==================== 訓練函數 ====================

def train_epoch(model, loader, criterion, optimizer, device):
    """訓練一個 epoch"""
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向傳播
        outputs = model(images)
        
        # 計算損失（支持深度監督）
        loss = criterion(outputs, labels)
        
        # 反向傳播
        optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪（nnU-Net 使用）
        torch.nn.utils.clip_grad_norm_(model.parameters(), 12)
        
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)


def validate(model, loader, device):
    """驗證"""
    model.eval()
    all_dice_scores = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            
            # 如果是深度監督，只取最終輸出
            if isinstance(outputs, (list, tuple)):
                outputs = outputs[-1]
            
            preds = outputs.argmax(dim=1)
            
            # 計算 Dice
            for pred, label in zip(preds, labels):
                dice_scores = compute_dice(pred.cpu(), label.cpu())
                all_dice_scores.append(dice_scores)
    
    # 平均 Dice
    all_dice_scores = np.array(all_dice_scores)
    mean_dice = all_dice_scores.mean(axis=0)
    
    return mean_dice


# ==================== 自動檢測數據路徑 ====================

def find_data_directory():
    """自動查找數據目錄"""
    possible_paths = [
        '/home/claude/Task04_Hippocampus',
        '/workspace/data/Task04_Hippocampus',
        './Task04_Hippocampus',
        './data/Task04_Hippocampus',
        '../data/Task04_Hippocampus',
        '/data/Task04_Hippocampus',
    ]
    
    for path in possible_paths:
        if Path(path).exists():
            images_dir = Path(path) / 'imagesTr'
            if images_dir.exists():
                # 過濾掉 macOS 隱藏文件
                image_files = [f for f in images_dir.glob('*.nii.gz') if not f.name.startswith('._')]
                if len(image_files) > 0:
                    print(f"✓ 找到數據目錄: {path}")
                    print(f"  ({len(image_files)} 個有效文件)")
                    return path
    
    print("\n❌ 找不到數據目錄！")
    return None


# ==================== 主訓練流程 ====================

def main():
    # 自動查找數據目錄
    data_dir = find_data_directory()
    
    if data_dir is None:
        print("\n" + "=" * 60)
        print("錯誤：找不到數據集！")
        print("=" * 60)
        return
    
    # nnU-Net 配置
    config = {
        'data_dir': data_dir,
        'batch_size': 2,
        'num_epochs': 100,  # nnU-Net 通常訓練更長時間
        'initial_lr': 1e-2,  # nnU-Net 使用較大的初始學習率
        'base_channels': 32,  # nnU-Net 使用更多通道
        'num_pool': 3,  # 下採樣層數
        'target_size': 64,
        'deep_supervision': True,
        'use_augmentation': True,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    
    print("\n" + "=" * 60)
    print("完整的 nnU-Net 訓練")
    print("=" * 60)
    print(f"配置:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print()
    print(f"SciPy 可用: {SCIPY_AVAILABLE}")
    if not SCIPY_AVAILABLE:
        print("  ⚠️  將使用簡化的數據增強（僅翻轉和 90° 旋轉）")
    print()
    
    device = torch.device(config['device'])
    
    # 創建數據集
    print("準備數據集...")
    try:
        full_dataset = nnUNetDataset(
            config['data_dir'],
            target_size=config['target_size'],
            is_train=True,
            use_augmentation=config['use_augmentation']
        )
    except ValueError as e:
        print(f"\n錯誤: {e}")
        return
    
    # 訓練/驗證劃分 (80/20)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"訓練集: {len(train_dataset)} 樣本")
    print(f"驗證集: {len(val_dataset)} 樣本")
    print()
    
    # 創建 nnU-Net 模型
    print("創建 nnU-Net 模型...")
    model = nnUNet(
        in_channels=1,
        num_classes=3,
        base_channels=config['base_channels'],
        num_pool=config['num_pool'],
        deep_supervision=config['deep_supervision']
    ).to(device)
    
    # 計算參數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型參數量: {total_params:,} ({total_params/1e6:.2f}M)")
    print()
    
    # 損失函數和優化器
    criterion = nnUNetLoss(dice_weight=1.0, ce_weight=1.0)
    
    # nnU-Net 使用 SGD with momentum 和 Nesterov
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config['initial_lr'],
        momentum=0.99,
        weight_decay=3e-5,
        nesterov=True
    )
    
    # Polynomial 學習率調度
    scheduler = PolynomialLRScheduler(
        optimizer,
        initial_lr=config['initial_lr'],
        max_epochs=config['num_epochs'],
        power=0.9
    )
    
    # 訓練循環
    print("開始訓練...")
    print(f"使用深度監督: {config['deep_supervision']}")
    print(f"使用數據增強: {config['use_augmentation']}")
    print()
    
    best_dice = 0.0
    
    for epoch in range(config['num_epochs']):
        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print("-" * 60)
        
        # 訓練
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # 驗證
        val_dice = validate(model, val_loader, device)
        
        # 學習率調度
        current_lr = scheduler.step()
        
        # 打印結果
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Dice - Class 1: {val_dice[0]:.4f}, Class 2: {val_dice[1]:.4f}, "
              f"Avg: {val_dice.mean():.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # 保存最佳模型
        if val_dice.mean() > best_dice:
            best_dice = val_dice.mean()
            
            # 確保輸出目錄存在
            output_dir = Path('/workspace/outputs')
            if not output_dir.exists():
                output_dir = Path('./outputs')
                output_dir.mkdir(exist_ok=True)
            
            save_path = output_dir / 'nnunet_best.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'dice': val_dice,
                'config': config
            }, save_path)
            print(f"✓ 保存最佳模型 (Dice: {best_dice:.4f}) -> {save_path}")
        
        print()
    
    print("=" * 60)
    print("訓練完成！")
    print(f"最佳驗證 Dice: {best_dice:.4f}")
    print(f"模型已保存至: {save_path}")
    print("=" * 60)


if __name__ == '__main__':
    main()


✓ 找到數據目錄: ./Task04_Hippocampus
  (260 個有效文件)

完整的 nnU-Net 訓練
配置:
  data_dir: ./Task04_Hippocampus
  batch_size: 2
  num_epochs: 100
  initial_lr: 0.01
  base_channels: 32
  num_pool: 3
  target_size: 64
  deep_supervision: True
  use_augmentation: True
  device: cuda

SciPy 可用: True

準備數據集...
找到 260 個訓練樣本
訓練集: 208 樣本
驗證集: 52 樣本

創建 nnU-Net 模型...
模型參數量: 6,271,980 (6.27M)

開始訓練...
使用深度監督: True
使用數據增強: True

Epoch 1/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.32it/s, loss=0.4146]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.90it/s]


Train Loss: 0.6857
Val Dice - Class 1: 0.6209, Class 2: 0.6890, Avg: 0.6549
Learning Rate: 0.009910
✓ 保存最佳模型 (Dice: 0.6549) -> /workspace/outputs/nnunet_best.pth

Epoch 2/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.10it/s, loss=0.2751]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.47it/s]


Train Loss: 0.3103
Val Dice - Class 1: 0.7304, Class 2: 0.7598, Avg: 0.7451
Learning Rate: 0.009820
✓ 保存最佳模型 (Dice: 0.7451) -> /workspace/outputs/nnunet_best.pth

Epoch 3/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.04it/s, loss=0.2592]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.41it/s]


Train Loss: 0.2503
Val Dice - Class 1: 0.7698, Class 2: 0.7791, Avg: 0.7745
Learning Rate: 0.009730
✓ 保存最佳模型 (Dice: 0.7745) -> /workspace/outputs/nnunet_best.pth

Epoch 4/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.36it/s, loss=0.2159]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.09it/s]


Train Loss: 0.2270
Val Dice - Class 1: 0.7947, Class 2: 0.7938, Avg: 0.7942
Learning Rate: 0.009639
✓ 保存最佳模型 (Dice: 0.7942) -> /workspace/outputs/nnunet_best.pth

Epoch 5/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.80it/s, loss=0.1724]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.13it/s]


Train Loss: 0.2143
Val Dice - Class 1: 0.7947, Class 2: 0.7970, Avg: 0.7959
Learning Rate: 0.009549
✓ 保存最佳模型 (Dice: 0.7959) -> /workspace/outputs/nnunet_best.pth

Epoch 6/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.98it/s, loss=0.2008]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.28it/s]


Train Loss: 0.2040
Val Dice - Class 1: 0.8088, Class 2: 0.7982, Avg: 0.8035
Learning Rate: 0.009458
✓ 保存最佳模型 (Dice: 0.8035) -> /workspace/outputs/nnunet_best.pth

Epoch 7/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.26it/s, loss=0.1725]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.31it/s]


Train Loss: 0.1973
Val Dice - Class 1: 0.8149, Class 2: 0.8051, Avg: 0.8100
Learning Rate: 0.009368
✓ 保存最佳模型 (Dice: 0.8100) -> /workspace/outputs/nnunet_best.pth

Epoch 8/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.00it/s, loss=0.2426]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.65it/s]


Train Loss: 0.1978
Val Dice - Class 1: 0.8212, Class 2: 0.8132, Avg: 0.8172
Learning Rate: 0.009277
✓ 保存最佳模型 (Dice: 0.8172) -> /workspace/outputs/nnunet_best.pth

Epoch 9/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.02it/s, loss=0.1516]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.63it/s]


Train Loss: 0.1893
Val Dice - Class 1: 0.8189, Class 2: 0.8140, Avg: 0.8165
Learning Rate: 0.009186

Epoch 10/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.05it/s, loss=0.1726]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.06it/s]


Train Loss: 0.1846
Val Dice - Class 1: 0.8244, Class 2: 0.8150, Avg: 0.8197
Learning Rate: 0.009095
✓ 保存最佳模型 (Dice: 0.8197) -> /workspace/outputs/nnunet_best.pth

Epoch 11/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.96it/s, loss=0.1821]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.33it/s]


Train Loss: 0.1814
Val Dice - Class 1: 0.8288, Class 2: 0.8192, Avg: 0.8240
Learning Rate: 0.009004
✓ 保存最佳模型 (Dice: 0.8240) -> /workspace/outputs/nnunet_best.pth

Epoch 12/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.16it/s, loss=0.2224]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.14it/s]


Train Loss: 0.1811
Val Dice - Class 1: 0.8264, Class 2: 0.8224, Avg: 0.8244
Learning Rate: 0.008913
✓ 保存最佳模型 (Dice: 0.8244) -> /workspace/outputs/nnunet_best.pth

Epoch 13/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.04it/s, loss=0.1965]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.41it/s]


Train Loss: 0.1771
Val Dice - Class 1: 0.8279, Class 2: 0.8143, Avg: 0.8211
Learning Rate: 0.008822

Epoch 14/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.92it/s, loss=0.2030]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.66it/s]


Train Loss: 0.1742
Val Dice - Class 1: 0.8345, Class 2: 0.8236, Avg: 0.8291
Learning Rate: 0.008731
✓ 保存最佳模型 (Dice: 0.8291) -> /workspace/outputs/nnunet_best.pth

Epoch 15/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.75it/s, loss=0.2407]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Train Loss: 0.1748
Val Dice - Class 1: 0.8323, Class 2: 0.8258, Avg: 0.8291
Learning Rate: 0.008639
✓ 保存最佳模型 (Dice: 0.8291) -> /workspace/outputs/nnunet_best.pth

Epoch 16/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.77it/s, loss=0.1472]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.57it/s]


Train Loss: 0.1735
Val Dice - Class 1: 0.8361, Class 2: 0.8200, Avg: 0.8280
Learning Rate: 0.008548

Epoch 17/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.26it/s, loss=0.1576]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.32it/s]


Train Loss: 0.1712
Val Dice - Class 1: 0.8350, Class 2: 0.8208, Avg: 0.8279
Learning Rate: 0.008456

Epoch 18/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.35it/s, loss=0.1474]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.22it/s]


Train Loss: 0.1716
Val Dice - Class 1: 0.8368, Class 2: 0.8222, Avg: 0.8295
Learning Rate: 0.008364
✓ 保存最佳模型 (Dice: 0.8295) -> /workspace/outputs/nnunet_best.pth

Epoch 19/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.08it/s, loss=0.1765]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.59it/s]


Train Loss: 0.1670
Val Dice - Class 1: 0.8444, Class 2: 0.8275, Avg: 0.8360
Learning Rate: 0.008272
✓ 保存最佳模型 (Dice: 0.8360) -> /workspace/outputs/nnunet_best.pth

Epoch 20/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.15it/s, loss=0.1579]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.11it/s]


Train Loss: 0.1681
Val Dice - Class 1: 0.8478, Class 2: 0.8264, Avg: 0.8371
Learning Rate: 0.008181
✓ 保存最佳模型 (Dice: 0.8371) -> /workspace/outputs/nnunet_best.pth

Epoch 21/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.94it/s, loss=0.1647]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.34it/s]


Train Loss: 0.1652
Val Dice - Class 1: 0.8454, Class 2: 0.8291, Avg: 0.8373
Learning Rate: 0.008088
✓ 保存最佳模型 (Dice: 0.8373) -> /workspace/outputs/nnunet_best.pth

Epoch 22/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.24it/s, loss=0.1386]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.31it/s]


Train Loss: 0.1638
Val Dice - Class 1: 0.8425, Class 2: 0.8311, Avg: 0.8368
Learning Rate: 0.007996

Epoch 23/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.26it/s, loss=0.1782]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.60it/s]


Train Loss: 0.1660
Val Dice - Class 1: 0.8401, Class 2: 0.8116, Avg: 0.8258
Learning Rate: 0.007904

Epoch 24/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.07it/s, loss=0.1542]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.45it/s]


Train Loss: 0.1645
Val Dice - Class 1: 0.8443, Class 2: 0.8306, Avg: 0.8375
Learning Rate: 0.007811
✓ 保存最佳模型 (Dice: 0.8375) -> /workspace/outputs/nnunet_best.pth

Epoch 25/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.00it/s, loss=0.1252]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.55it/s]


Train Loss: 0.1614
Val Dice - Class 1: 0.8459, Class 2: 0.8310, Avg: 0.8385
Learning Rate: 0.007719
✓ 保存最佳模型 (Dice: 0.8385) -> /workspace/outputs/nnunet_best.pth

Epoch 26/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.94it/s, loss=0.1834]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.47it/s]


Train Loss: 0.1612
Val Dice - Class 1: 0.8481, Class 2: 0.8314, Avg: 0.8398
Learning Rate: 0.007626
✓ 保存最佳模型 (Dice: 0.8398) -> /workspace/outputs/nnunet_best.pth

Epoch 27/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.14it/s, loss=0.1442]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.92it/s]


Train Loss: 0.1603
Val Dice - Class 1: 0.8492, Class 2: 0.8307, Avg: 0.8400
Learning Rate: 0.007533
✓ 保存最佳模型 (Dice: 0.8400) -> /workspace/outputs/nnunet_best.pth

Epoch 28/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.29it/s, loss=0.1483]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.53it/s]


Train Loss: 0.1601
Val Dice - Class 1: 0.8448, Class 2: 0.8294, Avg: 0.8371
Learning Rate: 0.007440

Epoch 29/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.79it/s, loss=0.2072]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.53it/s]


Train Loss: 0.1596
Val Dice - Class 1: 0.8522, Class 2: 0.8333, Avg: 0.8428
Learning Rate: 0.007347
✓ 保存最佳模型 (Dice: 0.8428) -> /workspace/outputs/nnunet_best.pth

Epoch 30/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.06it/s, loss=0.1859]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.71it/s]


Train Loss: 0.1580
Val Dice - Class 1: 0.8488, Class 2: 0.8340, Avg: 0.8414
Learning Rate: 0.007254

Epoch 31/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.00it/s, loss=0.1590]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.44it/s]


Train Loss: 0.1580
Val Dice - Class 1: 0.8451, Class 2: 0.8334, Avg: 0.8392
Learning Rate: 0.007161

Epoch 32/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.96it/s, loss=0.1449]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.05it/s]


Train Loss: 0.1546
Val Dice - Class 1: 0.8488, Class 2: 0.8361, Avg: 0.8425
Learning Rate: 0.007067

Epoch 33/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.09it/s, loss=0.1603]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.44it/s]


Train Loss: 0.1556
Val Dice - Class 1: 0.8513, Class 2: 0.8245, Avg: 0.8379
Learning Rate: 0.006974

Epoch 34/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.87it/s, loss=0.1641]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.36it/s]


Train Loss: 0.1559
Val Dice - Class 1: 0.8461, Class 2: 0.8340, Avg: 0.8401
Learning Rate: 0.006880

Epoch 35/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.07it/s, loss=0.1296]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.20it/s]


Train Loss: 0.1549
Val Dice - Class 1: 0.8479, Class 2: 0.8298, Avg: 0.8389
Learning Rate: 0.006786

Epoch 36/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.03it/s, loss=0.2202]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.99it/s]


Train Loss: 0.1539
Val Dice - Class 1: 0.8507, Class 2: 0.8348, Avg: 0.8428
Learning Rate: 0.006692

Epoch 37/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.25it/s, loss=0.1766]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.51it/s]


Train Loss: 0.1522
Val Dice - Class 1: 0.8551, Class 2: 0.8332, Avg: 0.8442
Learning Rate: 0.006598
✓ 保存最佳模型 (Dice: 0.8442) -> /workspace/outputs/nnunet_best.pth

Epoch 38/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.98it/s, loss=0.1441]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.17it/s]


Train Loss: 0.1537
Val Dice - Class 1: 0.8523, Class 2: 0.8354, Avg: 0.8438
Learning Rate: 0.006504

Epoch 39/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.20it/s, loss=0.1755]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.48it/s]


Train Loss: 0.1539
Val Dice - Class 1: 0.8567, Class 2: 0.8417, Avg: 0.8492
Learning Rate: 0.006409
✓ 保存最佳模型 (Dice: 0.8492) -> /workspace/outputs/nnunet_best.pth

Epoch 40/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.18it/s, loss=0.1352]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.60it/s]


Train Loss: 0.1536
Val Dice - Class 1: 0.8543, Class 2: 0.8372, Avg: 0.8457
Learning Rate: 0.006314

Epoch 41/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.82it/s, loss=0.1327]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.01it/s]


Train Loss: 0.1513
Val Dice - Class 1: 0.8543, Class 2: 0.8347, Avg: 0.8445
Learning Rate: 0.006220

Epoch 42/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.32it/s, loss=0.1674]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.76it/s]


Train Loss: 0.1514
Val Dice - Class 1: 0.8550, Class 2: 0.8339, Avg: 0.8444
Learning Rate: 0.006125

Epoch 43/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.43it/s, loss=0.1328]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.22it/s]


Train Loss: 0.1482
Val Dice - Class 1: 0.8548, Class 2: 0.8403, Avg: 0.8475
Learning Rate: 0.006030

Epoch 44/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.85it/s, loss=0.1863]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.79it/s]


Train Loss: 0.1483
Val Dice - Class 1: 0.8525, Class 2: 0.8332, Avg: 0.8429
Learning Rate: 0.005934

Epoch 45/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.22it/s, loss=0.1644]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.14it/s]


Train Loss: 0.1530
Val Dice - Class 1: 0.8588, Class 2: 0.8406, Avg: 0.8497
Learning Rate: 0.005839
✓ 保存最佳模型 (Dice: 0.8497) -> /workspace/outputs/nnunet_best.pth

Epoch 46/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.24it/s, loss=0.1548]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.34it/s]


Train Loss: 0.1492
Val Dice - Class 1: 0.8523, Class 2: 0.8349, Avg: 0.8436
Learning Rate: 0.005743

Epoch 47/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.38it/s, loss=0.1297]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.45it/s]


Train Loss: 0.1490
Val Dice - Class 1: 0.8564, Class 2: 0.8418, Avg: 0.8491
Learning Rate: 0.005647

Epoch 48/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.35it/s, loss=0.1354]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.71it/s]


Train Loss: 0.1483
Val Dice - Class 1: 0.8607, Class 2: 0.8427, Avg: 0.8517
Learning Rate: 0.005551
✓ 保存最佳模型 (Dice: 0.8517) -> /workspace/outputs/nnunet_best.pth

Epoch 49/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.44it/s, loss=0.1453]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.58it/s]


Train Loss: 0.1478
Val Dice - Class 1: 0.8569, Class 2: 0.8390, Avg: 0.8480
Learning Rate: 0.005455

Epoch 50/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.57it/s, loss=0.1586]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.70it/s]


Train Loss: 0.1470
Val Dice - Class 1: 0.8583, Class 2: 0.8379, Avg: 0.8481
Learning Rate: 0.005359

Epoch 51/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.53it/s, loss=0.1603]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.36it/s]


Train Loss: 0.1463
Val Dice - Class 1: 0.8585, Class 2: 0.8399, Avg: 0.8492
Learning Rate: 0.005262

Epoch 52/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.33it/s, loss=0.1127]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.59it/s]


Train Loss: 0.1447
Val Dice - Class 1: 0.8538, Class 2: 0.8376, Avg: 0.8457
Learning Rate: 0.005166

Epoch 53/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.99it/s, loss=0.1734]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.22it/s]


Train Loss: 0.1477
Val Dice - Class 1: 0.8584, Class 2: 0.8367, Avg: 0.8476
Learning Rate: 0.005069

Epoch 54/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.32it/s, loss=0.1253]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.84it/s]


Train Loss: 0.1459
Val Dice - Class 1: 0.8600, Class 2: 0.8421, Avg: 0.8511
Learning Rate: 0.004971

Epoch 55/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.35it/s, loss=0.1218]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.48it/s]


Train Loss: 0.1436
Val Dice - Class 1: 0.8606, Class 2: 0.8416, Avg: 0.8511
Learning Rate: 0.004874

Epoch 56/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.42it/s, loss=0.1137]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.63it/s]


Train Loss: 0.1451
Val Dice - Class 1: 0.8624, Class 2: 0.8420, Avg: 0.8522
Learning Rate: 0.004776
✓ 保存最佳模型 (Dice: 0.8522) -> /workspace/outputs/nnunet_best.pth

Epoch 57/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.56it/s, loss=0.1239]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.43it/s]


Train Loss: 0.1435
Val Dice - Class 1: 0.8620, Class 2: 0.8462, Avg: 0.8541
Learning Rate: 0.004679
✓ 保存最佳模型 (Dice: 0.8541) -> /workspace/outputs/nnunet_best.pth

Epoch 58/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.44it/s, loss=0.1559]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.00it/s]


Train Loss: 0.1428
Val Dice - Class 1: 0.8613, Class 2: 0.8456, Avg: 0.8534
Learning Rate: 0.004581

Epoch 59/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.52it/s, loss=0.1973]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.49it/s]


Train Loss: 0.1412
Val Dice - Class 1: 0.8644, Class 2: 0.8429, Avg: 0.8537
Learning Rate: 0.004482

Epoch 60/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.39it/s, loss=0.1234]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.63it/s]


Train Loss: 0.1415
Val Dice - Class 1: 0.8626, Class 2: 0.8445, Avg: 0.8536
Learning Rate: 0.004384

Epoch 61/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.23it/s, loss=0.1289]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.40it/s]


Train Loss: 0.1424
Val Dice - Class 1: 0.8638, Class 2: 0.8433, Avg: 0.8536
Learning Rate: 0.004285

Epoch 62/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.59it/s, loss=0.1293]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.50it/s]


Train Loss: 0.1417
Val Dice - Class 1: 0.8601, Class 2: 0.8406, Avg: 0.8503
Learning Rate: 0.004186

Epoch 63/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:07<00:00, 13.11it/s, loss=0.1834]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.27it/s]


Train Loss: 0.1413
Val Dice - Class 1: 0.8624, Class 2: 0.8444, Avg: 0.8534
Learning Rate: 0.004087

Epoch 64/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.51it/s, loss=0.2080]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.26it/s]


Train Loss: 0.1429
Val Dice - Class 1: 0.8612, Class 2: 0.8396, Avg: 0.8504
Learning Rate: 0.003987

Epoch 65/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.43it/s, loss=0.1526]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.51it/s]


Train Loss: 0.1417
Val Dice - Class 1: 0.8628, Class 2: 0.8398, Avg: 0.8513
Learning Rate: 0.003887

Epoch 66/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.21it/s, loss=0.1415]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.26it/s]


Train Loss: 0.1399
Val Dice - Class 1: 0.8626, Class 2: 0.8468, Avg: 0.8547
Learning Rate: 0.003787
✓ 保存最佳模型 (Dice: 0.8547) -> /workspace/outputs/nnunet_best.pth

Epoch 67/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.23it/s, loss=0.1230]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.40it/s]


Train Loss: 0.1408
Val Dice - Class 1: 0.8664, Class 2: 0.8456, Avg: 0.8560
Learning Rate: 0.003687
✓ 保存最佳模型 (Dice: 0.8560) -> /workspace/outputs/nnunet_best.pth

Epoch 68/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.07it/s, loss=0.1630]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.07it/s]


Train Loss: 0.1390
Val Dice - Class 1: 0.8612, Class 2: 0.8405, Avg: 0.8509
Learning Rate: 0.003586

Epoch 69/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.03it/s, loss=0.1194]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.82it/s]


Train Loss: 0.1395
Val Dice - Class 1: 0.8620, Class 2: 0.8405, Avg: 0.8513
Learning Rate: 0.003485

Epoch 70/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.48it/s, loss=0.1557]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.54it/s]


Train Loss: 0.1381
Val Dice - Class 1: 0.8627, Class 2: 0.8463, Avg: 0.8545
Learning Rate: 0.003384

Epoch 71/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.42it/s, loss=0.1368]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.34it/s]


Train Loss: 0.1368
Val Dice - Class 1: 0.8638, Class 2: 0.8444, Avg: 0.8541
Learning Rate: 0.003282

Epoch 72/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.18it/s, loss=0.1579]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.59it/s]


Train Loss: 0.1357
Val Dice - Class 1: 0.8651, Class 2: 0.8460, Avg: 0.8555
Learning Rate: 0.003180

Epoch 73/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:07<00:00, 13.03it/s, loss=0.1581]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.95it/s]


Train Loss: 0.1381
Val Dice - Class 1: 0.8636, Class 2: 0.8458, Avg: 0.8547
Learning Rate: 0.003078

Epoch 74/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.19it/s, loss=0.1365]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.35it/s]


Train Loss: 0.1360
Val Dice - Class 1: 0.8617, Class 2: 0.8452, Avg: 0.8534
Learning Rate: 0.002975

Epoch 75/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.46it/s, loss=0.1247]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.51it/s]


Train Loss: 0.1350
Val Dice - Class 1: 0.8646, Class 2: 0.8478, Avg: 0.8562
Learning Rate: 0.002872
✓ 保存最佳模型 (Dice: 0.8562) -> /workspace/outputs/nnunet_best.pth

Epoch 76/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.42it/s, loss=0.1273]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.23it/s]


Train Loss: 0.1361
Val Dice - Class 1: 0.8615, Class 2: 0.8451, Avg: 0.8533
Learning Rate: 0.002768

Epoch 77/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.17it/s, loss=0.1342]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.57it/s]


Train Loss: 0.1347
Val Dice - Class 1: 0.8652, Class 2: 0.8408, Avg: 0.8530
Learning Rate: 0.002664

Epoch 78/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.46it/s, loss=0.1157]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.52it/s]


Train Loss: 0.1357
Val Dice - Class 1: 0.8638, Class 2: 0.8475, Avg: 0.8556
Learning Rate: 0.002560

Epoch 79/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.75it/s, loss=0.1491]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.74it/s]


Train Loss: 0.1337
Val Dice - Class 1: 0.8668, Class 2: 0.8490, Avg: 0.8579
Learning Rate: 0.002455
✓ 保存最佳模型 (Dice: 0.8579) -> /workspace/outputs/nnunet_best.pth

Epoch 80/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.32it/s, loss=0.1278]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.44it/s]


Train Loss: 0.1342
Val Dice - Class 1: 0.8679, Class 2: 0.8488, Avg: 0.8584
Learning Rate: 0.002349
✓ 保存最佳模型 (Dice: 0.8584) -> /workspace/outputs/nnunet_best.pth

Epoch 81/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.13it/s, loss=0.1451]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.39it/s]


Train Loss: 0.1341
Val Dice - Class 1: 0.8659, Class 2: 0.8472, Avg: 0.8565
Learning Rate: 0.002243

Epoch 82/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.79it/s, loss=0.1194]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.23it/s]


Train Loss: 0.1325
Val Dice - Class 1: 0.8651, Class 2: 0.8432, Avg: 0.8541
Learning Rate: 0.002137

Epoch 83/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.99it/s, loss=0.1404]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.63it/s]


Train Loss: 0.1324
Val Dice - Class 1: 0.8659, Class 2: 0.8489, Avg: 0.8574
Learning Rate: 0.002030

Epoch 84/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.73it/s, loss=0.1334]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.37it/s]


Train Loss: 0.1311
Val Dice - Class 1: 0.8688, Class 2: 0.8500, Avg: 0.8594
Learning Rate: 0.001922
✓ 保存最佳模型 (Dice: 0.8594) -> /workspace/outputs/nnunet_best.pth

Epoch 85/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.40it/s, loss=0.1167]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.50it/s]


Train Loss: 0.1309
Val Dice - Class 1: 0.8692, Class 2: 0.8497, Avg: 0.8595
Learning Rate: 0.001813
✓ 保存最佳模型 (Dice: 0.8595) -> /workspace/outputs/nnunet_best.pth

Epoch 86/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.28it/s, loss=0.1140]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.36it/s]


Train Loss: 0.1288
Val Dice - Class 1: 0.8654, Class 2: 0.8506, Avg: 0.8580
Learning Rate: 0.001704

Epoch 87/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.41it/s, loss=0.1443]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.34it/s]


Train Loss: 0.1308
Val Dice - Class 1: 0.8685, Class 2: 0.8506, Avg: 0.8595
Learning Rate: 0.001594
✓ 保存最佳模型 (Dice: 0.8595) -> /workspace/outputs/nnunet_best.pth

Epoch 88/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.09it/s, loss=0.1125]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.24it/s]


Train Loss: 0.1293
Val Dice - Class 1: 0.8678, Class 2: 0.8486, Avg: 0.8582
Learning Rate: 0.001483

Epoch 89/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.38it/s, loss=0.1265]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.60it/s]


Train Loss: 0.1303
Val Dice - Class 1: 0.8683, Class 2: 0.8482, Avg: 0.8583
Learning Rate: 0.001372

Epoch 90/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.46it/s, loss=0.1469]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.57it/s]


Train Loss: 0.1293
Val Dice - Class 1: 0.8707, Class 2: 0.8496, Avg: 0.8601
Learning Rate: 0.001259
✓ 保存最佳模型 (Dice: 0.8601) -> /workspace/outputs/nnunet_best.pth

Epoch 91/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.15it/s, loss=0.1358]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.61it/s]


Train Loss: 0.1276
Val Dice - Class 1: 0.8684, Class 2: 0.8505, Avg: 0.8594
Learning Rate: 0.001145

Epoch 92/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.81it/s, loss=0.1364]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.49it/s]


Train Loss: 0.1280
Val Dice - Class 1: 0.8722, Class 2: 0.8537, Avg: 0.8629
Learning Rate: 0.001030
✓ 保存最佳模型 (Dice: 0.8629) -> /workspace/outputs/nnunet_best.pth

Epoch 93/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.07it/s, loss=0.1418]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.25it/s]


Train Loss: 0.1254
Val Dice - Class 1: 0.8697, Class 2: 0.8510, Avg: 0.8603
Learning Rate: 0.000913

Epoch 94/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.11it/s, loss=0.1217]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.16it/s]


Train Loss: 0.1270
Val Dice - Class 1: 0.8708, Class 2: 0.8527, Avg: 0.8617
Learning Rate: 0.000795

Epoch 95/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.17it/s, loss=0.1147]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.28it/s]


Train Loss: 0.1270
Val Dice - Class 1: 0.8697, Class 2: 0.8514, Avg: 0.8606
Learning Rate: 0.000675

Epoch 96/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.97it/s, loss=0.1824]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.41it/s]


Train Loss: 0.1283
Val Dice - Class 1: 0.8695, Class 2: 0.8525, Avg: 0.8610
Learning Rate: 0.000552

Epoch 97/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.23it/s, loss=0.1218]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1248
Val Dice - Class 1: 0.8723, Class 2: 0.8544, Avg: 0.8634
Learning Rate: 0.000426
✓ 保存最佳模型 (Dice: 0.8634) -> /workspace/outputs/nnunet_best.pth

Epoch 98/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 11.96it/s, loss=0.1215]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.55it/s]


Train Loss: 0.1254
Val Dice - Class 1: 0.8722, Class 2: 0.8529, Avg: 0.8626
Learning Rate: 0.000296

Epoch 99/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.22it/s, loss=0.1237]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.23it/s]


Train Loss: 0.1267
Val Dice - Class 1: 0.8719, Class 2: 0.8530, Avg: 0.8625
Learning Rate: 0.000158

Epoch 100/100
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:08<00:00, 12.13it/s, loss=0.1153]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.57it/s]

Train Loss: 0.1253
Val Dice - Class 1: 0.8691, Class 2: 0.8517, Avg: 0.8604
Learning Rate: 0.000000

訓練完成！
最佳驗證 Dice: 0.8634
模型已保存至: /workspace/outputs/nnunet_best.pth





In [54]:
#!/usr/bin/env python3
"""
增強版 nnU-Net - 用於 Ensemble
- 200 epochs（更長訓練）
- 48 base channels（更大模型）
- 支持多個隨機種子訓練
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import nibabel as nib
from pathlib import Path
import random
from tqdm import tqdm
import argparse

# 檢查 scipy 是否可用
SCIPY_AVAILABLE = False
try:
    import scipy
    SCIPY_AVAILABLE = True
except:
    pass

# ==================== 設置隨機種子 ====================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ==================== nnU-Net 架構（與之前相同）====================

class nnUNetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.norm = nn.InstanceNorm3d(out_channels, affine=True)
        self.activation = nn.LeakyReLU(negative_slope=0.01, inplace=True)
    
    def forward(self, x):
        return self.activation(self.norm(self.conv(x)))

class nnUNetResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nnUNetConvBlock(in_channels, out_channels)
        self.conv2 = nnUNetConvBlock(out_channels, out_channels)
        self.skip = None
        if in_channels != out_channels:
            self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.skip is not None:
            residual = self.skip(residual)
        return out + residual

class nnUNetDownsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nnUNetConvBlock(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
    
    def forward(self, x):
        return self.conv(x)

class nnUNetUpsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
    
    def forward(self, x):
        return self.upconv(x)

class nnUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=3, base_channels=48, num_pool=3, deep_supervision=True):
        super().__init__()
        self.num_pool = num_pool
        self.deep_supervision = deep_supervision
        
        # 編碼器
        self.encoders = nn.ModuleList()
        self.downsamplers = nn.ModuleList()
        
        current_channels = in_channels
        for i in range(num_pool + 1):
            out_channels = base_channels * (2 ** i)
            self.encoders.append(nnUNetResidualBlock(current_channels, out_channels))
            if i < num_pool:
                self.downsamplers.append(nnUNetDownsample(out_channels, out_channels))
            current_channels = out_channels
        
        # 解碼器
        self.upsamplers = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        for i in range(num_pool):
            in_ch = base_channels * (2 ** (num_pool - i))
            out_ch = base_channels * (2 ** (num_pool - i - 1))
            self.upsamplers.append(nnUNetUpsample(in_ch, out_ch))
            self.decoders.append(nnUNetResidualBlock(in_ch, out_ch))
        
        # 輸出頭
        self.seg_outputs = nn.ModuleList()
        for i in range(num_pool + 1):
            out_ch = base_channels * (2 ** i)
            self.seg_outputs.append(nn.Conv3d(out_ch, num_classes, kernel_size=1))
    
    def forward(self, x):
        encoder_outputs = []
        current = x
        
        for i, encoder in enumerate(self.encoders):
            current = encoder(current)
            encoder_outputs.append(current)
            if i < self.num_pool:
                current = self.downsamplers[i](current)
        
        seg_outputs = []
        if self.deep_supervision:
            seg_outputs.append(self.seg_outputs[-1](encoder_outputs[-1]))
        
        current = encoder_outputs[-1]
        for i in range(self.num_pool):
            current = self.upsamplers[i](current)
            skip = encoder_outputs[-(i + 2)]
            current = torch.cat([current, skip], dim=1)
            current = self.decoders[i](current)
            if self.deep_supervision:
                seg_outputs.append(self.seg_outputs[-(i + 2)](current))
        
        final_output = self.seg_outputs[0](current) if not self.deep_supervision else seg_outputs[-1]
        
        if self.deep_supervision and self.training:
            return list(reversed(seg_outputs))
        else:
            return final_output

# ==================== 損失函數 ====================

class nnUNetLoss(nn.Module):
    def __init__(self, deep_supervision_weights=None, dice_weight=1.0, ce_weight=1.0):
        super().__init__()
        self.deep_supervision_weights = deep_supervision_weights
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.ce_loss = nn.CrossEntropyLoss()
    
    def dice_loss(self, pred, target, smooth=1.0):
        pred = F.softmax(pred, dim=1)
        dice_scores = []
        for c in range(pred.shape[1]):
            pred_c = pred[:, c]
            target_c = (target == c).float()
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            dice = (2.0 * intersection + smooth) / (union + smooth)
            dice_scores.append(dice)
        return 1.0 - torch.stack(dice_scores).mean()
    
    def forward(self, outputs, target):
        if isinstance(outputs, (list, tuple)):
            if self.deep_supervision_weights is None:
                weights = [1.0 / (2 ** i) for i in range(len(outputs))]
                weights = [w / sum(weights) for w in weights]
            else:
                weights = self.deep_supervision_weights
            
            total_loss = 0
            for i, output in enumerate(outputs):
                if output.shape[2:] != target.shape[1:]:
                    target_resized = F.interpolate(
                        target.unsqueeze(1).float(),
                        size=output.shape[2:],
                        mode='nearest'
                    ).squeeze(1).long()
                else:
                    target_resized = target
                
                ce = self.ce_loss(output, target_resized)
                dice = self.dice_loss(output, target_resized)
                total_loss += weights[i] * (self.ce_weight * ce + self.dice_weight * dice)
            return total_loss
        else:
            ce = self.ce_loss(outputs, target)
            dice = self.dice_loss(outputs, target)
            return self.ce_weight * ce + self.dice_weight * dice

# ==================== 數據增強 ====================

class nnUNetAugmentation:
    @staticmethod
    def random_rotation(image, label, angle_range=(-20, 20)):  # 增強到 ±20°
        if not SCIPY_AVAILABLE:
            if random.random() > 0.5:
                k = random.randint(1, 3)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = np.rot90(image, k, axes).copy()
                label = np.rot90(label, k, axes).copy()
            return image, label
        
        if random.random() > 0.5:
            try:
                from scipy.ndimage import rotate
                angle = random.uniform(*angle_range)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = rotate(image, angle, axes=axes, reshape=False, order=3, mode='constant')
                label = rotate(label, angle, axes=axes, reshape=False, order=0, mode='constant')
            except:
                k = random.randint(1, 3)
                axes = random.choice([(0, 1), (0, 2), (1, 2)])
                image = np.rot90(image, k, axes).copy()
                label = np.rot90(label, k, axes).copy()
        return image, label
    
    @staticmethod
    def random_scaling(image, label, scale_range=(0.8, 1.3)):  # 增強縮放範圍
        if not SCIPY_AVAILABLE:
            return image, label
        
        if random.random() > 0.5:
            try:
                from scipy.ndimage import zoom
                scale = random.uniform(*scale_range)
                scales = [scale] * 3
                image = zoom(image, scales, order=3, mode='constant')
                label = zoom(label, scales, order=0, mode='constant')
            except:
                pass
        return image, label
    
    @staticmethod
    def random_elastic_deformation(image, label, alpha=100, sigma=10):
        if not SCIPY_AVAILABLE:
            return image, label
        
        if random.random() > 0.3:
            try:
                from scipy.ndimage import gaussian_filter, map_coordinates
                shape = image.shape
                dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma) * alpha
                x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]), indexing='ij')
                indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)), np.reshape(z + dz, (-1, 1))
                image = map_coordinates(image, indices, order=3, mode='reflect').reshape(shape)
                label = map_coordinates(label, indices, order=0, mode='reflect').reshape(shape)
            except:
                pass
        return image, label
    
    @staticmethod
    def random_gamma(image, gamma_range=(0.7, 1.5)):
        if random.random() > 0.5:
            gamma = random.uniform(*gamma_range)
            image_min = image.min()
            image_range = image.max() - image_min
            if image_range > 0:
                image = ((image - image_min) / image_range) ** gamma * image_range + image_min
        return image
    
    @staticmethod
    def random_brightness(image, brightness_range=(-0.2, 0.2)):
        if random.random() > 0.5:
            brightness = random.uniform(*brightness_range)
            image = image + brightness * image.std()
        return image
    
    @staticmethod
    def random_contrast(image, contrast_range=(0.75, 1.25)):
        if random.random() > 0.5:
            contrast = random.uniform(*contrast_range)
            mean = image.mean()
            image = (image - mean) * contrast + mean
        return image
    
    @staticmethod
    def random_flip(image, label):
        for axis in range(3):
            if random.random() > 0.5:
                image = np.flip(image, axis=axis).copy()
                label = np.flip(label, axis=axis).copy()
        return image, label

# ==================== 數據集 ====================

class nnUNetDataset(Dataset):
    def __init__(self, data_dir, target_size=64, is_train=True, use_augmentation=True):
        self.data_dir = Path(data_dir)
        self.target_size = target_size
        self.is_train = is_train
        self.use_augmentation = use_augmentation and is_train
        
        all_image_files = list((self.data_dir / 'imagesTr').glob('*.nii.gz'))
        self.image_files = sorted([f for f in all_image_files if not f.name.startswith('._')])
        all_label_files = list((self.data_dir / 'labelsTr').glob('*.nii.gz'))
        self.label_files = sorted([f for f in all_label_files if not f.name.startswith('._')])
        
        if len(self.image_files) == 0:
            raise ValueError(f"在 {self.data_dir / 'imagesTr'} 找不到任何 .nii.gz 文件！")
        
        self.aug = nnUNetAugmentation()
    
    def __len__(self):
        return len(self.image_files)
    
    def preprocess(self, image):
        p1, p99 = np.percentile(image[image > 0], [0.5, 99.5]) if (image > 0).any() else (0, 1)
        image = np.clip(image, p1, p99)
        mean = image[image > 0].mean() if (image > 0).any() else 0
        std = image[image > 0].std() if (image > 0).any() else 1
        image = (image - mean) / (std + 1e-8)
        return image
    
    def apply_augmentation(self, image, label):
        image, label = self.aug.random_rotation(image, label)
        image, label = self.aug.random_scaling(image, label)
        image, label = self.aug.random_flip(image, label)
        image = self.aug.random_gamma(image)
        image = self.aug.random_brightness(image)
        image = self.aug.random_contrast(image)
        return image, label
    
    def __getitem__(self, idx):
        image = nib.load(self.image_files[idx]).get_fdata(dtype=np.float32)
        label = nib.load(self.label_files[idx]).get_fdata(dtype=np.float32)
        
        image = self.preprocess(image)
        
        if self.use_augmentation:
            image, label = self.apply_augmentation(image, label)
        
        image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0)
        label = torch.from_numpy(label).unsqueeze(0).unsqueeze(0)
        
        image = F.interpolate(image, size=(self.target_size, self.target_size, self.target_size),
                            mode='trilinear', align_corners=False).squeeze(0)
        label = F.interpolate(label, size=(self.target_size, self.target_size, self.target_size),
                            mode='nearest').squeeze(0)
        
        label = label.squeeze(0).long()
        label = torch.clamp(label, 0, 2)
        return image, label

# ==================== 評估指標 ====================

def compute_dice(pred, target, num_classes=3):
    dice_scores = []
    for c in range(1, num_classes):
        pred_c = (pred == c)
        target_c = (target == c)
        intersection = (pred_c & target_c).sum().float()
        union = pred_c.sum().float() + target_c.sum().float()
        if union == 0:
            dice = 1.0 if intersection == 0 else 0.0
        else:
            dice = (2.0 * intersection) / union
        dice_scores.append(dice.item())
    return dice_scores

# ==================== 學習率調度器 ====================

class PolynomialLRScheduler:
    def __init__(self, optimizer, initial_lr, max_epochs, power=0.9):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.max_epochs = max_epochs
        self.power = power
        self.current_epoch = 0
    
    def step(self):
        self.current_epoch += 1
        lr = self.initial_lr * (1 - self.current_epoch / self.max_epochs) ** self.power
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

# ==================== 訓練函數 ====================

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 12)
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    all_dice_scores = []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            if isinstance(outputs, (list, tuple)):
                outputs = outputs[-1]
            preds = outputs.argmax(dim=1)
            for pred, label in zip(preds, labels):
                dice_scores = compute_dice(pred.cpu(), label.cpu())
                all_dice_scores.append(dice_scores)
    all_dice_scores = np.array(all_dice_scores)
    mean_dice = all_dice_scores.mean(axis=0)
    return mean_dice

# ==================== 主訓練流程 ====================

def main(seed=42):
    set_seed(seed)
    
    config = {
        'data_dir': './Task04_Hippocampus',
        'batch_size': 2,
        'num_epochs': 200,  # 增加到 200
        'initial_lr': 1e-2,
        'base_channels': 48,  # 增加到 48
        'num_pool': 3,
        'target_size': 64,
        'deep_supervision': True,
        'use_augmentation': True,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'seed': seed
    }
    
    print(f"\n{'='*60}")
    print(f"增強版 nnU-Net 訓練 (Seed: {seed})")
    print(f"{'='*60}")
    print(f"配置:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    print()
    
    device = torch.device(config['device'])
    
    # 數據集
    full_dataset = nnUNetDataset(
        config['data_dir'],
        target_size=config['target_size'],
        is_train=True,
        use_augmentation=config['use_augmentation']
    )
    
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(seed)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'],
                            shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'],
                          shuffle=False, num_workers=4, pin_memory=True)
    
    print(f"訓練集: {len(train_dataset)} 樣本")
    print(f"驗證集: {len(val_dataset)} 樣本\n")
    
    # 模型
    model = nnUNet(
        in_channels=1,
        num_classes=3,
        base_channels=config['base_channels'],
        num_pool=config['num_pool'],
        deep_supervision=config['deep_supervision']
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型參數量: {total_params:,} ({total_params/1e6:.2f}M)\n")
    
    # 優化器
    criterion = nnUNetLoss(dice_weight=1.0, ce_weight=1.0)
    optimizer = torch.optim.SGD(model.parameters(), lr=config['initial_lr'],
                               momentum=0.99, weight_decay=3e-5, nesterov=True)
    scheduler = PolynomialLRScheduler(optimizer, config['initial_lr'],
                                     config['num_epochs'], power=0.9)
    
    # 訓練
    best_dice = 0.0
    output_dir = Path('./outputs')
    output_dir.mkdir(exist_ok=True)
    
    for epoch in range(config['num_epochs']):
        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print("-" * 60)
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_dice = validate(model, val_loader, device)
        current_lr = scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Dice - Class 1: {val_dice[0]:.4f}, Class 2: {val_dice[1]:.4f}, "
              f"Avg: {val_dice.mean():.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        if val_dice.mean() > best_dice:
            best_dice = val_dice.mean()
            save_path = output_dir / f'nnunet_enhanced_seed{seed}_best.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'dice': val_dice,
                'config': config
            }, save_path)
            print(f"✓ 保存最佳模型 (Dice: {best_dice:.4f})")
        print()
    
    print("=" * 60)
    print(f"訓練完成！Seed: {seed}")
    print(f"最佳驗證 Dice: {best_dice:.4f}")
    print("=" * 60)
    
    return best_dice

if __name__ == '__main__':
    # 檢測是否在 Jupyter 環境中
    try:
        get_ipython()
        # 在 Jupyter 中，直接使用默認種子
        print("檢測到 Jupyter 環境，使用默認種子 42")
        main(seed=42)
    except NameError:
        # 在命令行中，使用 argparse
        parser = argparse.ArgumentParser()
        parser.add_argument('--seed', type=int, default=42, help='隨機種子')
        args = parser.parse_args()
        main(seed=args.seed)


檢測到 Jupyter 環境，使用默認種子 42

增強版 nnU-Net 訓練 (Seed: 42)
配置:
  data_dir: ./Task04_Hippocampus
  batch_size: 2
  num_epochs: 200
  initial_lr: 0.01
  base_channels: 48
  num_pool: 3
  target_size: 64
  deep_supervision: True
  use_augmentation: True
  device: cuda
  seed: 42

訓練集: 208 樣本
驗證集: 52 樣本

模型參數量: 14,105,820 (14.11M)

Epoch 1/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.22it/s, loss=0.4846]
Validating: 100%|██████████| 26/26 [00:03<00:00,  6.59it/s]


Train Loss: 0.6262
Val Dice - Class 1: 0.6574, Class 2: 0.7210, Avg: 0.6892
Learning Rate: 0.009955
✓ 保存最佳模型 (Dice: 0.6892)

Epoch 2/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.22it/s, loss=0.2565]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Train Loss: 0.3083
Val Dice - Class 1: 0.7604, Class 2: 0.7597, Avg: 0.7601
Learning Rate: 0.009910
✓ 保存最佳模型 (Dice: 0.7601)

Epoch 3/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.25it/s, loss=0.3606]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.47it/s]


Train Loss: 0.2527
Val Dice - Class 1: 0.7627, Class 2: 0.7887, Avg: 0.7757
Learning Rate: 0.009865
✓ 保存最佳模型 (Dice: 0.7757)

Epoch 4/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.22it/s, loss=0.3312]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.26it/s]


Train Loss: 0.2279
Val Dice - Class 1: 0.7816, Class 2: 0.7886, Avg: 0.7851
Learning Rate: 0.009820
✓ 保存最佳模型 (Dice: 0.7851)

Epoch 5/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.22it/s, loss=0.2139]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.44it/s]


Train Loss: 0.2142
Val Dice - Class 1: 0.8059, Class 2: 0.8074, Avg: 0.8067
Learning Rate: 0.009775
✓ 保存最佳模型 (Dice: 0.8067)

Epoch 6/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.17it/s, loss=0.1980]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.25it/s]


Train Loss: 0.2056
Val Dice - Class 1: 0.8091, Class 2: 0.8054, Avg: 0.8072
Learning Rate: 0.009730
✓ 保存最佳模型 (Dice: 0.8072)

Epoch 7/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.3041]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Train Loss: 0.1974
Val Dice - Class 1: 0.8150, Class 2: 0.8128, Avg: 0.8139
Learning Rate: 0.009684
✓ 保存最佳模型 (Dice: 0.8139)

Epoch 8/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.1863]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.93it/s]


Train Loss: 0.1951
Val Dice - Class 1: 0.8179, Class 2: 0.8091, Avg: 0.8135
Learning Rate: 0.009639

Epoch 9/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.17it/s, loss=0.2450]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.48it/s]


Train Loss: 0.1924
Val Dice - Class 1: 0.8231, Class 2: 0.8178, Avg: 0.8205
Learning Rate: 0.009594
✓ 保存最佳模型 (Dice: 0.8205)

Epoch 10/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.16it/s, loss=0.1478]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.20it/s]


Train Loss: 0.1828
Val Dice - Class 1: 0.8331, Class 2: 0.8205, Avg: 0.8268
Learning Rate: 0.009549
✓ 保存最佳模型 (Dice: 0.8268)

Epoch 11/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.17it/s, loss=0.1918]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.08it/s]


Train Loss: 0.1806
Val Dice - Class 1: 0.8321, Class 2: 0.8194, Avg: 0.8258
Learning Rate: 0.009504

Epoch 12/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.1885]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.47it/s]


Train Loss: 0.1776
Val Dice - Class 1: 0.8354, Class 2: 0.8243, Avg: 0.8299
Learning Rate: 0.009458
✓ 保存最佳模型 (Dice: 0.8299)

Epoch 13/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.16it/s, loss=0.1763]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1769
Val Dice - Class 1: 0.8329, Class 2: 0.8222, Avg: 0.8276
Learning Rate: 0.009413

Epoch 14/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.21it/s, loss=0.1331]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.58it/s]


Train Loss: 0.1730
Val Dice - Class 1: 0.8334, Class 2: 0.8175, Avg: 0.8255
Learning Rate: 0.009368

Epoch 15/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.1415]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1730
Val Dice - Class 1: 0.8389, Class 2: 0.8259, Avg: 0.8324
Learning Rate: 0.009322
✓ 保存最佳模型 (Dice: 0.8324)

Epoch 16/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.16it/s, loss=0.1551]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.24it/s]


Train Loss: 0.1717
Val Dice - Class 1: 0.8380, Class 2: 0.8231, Avg: 0.8306
Learning Rate: 0.009277

Epoch 17/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.2227]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.21it/s]


Train Loss: 0.1711
Val Dice - Class 1: 0.8451, Class 2: 0.8293, Avg: 0.8372
Learning Rate: 0.009232
✓ 保存最佳模型 (Dice: 0.8372)

Epoch 18/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.15it/s, loss=0.2356]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.30it/s]


Train Loss: 0.1674
Val Dice - Class 1: 0.8401, Class 2: 0.8271, Avg: 0.8336
Learning Rate: 0.009186

Epoch 19/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.21it/s, loss=0.1827]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.28it/s]


Train Loss: 0.1665
Val Dice - Class 1: 0.8454, Class 2: 0.8338, Avg: 0.8396
Learning Rate: 0.009141
✓ 保存最佳模型 (Dice: 0.8396)

Epoch 20/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.17it/s, loss=0.1469]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.93it/s]


Train Loss: 0.1664
Val Dice - Class 1: 0.8422, Class 2: 0.8304, Avg: 0.8363
Learning Rate: 0.009095

Epoch 21/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1294]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.66it/s]


Train Loss: 0.1649
Val Dice - Class 1: 0.8469, Class 2: 0.8308, Avg: 0.8389
Learning Rate: 0.009050

Epoch 22/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.17it/s, loss=0.1923]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.20it/s]


Train Loss: 0.1633
Val Dice - Class 1: 0.8405, Class 2: 0.8321, Avg: 0.8363
Learning Rate: 0.009004

Epoch 23/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.19it/s, loss=0.1628]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.88it/s]


Train Loss: 0.1653
Val Dice - Class 1: 0.8462, Class 2: 0.8294, Avg: 0.8378
Learning Rate: 0.008959

Epoch 24/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.19it/s, loss=0.1754]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1642
Val Dice - Class 1: 0.8367, Class 2: 0.8283, Avg: 0.8325
Learning Rate: 0.008913

Epoch 25/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.21it/s, loss=0.1230]
Validating: 100%|██████████| 26/26 [00:02<00:00, 12.20it/s]


Train Loss: 0.1618
Val Dice - Class 1: 0.8488, Class 2: 0.8312, Avg: 0.8400
Learning Rate: 0.008868
✓ 保存最佳模型 (Dice: 0.8400)

Epoch 26/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.24it/s, loss=0.1456]
Validating: 100%|██████████| 26/26 [00:02<00:00, 12.03it/s]


Train Loss: 0.1600
Val Dice - Class 1: 0.8513, Class 2: 0.8328, Avg: 0.8421
Learning Rate: 0.008822
✓ 保存最佳模型 (Dice: 0.8421)

Epoch 27/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.28it/s, loss=0.1508]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.71it/s]


Train Loss: 0.1594
Val Dice - Class 1: 0.8542, Class 2: 0.8347, Avg: 0.8445
Learning Rate: 0.008776
✓ 保存最佳模型 (Dice: 0.8445)

Epoch 28/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.23it/s, loss=0.2022]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.67it/s]


Train Loss: 0.1578
Val Dice - Class 1: 0.8475, Class 2: 0.8272, Avg: 0.8374
Learning Rate: 0.008731

Epoch 29/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.29it/s, loss=0.1569]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.22it/s]


Train Loss: 0.1556
Val Dice - Class 1: 0.8500, Class 2: 0.8312, Avg: 0.8406
Learning Rate: 0.008685

Epoch 30/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.24it/s, loss=0.1726]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.58it/s]


Train Loss: 0.1543
Val Dice - Class 1: 0.8483, Class 2: 0.8344, Avg: 0.8414
Learning Rate: 0.008639

Epoch 31/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.28it/s, loss=0.1446]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.56it/s]


Train Loss: 0.1556
Val Dice - Class 1: 0.8515, Class 2: 0.8350, Avg: 0.8433
Learning Rate: 0.008594

Epoch 32/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.30it/s, loss=0.1331]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.00it/s]


Train Loss: 0.1583
Val Dice - Class 1: 0.8508, Class 2: 0.8349, Avg: 0.8429
Learning Rate: 0.008548

Epoch 33/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.28it/s, loss=0.1141]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.62it/s]


Train Loss: 0.1539
Val Dice - Class 1: 0.8539, Class 2: 0.8327, Avg: 0.8433
Learning Rate: 0.008502

Epoch 34/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.27it/s, loss=0.1787]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.17it/s]


Train Loss: 0.1557
Val Dice - Class 1: 0.8561, Class 2: 0.8370, Avg: 0.8466
Learning Rate: 0.008456
✓ 保存最佳模型 (Dice: 0.8466)

Epoch 35/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.28it/s, loss=0.1466]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.50it/s]


Train Loss: 0.1508
Val Dice - Class 1: 0.8527, Class 2: 0.8380, Avg: 0.8453
Learning Rate: 0.008410

Epoch 36/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.29it/s, loss=0.1990]
Validating: 100%|██████████| 26/26 [00:02<00:00, 10.68it/s]


Train Loss: 0.1515
Val Dice - Class 1: 0.8558, Class 2: 0.8397, Avg: 0.8478
Learning Rate: 0.008364
✓ 保存最佳模型 (Dice: 0.8478)

Epoch 37/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.27it/s, loss=0.1467]
Validating: 100%|██████████| 26/26 [00:02<00:00, 12.40it/s]


Train Loss: 0.1500
Val Dice - Class 1: 0.8558, Class 2: 0.8359, Avg: 0.8459
Learning Rate: 0.008318

Epoch 38/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.30it/s, loss=0.1297]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.97it/s]


Train Loss: 0.1491
Val Dice - Class 1: 0.8567, Class 2: 0.8392, Avg: 0.8480
Learning Rate: 0.008272
✓ 保存最佳模型 (Dice: 0.8480)

Epoch 39/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.27it/s, loss=0.1883]
Validating: 100%|██████████| 26/26 [00:02<00:00, 12.37it/s]


Train Loss: 0.1498
Val Dice - Class 1: 0.8581, Class 2: 0.8368, Avg: 0.8474
Learning Rate: 0.008227

Epoch 40/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.25it/s, loss=0.1431]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.89it/s]


Train Loss: 0.1477
Val Dice - Class 1: 0.8547, Class 2: 0.8424, Avg: 0.8486
Learning Rate: 0.008181
✓ 保存最佳模型 (Dice: 0.8486)

Epoch 41/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.20it/s, loss=0.1373]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.75it/s]


Train Loss: 0.1480
Val Dice - Class 1: 0.8566, Class 2: 0.8386, Avg: 0.8476
Learning Rate: 0.008134

Epoch 42/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.25it/s, loss=0.1336]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.91it/s]


Train Loss: 0.1470
Val Dice - Class 1: 0.8570, Class 2: 0.8415, Avg: 0.8492
Learning Rate: 0.008088
✓ 保存最佳模型 (Dice: 0.8492)

Epoch 43/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.26it/s, loss=0.1398]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.05it/s]


Train Loss: 0.1468
Val Dice - Class 1: 0.8590, Class 2: 0.8407, Avg: 0.8499
Learning Rate: 0.008042
✓ 保存最佳模型 (Dice: 0.8499)

Epoch 44/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.25it/s, loss=0.1486]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.49it/s]


Train Loss: 0.1468
Val Dice - Class 1: 0.8617, Class 2: 0.8435, Avg: 0.8526
Learning Rate: 0.007996
✓ 保存最佳模型 (Dice: 0.8526)

Epoch 45/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.27it/s, loss=0.1521]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.05it/s]


Train Loss: 0.1430
Val Dice - Class 1: 0.8628, Class 2: 0.8451, Avg: 0.8539
Learning Rate: 0.007950
✓ 保存最佳模型 (Dice: 0.8539)

Epoch 46/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1757]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.85it/s]


Train Loss: 0.1429
Val Dice - Class 1: 0.8624, Class 2: 0.8409, Avg: 0.8516
Learning Rate: 0.007904

Epoch 47/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1523]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.40it/s]


Train Loss: 0.1459
Val Dice - Class 1: 0.8592, Class 2: 0.8364, Avg: 0.8478
Learning Rate: 0.007858

Epoch 48/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1713]
Validating: 100%|██████████| 26/26 [00:02<00:00, 11.91it/s]


Train Loss: 0.1431
Val Dice - Class 1: 0.8591, Class 2: 0.8406, Avg: 0.8499
Learning Rate: 0.007811

Epoch 49/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1219]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.85it/s]


Train Loss: 0.1456
Val Dice - Class 1: 0.8569, Class 2: 0.8419, Avg: 0.8494
Learning Rate: 0.007765

Epoch 50/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1422]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.68it/s]


Train Loss: 0.1457
Val Dice - Class 1: 0.8571, Class 2: 0.8392, Avg: 0.8482
Learning Rate: 0.007719

Epoch 51/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1373]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.84it/s]


Train Loss: 0.1444
Val Dice - Class 1: 0.8617, Class 2: 0.8454, Avg: 0.8535
Learning Rate: 0.007673

Epoch 52/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1531]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.54it/s]


Train Loss: 0.1454
Val Dice - Class 1: 0.8556, Class 2: 0.8452, Avg: 0.8504
Learning Rate: 0.007626

Epoch 53/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.03it/s, loss=0.1287]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.04it/s]


Train Loss: 0.1433
Val Dice - Class 1: 0.8651, Class 2: 0.8458, Avg: 0.8554
Learning Rate: 0.007580
✓ 保存最佳模型 (Dice: 0.8554)

Epoch 54/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.05it/s, loss=0.1564]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.05it/s]


Train Loss: 0.1452
Val Dice - Class 1: 0.8554, Class 2: 0.8411, Avg: 0.8483
Learning Rate: 0.007533

Epoch 55/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1264]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.90it/s]


Train Loss: 0.1444
Val Dice - Class 1: 0.8602, Class 2: 0.8391, Avg: 0.8496
Learning Rate: 0.007487

Epoch 56/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.05it/s, loss=0.1651]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.97it/s]


Train Loss: 0.1420
Val Dice - Class 1: 0.8673, Class 2: 0.8454, Avg: 0.8564
Learning Rate: 0.007440
✓ 保存最佳模型 (Dice: 0.8564)

Epoch 57/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.1449]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1386
Val Dice - Class 1: 0.8636, Class 2: 0.8446, Avg: 0.8541
Learning Rate: 0.007394

Epoch 58/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1269]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.31it/s]


Train Loss: 0.1404
Val Dice - Class 1: 0.8632, Class 2: 0.8466, Avg: 0.8549
Learning Rate: 0.007347

Epoch 59/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.05it/s, loss=0.1062]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.78it/s]


Train Loss: 0.1407
Val Dice - Class 1: 0.8626, Class 2: 0.8436, Avg: 0.8531
Learning Rate: 0.007301

Epoch 60/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1392]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.83it/s]


Train Loss: 0.1394
Val Dice - Class 1: 0.8690, Class 2: 0.8506, Avg: 0.8598
Learning Rate: 0.007254
✓ 保存最佳模型 (Dice: 0.8598)

Epoch 61/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1357]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.99it/s]


Train Loss: 0.1376
Val Dice - Class 1: 0.8580, Class 2: 0.8442, Avg: 0.8511
Learning Rate: 0.007208

Epoch 62/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1666]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.70it/s]


Train Loss: 0.1363
Val Dice - Class 1: 0.8647, Class 2: 0.8435, Avg: 0.8541
Learning Rate: 0.007161

Epoch 63/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.16it/s, loss=0.1191]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.75it/s]


Train Loss: 0.1396
Val Dice - Class 1: 0.8645, Class 2: 0.8483, Avg: 0.8564
Learning Rate: 0.007114

Epoch 64/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1484]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.08it/s]


Train Loss: 0.1368
Val Dice - Class 1: 0.8690, Class 2: 0.8511, Avg: 0.8600
Learning Rate: 0.007067
✓ 保存最佳模型 (Dice: 0.8600)

Epoch 65/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.05it/s, loss=0.1635]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.64it/s]


Train Loss: 0.1378
Val Dice - Class 1: 0.8629, Class 2: 0.8469, Avg: 0.8549
Learning Rate: 0.007021

Epoch 66/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.03it/s, loss=0.1366]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


Train Loss: 0.1376
Val Dice - Class 1: 0.8622, Class 2: 0.8469, Avg: 0.8545
Learning Rate: 0.006974

Epoch 67/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.03it/s, loss=0.1303]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.06it/s]


Train Loss: 0.1391
Val Dice - Class 1: 0.8643, Class 2: 0.8434, Avg: 0.8538
Learning Rate: 0.006927

Epoch 68/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1269]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.72it/s]


Train Loss: 0.1368
Val Dice - Class 1: 0.8603, Class 2: 0.8399, Avg: 0.8501
Learning Rate: 0.006880

Epoch 69/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1553]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Train Loss: 0.1363
Val Dice - Class 1: 0.8612, Class 2: 0.8445, Avg: 0.8529
Learning Rate: 0.006833

Epoch 70/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1530]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.93it/s]


Train Loss: 0.1370
Val Dice - Class 1: 0.8645, Class 2: 0.8457, Avg: 0.8551
Learning Rate: 0.006786

Epoch 71/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1496]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.99it/s]


Train Loss: 0.1366
Val Dice - Class 1: 0.8605, Class 2: 0.8470, Avg: 0.8537
Learning Rate: 0.006739

Epoch 72/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1183]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.67it/s]


Train Loss: 0.1349
Val Dice - Class 1: 0.8626, Class 2: 0.8458, Avg: 0.8542
Learning Rate: 0.006692

Epoch 73/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1842]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.55it/s]


Train Loss: 0.1346
Val Dice - Class 1: 0.8658, Class 2: 0.8451, Avg: 0.8555
Learning Rate: 0.006645

Epoch 74/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1530]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.73it/s]


Train Loss: 0.1339
Val Dice - Class 1: 0.8634, Class 2: 0.8480, Avg: 0.8557
Learning Rate: 0.006598

Epoch 75/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1179]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.85it/s]


Train Loss: 0.1345
Val Dice - Class 1: 0.8647, Class 2: 0.8472, Avg: 0.8559
Learning Rate: 0.006551

Epoch 76/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1299]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.1346
Val Dice - Class 1: 0.8647, Class 2: 0.8492, Avg: 0.8569
Learning Rate: 0.006504

Epoch 77/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1140]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.01it/s]


Train Loss: 0.1334
Val Dice - Class 1: 0.8641, Class 2: 0.8459, Avg: 0.8550
Learning Rate: 0.006456

Epoch 78/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1619]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.00it/s]


Train Loss: 0.1346
Val Dice - Class 1: 0.8643, Class 2: 0.8474, Avg: 0.8558
Learning Rate: 0.006409

Epoch 79/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.1274]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.94it/s]


Train Loss: 0.1319
Val Dice - Class 1: 0.8656, Class 2: 0.8491, Avg: 0.8573
Learning Rate: 0.006362

Epoch 80/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1472]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


Train Loss: 0.1328
Val Dice - Class 1: 0.8671, Class 2: 0.8490, Avg: 0.8580
Learning Rate: 0.006314

Epoch 81/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1184]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.60it/s]


Train Loss: 0.1303
Val Dice - Class 1: 0.8660, Class 2: 0.8526, Avg: 0.8593
Learning Rate: 0.006267

Epoch 82/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.15it/s, loss=0.1048]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.16it/s]


Train Loss: 0.1324
Val Dice - Class 1: 0.8661, Class 2: 0.8478, Avg: 0.8569
Learning Rate: 0.006220

Epoch 83/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1463]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.88it/s]


Train Loss: 0.1309
Val Dice - Class 1: 0.8690, Class 2: 0.8495, Avg: 0.8592
Learning Rate: 0.006172

Epoch 84/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1390]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.54it/s]


Train Loss: 0.1318
Val Dice - Class 1: 0.8616, Class 2: 0.8452, Avg: 0.8534
Learning Rate: 0.006125

Epoch 85/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1432]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.67it/s]


Train Loss: 0.1296
Val Dice - Class 1: 0.8668, Class 2: 0.8504, Avg: 0.8586
Learning Rate: 0.006077

Epoch 86/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1193]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.56it/s]


Train Loss: 0.1308
Val Dice - Class 1: 0.8646, Class 2: 0.8512, Avg: 0.8579
Learning Rate: 0.006030

Epoch 87/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1507]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.66it/s]


Train Loss: 0.1273
Val Dice - Class 1: 0.8680, Class 2: 0.8494, Avg: 0.8587
Learning Rate: 0.005982

Epoch 88/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1139]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.80it/s]


Train Loss: 0.1265
Val Dice - Class 1: 0.8730, Class 2: 0.8556, Avg: 0.8643
Learning Rate: 0.005934
✓ 保存最佳模型 (Dice: 0.8643)

Epoch 89/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1208]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.28it/s]


Train Loss: 0.1247
Val Dice - Class 1: 0.8690, Class 2: 0.8531, Avg: 0.8611
Learning Rate: 0.005887

Epoch 90/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1211]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.1278
Val Dice - Class 1: 0.8655, Class 2: 0.8502, Avg: 0.8579
Learning Rate: 0.005839

Epoch 91/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1100]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.97it/s]


Train Loss: 0.1289
Val Dice - Class 1: 0.8683, Class 2: 0.8498, Avg: 0.8590
Learning Rate: 0.005791

Epoch 92/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.0964]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.76it/s]


Train Loss: 0.1259
Val Dice - Class 1: 0.8680, Class 2: 0.8518, Avg: 0.8599
Learning Rate: 0.005743

Epoch 93/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1336]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.09it/s]


Train Loss: 0.1282
Val Dice - Class 1: 0.8700, Class 2: 0.8509, Avg: 0.8605
Learning Rate: 0.005695

Epoch 94/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1106]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.10it/s]


Train Loss: 0.1270
Val Dice - Class 1: 0.8673, Class 2: 0.8489, Avg: 0.8581
Learning Rate: 0.005647

Epoch 95/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1544]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.17it/s]


Train Loss: 0.1262
Val Dice - Class 1: 0.8684, Class 2: 0.8553, Avg: 0.8618
Learning Rate: 0.005599

Epoch 96/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1359]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.09it/s]


Train Loss: 0.1255
Val Dice - Class 1: 0.8691, Class 2: 0.8508, Avg: 0.8599
Learning Rate: 0.005551

Epoch 97/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1368]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.23it/s]


Train Loss: 0.1278
Val Dice - Class 1: 0.8691, Class 2: 0.8504, Avg: 0.8597
Learning Rate: 0.005503

Epoch 98/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1017]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.52it/s]


Train Loss: 0.1265
Val Dice - Class 1: 0.8699, Class 2: 0.8515, Avg: 0.8607
Learning Rate: 0.005455

Epoch 99/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.1257]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.08it/s]


Train Loss: 0.1256
Val Dice - Class 1: 0.8711, Class 2: 0.8521, Avg: 0.8616
Learning Rate: 0.005407

Epoch 100/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1186]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.1252
Val Dice - Class 1: 0.8675, Class 2: 0.8532, Avg: 0.8604
Learning Rate: 0.005359

Epoch 101/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1254]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.09it/s]


Train Loss: 0.1234
Val Dice - Class 1: 0.8692, Class 2: 0.8534, Avg: 0.8613
Learning Rate: 0.005311

Epoch 102/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1344]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


Train Loss: 0.1237
Val Dice - Class 1: 0.8653, Class 2: 0.8484, Avg: 0.8568
Learning Rate: 0.005262

Epoch 103/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.1340]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.89it/s]


Train Loss: 0.1245
Val Dice - Class 1: 0.8698, Class 2: 0.8510, Avg: 0.8604
Learning Rate: 0.005214

Epoch 104/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.04it/s, loss=0.1169]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.21it/s]


Train Loss: 0.1228
Val Dice - Class 1: 0.8705, Class 2: 0.8545, Avg: 0.8625
Learning Rate: 0.005166

Epoch 105/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1184]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.40it/s]


Train Loss: 0.1215
Val Dice - Class 1: 0.8679, Class 2: 0.8502, Avg: 0.8590
Learning Rate: 0.005117

Epoch 106/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1149]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.13it/s]


Train Loss: 0.1230
Val Dice - Class 1: 0.8682, Class 2: 0.8544, Avg: 0.8613
Learning Rate: 0.005069

Epoch 107/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1239]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.61it/s]


Train Loss: 0.1229
Val Dice - Class 1: 0.8690, Class 2: 0.8552, Avg: 0.8621
Learning Rate: 0.005020

Epoch 108/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.0851]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.08it/s]


Train Loss: 0.1225
Val Dice - Class 1: 0.8720, Class 2: 0.8552, Avg: 0.8636
Learning Rate: 0.004971

Epoch 109/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1081]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.96it/s]


Train Loss: 0.1217
Val Dice - Class 1: 0.8680, Class 2: 0.8545, Avg: 0.8612
Learning Rate: 0.004923

Epoch 110/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1188]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.89it/s]


Train Loss: 0.1214
Val Dice - Class 1: 0.8698, Class 2: 0.8547, Avg: 0.8623
Learning Rate: 0.004874

Epoch 111/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1768]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.1234
Val Dice - Class 1: 0.8667, Class 2: 0.8521, Avg: 0.8594
Learning Rate: 0.004825

Epoch 112/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.1087]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.41it/s]


Train Loss: 0.1215
Val Dice - Class 1: 0.8656, Class 2: 0.8512, Avg: 0.8584
Learning Rate: 0.004776

Epoch 113/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0962]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.01it/s]


Train Loss: 0.1195
Val Dice - Class 1: 0.8724, Class 2: 0.8549, Avg: 0.8636
Learning Rate: 0.004728

Epoch 114/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1147]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.86it/s]


Train Loss: 0.1195
Val Dice - Class 1: 0.8710, Class 2: 0.8539, Avg: 0.8625
Learning Rate: 0.004679

Epoch 115/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.16it/s, loss=0.1530]
Validating: 100%|██████████| 26/26 [00:02<00:00,  8.86it/s]


Train Loss: 0.1199
Val Dice - Class 1: 0.8710, Class 2: 0.8539, Avg: 0.8625
Learning Rate: 0.004630

Epoch 116/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1308]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.98it/s]


Train Loss: 0.1223
Val Dice - Class 1: 0.8701, Class 2: 0.8540, Avg: 0.8621
Learning Rate: 0.004581

Epoch 117/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1105]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.88it/s]


Train Loss: 0.1204
Val Dice - Class 1: 0.8706, Class 2: 0.8554, Avg: 0.8630
Learning Rate: 0.004532

Epoch 118/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1169]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.15it/s]


Train Loss: 0.1201
Val Dice - Class 1: 0.8722, Class 2: 0.8552, Avg: 0.8637
Learning Rate: 0.004482

Epoch 119/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1187]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.71it/s]


Train Loss: 0.1187
Val Dice - Class 1: 0.8719, Class 2: 0.8515, Avg: 0.8617
Learning Rate: 0.004433

Epoch 120/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0828]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.96it/s]


Train Loss: 0.1187
Val Dice - Class 1: 0.8678, Class 2: 0.8521, Avg: 0.8599
Learning Rate: 0.004384

Epoch 121/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1338]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.39it/s]


Train Loss: 0.1187
Val Dice - Class 1: 0.8712, Class 2: 0.8550, Avg: 0.8631
Learning Rate: 0.004334

Epoch 122/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.18it/s, loss=0.0896]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.15it/s]


Train Loss: 0.1183
Val Dice - Class 1: 0.8714, Class 2: 0.8539, Avg: 0.8626
Learning Rate: 0.004285

Epoch 123/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1094]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.85it/s]


Train Loss: 0.1173
Val Dice - Class 1: 0.8695, Class 2: 0.8516, Avg: 0.8605
Learning Rate: 0.004236

Epoch 124/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1348]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.1174
Val Dice - Class 1: 0.8719, Class 2: 0.8531, Avg: 0.8625
Learning Rate: 0.004186

Epoch 125/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1013]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.69it/s]


Train Loss: 0.1162
Val Dice - Class 1: 0.8744, Class 2: 0.8570, Avg: 0.8657
Learning Rate: 0.004136
✓ 保存最佳模型 (Dice: 0.8657)

Epoch 126/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.04it/s, loss=0.1492]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.35it/s]


Train Loss: 0.1138
Val Dice - Class 1: 0.8703, Class 2: 0.8558, Avg: 0.8630
Learning Rate: 0.004087

Epoch 127/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1116]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.70it/s]


Train Loss: 0.1170
Val Dice - Class 1: 0.8708, Class 2: 0.8504, Avg: 0.8606
Learning Rate: 0.004037

Epoch 128/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1601]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.54it/s]


Train Loss: 0.1149
Val Dice - Class 1: 0.8737, Class 2: 0.8561, Avg: 0.8649
Learning Rate: 0.003987

Epoch 129/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1134]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.46it/s]


Train Loss: 0.1130
Val Dice - Class 1: 0.8706, Class 2: 0.8533, Avg: 0.8620
Learning Rate: 0.003937

Epoch 130/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0934]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.35it/s]


Train Loss: 0.1139
Val Dice - Class 1: 0.8708, Class 2: 0.8537, Avg: 0.8623
Learning Rate: 0.003887

Epoch 131/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0882]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.65it/s]


Train Loss: 0.1127
Val Dice - Class 1: 0.8711, Class 2: 0.8538, Avg: 0.8624
Learning Rate: 0.003837

Epoch 132/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1148]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.00it/s]


Train Loss: 0.1128
Val Dice - Class 1: 0.8703, Class 2: 0.8540, Avg: 0.8622
Learning Rate: 0.003787

Epoch 133/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0918]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.97it/s]


Train Loss: 0.1142
Val Dice - Class 1: 0.8691, Class 2: 0.8542, Avg: 0.8617
Learning Rate: 0.003737

Epoch 134/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.0923]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.07it/s]


Train Loss: 0.1160
Val Dice - Class 1: 0.8741, Class 2: 0.8563, Avg: 0.8652
Learning Rate: 0.003687

Epoch 135/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1090]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.89it/s]


Train Loss: 0.1140
Val Dice - Class 1: 0.8716, Class 2: 0.8530, Avg: 0.8623
Learning Rate: 0.003637

Epoch 136/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1227]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.91it/s]


Train Loss: 0.1132
Val Dice - Class 1: 0.8711, Class 2: 0.8535, Avg: 0.8623
Learning Rate: 0.003586

Epoch 137/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0880]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.01it/s]


Train Loss: 0.1133
Val Dice - Class 1: 0.8708, Class 2: 0.8530, Avg: 0.8619
Learning Rate: 0.003536

Epoch 138/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1264]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.99it/s]


Train Loss: 0.1123
Val Dice - Class 1: 0.8728, Class 2: 0.8561, Avg: 0.8644
Learning Rate: 0.003485

Epoch 139/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1183]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.94it/s]


Train Loss: 0.1114
Val Dice - Class 1: 0.8712, Class 2: 0.8533, Avg: 0.8623
Learning Rate: 0.003435

Epoch 140/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0842]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.93it/s]


Train Loss: 0.1129
Val Dice - Class 1: 0.8706, Class 2: 0.8550, Avg: 0.8628
Learning Rate: 0.003384

Epoch 141/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1013]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.72it/s]


Train Loss: 0.1127
Val Dice - Class 1: 0.8726, Class 2: 0.8552, Avg: 0.8639
Learning Rate: 0.003333

Epoch 142/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0904]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.82it/s]


Train Loss: 0.1128
Val Dice - Class 1: 0.8714, Class 2: 0.8537, Avg: 0.8626
Learning Rate: 0.003282

Epoch 143/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0913]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.92it/s]


Train Loss: 0.1107
Val Dice - Class 1: 0.8680, Class 2: 0.8535, Avg: 0.8607
Learning Rate: 0.003231

Epoch 144/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1234]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.82it/s]


Train Loss: 0.1137
Val Dice - Class 1: 0.8707, Class 2: 0.8527, Avg: 0.8617
Learning Rate: 0.003180

Epoch 145/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1017]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.13it/s]


Train Loss: 0.1111
Val Dice - Class 1: 0.8704, Class 2: 0.8540, Avg: 0.8622
Learning Rate: 0.003129

Epoch 146/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.0895]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.17it/s]


Train Loss: 0.1105
Val Dice - Class 1: 0.8727, Class 2: 0.8555, Avg: 0.8641
Learning Rate: 0.003078

Epoch 147/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1099]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.37it/s]


Train Loss: 0.1099
Val Dice - Class 1: 0.8722, Class 2: 0.8506, Avg: 0.8614
Learning Rate: 0.003026

Epoch 148/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1258]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.87it/s]


Train Loss: 0.1115
Val Dice - Class 1: 0.8704, Class 2: 0.8548, Avg: 0.8626
Learning Rate: 0.002975

Epoch 149/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1280]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.11it/s]


Train Loss: 0.1094
Val Dice - Class 1: 0.8707, Class 2: 0.8553, Avg: 0.8630
Learning Rate: 0.002923

Epoch 150/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1071]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.75it/s]


Train Loss: 0.1088
Val Dice - Class 1: 0.8697, Class 2: 0.8546, Avg: 0.8621
Learning Rate: 0.002872

Epoch 151/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1097]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.97it/s]


Train Loss: 0.1091
Val Dice - Class 1: 0.8675, Class 2: 0.8533, Avg: 0.8604
Learning Rate: 0.002820

Epoch 152/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.0896]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.84it/s]


Train Loss: 0.1082
Val Dice - Class 1: 0.8686, Class 2: 0.8533, Avg: 0.8609
Learning Rate: 0.002768

Epoch 153/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1056]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.52it/s]


Train Loss: 0.1092
Val Dice - Class 1: 0.8692, Class 2: 0.8549, Avg: 0.8621
Learning Rate: 0.002716

Epoch 154/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1184]
Validating: 100%|██████████| 26/26 [00:03<00:00,  6.84it/s]


Train Loss: 0.1091
Val Dice - Class 1: 0.8728, Class 2: 0.8549, Avg: 0.8639
Learning Rate: 0.002664

Epoch 155/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.06it/s, loss=0.0943]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.17it/s]


Train Loss: 0.1091
Val Dice - Class 1: 0.8738, Class 2: 0.8563, Avg: 0.8650
Learning Rate: 0.002612

Epoch 156/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0910]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.96it/s]


Train Loss: 0.1069
Val Dice - Class 1: 0.8696, Class 2: 0.8549, Avg: 0.8623
Learning Rate: 0.002560

Epoch 157/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0865]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.03it/s]


Train Loss: 0.1070
Val Dice - Class 1: 0.8721, Class 2: 0.8567, Avg: 0.8644
Learning Rate: 0.002507

Epoch 158/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1415]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.58it/s]


Train Loss: 0.1066
Val Dice - Class 1: 0.8732, Class 2: 0.8566, Avg: 0.8649
Learning Rate: 0.002455

Epoch 159/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.04it/s, loss=0.0990]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


Train Loss: 0.1076
Val Dice - Class 1: 0.8712, Class 2: 0.8559, Avg: 0.8635
Learning Rate: 0.002402

Epoch 160/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1069]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.32it/s]


Train Loss: 0.1081
Val Dice - Class 1: 0.8702, Class 2: 0.8555, Avg: 0.8629
Learning Rate: 0.002349

Epoch 161/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.04it/s, loss=0.0908]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.25it/s]


Train Loss: 0.1073
Val Dice - Class 1: 0.8717, Class 2: 0.8559, Avg: 0.8638
Learning Rate: 0.002296

Epoch 162/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1179]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.03it/s]


Train Loss: 0.1061
Val Dice - Class 1: 0.8733, Class 2: 0.8578, Avg: 0.8655
Learning Rate: 0.002243

Epoch 163/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1223]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


Train Loss: 0.1049
Val Dice - Class 1: 0.8735, Class 2: 0.8575, Avg: 0.8655
Learning Rate: 0.002190

Epoch 164/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.0941]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.90it/s]


Train Loss: 0.1034
Val Dice - Class 1: 0.8731, Class 2: 0.8555, Avg: 0.8643
Learning Rate: 0.002137

Epoch 165/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1207]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.63it/s]


Train Loss: 0.1037
Val Dice - Class 1: 0.8736, Class 2: 0.8571, Avg: 0.8654
Learning Rate: 0.002083

Epoch 166/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1165]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.00it/s]


Train Loss: 0.1045
Val Dice - Class 1: 0.8708, Class 2: 0.8531, Avg: 0.8620
Learning Rate: 0.002030

Epoch 167/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1113]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.84it/s]


Train Loss: 0.1052
Val Dice - Class 1: 0.8738, Class 2: 0.8558, Avg: 0.8648
Learning Rate: 0.001976

Epoch 168/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1165]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.81it/s]


Train Loss: 0.1038
Val Dice - Class 1: 0.8723, Class 2: 0.8578, Avg: 0.8650
Learning Rate: 0.001922

Epoch 169/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.1266]
Validating: 100%|██████████| 26/26 [00:02<00:00,  9.97it/s]


Train Loss: 0.1035
Val Dice - Class 1: 0.8705, Class 2: 0.8541, Avg: 0.8623
Learning Rate: 0.001868

Epoch 170/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.0827]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.87it/s]


Train Loss: 0.1057
Val Dice - Class 1: 0.8739, Class 2: 0.8574, Avg: 0.8657
Learning Rate: 0.001813

Epoch 171/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.0935]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.78it/s]


Train Loss: 0.1020
Val Dice - Class 1: 0.8723, Class 2: 0.8568, Avg: 0.8645
Learning Rate: 0.001759

Epoch 172/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.1306]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.13it/s]


Train Loss: 0.1029
Val Dice - Class 1: 0.8710, Class 2: 0.8550, Avg: 0.8630
Learning Rate: 0.001704

Epoch 173/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.0914]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.11it/s]


Train Loss: 0.1000
Val Dice - Class 1: 0.8707, Class 2: 0.8543, Avg: 0.8625
Learning Rate: 0.001649

Epoch 174/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.1040]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.19it/s]


Train Loss: 0.1018
Val Dice - Class 1: 0.8738, Class 2: 0.8583, Avg: 0.8661
Learning Rate: 0.001594
✓ 保存最佳模型 (Dice: 0.8661)

Epoch 175/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1037]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.25it/s]


Train Loss: 0.1008
Val Dice - Class 1: 0.8720, Class 2: 0.8521, Avg: 0.8620
Learning Rate: 0.001539

Epoch 176/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.15it/s, loss=0.0996]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.82it/s]


Train Loss: 0.1013
Val Dice - Class 1: 0.8707, Class 2: 0.8517, Avg: 0.8612
Learning Rate: 0.001483

Epoch 177/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1324]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.05it/s]


Train Loss: 0.1009
Val Dice - Class 1: 0.8709, Class 2: 0.8573, Avg: 0.8641
Learning Rate: 0.001428

Epoch 178/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.09it/s, loss=0.1049]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.09it/s]


Train Loss: 0.1017
Val Dice - Class 1: 0.8729, Class 2: 0.8579, Avg: 0.8654
Learning Rate: 0.001372

Epoch 179/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.0880]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.12it/s]


Train Loss: 0.1006
Val Dice - Class 1: 0.8707, Class 2: 0.8550, Avg: 0.8628
Learning Rate: 0.001315

Epoch 180/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0989]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.57it/s]


Train Loss: 0.1003
Val Dice - Class 1: 0.8708, Class 2: 0.8558, Avg: 0.8633
Learning Rate: 0.001259

Epoch 181/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1182]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.81it/s]


Train Loss: 0.1011
Val Dice - Class 1: 0.8715, Class 2: 0.8555, Avg: 0.8635
Learning Rate: 0.001202

Epoch 182/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.0932]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.73it/s]


Train Loss: 0.1003
Val Dice - Class 1: 0.8758, Class 2: 0.8585, Avg: 0.8672
Learning Rate: 0.001145
✓ 保存最佳模型 (Dice: 0.8672)

Epoch 183/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.19it/s, loss=0.1180]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.15it/s]


Train Loss: 0.0998
Val Dice - Class 1: 0.8724, Class 2: 0.8568, Avg: 0.8646
Learning Rate: 0.001088

Epoch 184/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.1154]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.10it/s]


Train Loss: 0.0994
Val Dice - Class 1: 0.8741, Class 2: 0.8580, Avg: 0.8661
Learning Rate: 0.001030

Epoch 185/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.0849]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.39it/s]


Train Loss: 0.0982
Val Dice - Class 1: 0.8743, Class 2: 0.8579, Avg: 0.8661
Learning Rate: 0.000972

Epoch 186/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.08it/s, loss=0.1393]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.17it/s]


Train Loss: 0.0993
Val Dice - Class 1: 0.8757, Class 2: 0.8592, Avg: 0.8674
Learning Rate: 0.000913
✓ 保存最佳模型 (Dice: 0.8674)

Epoch 187/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.0844]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.70it/s]


Train Loss: 0.0972
Val Dice - Class 1: 0.8727, Class 2: 0.8574, Avg: 0.8651
Learning Rate: 0.000854

Epoch 188/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1131]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.00it/s]


Train Loss: 0.0991
Val Dice - Class 1: 0.8732, Class 2: 0.8571, Avg: 0.8651
Learning Rate: 0.000795

Epoch 189/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.14it/s, loss=0.0898]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.14it/s]


Train Loss: 0.0979
Val Dice - Class 1: 0.8726, Class 2: 0.8572, Avg: 0.8649
Learning Rate: 0.000735

Epoch 190/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1014]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


Train Loss: 0.0972
Val Dice - Class 1: 0.8735, Class 2: 0.8583, Avg: 0.8659
Learning Rate: 0.000675

Epoch 191/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.1026]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.83it/s]


Train Loss: 0.0971
Val Dice - Class 1: 0.8704, Class 2: 0.8546, Avg: 0.8625
Learning Rate: 0.000614

Epoch 192/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1096]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.10it/s]


Train Loss: 0.0964
Val Dice - Class 1: 0.8729, Class 2: 0.8572, Avg: 0.8650
Learning Rate: 0.000552

Epoch 193/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.10it/s, loss=0.0868]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.51it/s]


Train Loss: 0.0988
Val Dice - Class 1: 0.8732, Class 2: 0.8598, Avg: 0.8665
Learning Rate: 0.000489

Epoch 194/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.07it/s, loss=0.0918]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.90it/s]


Train Loss: 0.0958
Val Dice - Class 1: 0.8742, Class 2: 0.8563, Avg: 0.8653
Learning Rate: 0.000426

Epoch 195/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.0958]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.08it/s]


Train Loss: 0.0979
Val Dice - Class 1: 0.8733, Class 2: 0.8578, Avg: 0.8656
Learning Rate: 0.000362

Epoch 196/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.15it/s, loss=0.0712]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.82it/s]


Train Loss: 0.0970
Val Dice - Class 1: 0.8732, Class 2: 0.8570, Avg: 0.8651
Learning Rate: 0.000296

Epoch 197/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1040]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.94it/s]


Train Loss: 0.0963
Val Dice - Class 1: 0.8748, Class 2: 0.8591, Avg: 0.8670
Learning Rate: 0.000228

Epoch 198/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.13it/s, loss=0.0980]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.87it/s]


Train Loss: 0.0963
Val Dice - Class 1: 0.8745, Class 2: 0.8578, Avg: 0.8661
Learning Rate: 0.000158

Epoch 199/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.12it/s, loss=0.1221]
Validating: 100%|██████████| 26/26 [00:03<00:00,  8.14it/s]


Train Loss: 0.0959
Val Dice - Class 1: 0.8751, Class 2: 0.8594, Avg: 0.8673
Learning Rate: 0.000085

Epoch 200/200
------------------------------------------------------------


Training: 100%|██████████| 104/104 [00:12<00:00,  8.11it/s, loss=0.0976]
Validating: 100%|██████████| 26/26 [00:03<00:00,  7.35it/s]

Train Loss: 0.0958
Val Dice - Class 1: 0.8713, Class 2: 0.8548, Avg: 0.8630
Learning Rate: 0.000000

訓練完成！Seed: 42
最佳驗證 Dice: 0.8674



