# Imports, settings & definitions

In [None]:
%load_ext autoreload
%autoreload 2
import os
import pathlib
import functools

import pandas as pd
import torch
import wandb

from gnn_fiedler_approx.algebraic_connectivity_script import load_dataset, generate_model, generate_loss_function, evaluate
from gnn_fiedler_approx.algebraic_connectivity_evaluate import combine_data_objects

ON_HPC = "PBS_O_HOME" in os.environ
HPC_MODEL_DIR = pathlib.Path("/lustre/home/mkrizman/Topocon_GNN/gnn_fiedler_approx/models/")
LOC_MODEL_DIR = pathlib.Path().cwd().parent / "models"

api = wandb.Api()
device = "cuda" if torch.cuda.is_available() else "cpu"
__builtins__.device = device  # Hack to make device available in other modules.


In [None]:
@functools.lru_cache(maxsize=None)
def load_dataset_cached(selected_graph_sizes, selected_features, transform, batch_size, split):
    if selected_graph_sizes is not None:
        selected_graph_sizes = {selected_graph_sizes: -1}
    return load_dataset(
        selected_graph_sizes=selected_graph_sizes,
        selected_features=selected_features,
        label_normalization=None,
        transform=transform,
        batch_size=batch_size,
        split=split,
        suppress_output=True,
    )

def prepare_data(run, dataset_definition):
    # Load the dataset.
    if dataset_definition == "standard":
        print("standard")
        train_data_obj, val_data_obj, test_data_obj, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes=None,
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=tuple(run.config["dataset"]["split"]),
        )

        datasets_for_evaluation = ["test", "entire"]
        dataobjects_for_evaluation = [test_data_obj, combine_data_objects([train_data_obj, val_data_obj, test_data_obj])]

    elif dataset_definition == "generalization":
        train_data_obj = None

        train_data_obj, val_data_obj, test_data_obj, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes=None,
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=tuple(run.config["dataset"]["split"]),
        )

        _, _, test_data_obj_15, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes="11-15_mix_200",
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=(0.0, 0.0),
        )

        _, _, test_data_obj_20, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes="16-20_mix_200",
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=(0.0, 0.0),
        )

        _, _, test_data_obj_25, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes="21-25_mix_200",
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=(0.0, 0.0),
        )

        _, _, test_data_obj_50, dataset_config, features, dataset_props = load_dataset_cached(
            selected_graph_sizes="50_mix_200",
            selected_features=tuple(run.config["selected_features"]),
            transform=run.config["transform"],
            batch_size=run.config["dataset"]["batch_size"],
            split=(0.0, 0.0),
        )

        datasets_for_evaluation = ["test", "15", "20", "25", "50", "entire"]
        dataobjects_for_evaluation = [
            test_data_obj,
            test_data_obj_15,
            test_data_obj_20,
            test_data_obj_25,
            test_data_obj_50,
        ]
        dataobjects_for_evaluation.append(combine_data_objects(dataobjects_for_evaluation))

    return train_data_obj, datasets_for_evaluation, dataobjects_for_evaluation, dataset_props


In [None]:

