In [14]:
# ========================================
# Alzheimer Early Detection - OASIS-2
# Notebook: Imágenes MRI con Deep Learning
# ========================================

# 1. Importación de librerías
import os
import glob
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim

In [15]:
# ========================================
# 2. Configuración de paths
# ========================================

# Carpeta donde tienes las imágenes descargadas/descomprimidas
DATA_DIR = "DATA/OAS2_RAW/"   # ruta a las imágenes OASIS-2
OUTPUT_DIR = "DATA/processed_images/"  # donde guardaremos PNGs

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [16]:
# ========================================
# 3. Exploración inicial de archivos
# ========================================

# Vamos a buscar imágenes en formato .hdr/.img dentro de RAW
raw_files = glob.glob(os.path.join(DATA_DIR, "**", "RAW", "*.hdr"), recursive=True)
print(f"Total archivos RAW encontrados: {len(raw_files)}")
print("Ejemplo:", raw_files[0])

Total archivos RAW encontrados: 1368
Ejemplo: DATA/OAS2_RAW\OAS2_0001_MR1\RAW\mpr-1.nifti.hdr


In [18]:
# ========================================
# 4. Pre-processing and Conversion
# ========================================
import pandas as pd
import cv2

# --- Cargar y limpiar el dataframe de etiquetas ---
labels_df = pd.read_excel("oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx")

# Renombrar columnas
rename_map = {
    'Subject ID': 'subject_id', 'MRI ID': 'scan_id', 'Group': 'group',
    'Visit': 'visit', 'MR Delay': 'mr_delay', 'M/F': 'sex', 'Hand': 'hand',
    'Age': 'age', 'EDUC': 'educ', 'SES': 'ses', 'MMSE': 'mmse', 'CDR': 'cdr',
    'eTIV': 'etiv', 'nWBV': 'nwbv', 'ASF': 'asf'
}
labels_df = labels_df.rename(columns=rename_map)

# Crear la etiqueta binaria (0=Nondemented, 1=Demented/Converted)
labels_df['label'] = labels_df['group'].map({'Nondemented': 0, 'Demented': 1, 'Converted': 1})

# Eliminar filas donde la etiqueta es desconocida (si las hubiera)
labels_df.dropna(subset=['label'], inplace=True)
labels_df['label'] = labels_df['label'].astype(int)

# *** CLAVE: Crear un diccionario para búsqueda rápida de etiquetas ***
# Esto resuelve el problema de la coincidencia de IDs.
label_map = labels_df.set_index('scan_id')['label'].to_dict()
print(f"Creado mapa de etiquetas para {len(label_map)} scans.")

# --- Función para guardar slices (sin cambios) ---
def save_slices_from_nifti(img_path, scan_id, max_slices=5):
    img = nib.load(img_path)
    data = img.get_fdata()
    # Normalizar para guardar como imagen de 8 bits
    if data.max() > 0:
        data = (data - data.min()) / (data.max() - data.min()) * 255
    data = data.astype(np.uint8)

    # Seleccionar 5 cortes axiales centrales
    mid_slice = data.shape[2] // 2
    start_slice = mid_slice - max_slices // 2
    end_slice = start_slice + max_slices
    
    saved_paths = []
    slice_indices = range(start_slice, end_slice)

    for i, slice_idx in enumerate(slice_indices):
        if 0 <= slice_idx < data.shape[2]:
            slice_img = data[:, :, slice_idx]
            out_path = os.path.join(OUTPUT_DIR, f"{scan_id}_slice{i}.png")
            cv2.imwrite(out_path, slice_img)
            saved_paths.append(out_path)
    return saved_paths

# --- *** BUCLE CORREGIDO: Procesar TODAS las imágenes *** ---
print("Iniciando conversión de imágenes .hdr a .png...")
# Importa la librería Path
from pathlib import Path

raw_files = glob.glob(os.path.join(DATA_DIR, "**", "*.hdr"), recursive=True)

for file_path in tqdm(raw_files, desc="Procesando archivos"):
    try:
        # --- LÍNEA CORREGIDA ---
        # Esta es la forma robusta de obtener el ID independientemente del sistema operativo.
        # Extrae el nombre de la carpeta que está dos niveles por encima del archivo.
        # Ejemplo: para '...\OAS2_0001_MR1\RAW\mpr-1.hdr', el resultado es 'OAS2_0001_MR1'
        p = Path(file_path)
        scan_id = p.parent.parent.name

        if scan_id in label_map: # Solo procesar imágenes que tienen etiqueta
            save_slices_from_nifti(file_path, scan_id)

    except IndexError:
        # Este error ya no debería ocurrir, pero es bueno mantenerlo por si acaso
        print(f"No se pudo extraer scan_id de {file_path}")

