In [None]:
import pickle
import os
import pandas as pd
import numbers
import networkx as nx
import numpy as np
from tqdm.auto import tqdm
import plotly.express as px
import plotly.io as pio
from datasets import get_dataset

pio.templates.default = "plotly_white"

In [None]:
datasets = {dataset_name: get_dataset(dataset_name) for dataset_name in ["Is_Acyclic_Ones", "MUTAG", "Shapes_Ones"]}
dataset_name = "Shapes_Ones"

In [None]:
def create_graph(adjacency_matrix, node_features):
    """
    Create a NetworkX graph from a numpy adjacency matrix and node feature matrix.

    Parameters:
    - adjacency_matrix (numpy.ndarray): The adjacency matrix of the graph.
    - node_features (numpy.ndarray): The matrix of node features.

    Returns:
    - nx.Graph: The created NetworkX graph.
    """

    # Create an empty graph
    G = nx.Graph()

    # Get the number of nodes in the graph
    num_nodes = adjacency_matrix.shape[0]

    # Add nodes to the graph with corresponding features
    for i in range(num_nodes):
        G.add_node(i, label=node_features[i])

    # Add edges to the graph based on the adjacency matrix
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if adjacency_matrix[i, j] != 0:
                G.add_edge(i, j)

    return G

def average_edit_distance(Gs):
    # Get the average edit distance between all pairs of graphs
    average_edit_distances = []
    for i in range(len(Gs)):
        edit_distances = []
        for j in range(len(Gs)):
            edit_distances.append(nx.graph_edit_distance(Gs[i], Gs[j], node_match=lambda x, y: np.isclose(x['label'], y['label']).all()))
        average_edit_distances.append(np.mean(edit_distances))   
    return np.mean(average_edit_distances)

In [None]:
%%capture

gdir = "./results/all_info/"
d_list = []
for filename in os.listdir(gdir):
    try:
        if d["dataset_name"] != dataset_name:
            continue
        d = pickle.load(open(gdir+filename, "rb"))
        d["run_id"] = filename.split(".")[0]
        d["max_class_name"] = datasets[d["dataset_name"]].GRAPH_CLS[d["max_class"]]
        d_list.append(d)
    except Exception:
        continue

In [None]:
pd.DataFrame(d_list[1]["mip_information"])

In [None]:
sorted(d_list[0].keys()), d_list[0]["solutions"][-1], d_list[0]["mip_information"][0]

In [None]:
for d in d_list:
    if not d["num_nodes"] == 8 or not d["dataset_name"] == "Shapes_Ones":
        continue
    print("Ah")
    d["mip_information"] = pd.DataFrame(d["mip_information"])
    d["mip_information"].rename(columns={"BestBound": "Best Bound", "ObjBound": "Objective Bound", "WorkUnits": "Work Units", "ExploredNodeCount": "Explored Node Count", "UnexploredNodeCount": "Unexplored Node Count"}, inplace=True)
    fig = px.line(d["mip_information"], 
    x="Work Units", 
    y=["Best Bound", "Objective Bound"], 
    title="Convergence of Objective Bounds", 
    width=1000, 
    height=800,
    log_y = True)
    fig.update_layout(
        font=dict(
            family= "Roman Modern",
            size=18,
            color='rgb(82, 82, 82)',
        ),
    xaxis=dict(
        ticks='outside',
        tickfont=dict(
            size=18,
            color='rgb(82, 82, 82)',
        ),
    ),
    yaxis=dict(
        title="Objective Value",
        tickfont=dict(
            size=18,
            color='rgb(82, 82, 82)',
        ),
    ),
    # legend=dict(
    #     visible=False
    # ),
    showlegend=False,
    autosize=False,
    margin=dict(
        autoexpand=True,
        l=100,
        r=20,
        t=110,
    ),
)   
    # Save figure
    fig.write_image(f"./results/figures/convergence_{d['dataset']}_class_{d['max_class']}_n_{d['num_nodes']}_id_{d['run_id']}.png")
    # fig.write_html("./results/figures/convergence_{d['dataset']}_class_{d['max_class']}_n_{d['num_nodes']}_id_{d['run_id']}.html") 
    # fig.show()


    fig = px.line(d["mip_information"], 
    x="Work Units", 
    y=["Explored Node Count", "Unexplored Node Count"], 
    title="Node Counts", 
    width=1000, 
    height=800,
    log_y = True)
    fig.update_layout(
        font=dict(
            family= "Roman Modern",
            size=18,
            color='rgb(82, 82, 82)',
        ),
    xaxis=dict(
        ticks='outside',
        tickfont=dict(
            size=18,
            color='rgb(82, 82, 82)',
        ),
    ),
    yaxis=dict(
        title="Number of Nodes",
        tickfont=dict(
            size=18,
            color='rgb(82, 82, 82)',
        ),
    ),
    showlegend=True,
    autosize=False,
    margin=dict(
        autoexpand=True,
        l=100,
        r=20,
        t=110,
    ),
    )   
    fig.write_image(f"./results/figures/node_counts_{d['dataset']}_class_{d['max_class']}_n_{d['num_nodes']}_id_{d['run_id']}.png")
    # fig.show()
    

