In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("yiweilu2033/well-documented-alzheimers-dataset")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/well-documented-alzheimers-dataset


In [None]:
import os

print(os.listdir(path))


['ModerateDemented', 'NonDemented (2)', 'oasis_cross-sectional-5708aa0a98d82080 (1).xlsx', 'VeryMildDemented', 'MildDemented']


In [None]:
import os



def count_images_in_folders(dataset_path):
    folder_counts = {}

    for folder in os.listdir(dataset_path):
        folder_path = os.path.join(dataset_path, folder)

        if os.path.isdir(folder_path):
            # Buscar subcarpeta con el mismo nombre (excepto para 'NonDemented')
            subfolder_name = folder if "NonDemented" not in folder else "NonDemented"
            subfolder_path = os.path.join(folder_path, subfolder_name)

            if os.path.isdir(subfolder_path):
                image_count = len([f for f in os.listdir(subfolder_path) if f.lower().endswith(('png', 'jpg', 'jpeg'))])
                folder_counts[folder] = image_count
            else:
                print(f"Subcarpeta no encontrada en: {folder_path}")

    return folder_counts

# Ruta del dataset
image_counts = count_images_in_folders(path)

# Imprimir resultados
for folder, count in image_counts.items():
    print(f"{folder}: {count} imágenes")


ModerateDemented: 376 imágenes
NonDemented (2): 63560 imágenes
VeryMildDemented: 13796 imágenes
MildDemented: 5184 imágenes


In [None]:
!pip install imgaug
!pip install numpy==1.24.0


Collecting imgaug
  Downloading imgaug-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m948.0/948.0 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: imgaug
Successfully installed imgaug-0.4.0
Collecting numpy==1.24.0
  Downloading numpy-1.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour 

In [None]:
import imgaug.augmenters as iaa
import cv2
import os
import numpy as np

# Configurar las transformaciones
augmenters = iaa.Sequential([
    iaa.Fliplr(0.5),  # Volteo horizontal
    iaa.Affine(rotate=(-15, 15)),  # Rotación aleatoria
    iaa.GaussianBlur(sigma=(0, 1.0)),  # Desenfoque aleatorio
])

# Directorio donde guardar imágenes aumentadas
augmented_dir = "/root/.cache/kagglehub/datasets/augmented_dataset"
os.makedirs(augmented_dir, exist_ok=True)

# Directorio base donde están las imágenes
base_dir = path

# Diccionario con la estructura correcta de carpetas y subcarpetas (sin ModerateDemented)
category_subfolders = {
    "MildDemented": "MildDemented",
    "VeryMildDemented": "VeryMildDemented",
    "NonDemented (2)": "NonDemented"
}

# Número de imágenes objetivo por clase (sin ModerateDemented)
target_counts = {
    "MildDemented": 8000,
    "VeryMildDemented": 15000,
    "NonDemented": None  # No se aumenta esta categoría
}

