# Instructions
If you would like to skip processing the whole dataset as it can take a long time, you can use the preprocessed results data and models available in Results and Models directories.

To run this reduced analysis, run the following sections of this notebook:
1. Imports
1. Definitions for testing functions
1. Save & load (1st and 3rd cell)
1. Analyze results

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import pickle
import random
import time

import networkx as nx
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch_geometric.utils as tg_utils
import wandb
from torch_geometric.loader import DataLoader
from tqdm.notebook import tqdm, trange

from MIDS_dataset import MIDSDataset, MIDSLabelsDataset, MIDSProbabilitiesDataset
from MIDS_script import generate_model
from my_graphs_dataset import GraphDataset, GraphType

In [5]:
root = pathlib.Path().cwd()  # For Jupyter notebook.

In [2]:
from collections import Counter

def print_dataset_splits(train_dataset, val_dataset, test_dataset):
    train_counter = Counter([data.x.shape[0] for data in train_dataset])  # type: ignore
    val_counter = Counter([data.x.shape[0] for data in val_dataset])  # type: ignore
    test_counter = Counter([data.x.shape[0] for data in test_dataset])  # type: ignore

    sizes = set(train_counter + val_counter + test_counter)
    total_per_size = {size: train_counter[size] + val_counter[size] + test_counter[size] for size in sizes}
    train_splits_per_size = {size: round(train_counter.get(size, 0) / total_per_size[size], 2) for size in sizes}
    val_splits_per_size = {size: round(val_counter.get(size, 0) / total_per_size[size], 2) for size in sizes}
    test_splits_per_size = {size: round(test_counter.get(size, 0) / total_per_size[size], 2) for size in sizes}

    data = []
    for size in sorted(sizes):
        data.append([size, train_counter.get(size, 0), val_counter.get(size, 0), test_counter.get(size, 0),
                    train_splits_per_size[size], val_splits_per_size[size], test_splits_per_size[size]])

    df = pd.DataFrame(data, columns=["Size", "Train", "Val", "Test", "Train Split", "Val Split", "Test Split"])
    df.loc["Total"] = ["Total", train_counter.total(), val_counter.total(), test_counter.total(), "", "", ""]
    df = df.set_index("Size").T  # Transpose the dataframe

    df_str = df.to_string()
    lines = df_str.split('\n')
    separator = '-' * len(lines[0])
    lines.insert(1, separator)  # Insert separator after the first row
    lines.insert(5, separator)  # Insert separator after the fourth row
    lines.append(separator)  # Append separator at the end
    print('\n'.join(lines))

def load_dataset(root):
    # Set up parameters.
    seed = 42
    selected_graph_sizes = {
        # "26-50_mix_100": -1,
        "03-25_mix_750": -1

    }
    split = (0.6, 0.2)
    batch_size = 1

    # Get the dataset.
    loader = GraphDataset(selection=selected_graph_sizes, seed=seed)
    prob_dataset = MIDSProbabilitiesDataset(root / "Dataset", loader)
    labels_dataset = MIDSLabelsDataset(root / "Dataset", loader, selected_extra_feature="")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    prob_dataset, perm = prob_dataset.shuffle(return_perm=True)
    labels_dataset = labels_dataset.index_select(perm) # type: ignore
    assert isinstance(prob_dataset, MIDSDataset)
    assert isinstance(labels_dataset, MIDSDataset)

    # Flexible dataset splitting. Can be split to train/test or train/val/test.
    if isinstance(split, tuple):
        train_size, val_size = split
        train_size = round(train_size * len(prob_dataset))
        val_size = round(val_size * len(prob_dataset))
    else:
        train_size = round(split * len(prob_dataset))
        val_size = len(prob_dataset) - train_size

    test_probs = prob_dataset[train_size + val_size:]
    test_labels = labels_dataset[train_size + val_size:]
    train_labels = labels_dataset[:train_size]
    val_labels = labels_dataset[train_size:train_size + val_size]

    print_dataset_splits(train_labels, val_labels, test_labels)

    # Batch and load data.
    prob_loader = DataLoader(test_probs, batch_size, shuffle=False, pin_memory=True)  # type: ignore
    labels_loader = DataLoader(test_labels, batch_size, shuffle=False, pin_memory=True)  # type: ignore
    prob_data = [test_batch for test_batch in prob_loader]
    labels_data = [test_batch for test_batch in labels_loader]

    return prob_data, labels_data, labels_dataset.num_features

### Definitions for testing functions

In [3]:
def calc_confusion_matrix(preds, labels):
    # For each option in labels, count the number of matches and mismatches.
    best = -1
    for i in range(labels.shape[1]):
        if labels[0, i] == -1:
            break
        correct = (preds[:, 0] == labels[:, i]).sum()
        if correct > best:
            best = correct
            TP = ((preds[:, 0] == labels[:, i]) & (labels[:, i] == 1)).sum()
            TN = ((preds[:, 0] == labels[:, i]) & (labels[:, i] == 0)).sum()
            FP = ((preds[:, 0] != labels[:, i]) & (labels[:, i] == 0)).sum()
            FN = ((preds[:, 0] != labels[:, i]) & (labels[:, i] == 1)).sum()
    return TP, TN, FP, FN