In [None]:
mipexplainer_df = pd.DataFrame([{key: value for key, value in d.items() if isinstance(value, numbers.Number) or key in {"run_id", "dataset_name"}} for d in d_list])

mipexplainer_df["G"] = [create_graph(d["solutions"][-1]["A"], d["solutions"][-1]["X"]) for d in d_list]
mipexplainer_df["init_G"] = [create_graph(d["solutions"][0]["A"], d["solutions"][0]["X"]) for d in d_list]
mipexplainer_df["method"] = "MIPExplainer"

mipexplainer_df = mipexplainer_df[mipexplainer_df["dataset_name"]==dataset_name]

In [None]:
gnninterpreter_df = pd.DataFrame(pickle.load(open(f"results/gnninterpreter_{dataset_name}.pkl", "rb")))
xgnn_df = pd.DataFrame(pickle.load(open(f"results/xgnn_{dataset_name}.pkl", "rb")))

df = pd.concat([mipexplainer_df, gnninterpreter_df, xgnn_df])
index_names = ["dataset_name", "max_class", "num_nodes", "method"] 

df = df[df["max_class"] != 4]

df = df.set_index(index_names).sort_index()

In [None]:
a = df[[f"Output Logit {i}" for i in range(4)]].copy() # df[[c for c in df.columns if "Output Logit" in c]]
a.rename(columns={f"Output Logit {i}": f"{datasets[dataset_name].GRAPH_CLS[i]} Output Logit" for i in range(datasets[dataset_name].num_classes)}, inplace=True)
# a = a.div(a.sum(axis=1)**2, axis=0)
logit_table = a.groupby(index_names).mean()
with open(f"results/tables/output_logit_{dataset_name}.tex", "w") as f:
    f.write(logit_table.to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))
logit_table

In [None]:
runtime_table = df.groupby(index_names)["runtime"].mean()
with open(f"results/tables/runtime_{dataset_name}.tex", "w") as f:
    f.write(runtime_table.to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))
runtime_table

In [None]:
distances = []
for name, group in tqdm(df.groupby(index_names)["G"]):
    # Save the average edit distance of the group to a df
    group = list(group)
    distances.append({"Consistency": average_edit_distance(group)} | dict(zip(index_names, name)))
distances_df = pd.DataFrame(distances).set_index(index_names).sort_index()

In [None]:
with open(f"results/tables/consistency_{dataset_name}.tex", "w") as f:
    f.write(distances_df.to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))
runtime_table

In [None]:
# Average over num_nodes
averaged_distances_df = distances_df.groupby(["dataset_name", "max_class", "method"]).mean()
with open(f"results/tables/averaged_consistency_{dataset_name}.tex", "w") as f:
    f.write(averaged_distances_df.to_latex(index=True, float_format="{:.3f}".format).replace("_", "\\_"))