# MedAug

In [None]:
from medaug import *

# === 基本設定 ===
ZIP_FILE = "all.zip"
OUTPUT_ROOT = "./datasets_kfold_medaug_original"
FOLDS = 7
EPOCHS = 25
BATCH = 12
MODEL_PATH = "yolo12x.pt"
CLASS_NAME = "aortic_valve"

# === MedAugment 參數（原論文設定）===
USE_MEDAUGMENT = True
MEDAUG_LEVEL = 5         # 原論文預設值
NUMBER_BRANCH = 4        # 原論文預設值
SAVE_ORIGINAL = True

# === 解壓縮 all.zip ===
if os.path.exists(ZIP_FILE):
    print(f"  解壓 {ZIP_FILE} 到 ./all")
    with zipfile.ZipFile(ZIP_FILE, "r") as z:
        z.extractall("./all")

# === 設定圖片與標籤路徑 ===
IMAGE_ROOT = "./datasets/all_images"
LABEL_ROOT = "./datasets/all_labels"

# === 掃描影像檔案 ===
image_files = sorted(glob(os.path.join(IMAGE_ROOT, "*.jpg"))) + sorted(glob(os.path.join(IMAGE_ROOT, "*.png")))
label_files = [os.path.join(LABEL_ROOT, os.path.basename(f).rsplit(".", 1)[0] + ".txt") for f in image_files]

# 過濾掉沒有對應標籤的影像
paired = [(img, lbl) for img, lbl in zip(image_files, label_files) if os.path.exists(lbl)]
random.shuffle(paired)
fold_size = len(paired) // FOLDS
print(f"  總共 {len(paired)} 張影像，平均每折 {fold_size} 張")
print(f"  使用原始 MedAugment 論文參數 (Document 3)")
print(f"   - keep_ratio=True (所有 Affine 變換)")
print(f"   - Shear/Translate: 單向 (0, X)")
print(f"   - Level: {MEDAUG_LEVEL}, Branches: {NUMBER_BRANCH}")

for i in range(FOLDS):
    print(f"\n  建立 Fold {i}...")
    fold_dir = os.path.join(OUTPUT_ROOT, f"fold_{i}")
    train_img = os.path.join(fold_dir, "train/images")
    train_lbl = os.path.join(fold_dir, "train/labels")
    val_img = os.path.join(fold_dir, "val/images")
    val_lbl = os.path.join(fold_dir, "val/labels")
    
    for p in [train_img, train_lbl, val_img, val_lbl]:
        os.makedirs(p, exist_ok=True)
    
    val_set = paired[i * fold_size:(i + 1) * fold_size]
    train_set = [p for p in paired if p not in val_set]
    
    # 複製並增強訓練資料
    print(f"    處理訓練資料 ({len(train_set)} 張)...")
    for idx, (img, lbl) in enumerate(train_set):
        if USE_MEDAUGMENT:
            augment_image_medaugment(
                img, lbl, train_img, train_lbl,
                level=MEDAUG_LEVEL,
                number_branch=NUMBER_BRANCH,
                save_original=SAVE_ORIGINAL
            )
        else:
            shutil.copy(img, train_img)
            shutil.copy(lbl, train_lbl)
        
        if (idx + 1) % 50 == 0:
            print(f"    處理進度: {idx+1}/{len(train_set)}")
    
    # 複製驗證資料（不增強）
    print(f"    複製驗證資料 ({len(val_set)} 張)...")
    for img, lbl in val_set:
        shutil.copy(img, val_img)
        shutil.copy(lbl, val_lbl)
    
    # 統計資料量
    train_count = len(glob(os.path.join(train_img, "*.jpg"))) + len(glob(os.path.join(train_img, "*.png")))
    val_count = len(glob(os.path.join(val_img, "*.jpg"))) + len(glob(os.path.join(val_img, "*.png")))
    print(f"    訓練集: {train_count} 張 | 驗證集: {val_count} 張")
    if USE_MEDAUGMENT:
        print(f"     增強倍率: {train_count / len(train_set):.1f}x")
    
    # 生成 YAML
    yaml_path = os.path.join(fold_dir, f"dataset_fold_{i}.yaml")
    abs_train_img = os.path.abspath(train_img)
    abs_val_img = os.path.abspath(val_img)
    with open(yaml_path, "w", encoding="utf-8") as f:
        f.write(f"train: {abs_train_img}\n")
        f.write(f"val: {abs_val_img}\n\n")
        f.write("names:\n")
        f.write(f"  0: {CLASS_NAME}\n")
    print(f"    YAML 已建立: {yaml_path}")

# Model Training

In [None]:
for i in range(FOLDS):
    yaml_path = os.path.join(OUTPUT_ROOT, f"fold_{i}", f"dataset_fold_{i}.yaml")
    model = YOLO(MODEL_PATH)
    model.train(
        data=yaml_path,
        epochs=EPOCHS,
        batch=BATCH,
        lr0=0.001428,
        amp=False,
        val=True,
        patience=4,
        optimizer="AdamW",
        save_period=1,
        name=f"yolo12x_medaug_original_fold{i}",
        exist_ok=True,
        augment = False
    )
    import torch ,gc

    # 刪除大型變數
    del model
    gc.collect()
    torch.cuda.empty_cache()