In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from scipy.ndimage import binary_dilation
import ipywidgets as widgets
from IPython.display import display


# 📁 Config
DATA_DIR = "/content/drive/MyDrive/NeuroOnco/Derivate"
OUTPUT_FILE = "/content/drive/MyDrive/NeuroOnco/TEMP/patches_dataset.npz"
SERIE_USATE = ["T1W", "FLAIR", "DWI", "KTRAN", "CBV", "APT"]
PATCH_SIZE = 5
N_PATCH_TOTAL = 20000  # Patch totali da generare
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

TMP_DIR = "/content/tmp_patches"
os.makedirs(f"{TMP_DIR}/patches", exist_ok=True)
os.makedirs(f"{TMP_DIR}/labels", exist_ok=True)

patch_counter = 0


# 📋 Lista pazienti
pt_files_all = sorted(Path(DATA_DIR).glob("*.pt"))
tutti_pazienti = [f.name.replace(".pt", "") for f in pt_files_all]
file_dict = {f.name.replace(".pt", ""): f for f in pt_files_all}

# 🔘 Widget per selezione
multi_select = widgets.SelectMultiple(
    options=tutti_pazienti,
    value=tutti_pazienti,
    description='Pazienti:',
    layout=widgets.Layout(width='60%', height='200px')
)

run_button = widgets.Button(description="Esegui elaborazione", button_style='success')
output = widgets.Output()

display(multi_select, run_button, output)


def esegui_estrazione(_):
    with output:
        output.clear_output()
        selezionati = multi_select.value
        if not selezionati:
            tqdm.write("⚠️ Nessun paziente selezionato.")
            return

        # 🧹 Pulisce la cartella temporanea
        import shutil
        if os.path.exists(TMP_DIR):
            shutil.rmtree(TMP_DIR)
        os.makedirs(f"{TMP_DIR}/patches", exist_ok=True)
        os.makedirs(f"{TMP_DIR}/labels", exist_ok=True)

        # ↪️ Dati selezionati
        pt_files = [file_dict[nome] for nome in selezionati]
        NUM_PAZIENTI = len(pt_files)
        N_PER_PAZIENTE = N_PATCH_TOTAL // NUM_PAZIENTI

        tqdm.write(f"🎯 Estrazione da {NUM_PAZIENTI} pazienti selezionati (~{N_PER_PAZIENTE} patch ciascuno)")
        patch_counter = 0
        log_pazienti = {}

        nan_mdc = 0
        nan_sano = 0


        for path in tqdm(pt_files, desc="📦 Estrazione patch bilanciate"):
            try:
                volume, roi_masks = carica_volume(path, SERIE_USATE)
                quota_malato = N_PER_PAZIENTE // 2
                quota_sano = N_PER_PAZIENTE - quota_malato

                p_malato = estrai_patch(volume, roi_masks["MALATO"], PATCH_SIZE, quota_malato)
                p_sano = estrai_patch(volume, roi_masks["SANO"], PATCH_SIZE, quota_sano)

                valid_malato = 0
                valid_sano = 0

                for p in p_malato:
                    if valid_malato >= quota_malato:
                        break
                    if np.isnan(p).any():
                        nan_mdc += 1
                        continue
                    np.save(f"{TMP_DIR}/patches/patch_{patch_counter:05d}.npy", p)
                    np.save(f"{TMP_DIR}/labels/label_{patch_counter:05d}.npy", np.uint8(1))  # etichetta MALATO = 1
                    patch_counter += 1
                    valid_malato += 1

                for p in p_sano:
                    if valid_sano >= quota_sano:
                        break
                    if np.isnan(p).any():
                        nan_sano += 1
                        continue
                    np.save(f"{TMP_DIR}/patches/patch_{patch_counter:05d}.npy", p)
                    np.save(f"{TMP_DIR}/labels/label_{patch_counter:05d}.npy", np.uint8(0))  # etichetta SANO = 0
                    patch_counter += 1
                    valid_sano += 1

                log_pazienti[path.name] = {"malato": valid_malato, "sano": valid_sano}
                tqdm.write(f"📂 {path.name} → MALATO={valid_malato}, SANO={valid_sano}, ACCUM={patch_counter}")


            except Exception as e:
                tqdm.write(f"❌ Errore su {path.name}: {e}")


        # 🔁 Salvataggio finale
        patch_paths = sorted(Path(f"{TMP_DIR}/patches").glob("*.npy"))
        label_paths = sorted(Path(f"{TMP_DIR}/labels").glob("*.npy"))
        patches = np.stack([np.load(p) for p in patch_paths])
        labels = np.array([np.load(l) for l in label_paths], dtype=np.uint8)

        np.savez(OUTPUT_FILE, patches=patches, labels=labels)

        tqdm.write(f"\n🧹 Patch scartate per NaN:")
        tqdm.write(f"   MDC  : {nan_mdc}")
        tqdm.write(f"   SANO : {nan_sano}")
        tqdm.write(f"   TOT  : {nan_mdc + nan_sano}")

        tqdm.write(f"\n✅ Salvate {len(labels)} patch in: {OUTPUT_FILE}")
        tqdm.write(f"🧠 Patch shape: {patches.shape}")

        # 📋 Riepilogo finale
        tqdm.write("\n📋 Riepilogo pazienti:")
        for k, v in log_pazienti.items():
            tqdm.write(f"  - {k}: MALATO={v['malato']}, SANO={v['sano']}, TOTAL={v['malato']+v['sano']}")





# ✅ Funzione: carica volume e genera ROI sana
def carica_volume(path_pt, serie_usate=None):
    data = torch.load(path_pt, map_location=torch.device('cpu'))
    volume = data["volume"].float()
    nomi = data.get("nomi_serie", [f"Serie_{i}" for i in range(volume.shape[0])])
    roi_masks = data["roi_masks"]

    if serie_usate:
        indices = []
        serie_mancanti = []
        for s in serie_usate:
            if s in nomi:
                indices.append(nomi.index(s))
            else:
                serie_mancanti.append(s)

        if len(indices) < len(serie_usate):
            raise ValueError(f"⚠️ Serie mancanti in {path_pt.name}: {serie_mancanti}")

        volume = volume[indices]

    # ✅ Controlla presenza delle ROI richieste
    if "MALATO" not in roi_masks or "SANO" not in roi_masks:
        raise ValueError(f"❌ ROI 'MALATO' o 'SANO' mancanti in {path_pt.name}")

    return volume, roi_masks


# ✅ Funzione: estrai patch da una ROI
def estrai_patch(volume, mask, patch_size, n_max):
    c, z, y, x = volume.shape
    half = patch_size // 2
    coords = torch.argwhere(mask)

    # Elimina bordi
    coords = coords[
        (coords[:, 0] >= half) & (coords[:, 0] < z - half) &
        (coords[:, 1] >= half) & (coords[:, 1] < y - half) &
        (coords[:, 2] >= half) & (coords[:, 2] < x - half)
    ]

    # Shuffle e taglia
    if len(coords) > n_max:
        coords = coords[torch.randperm(len(coords))[:n_max]]

    patches = []
    for zc, yc, xc in coords:
        patch = volume[
            :,
            zc - half: zc + half + 1,
            yc - half: yc + half + 1,
            xc - half: xc + half + 1,
        ]
        patches.append(patch.numpy())

    return patches


# ✅ Collega il bottone
run_button.on_click(esegui_estrazione)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


SelectMultiple(description='Pazienti:', index=(0, 1, 2, 3, 4, 5), layout=Layout(height='200px', width='60%'), …

Button(button_style='success', description='Esegui elaborazione', style=ButtonStyle())

Output()