In [14]:
# ========================================
# Alzheimer Early Detection - OASIS-2
# Notebook: Imágenes MRI con Deep Learning
# ========================================

# 1. Importación de librerías
import os
import glob
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim

In [15]:
# ========================================
# 2. Configuración de paths
# ========================================

# Carpeta donde tienes las imágenes descargadas/descomprimidas
DATA_DIR = "DATA/OAS2_RAW/"   # ruta a las imágenes OASIS-2
OUTPUT_DIR = "DATA/processed_images/"  # donde guardaremos PNGs

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [16]:
# ========================================
# 3. Exploración inicial de archivos
# ========================================

# Vamos a buscar imágenes en formato .hdr/.img dentro de RAW
raw_files = glob.glob(os.path.join(DATA_DIR, "**", "RAW", "*.hdr"), recursive=True)
print(f"Total archivos RAW encontrados: {len(raw_files)}")
print("Ejemplo:", raw_files[0])

Total archivos RAW encontrados: 1368
Ejemplo: DATA/OAS2_RAW\OAS2_0001_MR1\RAW\mpr-1.nifti.hdr


In [65]:
# ========================================
# 4. Pre-processing and Conversion
# ========================================
import pandas as pd
import cv2

# --- Cargar y limpiar el dataframe de etiquetas ---
labels_df = pd.read_excel("oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx")

# Renombrar columnas
rename_map = {
    'Subject ID': 'subject_id', 'MRI ID': 'scan_id', 'Group': 'group',
    'Visit': 'visit', 'MR Delay': 'mr_delay', 'M/F': 'sex', 'Hand': 'hand',
    'Age': 'age', 'EDUC': 'educ', 'SES': 'ses', 'MMSE': 'mmse', 'CDR': 'cdr',
    'eTIV': 'etiv', 'nWBV': 'nwbv', 'ASF': 'asf'
}
labels_df = labels_df.rename(columns=rename_map)

# Crear la etiqueta binaria (0=Nondemented, 1=Demented/Converted)
labels_df['label'] = labels_df['group'].map({'Nondemented': 0, 'Demented': 1, 'Converted': 1})

# Eliminar filas donde la etiqueta es desconocida (si las hubiera)
labels_df.dropna(subset=['label'], inplace=True)
labels_df['label'] = labels_df['label'].astype(int)

# *** CLAVE: Crear un diccionario para búsqueda rápida de etiquetas ***
# Esto resuelve el problema de la coincidencia de IDs.
label_map = labels_df.set_index('scan_id')['label'].to_dict()
print(f"Creado mapa de etiquetas para {len(label_map)} scans.")

# --- Función para guardar slices con CLAHE + z-score ---
def save_slices_from_nifti(img_path, scan_id, max_slices=20):
    img = nib.load(img_path)
    data = img.get_fdata()

    # Normalizar a [0,255]
    if data.max() > 0:
        data = (data - data.min()) / (data.max() - data.min()) * 255
    data = data.astype(np.uint8)

    # Seleccionar 20 cortes axiales alrededor del centro
    mid_slice = data.shape[2] // 2
    start_slice = mid_slice - max_slices // 2
    end_slice = start_slice + max_slices
    
    saved_paths = []
    slice_indices = range(start_slice, end_slice)

    for i, slice_idx in enumerate(slice_indices):
        if 0 <= slice_idx < data.shape[2]:
            slice_img = data[:, :, slice_idx]

            # --- Aplicar CLAHE y z-score normalización ---
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            slice_img = clahe.apply(slice_img)

            # Normalización z-score
            mean, std = slice_img.mean(), slice_img.std()
            if std > 0:
                slice_img = ((slice_img - mean) / std) * 64 + 128
            slice_img = np.clip(slice_img, 0, 255).astype(np.uint8)

            out_path = os.path.join(OUTPUT_DIR, f"{scan_id}_slice{i}.png")
            cv2.imwrite(out_path, slice_img)
            saved_paths.append(out_path)

    return saved_paths