print("Conversión completada.")

Creado mapa de etiquetas para 373 scans.
Iniciando conversión de imágenes .hdr a .png...


Procesando archivos: 100%|██████████| 1368/1368 [05:44<00:00,  3.97it/s]

Conversión completada.





In [None]:
# ========================================
# 5. División Train/Test por PACIENTE
# ========================================
from sklearn.model_selection import train_test_split

# Lista de pacientes (scan_ids) y sus etiquetas
scan_ids = list(label_map.keys())
scan_labels = [label_map[s] for s in scan_ids]

# División estratificada por etiqueta
train_ids, test_ids = train_test_split(scan_ids, test_size=0.2, stratify=scan_labels, random_state=42)

print(f"Pacientes en Train: {len(train_ids)} | Pacientes en Test: {len(test_ids)}")

# Dataset adaptado para aceptar subconjunto de pacientes
class MRIDataset(Dataset):
    def __init__(self, img_dir, label_map, scan_ids, transform=None):
        self.img_dir = img_dir
        self.label_map = label_map
        self.scan_ids = set(scan_ids)   # <- filtramos por IDs permitidos
        self.transform = transform
        # Solo imágenes cuyo scan_id está en el subset (train/test)
        all_paths = glob.glob(os.path.join(img_dir, "*.png"))
        self.img_paths = [p for p in all_paths if self.get_scan_id_from_path(p) in self.scan_ids]
        
    def __len__(self):
        return len(self.img_paths)
    
    def get_scan_id_from_path(self, img_path):
        basename = os.path.basename(img_path)
        return basename.split("_slice")[0]
        
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        scan_id = self.get_scan_id_from_path(img_path)

        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 
        
        if self.transform:
            img = self.transform(img)
        
        label = self.label_map[scan_id]
        
        if self.return_path:
            return img, label, img_path
        else:
            return img, label

# --- Transforms (con augmentación en train) ---
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Crear datasets
train_dataset = MRIDataset(OUTPUT_DIR, label_map, train_ids, transform=train_transform)
test_dataset = MRIDataset(OUTPUT_DIR, label_map, test_ids, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Train: {len(train_dataset)} imágenes | Test: {len(test_dataset)} imágenes")


Pacientes en Train: 298 | Pacientes en Test: 75
Train: 1490 imágenes | Test: 375 imágenes


In [27]:
# ========================================
# 6. Modelo Preentrenado (ResNet50)
# ========================================

from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modelo preentrenado
model = models.resnet50(weights="IMAGENET1K_V2")
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [28]:
# ========================================
# 7. Entrenamiento del modelo
# ========================================

EPOCHS = 5
for epoch in range(EPOCHS):
    model.train()
    running_loss, correct = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
    
    acc = correct / len(train_dataset)
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {running_loss/len(train_loader):.4f} - Train Acc: {acc:.4f}")

Epoch 1/5 - Loss: 0.6558 - Train Acc: 0.6128
Epoch 2/5 - Loss: 0.4656 - Train Acc: 0.7859
Epoch 3/5 - Loss: 0.3206 - Train Acc: 0.8691
Epoch 4/5 - Loss: 0.1587 - Train Acc: 0.9396
Epoch 5/5 - Loss: 0.1634 - Train Acc: 0.9443


In [30]:
# ========================================
# 8. Evaluación del modelo (por paciente)
# ========================================

from collections import defaultdict
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

model.eval()
all_probs = []
all_labels = []
all_ids = []

with torch.no_grad():
    for imgs, labels, paths in test_loader:  # Dataset ahora devuelve también paths
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)[:,1].cpu().numpy()
        
        for p, prob, label in zip(paths, probs, labels):
            scan_id = os.path.basename(p).split("_slice")[0]  # OAS2_0001_MR1
            all_ids.append(scan_id)
            all_probs.append(prob)
            all_labels.append(label.item())

# --- Agrupar por paciente ---
patient_probs = defaultdict(list)
patient_labels = {}

for pid, prob, label in zip(all_ids, all_probs, all_labels):
    patient_probs[pid].append(prob)
    patient_labels[pid] = label

final_probs = {pid: np.mean(probs) for pid, probs in patient_probs.items()}
final_preds = {pid: int(prob > 0.5) for pid, prob in final_probs.items()}

y_true = [patient_labels[pid] for pid in final_preds.keys()]
y_pred = [final_preds[pid] for pid in final_preds.keys()]
y_score = [final_probs[pid] for pid in final_preds.keys()]

print("\n=== Evaluación en TEST (nivel paciente) ===")
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))
print("ROC-AUC:", roc_auc_score(y_true, y_score))


ValueError: not enough values to unpack (expected 3, got 2)