In [1]:
import sys
import os

# Agregar el directorio raíz del proyecto a sys.path
project_root = "/home/javitrucas/TFG"
if project_root not in sys.path:
    sys.path.append(project_root)

In [None]:
# Importar bibliotecas necesarias
import os
import torch
import numpy as np
import matplotlib.pyplot as plt  # Para visualización de gráficas
from torchvision.utils import make_grid  # Para visualizar imágenes
import wandb  # Para registro con Weights & Biases
from scripts.MNIST.MNISTMILDataset import MNISTMILDataset
from scripts.MNIST.evaluation import ModelEvaluator
from scripts.MNIST.training import Training

# Configuración inicial
output_model_dir = './models'  # Ruta relativa al directorio actual

# Crear directorios si no existen
os.makedirs(output_model_dir, exist_ok=True)

# Definir función para ejecutar experimentos
def run_experiment(target_digit, bag_size, num_epochs, learning_rate, pooling_type):
    """
    Ejecuta un experimento con los hiperparámetros dados y registra resultados en wandb.
    
    Args:
        target_digit (int): Dígito objetivo para las bolsas.
        bag_size (int): Número de instancias por bolsa.
        num_epochs (int): Número de épocas de entrenamiento.
        learning_rate (float): Tasa de aprendizaje.
        pooling_type (str): Tipo de agrupación ('attention', 'mean', 'max').
    """
    # Inicializar wandb
    wandb.init(
        project="TFG",  # Nombre del proyecto en wandb
        config={
            "target_digit": target_digit,
            "bag_size": bag_size,
            "num_epochs": num_epochs,
            "learning_rate": learning_rate,
            "pooling_type": pooling_type
        }
    )
    
    print(f"=== Iniciando experimento ===")
    print(f"Target Digit: {target_digit}, Bag Size: {bag_size}, Epochs: {num_epochs}, LR: {learning_rate}, Pooling: {pooling_type}")
    wandb.log({"status": "Experiment started", "target_digit": target_digit, "bag_size": bag_size, "num_epochs": num_epochs, "learning_rate": learning_rate, "pooling_type": pooling_type})
    
    # Crear datasets
    print("Creando datasets...")
    train_dataset = MNISTMILDataset(subset="train", bag_size=bag_size, obj_label=target_digit)
    test_dataset = MNISTMILDataset(subset="test", bag_size=bag_size, obj_label=target_digit)
    wandb.log({"status": "Datasets created", "train_dataset_size": len(train_dataset), "test_dataset_size": len(test_dataset)})
    
    # Dividir el conjunto de entrenamiento en entrenamiento (80%) y validación (20%)
    print("Dividiendo el conjunto de entrenamiento en entrenamiento (80%) y validación (20%)...")
    train_split_idx = int(len(train_dataset) * 0.8)
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_split_idx, len(train_dataset) - train_split_idx])
    wandb.log({"status": "Training and validation split completed", "train_split_size": len(train_dataset), "val_split_size": len(val_dataset)})
    
    # Entrenamiento
    print("Iniciando entrenamiento...")
    trainer = Training(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        output_model_dir=output_model_dir,
        pooling_type=pooling_type
    )
    trainer.train()
    wandb.log({"status": "Training completed"})
    
    # Registrar métricas de entrenamiento en wandb
    if hasattr(trainer, 'train_losses') and hasattr(trainer, 'val_losses'):
        print("Registrando métricas de entrenamiento en wandb...")
        for epoch in range(num_epochs):
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": trainer.train_losses[epoch],
                "val_loss": trainer.val_losses[epoch]
            })
        wandb.log({"status": "Training metrics logged to wandb"})
    
    # Evaluación
    print("Evaluando el modelo...")
    evaluator = ModelEvaluator(
        model_path=os.path.join(output_model_dir, 'model.pth'),
        test_dataset=test_dataset,
        batch_size=1,
        pooling_type=pooling_type
    )
    
    results, attention_weights = evaluator.evaluate()
    wandb.log({"status": "Model evaluation completed", **results})
    
    # Mostrar gráficas de entrenamiento
    if hasattr(trainer, 'train_losses') and hasattr(trainer, 'val_losses'):
        print("Generando gráfica de pérdidas durante el entrenamiento...")
        plt.figure(figsize=(10, 5))
        plt.plot(trainer.train_losses, label='Train Loss')
        plt.plot(trainer.val_losses, label='Validation Loss')
        plt.title('Loss durante el entrenamiento')
        plt.xlabel('Época')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()
        wandb.log({"status": "Training loss plot generated"})
    
    # Visualizar heatmaps de atención (si están disponibles)
    if attention_weights is not None:
        print("Mostrando heatmaps de atención mejorados...")
        
        num_bags = min(5, len(attention_weights))  # Mostrar máximo 5 bolsas
        
        # Obtener rangos para normalización si es necesario
        att_min, att_max = np.min(attention_weights), np.max(attention_weights)

        for i, weights in enumerate(attention_weights[:num_bags]):
            plt.figure(figsize=(6, 6))
            plt.imshow(weights, cmap='inferno', aspect='auto', vmin=att_min, vmax=att_max)
            plt.colorbar(label="Intensidad de atención")
            plt.title(f"Heatmap de atención para la bolsa {i+1}")
            plt.xlabel("Elementos en la bolsa")
            plt.ylabel("Características")
            plt.show()
            wandb.log({"status": f"Attention heatmap for bag {i+1} generated"})
        
        # Gráfica de la evolución de la atención
        print("Generando gráfica de evolución de la atención...")
        mean_attention = [np.mean(weights) for weights in attention_weights]
        plt.figure(figsize=(10, 5))
        plt.plot(mean_attention, marker='o', linestyle='-', color='blue', alpha=0.7)
        plt.title("Evolución de la Intensidad de Atención por Bolsa")
        plt.xlabel("Bolsa")
        plt.ylabel("Media de Atención")
        plt.grid(True)
        plt.show()
        wandb.log({"status": "Attention evolution plot generated"})

        # Histograma de distribución de pesos de atención
        print("Generando histograma de distribución de pesos de atención...")
        all_weights = np.concatenate([weights.flatten() for weights in attention_weights])
        plt.figure(figsize=(8, 5))
        plt.hist(all_weights, bins=30, color='purple', alpha=0.75)
        plt.title("Distribución de Pesos de Atención")
        plt.xlabel("Valor de Atención")
        plt.ylabel("Frecuencia")
        plt.grid(True)
        plt.show()
        wandb.log({"status": "Attention distribution histogram generated"})
    
    # Guardar el modelo como artefacto en wandb
    print("Guardando el modelo como artefacto en wandb...")
    artifact = wandb.Artifact('trained_model', type='model')
    artifact.add_file(os.path.join(output_model_dir, 'model.pth'))
    wandb.log_artifact(artifact)
    wandb.log({"status": "Model saved as artifact in wandb"})
    
    print(f"=== Resultados del experimento ===")
    print(results)
    print(f"=== Fin del experimento ===\n")
    
    # Finalizar wandb
    wandb.log({"status": "Experiment finished"})
    wandb.finish()