# --- *** BUCLE CORREGIDO: Procesar TODAS las imágenes *** ---
print("Iniciando conversión de imágenes .hdr a .png...")
# Importa la librería Path
from pathlib import Path

raw_files = glob.glob(os.path.join(DATA_DIR, "**", "*.hdr"), recursive=True)

for file_path in tqdm(raw_files, desc="Procesando archivos"):
    try:
        # --- LÍNEA CORREGIDA ---
        # Esta es la forma robusta de obtener el ID independientemente del sistema operativo.
        # Extrae el nombre de la carpeta que está dos niveles por encima del archivo.
        # Ejemplo: para '...\OAS2_0001_MR1\RAW\mpr-1.hdr', el resultado es 'OAS2_0001_MR1'
        p = Path(file_path)
        scan_id = p.parent.parent.name

        if scan_id in label_map: # Solo procesar imágenes que tienen etiqueta
            save_slices_from_nifti(file_path, scan_id)

    except IndexError:
        # Este error ya no debería ocurrir, pero es bueno mantenerlo por si acaso
        print(f"No se pudo extraer scan_id de {file_path}")

print("Conversión completada.")

Creado mapa de etiquetas para 373 scans.
Iniciando conversión de imágenes .hdr a .png...


Procesando archivos: 100%|██████████| 1368/1368 [07:09<00:00,  3.19it/s]

Conversión completada.





In [66]:
# ========================================
# 5. División Train/Val/Test por PACIENTE
# ========================================
from sklearn.model_selection import train_test_split

# Lista de pacientes (scan_ids) y sus etiquetas
scan_ids = list(label_map.keys())
scan_labels = [label_map[s] for s in scan_ids]

# Primero separamos Train (60%) y Temp (40%)
train_ids, temp_ids, train_labels, temp_labels = train_test_split(
    scan_ids, scan_labels,
    test_size=0.4, stratify=scan_labels, random_state=42
)

# Luego dividimos Temp en Val (20%) y Test (20%)
val_ids, test_ids = train_test_split(
    temp_ids, test_size=0.5, stratify=temp_labels, random_state=42
)

print(f"Pacientes en Train: {len(train_ids)} | Val: {len(val_ids)} | Test: {len(test_ids)}")

# ----------------------------------------
# Dataset adaptado para aceptar subconjunto de pacientes
# ----------------------------------------
class MRIDataset(Dataset):
    def __init__(self, img_dir, label_map, scan_ids, transform=None, return_path=False):
        self.img_dir = img_dir
        self.label_map = label_map
        self.scan_ids = set(scan_ids)   # filtramos por IDs permitidos
        self.transform = transform
        self.return_path = return_path

        # Solo imágenes cuyo scan_id está en el subset (train/val/test)
        all_paths = glob.glob(os.path.join(img_dir, "*.png"))
        self.img_paths = [p for p in all_paths if self.get_scan_id_from_path(p) in self.scan_ids]
        
    def __len__(self):
        return len(self.img_paths)
    
    def get_scan_id_from_path(self, img_path):
        basename = os.path.basename(img_path)
        return basename.split("_slice")[0]
        
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        scan_id = self.get_scan_id_from_path(img_path)

        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 
        
        if self.transform:
            img = self.transform(img)
        
        label = self.label_map[scan_id]
        
        if self.return_path:
            return img, label, img_path
        else:
            return img, label

# ----------------------------------------
# Transformaciones
# ----------------------------------------
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ----------------------------------------
# Datasets con transform diferentes
# ----------------------------------------
train_dataset = MRIDataset(OUTPUT_DIR, label_map, train_ids, transform=train_transform)
val_dataset   = MRIDataset(OUTPUT_DIR, label_map, val_ids, transform=val_test_transform)
test_dataset  = MRIDataset(OUTPUT_DIR, label_map, test_ids, transform=val_test_transform, return_path=True)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")



Pacientes en Train: 223 | Val: 75 | Test: 75
Train: 4460, Val: 1500, Test: 1500