def calc_hausdorff(A, preds, labels):
    """Calculate the Hausdorff distance between the predicted and true labels."""
    # Load graph from adjacency matrix.
    G = nx.from_numpy_array(A)
    n = A.shape[0]
    connected = nx.is_connected(G)

    # Convert supporting set to actual set.
    preds = [i for i, v in enumerate(preds) if v == 1]

    if len(preds) == 0:
        return sum(labels[:, 0]), connected  # If no predictions, return the number of nodes as distance.

    best_hausdorff_distance = n
    for i in range(labels.shape[1]):
        if labels[0, i] == -1:
            break

        # Convert supporting set to actual set.
        labs = [j for j, v in enumerate(labels[:, i]) if v == 1]

        # Calculate distances.
        all_distances = dict(nx.shortest_path_length(G))

        # Find the Hausdorff distance.
        hausdorff_distances = []
        for src in preds:
            min_distance = n
            for dst in labs:
                min_distance = min(min_distance, all_distances[src].get(dst, n))
            hausdorff_distances.append(min_distance)

        hausdorff = max(hausdorff_distances)
        if hausdorff < best_hausdorff_distance:
            best_hausdorff_distance = hausdorff

    return best_hausdorff_distance, connected

def calc_undominated(A, preds):
    """Calculate the number of undominated nodes in the predicted set."""
    n = A.shape[0]
    return n - np.sum((A + np.eye(n)) @ preds >= 1)


def calc_violated(A, preds):
    """Calculate the number of violated independence constraints."""
    n = A.shape[0]
    violated = 0
    for i in range(n):
        for j in range(i + 1, n):
            if preds[i] == 1 and preds[j] == 1 and A[i, j] == 1:
                violated += 1
    return violated


def calc_iou(preds, labels):
    """Calculate the Intersection over Union (IoU) between the predicted and true labels."""
    best_iou = 0
    for i in range(labels.shape[1]):
        if labels[0, i] == -1:
            break

        iou = np.sum(np.logical_and(preds[:, 0], labels[:, i])) / np.sum(np.logical_or(preds[:, 0], labels[:, i]))

        if iou > best_iou:
            best_iou = iou

    return best_iou

In [4]:
from Utilities.mids_utils import check_MIDS

def run_GNN(root, prob_data, label_data, num_features):
    device = "cpu"  # "cuda" if torch.cuda.is_available() else "cpu"

    # Load the probability model.
    prob_model = torch.load(root / "Models" / "prob_model_best.pth")
    prob_model.to(device)
    prob_model.eval()

    model_id = "5udn5jrt"  # Input model ID here.
    api = wandb.Api()
    run = api.run(f"/LARICS-GNN/MIDS-GNN/runs/{model_id}")

    # Load the model.
    model_kwargs = {}
    if "GIN" in run.config["architecture"]:
        model_kwargs = {"train_eps": True}
    elif "GAT" in run.config["architecture"]:
        model_kwargs = {"v2": True}

    label_model = generate_model(
        run.config["architecture"],
        num_features,
        run.config["hidden_channels"],
        run.config["gnn_layers"],
        act=run.config["activation"],
        jk=run.config["jk"] if run.config["jk"] != "none" else None,
        **model_kwargs
    )
    saved_model_dict = torch.load(root / "Models" / f"{model_id}_best_model.pth", weights_only=False)
    label_model.load_state_dict(saved_model_dict["model_state_dict"])
    label_model.to(device)
    label_model.eval()

    records = []
    # Calculate the execution time on each data example
    for i in trange(len(prob_data), leave=False, desc=run.config["architecture"]):
        start = time.perf_counter()
        # First, predict probabilities.
        # example = prob_data[i].to(device)
        # out = prob_model(example.x, example.edge_index)

        # Second, predict labels.
        example = label_data[i].to(device)
        out = label_model(example.x, example.edge_index)
        end = time.perf_counter()

        # Third, check the MIDS.
        pred = torch.where(out > 0, 1.0, 0.0).numpy()
        A = tg_utils.to_dense_adj(example.edge_index).squeeze().cpu().numpy()
        try:
            correct = int(check_MIDS(A, pred, sum(example.y[:, 0].numpy())))
        except ValueError:
            continue

        # Forth, calculate detailed metrics.
        TP, TN, FP, FN = calc_confusion_matrix(pred, example.y.numpy())
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (TP + TN) / (TP + FP + FN + TN) if (TP + FP + FN + TN) > 0 else 0

        hausdorff, connected = calc_hausdorff(A, pred, example.y.numpy())
        undominated = calc_undominated(A, pred)
        violated = calc_violated(A, pred)
        IOU = calc_iou(pred, example.y.numpy())

        if not connected:
            continue

        records.append({
            "model": run.config["architecture"],
            "num_nodes": example.num_nodes,
            "num_edges": example.num_edges,
            "connected": connected,
            "correct": correct,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "hausdorff": hausdorff,
            "IOU": IOU,
            "undominated": undominated,
            "violated": violated,
            "ratio": (TP + FP) / (TP + FN),
            "execution_time": (end - start) * 1000
        })

    return pd.DataFrame(records)