In [None]:
# Experimento con pooling_type="attention"
params_attention = {
    "target_digit": 3,
    "bag_size": 10,
    "num_epochs": 5,
    "learning_rate": 1e-3,
    "pooling_type": "attention"
}
run_experiment(**params_attention)

In [None]:
# Experimento con pooling_type="mean"
params_mean = {
    "target_digit": 3,
    "bag_size": 10,
    "num_epochs": 5,
    "learning_rate": 1e-3,
    "pooling_type": "mean"
}
run_experiment(**params_mean)

In [None]:
# Experimento con pooling_type="max"
params_max = {
    "target_digit": 3,
    "bag_size": 10,
    "num_epochs": 5,
    "learning_rate": 1e-3,
    "pooling_type": "max"
}
run_experiment(**params_max)

In [None]:
import os
import random
import torch
import numpy as np
import pandas as pd
import csv

from scripts.MNIST.MNISTMILDataset import MNISTMILDataset
from scripts.MNIST.training import Training
from scripts.MNIST.evaluation import ModelEvaluator

# Parámetros fijos y rejilla de búsqueda
target_digit   = 3
bag_sizes      = [10, 15, 30]
learning_rates = [1e-4, 1e-3, 1e-2]
pooling_types  = ['attention', 'mean', 'max']
seeds          = list(range(5))    # cinco ejecuciones distintas