In [67]:
# ========================================
# 6. Modelo Preentrenado (ResNet50)
# ========================================

from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelo preentrenado
model = models.resnet50(weights="IMAGENET1K_V2")
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [68]:
# ========================================
# 7. Early Stopping
# ========================================

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [69]:
# ========================================
# 8. Entrenamiento del modelo
# ========================================

EPOCHS = 20
early_stopping = EarlyStopping(patience=4)

for epoch in range(EPOCHS):
    # ---- Entrenamiento ----
    model.train()
    running_loss, correct = 0, 0
    for imgs, labels_ in train_loader:
        imgs, labels_ = imgs.to(device), labels_.to(device)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels_)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        preds = outputs.argmax(1)
        correct += (preds == labels_).sum().item()
    
    train_acc = correct / len(train_dataset)
    train_loss = running_loss / len(train_loader)

    # ---- Validación ----
    model.eval()
    val_loss, val_correct = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.argmax(1)
            val_correct += (preds == labels).sum().item()
    
    val_acc = val_correct / len(val_dataset)
    val_loss /= len(val_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} - "
          f"Train Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} - Val Acc: {val_acc:.4f}")

    # ---- Early Stopping ----
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("⏹️ Early stopping activado")
        break

Epoch 1/20 - Train Loss: 0.5279 - Train Acc: 0.7186 | Val Loss: 0.8643 - Val Acc: 0.6187
Epoch 2/20 - Train Loss: 0.2585 - Train Acc: 0.8996 | Val Loss: 0.8215 - Val Acc: 0.6613
Epoch 3/20 - Train Loss: 0.1182 - Train Acc: 0.9558 | Val Loss: 1.0935 - Val Acc: 0.6527
Epoch 4/20 - Train Loss: 0.0783 - Train Acc: 0.9711 | Val Loss: 1.0667 - Val Acc: 0.7013
Epoch 5/20 - Train Loss: 0.0533 - Train Acc: 0.9825 | Val Loss: 1.6310 - Val Acc: 0.6600
Epoch 6/20 - Train Loss: 0.0773 - Train Acc: 0.9724 | Val Loss: 0.9106 - Val Acc: 0.7113
⏹️ Early stopping activado


In [70]:
# ========================================
# 9. Evaluación del modelo (por paciente)
# ========================================

from collections import defaultdict
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

model.eval()
all_probs = []
all_labels = []
all_ids = []

with torch.no_grad():
    # El bucle ahora es más limpio, desempaquetando las rutas directamente
    for imgs, labels, paths in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)[:,1].cpu().numpy()
        
        # Ya no necesitas el cálculo manual de batch_start y batch_end
        for p, prob, label in zip(paths, probs, labels):
            scan_id = os.path.basename(p).split("_slice")[0]
            all_ids.append(scan_id)
            all_probs.append(prob)
            all_labels.append(label.item())

# --- Agrupar por paciente ---
patient_probs = defaultdict(list)
patient_labels = {}

for pid, prob, label in zip(all_ids, all_probs, all_labels):
    patient_probs[pid].append(prob)
    patient_labels[pid] = label

final_probs = {pid: np.mean(probs) for pid, probs in patient_probs.items()}
final_preds = {pid: int(prob > 0.5) for pid, prob in final_probs.items()}

y_true = [patient_labels[pid] for pid in final_preds.keys()]
y_pred = [final_preds[pid] for pid in final_preds.keys()]
y_score = [final_probs[pid] for pid in final_preds.keys()]

print("\n=== Evaluación en TEST (nivel paciente) ===")
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))
print("ROC-AUC:", roc_auc_score(y_true, y_score))



=== Evaluación en TEST (nivel paciente) ===
[[33  5]
 [10 27]]
              precision    recall  f1-score   support

           0       0.77      0.87      0.81        38
           1       0.84      0.73      0.78        37

    accuracy                           0.80        75
   macro avg       0.81      0.80      0.80        75
weighted avg       0.81      0.80      0.80        75

ROC-AUC: 0.8577524893314367