def evaluate_run(run, dataset_definition):
    train_data_obj, datasets_for_evaluation, dataobjects_for_evaluation, dataset_props = prepare_data(run, dataset_definition)

    # Generate the model skeleton.
    model = generate_model(
        architecture=run.config["architecture"],
        in_channels=dataset_props["feature_dim"],
        hidden_channels=run.config["hidden_channels"],
        gnn_layers=run.config["gnn_layers"],
        mlp_layers=run.config["mlp_layers"],
        pool=run.config["pool"],
        jk=run.config["jk"],
        dropout=run.config["dropout"],
        norm=run.config["norm"],
        act=run.config["activation"],
    )
    model.to(device)

    # Download or locate the model dict.
    if ON_HPC:
        model_dict_path = HPC_MODEL_DIR / f"{run.id}_best_model.pth"
    else:
        model_dict_path = LOC_MODEL_DIR / f"{run.id}_best_model.pth"
        if not model_dict_path.exists():
            os.system(f"scp mkrizman@login-gpu.hpc.srce.hr:{HPC_MODEL_DIR}/{run.id}_best_model.pth {LOC_MODEL_DIR}/")

    # Load the model state dict.
    model.load_state_dict(torch.load(model_dict_path, map_location=device)["model_state_dict"])

    criterion = generate_loss_function(run.config.get("loss", "MAPE"))


    results_table = dict.fromkeys(datasets_for_evaluation)
    results_scores = dict.fromkeys(datasets_for_evaluation)
    epoch = -1

    print(f"    {run.id}")
    for dataset, eval_data_obj in zip(datasets_for_evaluation, dataobjects_for_evaluation):
        eval_results = evaluate(
                model,
                epoch,
                criterion,
                train_data_obj,
                eval_data_obj,
                dataset_props["transformation"],
                title=f"Results on the {dataset} dataset",
                plot_graphs_wandb=False,
                plot_embeddings=False,
                make_table_wandb=False,
                suppress_output=True
            )

        print(f"        {dataset}")
        if dataset == "test":
            print(f"            mean_err: {eval_results['mean_err']:.5f} | {run.summary.get('mean_err', 'N/A'):.5f}")
            print(f"            stddev_err: {eval_results['stddev_err']:.5f} | {run.summary.get('stddev_err', 'N/A'):.5f}")
            print(f"            eval_train_loss: {eval_results['eval_train_loss']:.5f} | {run.summary.get('eval_train_loss', 'N/A'):.5f}")
            print(f"            eval_test_loss: {eval_results['eval_test_loss']:.5f} | {run.summary.get('eval_test_loss', 'N/A'):.5f}")
            print(f"            good_within.99: {eval_results['good_within']['99']:.5f} | {run.summary['good_within']['99']:.5f}")
            print(f"            good_within.95: {eval_results['good_within']['95']:.5f} | {run.summary['good_within']['95']:.5f}")
            print(f"            good_within.90: {eval_results['good_within']['90']:.5f} | {run.summary['good_within']['90']:.5f}")
            print(f"            good_within.80: {eval_results['good_within']['80']:.5f} | {run.summary['good_within']['80']:.5f}")

        results_table[dataset] = eval_results["table"]
        results_scores[dataset] = eval_results["good_within"]

    return results_table, results_scores, datasets_for_evaluation


In [None]:
import numpy as np


def evaluate_baseline(run, dataset_definition, template_tables):
    train_data_obj, datasets_for_evaluation, dataobjects_for_evaluation, dataset_props = prepare_data(run, dataset_definition)

    avg = torch.mean(train_data_obj.y)
    avg = avg.item()

    results_table = dict.fromkeys(datasets_for_evaluation)
    results_scores = dict.fromkeys(datasets_for_evaluation)
    for dataset in datasets_for_evaluation:
        df = template_tables[dataset].copy()

        df["Predicted"] = avg

        # Calculate the statistics.
        df["Error"] = df["True"] - df["Predicted"]
        df["Error %"] = 100 * df["Error"] / df["True"]
        df["abs(Error)"] = np.abs(df["Error"])

        good_within = {
            "99": len(df[df["Error %"].between(-1, 1)]) / len(df) * 100,
            "95": len(df[df["Error %"].between(-5, 5)]) / len(df) * 100,
            "90": len(df[df["Error %"].between(-10, 10)]) / len(df) * 100,
            "80": len(df[df["Error %"].between(-20, 20)]) / len(df) * 100,
        }

        results_table[dataset] = df
        results_scores[dataset] = good_within

    return results_table, results_scores


# Evaluate final models on the dataset and store results for all graphs

In [None]:
save_path = pathlib.Path().cwd().parent / "results" / "runs_data"
save_path.mkdir(parents=True, exist_ok=True)

## Standard evaluation with full and relaxed models on "in-distribution" data.
# sweep_ids = {
#     "Full": "lo3tkjor",
#     "Dist. 32": "i6swq9gx",
#     "Dist. 64": "vuezsfvg",
# }
# save_path /= "results_final_evaluation.pkl"
# save_ext = ""
# dataset_definition = "standard"

## Evaluation of the full model on "out-of-distribution" data to test generalization.
sweep_ids = {
    "Full": "lo3tkjor",
}
save_path /= "results_final_generalization.pkl"
save_ext = "_gen"
dataset_definition = "generalization"