# CSV de salida
csv_file = 'mnist_experiment_runs.csv'
if os.path.exists(csv_file):
    os.remove(csv_file)

# Cabecera para el CSV (incluimos la columna 'seed')
fieldnames = [
    'bag_size', 'learning_rate', 'pooling', 'seed',
    'train_accuracy', 'train_auc', 'train_f1',
    'test_accuracy',  'test_auc',  'test_f1'
]

with open(csv_file, mode='w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()

for bag_size in bag_sizes:
    # Precreamos los datasets completos (se usarán semillas diferentes en el split)
    full_train = MNISTMILDataset(subset="train", bag_size=bag_size, obj_label=target_digit)
    test_ds    = MNISTMILDataset(subset="test",  bag_size=bag_size, obj_label=target_digit)
    split_count = len(full_train)
    split_idx   = int(split_count * 0.8)

    for lr in learning_rates:
        for pool in pooling_types:
            for seed in seeds:
                print(f"\n=== Experimento: bag_size={bag_size}, lr={lr}, pooling={pool}, seed={seed} ===")

                # Fijar semilla global
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed_all(seed)

                # Split reproducible con torch.Generator
                g = torch.Generator()
                g.manual_seed(seed)
                train_ds, val_ds = torch.utils.data.random_split(
                    full_train,
                    [split_idx, split_count - split_idx],
                    generator=g
                )

                # Entrenamiento
                trainer = Training(
                    train_dataset=train_ds,
                    val_dataset=val_ds,
                    num_epochs=7,           # ajusta si lo necesitas
                    learning_rate=lr,
                    output_model_dir='./models',
                    pooling_type=pool
                )
                trainer.train()

                # Evaluación en train (final)
                evaluator_train = ModelEvaluator(
                    model_path=os.path.join('./models', 'model.pth'),
                    test_dataset=train_ds,
                    batch_size=1,
                    pooling_type=pool
                )
                results_train, _ = evaluator_train.evaluate()

                # Evaluación en test
                evaluator_test = ModelEvaluator(
                    model_path=os.path.join('./models', 'model.pth'),
                    test_dataset=test_ds,
                    batch_size=1,
                    pooling_type=pool
                )
                results_test, _ = evaluator_test.evaluate()

                # Preparamos fila con métricas
                row = {
                    'bag_size': bag_size,
                    'learning_rate': lr,
                    'pooling': pool,
                    'seed': seed,
                    # métricas de entrenamiento
                    'train_accuracy': results_train.get('accuracy'),
                    'train_auc':      results_train.get('auc'),
                    'train_f1':       results_train.get('f1_score'),
                    # métricas de test
                    'test_accuracy': results_test.get('accuracy'),
                    'test_auc':      results_test.get('auc'),
                    'test_f1':       results_test.get('f1_score'),
                }

                # Añadimos la fila al CSV inmediatamente
                with open(csv_file, mode='a', newline='') as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writerow(row)

print(f"\nTodos los resultados (5 semillas por configuración) se han ido guardando en '{csv_file}'.")



=== Experimento: bag_size=10, lr=0.0001, pooling=attention, seed=0 ===
Epoch 1/7
Entrenamiento - Loss: 1846.9032, Accuracy: 0.8175
Validación - Loss: 183.0522, Accuracy: 0.9467
