# BraTS 2024 3D Segmentation - Kaggle Optimized
Acest notebook este configurat pentru a rula pe Kaggle folosind un GPU (T4 sau P100).
Include:
- **MONAI** pentru procesare 3D.
- **3D U-Net** cu patch-uri de 128x128x128.
- **Cosine Annealing LR** pentru optimizare fină.
- **Resume logic** pentru a continua antrenarea.

In [None]:
!pip install -q monai nibabel tqdm

In [None]:
import os
import json
import torch
import numpy as np
from pathlib import Path
from monai.data import DataLoader, Dataset, decollate_batch
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Orientationd,
    NormalizeIntensityd, RandCropByPosNegLabeld, RandFlipd,
    RandAffined, EnsureTyped, Lambdad, AsDiscrete
)
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from torch.optim.lr_scheduler import CosineAnnealingLR

# --- CONFIGURARE KAGGLE ---
CONFIG = {
    # Schimbă aici calea către JSON-ul tău de metadate urcat pe Kaggle
    "json_path": "/kaggle/input/brats2024-metadata/brats_metadata_splits.json", 
    "model_dir": "/kaggle/working/checkpoints",
    "batch_size": 4,           # Putem mări la 4 pe 16GB VRAM
    "spatial_size": (128, 128, 128), # Patch mai mare pentru context mai bun
    "epochs": 100,
    "lr": 2e-4,
    "val_interval": 5,
    "max_steps_per_epoch": 300,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

os.makedirs(CONFIG["model_dir"], exist_ok=True)
set_determinism(seed=42)


In [None]:
def build_brats_list_simple(base_path):
    """
    Cea mai simplă metodă: Scanează folderul de pe Kaggle și găsește fișierele.
    Nu ai nevoie de JSON, nu ai nevoie de nimic extra.
    """
    import glob
    items = []
    # Kaggle dataset structure: /kaggle/input/brats2024-small-dataset/BraTS2024_small_dataset/Patient_ID/...
    patient_dirs = sorted(glob.glob(os.path.join(base_path, "*")))
    
    for p_dir in patient_dirs:
        if not os.path.isdir(p_dir): continue
        
        sample = {"subject_id": os.path.basename(p_dir)}
        files = glob.glob(os.path.join(p_dir, "*.nii.gz"))
        
        for f in files:
            f_lower = f.lower()
            if "seg" in f_lower: sample["label"] = f
            elif "t1c" in f_lower or "t1ce" in f_lower: sample["t1c"] = f
            elif "t1n" in f_lower or "t1.nii" in f_lower: sample["t1n"] = f
            elif "flair" in f_lower or "t2f" in f_lower: sample["t2f"] = f
            elif "t2w" in f_lower or "t2.nii" in f_lower: sample["t2w"] = f
            
        # Verificăm dacă avem toate cele 4 modalități + label
        if all(k in sample for k in ["t1c", "t1n", "t2f", "t2w", "label"]):
            sample["image"] = [sample["t1c"], sample["t1n"], sample["t2f"], sample["t2w"]]
            items.append(sample)
            
    print(f"Succes! Am găsit {len(items)} pacienți gata de antrenare.")
    return items

def remap_seg(x):
    import numpy as np
    x = np.asarray(x).copy()
    x[x == 4] = 3
    return x

# --- Transformări ---
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Lambdad(keys="label", func=remap_seg),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label",
        spatial_size=CONFIG["spatial_size"], pos=1, neg=1, num_samples=2,
        image_key="image", image_threshold=0
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    EnsureTyped(keys=["image", "label"]),
])

val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Lambdad(keys="label", func=remap_seg),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    EnsureTyped(keys=["image", "label"]),
])

def get_model():
    return UNet(
        spatial_dims=3, in_channels=4, out_channels=4,
        channels=(32, 64, 128, 256, 512),
        strides=(2, 2, 2, 2), num_res_units=2,
        norm="instance", dropout=0.1,
    )


In [None]:
def train():
    # 1. Setup Model & Co
    model = get_model().to(CONFIG["device"])
    loss_function = DiceLoss(smooth_nr=1e-5, smooth_dr=1e-5, to_onehot_y=True, softmax=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    scheduler = CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"])
    scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

    # 2. Încarcă datele direct din folderul Kaggle (FĂRĂ JSON)
    # Modifică calea de mai jos dacă dataset-ul are alt nume pe Kaggle
    base_data_path = "/kaggle/input/brats2024-small-dataset/BraTS2024_small_dataset"
    all_files = build_brats_list_simple(base_data_path)
    
    if not all_files:
        print("EROARE: Nu am găsit fișiere. Verifică dacă ai adăugat dataset-ul corect în Kaggle!")
        return

    # Split automat: 85% Train, 15% Val
    split_idx = int(len(all_files) * 0.85)
    train_files = all_files[:split_idx]
    val_files = all_files[split_idx:]
    
    print(f"Antrenare pe {len(train_files)} pacienți, Validare pe {len(val_files)} pacienți.")

    train_ds = Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=4, pin_memory=True)
    
    val_ds = Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=2)

    # 3. Loop-ul de antrenare
    best_metric = -1
    dice_metric = DiceMetric(include_background=False, reduction="mean_batch")
    post_seg = AsDiscrete(argmax=True, to_onehot=4)
    post_label = AsDiscrete(to_onehot=4)

    for epoch in range(CONFIG["epochs"]):
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].to(CONFIG["device"]), batch_data["label"].to(CONFIG["device"])
            optimizer.zero_grad()
            
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
            
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            
            epoch_loss += loss.item()
            if step >= CONFIG["max_steps_per_epoch"]: break
        
        scheduler.step()
        print(f"Ep {epoch+1}/{CONFIG['epochs']} - Loss: {epoch_loss/step:.4f} - LR: {scheduler.get_last_lr()[0]:.6f}")

        # Validare
        if (epoch + 1) % CONFIG["val_interval"] == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = val_data["image"].to(CONFIG["device"]), val_data["label"].to(CONFIG["device"])
                    val_outputs = sliding_window_inference(val_inputs, CONFIG["spatial_size"], 4, model)
                    val_outputs = [post_seg(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)
                
                metric_batch = dice_metric.aggregate()
                dice_metric.reset()
                avg_dice = torch.mean(metric_batch).item()
                
                if avg_dice > best_metric:
                    best_metric = avg_dice
                    torch.save(model.state_dict(), os.path.join(CONFIG["model_dir"], "best_model.pth"))
                    print(f"*** Model Nou Salvat! Dice: {avg_dice:.4f} ***")

if __name__ == "__main__":
    train()