In [None]:
results_all_table = dict.fromkeys(sweep_ids.keys())
results_all_scores = dict.fromkeys(sweep_ids.keys())
for model_type, sweep_id in sweep_ids.items():
    sweep = api.sweep(f"marko-krizmancic/gnn_fiedler_approx_v3/{sweep_id}")  # labels_all
    runs = sweep.runs

    results_all_table[model_type] = {}
    results_all_scores[model_type] = {}
    print(f"Evaluating {model_type} models...")
    for run in runs:
        print(f"    {run.id}")
        table, scores, datasets_for_evaluation = evaluate_run(run, dataset_definition)
        results_all_table[model_type][run.id] = table
        results_all_scores[model_type][run.id] = scores

table, scores = evaluate_baseline(runs[0], dataset_definition, results_all_table[model_type][runs[0].id])
results_all_table["Baseline"] = {runs[0].id: table}
results_all_scores["Baseline"] = {runs[0].id: scores}

### Save or load evaluated data to avoid computing it again

In [None]:
import pickle
if save_path.exists():
    x = input(f"Are you sure you want to overwrite {save_path.name}? (y/n): ")
    if x.lower() != "y":
       print("Aborting save.")
       exit()

with open(save_path, "wb") as f:
    pickle.dump((results_all_table, results_all_scores, datasets_for_evaluation), f)

In [None]:
import pickle
with open(save_path, "rb") as f:
    results_all_table, results_all_scores, datasets_for_evaluation = pickle.load(f)

# Plot average error distributions

In [None]:
import plotly.graph_objects as go
import numpy as np

# --- Zoom window parameters ---
zoom_x_range = [0, 5]
zoom_y_range = [50, 100]
zoom_width = 0.5
zoom_height = 0.5
zoom_x_pos = 0.25
zoom_y_pos = 0.15

def paper_to_data(coord, range):
    return coord * (range[1] - range[0]) + range[0]

def data_to_paper(coord, range):
    return (coord - range[0]) / (range[1] - range[0])
# --- End of parameters ---

fig = go.Figure()

colors = {
    "Full": "blue",
    "Dist. 32": "green",
    "Dist. 64": "red",
    "Baseline": "#FFA500"  # Orange
}

line_styles = {  # ['solid', 'dot', 'dash', 'longdash', 'dashdot', 'longdashdot']
    "test": "solid",
    "entire": "dash",
    "15": "dot",
    "20": "longdash",
    "25": "dashdot",
    "50": "longdashdot",
}

# Define a common x-axis for interpolation
x_common = np.linspace(0, 100, 501)
xaxis_range = [0, 100]
yaxis_range = [0, 105]