for category, subfolder in category_subfolders.items():
    category_path = os.path.join(base_dir, category, subfolder)

    if not os.path.exists(category_path) or category not in target_counts:
        continue

    images = [img for img in os.listdir(category_path) if img.lower().endswith((".png", ".jpg", ".jpeg"))]

    if target_counts[category] is None or len(images) >= target_counts[category]:
        continue  # No se hace augmentación si ya hay suficientes imágenes

    new_category_path = os.path.join(augmented_dir, category)
    os.makedirs(new_category_path, exist_ok=True)

    augmentation_needed = target_counts[category] - len(images)

    for i in range(augmentation_needed):
        img_name = images[i % len(images)]  # Seleccionar imágenes repetidamente si es necesario
        img_path = os.path.join(category_path, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Aplicar augmentación
        augmented_image = augmenters(image=image)

        # Extraer ID del paciente correctamente
        patient_id = "_".join(img_name.split("_")[:3])  # Extrae "OAS1_0028_MR1"
        slice_info = "_".join(img_name.split("_")[3:])  # Extrae "1.nii_slice_113"

        # Nuevo nombre con ID del paciente y slice preservados
        aug_img_name = f"{patient_id}_aug{i}_{slice_info}"

        cv2.imwrite(os.path.join(new_category_path, aug_img_name), cv2.cvtColor(augmented_image, cv2.COLOR_RGB2BGR))

print("Data augmentation completado SIN la clase 'ModerateDemented'.")


Data augmentation completado SIN la clase 'ModerateDemented'.


In [None]:
import os

def count_augmented_images(augmented_dir):
    """Cuenta cuántas imágenes hay en cada subcarpeta dentro de augmented_dir."""
    folder_counts = {}

    for category in os.listdir(augmented_dir):
        category_path = os.path.join(augmented_dir, category)
        if os.path.isdir(category_path):
            image_count = len([img for img in os.listdir(category_path) if img.lower().endswith((".png", ".jpg", ".jpeg"))])
            folder_counts[category] = image_count

    return folder_counts

# Ruta de la carpeta con las imágenes aumentadas
augmented_dir = "/root/.cache/kagglehub/datasets/augmented_dataset"

# Contar imágenes
augmented_counts = count_augmented_images(augmented_dir)

# Imprimir resultados
for category, count in augmented_counts.items():
    print(f"{category}: {count} imágenes")


VeryMildDemented: 1204 imágenes
MildDemented: 2816 imágenes


In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split

# Directorios base
original_base_dir = path
augmented_base_dir = "/root/.cache/kagglehub/datasets/augmented_dataset"

# Categorías sin "ModerateDemented"
category_subfolders = {
    "MildDemented": "MildDemented",
    "VeryMildDemented": "VeryMildDemented",
    "NonDemented (2)": "NonDemented"
}

# Listas para almacenar la información de las imágenes
image_paths = []
patient_ids = []
labels = []

# Función para extraer el ID del paciente desde el nombre del archivo
def extract_patient_id(img_name):
    return "_".join(img_name.split("_")[:3])  # Ejemplo: "OAS1_0028_MR1"

# Cargar imágenes originales
for category, subfolder in category_subfolders.items():
    category_path = os.path.join(original_base_dir, category, subfolder)

    if not os.path.exists(category_path):
        print(f"Advertencia: No se encontró la carpeta {category_path}, se omitirá.")
        continue

    for img_name in os.listdir(category_path):
        if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
            patient_id = extract_patient_id(img_name)
            image_paths.append(os.path.join(category_path, img_name))
            patient_ids.append(patient_id)
            labels.append(category)

# Cargar imágenes aumentadas
for category in category_subfolders.keys():
    category_path = os.path.join(augmented_base_dir, category)

    if not os.path.exists(category_path):
        print(f"Advertencia: No se encontró la carpeta de aumentación {category_path}, se omitirá.")
        continue

    for img_name in os.listdir(category_path):
        if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
            patient_id = extract_patient_id(img_name)
            image_paths.append(os.path.join(category_path, img_name))
            patient_ids.append(patient_id)
            labels.append(category)

# Crear DataFrame con todas las imágenes
images_df = pd.DataFrame({
    "image_path": image_paths,
    "patient_id": patient_ids,
    "label": labels
})

# Obtener etiquetas de pacientes para estratificación
patient_labels = images_df.groupby("patient_id")["label"].first()

# Obtener IDs únicos de pacientes
unique_patient_ids = images_df["patient_id"].unique()

# División en train (80%) y test (20%) por pacientes
train_patient_ids, test_patient_ids = train_test_split(
    unique_patient_ids, test_size=0.2, random_state=42, stratify=patient_labels.loc[unique_patient_ids]
)

# División de train en train (80%) y validación (20%)
train_patient_ids, val_patient_ids = train_test_split(
    train_patient_ids, test_size=0.2, random_state=42, stratify=patient_labels.loc[train_patient_ids]
)

# Asignar cada imagen a su conjunto correspondiente
final_df = images_df.copy()
final_df["set"] = final_df["patient_id"].apply(
    lambda pid: "train" if pid in train_patient_ids else ("val" if pid in val_patient_ids else "test")
)

# Verificar la distribución final
print(f"Tamaño de entrenamiento: {len(final_df[final_df['set'] == 'train'])}")
print(f"Tamaño de validación: {len(final_df[final_df['set'] == 'val'])}")
print(f"Tamaño de prueba: {len(final_df[final_df['set'] == 'test'])}")

# Mostrar algunas filas del DataFrame final
print(final_df.head())


Advertencia: No se encontró la carpeta de aumentación /root/.cache/kagglehub/datasets/augmented_dataset/NonDemented (2), se omitirá.
Tamaño de entrenamiento: 54773
Tamaño de validación: 14014
Tamaño de prueba: 17773
                                          image_path     patient_id  \
0  /kaggle/input/well-documented-alzheimers-datas...  OAS1_0073_MR1   
1  /kaggle/input/well-documented-alzheimers-datas...  OAS1_0316_MR1   
2  /kaggle/input/well-documented-alzheimers-datas...  OAS1_0430_MR1   
3  /kaggle/input/well-documented-alzheimers-datas...  OAS1_0184_MR1   
4  /kaggle/input/well-documented-alzheimers-datas...  OAS1_0269_MR1   

          label    set  
0  MildDemented  train  
1  MildDemented  train  
2  MildDemented  train  
3  MildDemented  train  
4  MildDemented  train  


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Transformaciones estándar para todas las imágenes
transform = transforms.Compose([
    transforms.Resize((224, 224)),                 # Redimensionar la imagen
    transforms.RandomHorizontalFlip(p=0.5),        # Volteo aleatorio horizontal
    transforms.RandomRotation(degrees=15),         # Rotación aleatoria
    transforms.ToTensor(),                         # Convertir la imagen a tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalización
])

# Mapeo de etiquetas a números (sin "ModerateDemented")
label_map = {"MildDemented": 0, "VeryMildDemented": 1, "NonDemented (2)": 2}

class AlzheimerDataset(Dataset):
    def __init__(self, df, transform=None):
        """
        Inicializa el dataset de Alzheimer.

        Args:
        df (pd.DataFrame): DataFrame con las rutas de imágenes y etiquetas.
        transform (callable, optional): Transformaciones a aplicar a cada imagen.
        """
        self.df = df
        self.transform = transform

    def __len__(self):
        """Devuelve el número total de imágenes en el dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """
        Obtiene una imagen y su etiqueta correspondiente del DataFrame.

        Args:
        idx (int): Índice de la muestra.

        Returns:
        image (Tensor): Imagen transformada.
        label (int): Etiqueta de la imagen.
        """
        img_path = self.df.iloc[idx]["image_path"]
        label = label_map[self.df.iloc[idx]["label"]]

        # Cargar la imagen
        image = Image.open(img_path).convert("RGB")

        # Aplicar transformaciones
        if self.transform:
            image = self.transform(image)

        return image, label

# Filtrar el DataFrame para excluir "ModerateDemented"
filtered_df = final_df[final_df["label"] != "ModerateDemented"]

# Filtrar los DataFrames por cada conjunto
train_df = filtered_df[filtered_df["set"] == "train"]
val_df = filtered_df[filtered_df["set"] == "val"]
test_df = filtered_df[filtered_df["set"] == "test"]

# Crear los datasets de PyTorch para cada conjunto
train_dataset = AlzheimerDataset(train_df, transform=transform)
val_dataset = AlzheimerDataset(val_df, transform=transform)
test_dataset = AlzheimerDataset(test_df, transform=transform)

# Imprimir el tamaño de cada dataset
print(f"Tamaño del dataset -> Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

# Crear los DataLoaders para cargar los datos por lotes
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"DataLoaders creados con batch_size={batch_size}")


Tamaño del dataset -> Train: 54773, Val: 14014, Test: 17773
DataLoaders creados con batch_size=128


In [None]:
from collections import Counter
import torch
import torch.nn as nn

# Definir el dispositivo (CPU o GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cálculo de pesos por clase
label_to_idx = {
    "MildDemented": 0,
    "VeryMildDemented": 1,
    "NonDemented (2)": 2
}

# Contar frecuencia de clases en el conjunto de entrenamiento
label_counts = Counter([label_to_idx[label] for label in final_df[final_df["set"] == "train"]["label"]])
total_samples = sum(label_counts.values())

# Peso inversamente proporcional a la frecuencia de cada clase
class_weights = [0] * len(label_to_idx)
for label, idx in label_to_idx.items():
    class_weights[idx] = total_samples / (len(label_to_idx) * label_counts[idx])

# Convertir los pesos a tensor para PyTorch y mover al dispositivo
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# Definir la función de pérdida con pesos
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Ahora `criterion` puede ser utilizado en tu modelo


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AlzheimerCNN(nn.Module):
    def __init__(self, num_classes=3):  # Ahora tenemos 3 clases en lugar de 4
        super(AlzheimerCNN, self).__init__()

        # Bloque 1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bloque 2
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Calcular el tamaño de entrada para la capa densa
        self.flattened_size = self._get_flattened_size()

        # Capas densas
        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)  # Ajustado a 3 clases
        self.dropout = nn.Dropout(0.5)

    def _get_flattened_size(self):
        """Calcula automáticamente el tamaño de la salida antes de la capa fully connected."""
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 224, 224)  # Imagen de ejemplo
            x = self.pool1(F.relu(self.conv2(F.relu(self.conv1(dummy_input)))))
            x = self.pool2(F.relu(self.conv4(F.relu(self.conv3(x)))))
            return x.view(1, -1).size(1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)

        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


# Inicializar modelo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlzheimerCNN(num_classes=3).to(device)  # Se pasa el nuevo número de clases
print(model)


AlzheimerCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=100352, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from google.colab import files
# Calcular pesos de clase con base en train_df
from collections import Counter
i
# Cuenta cuántas instancias hay por clase en train_df
class_counts = Counter(train_df["label"])
total_samples = sum(class_counts.values())

# Asegúrate de que las clases estén en el orden correcto según label_map
weights = []
for label in ["MildDemented", "VeryMildDemented", "NonDemented (2)"]:
    count = class_counts[label]
    weights.append(total_samples / count)

# Convertir a tensor de PyTorch normalizado
weights = torch.tensor(weights, dtype=torch.float32)
weights = weights / weights.sum()  # Normalizamos para estabilidad
weights = weights.to(device)

#1. Definir la función de pérdida ponderada y el optimizador
criterion = nn.CrossEntropyLoss(weight=weights)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# 2. Función para entrenar el modelo (SIN Early Stopping)
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=25):
    train_losses = []
    val_losses = []
    val_accuracies = []
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            _, preds = torch.max(outputs, 1)
            correct_preds += torch.sum(preds == labels)
            total_preds += labels.size(0)

        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = correct_preds / total_preds

        # Evaluación en validación
        model.eval()
        running_val_loss = 0.0
        correct_preds_val = 0
        total_preds_val = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_val_loss += loss.item() * inputs.size(0)

                _, preds = torch.max(outputs, 1)
                correct_preds_val += torch.sum(preds == labels)
                total_preds_val += labels.size(0)

        epoch_val_loss = running_val_loss / len(val_loader.dataset)
        epoch_val_acc = correct_preds_val / total_preds_val

        train_losses.append(epoch_train_loss)
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_acc)

        # Guardar el mejor modelo basado en la pérdida de validación
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), "best_model.pth")
            print(f" Mejor modelo guardado en la época {epoch+1} con val_loss: {epoch_val_loss:.4f}")

        # Imprimir resultados de la época
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_acc:.4f} - "
              f"Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.4f}")

    return train_losses, val_losses, val_accuracies

# 3. Función para evaluar el modelo en el conjunto de prueba
def evaluate_model(model, test_loader):
    model.load_state_dict(torch.load("best_model.pth"))
    model.eval()
    correct_preds_test = 0
    total_preds_test = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            correct_preds_test += torch.sum(preds == labels)
            total_preds_test += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_accuracy = correct_preds_test / total_preds_test
    print(f"Test Accuracy: {test_accuracy:.4f}")

    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy (sklearn): {accuracy:.4f}")

    return accuracy

# 4. Entrenar el modelo (SIN Early Stopping)
num_epochs = 25
train_losses, val_losses, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

# 5. Evaluación final en el conjunto de prueba
test_accuracy = evaluate_model(model, test_loader)

# 6. Guardar el mejor modelo entrenado
model_save_path = "modelopaper2.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Modelo guardado en {model_save_path}")
files.download(model_save_path)


 Mejor modelo guardado en la época 1 con val_loss: 0.8058
Epoch 1/25 - Train Loss: 0.8964, Train Accuracy: 0.6066 - Validation Loss: 0.8058, Validation Accuracy: 0.5933
 Mejor modelo guardado en la época 2 con val_loss: 0.5811
Epoch 2/25 - Train Loss: 0.7042, Train Accuracy: 0.7269 - Validation Loss: 0.5811, Validation Accuracy: 0.7077
Epoch 3/25 - Train Loss: 0.6223, Train Accuracy: 0.7512 - Validation Loss: 0.6002, Validation Accuracy: 0.6929
Epoch 4/25 - Train Loss: 0.5555, Train Accuracy: 0.7743 - Validation Loss: 0.7093, Validation Accuracy: 0.6593
Epoch 5/25 - Train Loss: 0.5086, Train Accuracy: 0.7906 - Validation Loss: 0.6826, Validation Accuracy: 0.7035
Epoch 6/25 - Train Loss: 0.4765, Train Accuracy: 0.8006 - Validation Loss: 0.7284, Validation Accuracy: 0.6797
Epoch 7/25 - Train Loss: 0.4442, Train Accuracy: 0.8101 - Validation Loss: 0.6721, Validation Accuracy: 0.7045
Epoch 8/25 - Train Loss: 0.4157, Train Accuracy: 0.8203 - Validation Loss: 0.7307, Validation Accuracy: 0.6

In [None]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Evaluar el modelo en el conjunto de prueba
model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# Crear la matriz de confusión
labels_classes = ["MildDemented", "VeryMildDemented", "NonDemented"]  # Eliminado "ModerateDemented"
cm = confusion_matrix(y_true, y_pred)

# Graficar la matriz de confusión
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels_classes, yticklabels=labels_classes)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Matriz de Confusión")
plt.show()


In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
import numpy as np
import matplotlib.pyplot as plt

# Convertir las etiquetas en formato binario (one-vs-all) para 3 clases
y_true_bin = label_binarize(y_true, classes=[0, 1, 2])  # Solo 3 clases ahora
y_scores = []

with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        outputs = model(images)  # Obtiene los logits (valores antes de softmax)
        y_scores.extend(outputs.cpu().numpy())

y_scores = np.array(y_scores)

# Graficar la curva ROC para cada clase
plt.figure(figsize=(8,6))
labels_classes = ["MildDemented", "VeryMildDemented", "NonDemented"]  # Sin "ModerateDemented"

for i in range(3):  # Ahora solo 3 clases
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'Clase {labels_classes[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--')  # Línea diagonal
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Curvas ROC por Clase")
plt.legend()
plt.show()