### Save and load

In [None]:
# Save the results
with open(root / "Results" / "results_detailed.pkl", "wb") as f:
    pickle.dump(GNN_results, f)

In [6]:
with open(root / "Results" / "results_detailed.pkl", "rb") as f:
    GNN_results = pickle.load(f)

### Run experiments

In [None]:

# Load the dataset.
prob_data, label_data, num_features = load_dataset(root)

In [None]:
overall_results = dict()
by_nodesize_results = dict()

# Experiment for GNN models
print("Running GNN models")
GNN_results = run_GNN(root, prob_data, label_data, num_features)


## Analyze results

In [7]:
GNN_results = GNN_results.loc[GNN_results["correct"] == 0]

In [8]:
import plotly.express as px


# Calculate additional metrics.
GNN_results["non-dominating"] = GNN_results.apply(lambda row: row["undominated"] > 0, axis=1).astype(int)
GNN_results["non-independent"] = GNN_results.apply(lambda row: row["violated"] > 0, axis=1).astype(int)
GNN_results["IDS"] = GNN_results.apply(lambda row: (row["correct"] == 0) and (row["undominated"] == 0) and (row["violated"] == 0), axis=1).astype(int)

# Print summary statistics
num_correct = (GNN_results["correct"] == 1).sum()
avg_accuracy = GNN_results["accuracy"].mean()
avg_precision = GNN_results["precision"].mean()
avg_recall = GNN_results["recall"].mean()
avg_f1 = GNN_results["f1"].mean()
avg_hausdorff = GNN_results["hausdorff"].mean()
avg_IOU = GNN_results["IOU"].mean()
avg_ratio = GNN_results["ratio"].mean()

print(f"Average accuracy: {avg_accuracy:.3f}")
print(f"Average precision: {avg_precision:.3f}")
print(f"Average recall: {avg_recall:.3f}")
print(f"Average f1: {avg_f1:.3f}")
print(f"Average hausdorff: {avg_hausdorff:.3f}")
print(f"Average IOU: {avg_IOU:.3f}")
print(f"Average ratio: {avg_ratio:.3f}")
print(f"Non-dominating: {sum(GNN_results['undominated'] > 0)}/{len(GNN_results)} ({sum(GNN_results['undominated'] > 0) / len(GNN_results) * 100:.2f}%)")
print(f"Non-independent: {sum(GNN_results['violated'] > 0)}/{len(GNN_results)} ({sum(GNN_results['violated'] > 0) / len(GNN_results) * 100:.2f}%)")
count_ids = sum(GNN_results['IDS'] == 1)
print(f"MIDS: {sum(GNN_results['correct'] == 1)}/{len(GNN_results)} ({sum(GNN_results['correct'] == 1) / len(GNN_results) * 100:.2f}%)")
print(f"IDS: {count_ids}/{len(GNN_results)} ({count_ids / len(GNN_results) * 100:.2f}%)")
print(f"Unconnected graphs: {sum(GNN_results['connected'] == False)}/{len(GNN_results)} ({sum(GNN_results['connected'] == False) / len(GNN_results) * 100:.2f}%)")

metrics = ["accuracy", "precision", "recall", "f1", "hausdorff", "IOU", "ratio", "undominated", "violated"]

for metric in metrics:
    fig = px.box(
        GNN_results,
        x="num_nodes",
        y=metric,
        title=f"{metric.capitalize()} by Number of Nodes",
        labels={"num_nodes": "Number of Nodes", metric: metric.capitalize()},
    )
    fig.update_traces(boxmean=True)
    fig.show()

metrics = ["non-dominating", "non-independent", "IDS", "correct"]
for m in metrics:
    grouped = GNN_results.groupby("num_nodes")[m]
    count_ones = grouped.sum()
    avg_value = grouped.mean()
    total_examples = grouped.count()

    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=count_ones.index,
        y=count_ones.values,
        name=f'Count of {m} graphs',
        marker_color='indianred',
        yaxis='y1'
    ))
    fig.add_trace(go.Scatter(
        x=avg_value.index,
        y=avg_value.values,
        name=f'Average {m}',
        mode='lines+markers',
        marker_color='blue',
        yaxis='y2'
    ))
    fig.update_layout(
        title=f"{m.replace('-', ' ').capitalize()} by Number of Nodes",
        xaxis_title="Number of Nodes",
        yaxis=dict(
            title=f"Count of {m} graphs",
            side='left'
        ),
        yaxis2=dict(
            title=f"Average {m}",
            overlaying='y',
            side='right'
        ),
        legend=dict(x=0.01, y=0.99)
    )
    fig.show()

Average accuracy: 0.862
Average precision: 0.555
Average recall: 0.408
Average f1: 0.450
Average hausdorff: 0.842
Average IOU: 0.341
Average ratio: 0.685
Non-dominating: 4222/4959 (85.14%)
Non-independent: 433/4959 (8.73%)
MIDS: 0/4959 (0.00%)
IDS: 370/4959 (7.46%)
Unconnected graphs: 0/4959 (0.00%)
