# Librerías

# Cargar Dataset

# Prueba Chat

In [None]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch_geometric.data import DataLoader
from dataset_loader import load_test_dataset, load_norm_info
from evaluate import evaluate
from dataset_utils import load_model_by_name
from gcn_model import GCN
from gat_model import GAT
from sage_model import SAGE
from nnconv_model import NNConvNet

In [None]:
folder_base = "saved_models"
architectures = ["GCN", "GAT", "SAGE", "NNConv"]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset_path = "data/test_dataset.pth"
norm_info_path = "data/normalization_info.pth"
test_loader = load_test_dataset(dataset_path)
norm_info = load_norm_info(norm_info_path)

In [None]:
def get_model_class(arch):
    if arch == "GCN": return GCN
    if arch == "GAT": return GAT
    if arch == "SAGE": return SAGE
    if arch == "NNConv": return NNConvNet
    raise ValueError(f"Arquitectura no reconocida: {arch}")


def get_model_config_from_filename(filename):
    # Añade aquí si quieres parsear hiperparámetros del nombre
    return {
        "input_dim": 3,
        "hidden_dim": 128,
        "output_dim": 1,
        "num_layers": 3,
        "use_dropout": True,
        "dropout_rate": 0.2,
        "use_batchnorm": False,
        "use_residual": False,
        "edge_dim": 3  # solo para NNConv
    }

In [None]:
resultados = {}

for arch in architectures:
    model_dir = os.path.join(folder_base, arch)
    if not os.path.isdir(model_dir):
        print(f"No se encontró la carpeta de modelos para {arch}.")
        continue

    model_files = [f for f in os.listdir(model_dir) if f.endswith(".pth")]
    if not model_files:
        print(f"No hay modelos en {model_dir}")
        continue

    # Se puede elegir uno específico o el último por ejemplo
    filename = sorted(model_files)[-1]
    model_config = get_model_config_from_filename(filename)
    ModelClass = get_model_class(arch)

    if arch == "NNConv":
        model = ModelClass(
            input_dim=model_config["input_dim"],
            hidden_dim=model_config["hidden_dim"],
            output_dim=model_config["output_dim"],
            num_layers=model_config["num_layers"],
            edge_dim=model_config["edge_dim"],
            use_dropout=model_config["use_dropout"],
            dropout_rate=model_config["dropout_rate"],
            use_batchnorm=model_config["use_batchnorm"],
            use_residual=model_config["use_residual"]
        ).to(device)
    else:
        model = ModelClass(
            input_dim=model_config["input_dim"],
            hidden_dim=model_config["hidden_dim"],
            output_dim=model_config["output_dim"],
            num_layers=model_config["num_layers"],
            use_dropout=model_config["use_dropout"],
            dropout_rate=model_config["dropout_rate"],
            use_batchnorm=model_config["use_batchnorm"],
            use_residual=model_config["use_residual"]
        ).to(device)

    model_path = os.path.join(model_dir, filename)
    model = load_model_by_name(model, filename, folder=model_dir)

    print(f"\nEvaluando {arch} desde archivo: {filename}")
    plot = True  # Solo el último modelo graficado

    results = evaluate(
        model, test_loader, device, norm_info,
        lambda_physics=0.0,
        use_physics=False,
        lambda_boundary=0.0,
        use_boundary_loss=False,
        lambda_heater=0.0,
        use_heater_loss=False,
        error_threshold=2.0,
        percentage_threshold=None,
        plot_results=plot
    )

    resultados[arch] = results

In [None]:
# --- CELDA 6: Comparación visual o resumen ---
metricas = ["MSE", "MAE", "R2", "Accuracy"]

valores = {m: [] for m in metricas}

for arch in architectures:
    if arch not in resultados:
        for m in metricas:
            valores[m].append(None)
        continue

    mse, mae, r2, acc = resultados[arch][0], resultados[arch][1], resultados[arch][2], resultados[arch][3]
    valores["MSE"].append(mse)
    valores["MAE"].append(mae)
    valores["R2"].append(r2)
    valores["Accuracy"].append(acc)

x = np.arange(len(architectures))
width = 0.2

fig, ax = plt.subplots(figsize=(10, 6))

for i, m in enumerate(metricas):
    ax.bar(x + i * width, valores[m], width, label=m)

ax.set_xticks(x + width * 1.5)
ax.set_xticklabels(architectures)
ax.set_ylabel("Valor")
ax.set_title("Comparación de métricas por arquitectura")
ax.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()