for dataset in datasets_for_evaluation:
    for model_type in results_all_table:

        list_of_runs = results_all_table[model_type].keys()
        if not list_of_runs:
            continue

        interpolated_ecdfs = []
        for run_id in list_of_runs:
            if dataset not in results_all_table[model_type][run_id]:
                continue

            df = results_all_table[model_type][run_id][dataset]

            # Calculate ECDF for the current run
            plot_df = pd.DataFrame()
            plot_df["abs(Error %)"] = np.abs(df["Error %"])
            plot_df.sort_values(by="abs(Error %)", inplace=True)

            x_ecdf = plot_df["abs(Error %)"]
            y_ecdf = (np.arange(1, len(plot_df) + 1)) / len(plot_df) * 100

            # Interpolate the ECDF on the common x-axis
            y_interp = np.interp(x_common, x_ecdf, y_ecdf, right=100)
            interpolated_ecdfs.append(y_interp)

        if not interpolated_ecdfs:
            continue

        # Calculate mean, min, and max of the interpolated ECDFs
        all_ecdfs = np.vstack(interpolated_ecdfs)
        mean_ecdf = np.mean(all_ecdfs, axis=0)
        min_ecdf = np.min(all_ecdfs, axis=0)
        max_ecdf = np.max(all_ecdfs, axis=0)

        # --- Main plot traces ---
        # Plot the average line
        fig.add_trace(go.Scatter(
            x=x_common,
            y=mean_ecdf,
            mode='lines',
            name=f"{model_type} - {dataset}",
            line=dict(color=colors[model_type], dash=line_styles[dataset])
        ))

        if colors[model_type].startswith("#"):
            fillcolor = f"rgba({int(colors[model_type].lstrip('#')[0:2], 16)}, {int(colors[model_type].lstrip('#')[2:4], 16)}, {int(colors[model_type].lstrip('#')[4:6], 16)}, 0.2)"
        else:
            color_options = {'blue': '0,0,255', 'green': '0,128,0', 'red': '255,0,0'}
            fillcolor = f"rgba({color_options[colors[model_type]]}, 0.2)"

        # Plot the shaded area for min/max spread
        fig.add_trace(go.Scatter(
            x=np.concatenate([x_common, x_common[::-1]]),
            y=np.concatenate([min_ecdf, max_ecdf[::-1]]),
            fill='toself',
            fillcolor=fillcolor,
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

        # --- Zoom plot traces ---
        fig.add_trace(go.Scatter(
            x=x_common,
            y=mean_ecdf,
            mode='lines',
            line=dict(color=colors[model_type], dash=line_styles[dataset]),
            xaxis='x2',
            yaxis='y2',
            showlegend=False,
            hoverinfo="skip"
        ))
        fig.add_trace(go.Scatter(
            x=np.concatenate([x_common, x_common[::-1]]),
            y=np.concatenate([min_ecdf, max_ecdf[::-1]]),
            fill='toself',
            fillcolor=fillcolor,
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False,
            xaxis='x2',
            yaxis='y2'
        ))


fig.update_xaxes(showspikes=True, tickvals=[1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], title_text="Absolute Error (%)")
fig.update_yaxes(showspikes=True, nticks=10, title_text="% of graphs within error")

fig.update_layout(
    font=dict(size=20),
    margin=dict(l=5, r=5, t=5, b=65),
    legend=dict(
            title=dict(text="Model - Dataset", side="top"),
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1,
        ),
    xaxis_range=xaxis_range,
    yaxis_range=yaxis_range,
    xaxis2=dict(
        domain=[zoom_x_pos, zoom_x_pos + zoom_width],
        range=zoom_x_range,
        showticklabels=True,
        tickfont=dict(size=12),
        ticks="inside",
        tickmode="linear",
        tick0=zoom_x_range[0],
        dtick=1,
        anchor="y2"
    ),
    yaxis2=dict(
        domain=[zoom_y_pos, zoom_y_pos + zoom_height],
        range=zoom_y_range,
        showticklabels=True,
        tickfont=dict(size=14),
        ticks="inside",
        tickmode="linear",
        tick0=zoom_y_range[0],
        dtick=10,
        anchor="x2"
    ),
    shapes=[
        # Shape to indicate zoomed area on the main plot
        dict(
            type="rect",
            xref="x", yref="y",
            x0=zoom_x_range[0], y0=zoom_y_range[0],
            x1=zoom_x_range[1], y1=zoom_y_range[1],
            line=dict(color="rgba(0,0,0,0.5)", width=1, dash="dot"),
            fillcolor="rgba(0,0,0,0.1)"
        ),
        # Shape for the border of the zoom window
        dict(
            type='rect',
            xref='paper', yref='paper',
            x0=zoom_x_pos, y0=zoom_y_pos,
            x1=zoom_x_pos + zoom_width, y1=zoom_y_pos + zoom_height,
            line=dict(color='black', width=1)
        ),
        # Line connecting lower-left corners
        dict(
            type="line",
            xref="paper", yref="paper",
            x0=data_to_paper(zoom_x_range[0], xaxis_range), y0=data_to_paper(zoom_y_range[0], yaxis_range),
            x1=zoom_x_pos, y1=zoom_y_pos,
            line=dict(color="rgba(0,0,0,0.5)", width=1, dash="dot")
        ),
        # Line connecting upper-right corners
        dict(
            type="line",
            xref="paper", yref="paper",
            x0=data_to_paper(zoom_x_range[1], xaxis_range), y0=data_to_paper(zoom_y_range[1], yaxis_range),
            x1=zoom_x_pos + zoom_width, y1=zoom_y_pos + zoom_height,
            line=dict(color="rgba(0,0,0,0.5)", width=1, dash="dot")
        )
    ],
)

fig.show()
image_path = pathlib.Path().cwd().parent / "results" / "error_plots"
image_path.mkdir(parents=True, exist_ok=True)
image_path /= f"final_comparison{save_ext}.pdf"
fig.write_image(image_path)


# Plot accuracies per graph size

In [None]:
# Create plot showing percentage of graphs with absolute error <= 1% and <= 5%
import numpy as np
import pandas as pd
import plotly.graph_objects as go

colors = {
    "Full": "blue",
    "Dist. 32": "green",
    "Dist. 64": "red",
    "Baseline": "#FFA500"  # Orange
}

# Group data by number of nodes and calculate percentages
def calculate_error_percentages(df):
    results = []

    for nodes in sorted(df['Nodes'].unique()):
        node_data = df[df['Nodes'] == nodes]
        total_count = len(node_data)

        # Calculate absolute error percentages
        abs_error = np.abs(node_data['Error %'])

        # Count graphs with absolute error <= 1% and <= 5%
        count_1_percent = np.sum(abs_error <= 1.0)
        count_5_percent = np.sum(abs_error <= 5.0)

        # Calculate percentages
        pct_1_percent = (count_1_percent / total_count) * 100
        pct_5_percent = (count_5_percent / total_count) * 100

        results.append({
            'Nodes': nodes,
            'Total': total_count,
            'Count_1pct': count_1_percent,
            'Count_5pct': count_5_percent,
            'Percentage_1pct': pct_1_percent,
            'Percentage_5pct': pct_5_percent
        })

    return pd.DataFrame(results)

# Create the plot
fig = go.Figure()

dataset = "entire"
for model_type in results_all_table:

    list_of_runs = results_all_table[model_type].keys()
    if not list_of_runs:
        continue

    all_df = pd.DataFrame()
    for run_id in list_of_runs:
        if dataset not in results_all_table[model_type][run_id]:
            continue

        df = results_all_table[model_type][run_id][dataset]
        all_df = pd.concat([all_df, df])

    all_df = all_df.groupby("Graph", as_index=False).agg({"Error %": "mean", **{col: 'first' for col in df.columns if col not in ["Error %"]}})

    # Calculate the percentages
    error_percentages = calculate_error_percentages(all_df)

    # Add line for <= 1% error
    fig.add_trace(go.Scatter(
        x=error_percentages['Nodes'],
        y=error_percentages['Percentage_1pct'],
        mode='lines+markers',
        name=f'{model_type} | ≤ 1%',
        line=dict(color=colors[model_type], dash='solid'),
        hovertemplate=(
            'Nodes: %{x}<br>' +
            'Percentage ≤ 1%: %{y:.2f}%<br>' +
            'Count: %{customdata[0]}/%{customdata[1]}<br>' +
            '<extra></extra>'
        ),
        customdata=np.column_stack([error_percentages['Count_1pct'], error_percentages['Total']])
    ))

    # Add line for <= 5% error
    fig.add_trace(go.Scatter(
        x=error_percentages['Nodes'],
        y=error_percentages['Percentage_5pct'],
        mode='lines+markers',
        name=f'{model_type} | ≤ 5%',
        legend="legend2",
        line=dict(color=colors[model_type], dash='dot'),
        hovertemplate=(
            'Nodes: %{x}<br>' +
            'Percentage ≤ 5%: %{y:.2f}%<br>' +
            'Count: %{customdata[0]}/%{customdata[1]}<br>' +
            '<extra></extra>'
        ),
        customdata=np.column_stack([error_percentages['Count_5pct'], error_percentages['Total']])
    ))

    # Update layout
    fig.update_layout(
        xaxis_title='Number of Nodes',
        yaxis_title='% of graphs within error',
        hovermode='x unified',
        showlegend=True,
        font=dict(size=12),
        yaxis=dict(range=[0, 100])  # Set y-axis from 0 to 100%
    )

fig.update_layout(
    font=dict(size=20),
    margin=dict(l=5, r=5, t=5, b=65),
    legend=dict(
        title=dict(text="Model | Error threshold", side="left"),
        orientation="h",
        yanchor="bottom",
        y=1.14,
        xanchor="right",
        x=1,
        ),
    legend2=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1,
    ),
    xaxis=dict(
        tickmode="linear",
        tick0=0,
        dtick=5,
    ),
    yaxis=dict(
        tickmode="linear",
        tick0=0,
        dtick=10,
    )
)
fig.show()
image_path = pathlib.Path().cwd().parent / "results" / "error_plots"
image_path.mkdir(parents=True, exist_ok=True)
image_path /= f"accuracy_per_size{save_ext}.pdf"
fig.write_image(image_path)


# Plot errors per graph size
For each recorded dataset (test or entire) and model type (full, dist. 32, dist. 64), we take the results from all seeds and average the error on each graph over those 20 seeds. We then plot the distribution of errors as if it were a single experiment.

In [None]:
from gnn_fiedler_approx.gnn_utils.utils import create_combined_histogram

for dataset in datasets_for_evaluation:
    for model_type in results_all_table:

        list_of_runs = results_all_table[model_type].keys()
        if not list_of_runs:
            continue

        all_df = pd.DataFrame()
        for run_id in list_of_runs:
            if dataset not in results_all_table[model_type][run_id]:
                continue

            df = results_all_table[model_type][run_id][dataset]
            all_df = pd.concat([all_df, df])

        all_df = all_df.groupby("Graph", as_index=False).agg({"Error %": "mean", **{col: 'first' for col in df.columns if col not in ["Error %"]}})

        fig = create_combined_histogram(all_df, "Nodes", "Error %", title=f"{model_type} - {dataset}", xlabel="Number of nodes", option="ci")
        fig.show()

        if model_type == "Full" and dataset == "test":
            fig.update_layout(
                font=dict(size=20),
                margin=dict(l=5, r=5, t=5, b=65),
                legend=dict(
                        # title=dict(text="Model - Dataset", side="top"),
                        orientation="h",
                        yanchor="bottom",
                        y=1.02,
                        xanchor="right",
                        x=1,
                    ),
            )
            fig.write_image(pathlib.Path().cwd().parent / "results" / "error_plots" / f"histogram_{model_type}_{dataset}{save_ext}.pdf")

            fig.update_layout(
                yaxis2=dict(
                    range=[-5, 5],
                    tickvals=list(range(-5, 6, 1)),
                    )
            )
            fig.write_image(pathlib.Path().cwd().parent / "results" / "error_plots" / f"histogram_{model_type}_{dataset}_zoomed{save_ext}.pdf")


# Create the latex table with results

In [None]:
import pandas as pd
import numpy as np
import pathlib
from IPython.display import display, Markdown

# --- Data Extraction and Processing ---
dataset_keys = datasets_for_evaluation
# Process in reverse to have 1% error threshold at the top of the table
good_within_keys = ["99", "95", "90", "80"]

processed_results = {}

for model_type in sweep_ids.keys():
    processed_results[model_type] = {}
    for dataset in dataset_keys:
        processed_results[model_type][dataset] = {}
        for key in good_within_keys:
            processed_results[model_type][dataset][key] = []
            # Check if the model_type and its runs exist in the results
            if model_type in results_all_scores and results_all_scores[model_type]:
                for run_results in results_all_scores[model_type].values():
                    # Check if the dataset and key exist for the run
                    if dataset in run_results and key in run_results[dataset]:
                        processed_results[model_type][dataset][key].append(run_results[dataset][key])

# --- Table Generation ---

def create_latex_table(df, caption, label):
    """Generates LaTeX code for a given DataFrame."""
    df_copy = df.copy()
    # Escape the '%' sign in the index for LaTeX rendering
    df_copy.index = df_copy.index.str.replace('%', r'\%', regex=False)

    with pd.option_context('display.max_colwidth', None):
        latex_str = df_copy.to_latex(
            multicolumn_format='c',
            multirow=True,
            escape=False,
            caption=caption,
            label=label,
            column_format='l' + 'c' * (len(df.columns))
        )
    latex_str = latex_str.replace("\\begin{table}", "\\begin{table}[htbp]")
    return latex_str

def format_df_for_latex(df_numeric):
    """Converts a numeric DataFrame with mean/std columns to a LaTeX-formatted DataFrame."""
    # Create the column structure for the new LaTeX DataFrame
    new_cols = df_numeric.columns.droplevel(2).unique()
    df_latex = pd.DataFrame(index=df_numeric.index, columns=new_cols)

    for model, dataset in new_cols:
        # Ensure the mean/std columns exist before trying to access them
        if (model, dataset, 'mean') in df_numeric.columns and (model, dataset, 'std') in df_numeric.columns:
            mean_col = df_numeric[(model, dataset, 'mean')]
            std_col = df_numeric[(model, dataset, 'std')]
            # Format into "mean ± std" string, handling potential missing values
            df_latex[(model, dataset)] = [
                f"${m:.2f} \\pm {s:.2f}$" if pd.notna(m) and pd.notna(s) else "N/A"
                for m, s in zip(mean_col, std_col)
            ]
    return df_latex

# --- Create Numerical Tables ---

# Table 1: Full Model (Numerical)
data_full_num = []
for key in good_within_keys:
    row = [f"{100 - int(key)}%"]
    for dataset in dataset_keys:
        values = processed_results.get("Full", {}).get(dataset, {}).get(key, [])
        row.extend([np.mean(values), np.std(values)] if values else [np.nan, np.nan])
    data_full_num.append(row)

cols_full = pd.MultiIndex.from_product([['Full'], dataset_keys, ['mean', 'std']], names=['Model', 'Dataset', 'Stat'])
df_full_num = pd.DataFrame(data_full_num, columns=['Error threshold'] + list(range(len(cols_full)))).set_index('Error threshold')
df_full_num.columns = cols_full

# Table 2: Dist. 32 & Dist. 64 Models (Numerical)
data_dist_num = []
for key in good_within_keys:
    row = [f"{100 - int(key)}%"]
    for model_type in ["Dist. 32", "Dist. 64"]:
        for dataset in dataset_keys:
            values = processed_results.get(model_type, {}).get(dataset, {}).get(key, [])
            row.extend([np.mean(values), np.std(values)] if values else [np.nan, np.nan])
    data_dist_num.append(row)

cols_dist = pd.MultiIndex.from_product([["Dist. 32", "Dist. 64"], dataset_keys, ['mean', 'std']], names=['Model', 'Dataset', 'Stat'])
df_dist_num = pd.DataFrame(data_dist_num, columns=['Error threshold'] + list(range(len(cols_dist)))).set_index('Error threshold')
df_dist_num.columns = cols_dist

# Table 3: Combined Table (Numerical)
df_combined_num = pd.concat([df_full_num, df_dist_num], axis=1)
df_combined_num = df_combined_num.sort_index(axis=1, level=[0, 1])


# --- Display Numerical Tables ---
display(Markdown("### Full Model (Numerical)"))
display(df_full_num.style.format(precision=2, na_rep="N/A"))
display(Markdown("### Distributed Models (Numerical)"))
display(df_dist_num.style.format(precision=2, na_rep="N/A"))
display(Markdown("### Combined Models (Numerical)"))
display(df_combined_num.style.format(precision=2, na_rep="N/A"))


# --- Generate and Save LaTeX ---
df_full_latex = format_df_for_latex(df_full_num)
df_dist_latex = format_df_for_latex(df_dist_num)
df_combined_latex = format_df_for_latex(df_combined_num)

latex_full = create_latex_table(df_full_latex, "Performance of the Full model.", "tab:full_model")
latex_dist = create_latex_table(df_dist_latex, "Performance of the Distributed models.", "tab:dist_models")
latex_combined = create_latex_table(df_combined_latex, "Combined performance of all models.", "tab:combined_models")

output_path = pathlib.Path().cwd().parent / "results" / "latex_tables"
output_path.mkdir(exist_ok=True)
with open(output_path / "full.tex", "w") as f:
    f.write(latex_full)

with open(output_path / "dist.tex", "w") as f:
    f.write(latex_dist)

with open(output_path / "combined.tex", "w") as f:
    f.write(latex_combined)