In [7]:
import torch
from torch.utils.data import DataLoader
from utils.dataset import BreastDataset2DMulticlass
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
from torchvision import transforms
from model.model_lightning import MyModelMulticlass
from addict import Dict  

In [8]:
def load_model(ckpt_path, model_name="resnet"):
    # Determinar el dispositivo (GPU o CPU)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(f"Usando el dispositivo: {device}")

    # Crear configuraciones como objetos con atributos
    model_opts = Dict({'name': model_name})
    train_par = Dict({'eval_threshold': 0.5, 'loss_opts': {'name': 'CrossEntropyLoss'}})

    # Inicializar el modelo
    model = MyModelMulticlass(model_opts=model_opts, train_par=train_par)

    # Cargar el checkpoint en el dispositivo adecuado
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])

    # Enviar el modelo al dispositivo (GPU o CPU)
    model = model.to(device)
    model.eval()  # Establecer en modo evaluación
    return model, device


In [9]:
# Función para hacer predicciones con un modelo
def predict_with_model(model, dataloader, device):
    predictions = []

    with torch.no_grad():
        for images, _, _ in dataloader:
            images = images.to(device)
            preds = model(images)
            probs = torch.nn.functional.softmax(preds, dim=-1)
            predictions.append(probs.cpu().numpy())

    return np.vstack(predictions)

In [10]:
def ensemble_predictions(models, dataloader, device, method="average"):
    patient_predictions = {}
    label_mapping_reverse = {0: "No follow up", 1: "Follow up", 2: "Biopsy"}  # Convertir números a texto

    # Asegurarte de que cada modelo esté en el dispositivo
    models = [model.to(device).eval() for model in models]

    with torch.no_grad():
        for images, labels, patient_ids in dataloader:
            images = images.to(device)

            # Predicciones para cada modelo
            model_preds = [torch.nn.functional.softmax(model(images), dim=-1) for model in models]

            # Ensemble usando promedio
            if method == "average":
                preds = torch.stack(model_preds).mean(dim=0)
            else:
                raise ValueError("Método no soportado. Usa 'average'.")

            # Etiquetas predichas (convertir a texto)
            predicted_labels = torch.argmax(preds, dim=1).cpu().numpy()
            predicted_labels = [label_mapping_reverse[label] for label in predicted_labels]

            # Agrupar predicciones por paciente
            for i, patient_id in enumerate(patient_ids):
                if patient_id not in patient_predictions:
                    patient_predictions[patient_id] = []
                patient_predictions[patient_id].append(predicted_labels[i])

    # Consolidar predicciones a nivel de paciente (mayoría de votos, en texto)
    final_predictions = {
        patient_id: max(set(preds), key=preds.count)
        for patient_id, preds in patient_predictions.items()
    }

    return final_predictions



In [11]:
def evaluate_ensemble(ground_truth, predictions, classes):
    # Convertir ground_truth y predictions a listas alineadas
    y_true = [ground_truth[patient_id] for patient_id in predictions.keys()]
    y_pred = [predictions[patient_id] for patient_id in predictions.keys()]

    # Mapeo de etiquetas numéricas a texto (si es necesario)
    label_mapping_reverse = {0: "No follow up", 1: "Follow up", 2: "Biopsy"}
    
    # Convertir etiquetas numéricas a texto, solo si es necesario
    y_true = [label_mapping_reverse[label] if isinstance(label, int) else label for label in y_true]
    y_pred = [label_mapping_reverse[label] if isinstance(label, int) else label for label in y_pred]

    # Verificar que todas las etiquetas sean de tipo texto
    if any(isinstance(label, int) for label in y_true + y_pred):
        raise ValueError("Aún hay etiquetas en formato numérico. Verifica tu mapeo.")

    # Calcular matriz de confusión
    cm = confusion_matrix(y_true, y_pred, labels=classes)
    print("Matriz de Confusión:")
    print(pd.DataFrame(cm, index=classes, columns=classes))

    # Reporte de clasificación
    print("\nReporte de Clasificación:")
    print(classification_report(y_true, y_pred, target_names=classes))



In [12]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Rutas a los modelos y dataset
    ckpt_densenet = "results_multiclass/breast_cancer_classification/densenet.ckpt"
    ckpt_mobilenet = "results_multiclass/breast_cancer_classification/mobilenet.ckpt"
    ckpt_vgg16 = "results_multiclass/breast_cancer_classification/vgg16.ckpt"
    data_csv = "df_full.csv"  # CSV con "ID_paciente" y "ground_truth"
    data_dir = "Breast AI study data"   # Carpeta raíz con imágenes en "benign" y "malign"

    # Transformaciones para el dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    # Crear el dataset y DataLoader
    patient_dataset = BreastDataset2DMulticlass(
        csv_file=data_csv,
        data_dir=data_dir,
        transform=transform,
        resize_to=(224, 224)
    )
    dataloader = DataLoader(patient_dataset, batch_size=16, shuffle=False, num_workers=4)

    # Obtener ground truth por paciente
    df = pd.read_csv(data_csv)

    # Mapeo de etiquetas para ground truth
    label_mapping_reverse = {"No follow up": 0, "Follow up": 1, "Biopsy": 2}

    # Convertir ground_truth a numérico para garantizar consistencia
    ground_truth = df.groupby("ID_paciente")["ground_truth"].first().apply(
        lambda x: label_mapping_reverse[x] if isinstance(x, str) else x
    ).to_dict()

    # Cargar modelos
    densenet, _ = load_model(ckpt_densenet, model_name="densenet")
    mobilenet, _ = load_model(ckpt_mobilenet, model_name="mobilenet")
    vgg16, _ = load_model(ckpt_vgg16, model_name="vgg16")

    # Hacer predicciones con el ensemble
    ensemble_preds = ensemble_predictions([densenet, mobilenet, vgg16], dataloader, device, method="average")

    # Evaluar resultados
    evaluate_ensemble(ground_truth, ensemble_preds, classes=["No follow up", "Follow up", "Biopsy"])

if __name__ == "__main__":
    main()




[BreastDatasetMulticlass] Found 1390 total images across all patients.
Usando el dispositivo: cuda


  checkpoint = torch.load(ckpt_path, map_location=device)


Usando el dispositivo: cuda
Usando el dispositivo: cuda




Matriz de Confusión:
              No follow up  Follow up  Biopsy
No follow up            15          2       3
Follow up                3         13       2
Biopsy                   2          3      17

Reporte de Clasificación:
              precision    recall  f1-score   support

No follow up       0.77      0.77      0.77        22
   Follow up       0.72      0.72      0.72        18
      Biopsy       0.75      0.75      0.75        20

    accuracy                           0.75        60
   macro avg       0.75      0.75      0.75        60
weighted avg       0.75      0.75      0.75        60

