In [None]:
from Config.config import CONFIG
CONFIG = CONFIG("MOOC")

In [None]:
from DyGLib.models.GraphMixer import GraphMixer
from DyGLib.models.TGAT import TGAT
from DyGLib.models.TCL import TCL
from DyGLib.models.CAWN import CAWN
from DyGLib.models.DyGFormer import DyGFormer
from DyGLib.models.MemoryModel import MemoryModel, compute_src_dst_node_time_shifts

from DyGLib.models.modules import TGNN, NeuralNetworkSrcDst, BatchSubgraphs
from DyGLib.utils.DataLoader import get_link_prediction_data
from DyGLib.utils.utils import get_neighbor_sampler, NegativeEdgeSampler

import torch
import numpy as np
import pandas as pd
import seaborn as sns

import random

import graphviz
from IPython.display import SVG
import time

# Initialization

In [None]:
trained_model_path = CONFIG.model.trained_model_path
edge_feat_path = CONFIG.data.folder + CONFIG.data.edge_feat_file
node_feat_path = CONFIG.data.folder + CONFIG.data.node_feat_file
index_path = CONFIG.data.folder + CONFIG.data.index_file
feature_names_path = CONFIG.data.folder + CONFIG.data.feature_names_file

In [None]:
# get data for training, validation and testing
node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data = \
    get_link_prediction_data(val_ratio=0.1, test_ratio=0.1, node_dim=CONFIG.model.node_dim)

# initialize validation and test neighbor sampler to retrieve temporal graph
full_neighbor_sampler = get_neighbor_sampler(data=full_data, edge_features=edge_raw_features, sample_neighbor_strategy=CONFIG.model.sample_neighbor_strategy,
                                                time_scaling_factor=CONFIG.model.time_scaling_factor, seed=1)
train_neighbor_sampler = get_neighbor_sampler(data=train_data, edge_features=edge_raw_features, sample_neighbor_strategy=CONFIG.model.sample_neighbor_strategy,
                                                time_scaling_factor=CONFIG.model.time_scaling_factor, seed=1)

# create model
if CONFIG.model.model_name == 'TGAT':
    dynamic_backbone = TGAT(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                            time_feat_dim=CONFIG.model.time_feat_dim, num_layers=CONFIG.model.num_layers, num_heads=CONFIG.model.num_heads, dropout=CONFIG.model.dropout, device=CONFIG.model.device)
elif CONFIG.model.model_name in ['JODIE', 'DyRep', 'TGN']:
    # four floats that represent the mean and standard deviation of source and destination node time shifts in the training data, which is used for JODIE
    src_node_mean_time_shift, src_node_std_time_shift, dst_node_mean_time_shift_dst, dst_node_std_time_shift = \
        compute_src_dst_node_time_shifts(train_data.src_node_ids, train_data.dst_node_ids, train_data.node_interact_times)
    dynamic_backbone = MemoryModel(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                                    time_feat_dim=CONFIG.model.time_feat_dim, model_name=CONFIG.model.model_name, num_layers=CONFIG.model.num_layers, num_heads=CONFIG.model.num_heads,
                                    dropout=CONFIG.model.dropout, src_node_mean_time_shift=src_node_mean_time_shift, src_node_std_time_shift=src_node_std_time_shift,
                                    dst_node_mean_time_shift_dst=dst_node_mean_time_shift_dst, dst_node_std_time_shift=dst_node_std_time_shift, device=CONFIG.model.device)
elif CONFIG.model.model_name == 'CAWN':
    dynamic_backbone = CAWN(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                            time_feat_dim=CONFIG.model.time_feat_dim, position_feat_dim=CONFIG.model.position_feat_dim, walk_length=CONFIG.model.walk_length,
                            num_walk_heads=CONFIG.model.num_walk_heads, dropout=CONFIG.model.dropout, device=CONFIG.model.device)
elif CONFIG.model.model_name == 'TCL':
    dynamic_backbone = TCL(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                            time_feat_dim=CONFIG.model.time_feat_dim, num_layers=CONFIG.model.num_layers, num_heads=CONFIG.model.num_heads,
                            num_depths=CONFIG.model.num_neighbors + 1, dropout=CONFIG.model.dropout, device=CONFIG.model.device)
elif CONFIG.model.model_name == 'GraphMixer':
    dynamic_backbone = GraphMixer(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                            time_feat_dim=CONFIG.model.time_feat_dim, num_tokens=CONFIG.model.num_neighbors, num_layers=CONFIG.model.num_layers, dropout=CONFIG.model.dropout, device=CONFIG.model.device)
elif CONFIG.model.model_name == 'DyGFormer':
    dynamic_backbone = DyGFormer(num_nodes=node_raw_features.shape[0], node_dim=node_raw_features.shape[1], edge_dim=edge_raw_features.shape[1],
                                    time_feat_dim=CONFIG.model.time_feat_dim, channel_embedding_dim=CONFIG.model.channel_embedding_dim, patch_size=CONFIG.model.patch_size,
                                    num_layers=CONFIG.model.num_layers, num_heads=CONFIG.model.num_heads, dropout=CONFIG.model.dropout,
                                    max_input_sequence_length=CONFIG.model.max_input_sequence_length, device=CONFIG.model.device)
else:
    raise ValueError(f"Wrong value for model_name {CONFIG.model.model_name}!")

regressor = NeuralNetworkSrcDst(input_dim=node_raw_features.shape[1], num_layers=CONFIG.model.num_reg_layers, hidden_dim=CONFIG.model.hidden_reg_layers_dim)
model = TGNN(dynamic_backbone, regressor)

In [None]:
model.load_state_dict(torch.load(trained_model_path, weights_only=True))
model.to(CONFIG.model.device)
model.eval()

## Select edges

In [None]:
num_samples = 200

In [None]:
def get_edge_by_id(link_index):
    src, dst, time_stamp, edge_id, true_value = full_data.src_node_ids[link_index], full_data.dst_node_ids[link_index], full_data.node_interact_times[link_index], full_data.edge_ids[link_index], 1 # type: ignore
    return src, dst, time_stamp, edge_id, true_value

In [None]:
random.seed(2025)

sampled_edge_ids = random.sample((np.where((~np.isnan(full_data.labels)) & (~np.isin(full_data.edge_ids, train_data.edge_ids)))[0]).tolist(), num_samples)

edge_info_array = np.array([list(get_edge_by_id(i)) for i in sampled_edge_ids])
edge_info = pd.DataFrame(edge_info_array, columns=["Src", "Dst", "Time", "Edge", "Target"])

edges = edge_info["Edge"].to_numpy(dtype=int)

edge_info["InTrain"] = np.isin(edges, train_data.edge_ids)
edge_info = edge_info.sort_values(by="InTrain").reset_index(drop=True)
edge_info = edge_info[edge_info.InTrain == False]

srcs = edge_info["Src"].to_numpy(dtype=int)
dsts = edge_info["Dst"].to_numpy(dtype=int)
timestamps = edge_info["Time"].to_numpy(dtype="float32")
targets = edge_info["Target"].to_numpy(dtype="float32")

In [None]:
model.eval()

subgraphs_src = full_neighbor_sampler.get_multi_hop_neighbors(CONFIG.model.num_layers, srcs, timestamps, num_neighbors = CONFIG.model.num_neighbors)
subgraphs_dst = full_neighbor_sampler.get_multi_hop_neighbors(CONFIG.model.num_layers, dsts, timestamps, num_neighbors = CONFIG.model.num_neighbors)
edge_feat_src = full_neighbor_sampler.get_edge_features_for_multi_hop(subgraphs_src[1])
edge_feat_dst = full_neighbor_sampler.get_edge_features_for_multi_hop(subgraphs_dst[1])

subgraphs_src = BatchSubgraphs(*subgraphs_src, edge_feat_src)
subgraphs_src.to(CONFIG.model.device)
subgraphs_dst = BatchSubgraphs(*subgraphs_dst, edge_feat_dst)
subgraphs_dst.to(CONFIG.model.device)

predicts = model(src_node_ids=srcs,
                dst_node_ids=dsts,
                node_interact_times=timestamps,
                src_subgraphs = subgraphs_src,
                dst_subgraphs = subgraphs_dst,
                time_gap=CONFIG.model.time_gap,
                edges_are_positive=True).squeeze(dim=-1).sigmoid()

edge_info["Prediction"] = predicts.detach().cpu().numpy()
edge_info

# Shapley Value Approximations Comparison

In [None]:
import matplotlib.pyplot as plt
from Config.colors import PRIMARYCOLOR, PALLETTE, PALLETTE2

In [None]:
from tqdm import tqdm
from Explainers.Shapley4TGNN.Explainer import ShapleyExplainerEvents
from Explainers.Shapley4TGNN.Explainer import ShapleyExplainerFeatures

explainerEdge = ShapleyExplainerEvents(model, full_neighbor_sampler, full_data, edge_raw_features)
explainerFeatMonteCarlo = ShapleyExplainerFeatures(model, full_neighbor_sampler, full_data, edge_raw_features, None, shapley_alg="MonteCarlo", top_k=4)
explainerFeatPermutation = ShapleyExplainerFeatures(model, full_neighbor_sampler, full_data, edge_raw_features, None, shapley_alg="Permutation", top_k=4)

explainerEdge.initialize()
explainerFeatMonteCarlo.initialize()
explainerFeatPermutation.initialize()

dists12 = []
dists13 = []
dists23 = []

timings_shap = []

for src, dst, timestamp in tqdm(zip(srcs, dsts, timestamps)):   
    shap1 = explainerEdge.explain_instance(src,dst,timestamp, silent=True)   
    shap1_df = pd.DataFrame(np.concat((shap1[1].data, shap1[1].values)).T, columns=["EdgeID", "Value1"])
    shap1_df["EdgeID"] = shap1_df["EdgeID"].astype("int")

    start = time.time_ns()
    shap2, _, _, _ = explainerFeatMonteCarlo.explain_instance(src,dst,timestamp, silent=True)
    end = time.time_ns()
    timings_shap.append([end-start, "Monte Carlo"])
    shap2_df = pd.DataFrame(np.array(shap2)[:,[0,-1]], columns=["EdgeID", "Value2"])
    shap2_df["Value2"] = shap2_df["Value2"].astype("float")
    shap2_df["EdgeID"] = shap2_df["EdgeID"].astype("int")
    shap2_df = shap2_df.groupby("EdgeID").sum().reset_index()

    start = time.time_ns()
    shap3, _, _, _ = explainerFeatPermutation.explain_instance(src,dst,timestamp, silent=True)
    end = time.time_ns()
    timings_shap.append([end-start, "Permutation"])
    shap3_df = pd.DataFrame(np.array(shap3)[:,[0,-1]], columns=["EdgeID", "Value3"])
    shap3_df["Value3"] = shap3_df["Value3"].astype("float")
    shap3_df["EdgeID"] = shap3_df["EdgeID"].astype("int")
    shap3_df = shap3_df.groupby("EdgeID").sum().reset_index()

    df = pd.merge(shap1_df, shap2_df, on="EdgeID")
    df = pd.merge(df, shap3_df, on="EdgeID")
    dists12.extend(df["Value1"]-df["Value2"])
    dists13.extend(df["Value1"]-df["Value3"])
    dists23.extend(df["Value2"]-df["Value3"])


df_dist12 = pd.DataFrame(dists12, columns=["Distance"])
df_dist12["Comparison"] = "Edge vs. Feat. (Monte Carlo)"

df_dist13 = pd.DataFrame(dists13, columns=["Distance"])
df_dist13["Comparison"] = "Edge vs. Feat. (Permutation)"

df_dist23 = pd.DataFrame(dists13, columns=["Distance"])
df_dist23["Comparison"] = "Feat. (Monte Carlo) vs. Feat. (Permutation)"

df = pd.concat([df_dist12, df_dist13, df_dist23])
df_timings_shap = pd.DataFrame(timings_shap, columns=["Timing", "Method"])
df_timings_shap["Timing"] = df_timings_shap["Timing"]/1000000

In [None]:
boxplt = sns.boxplot(df, x = "Comparison", y="Distance", color=PRIMARYCOLOR)
plt.xticks(rotation=20, horizontalalignment='right')
plt.xlabel("")
plt.savefig("Documents/Images/MOOC/Distances_approx_method.png", bbox_inches='tight')

In [None]:
boxplt = sns.boxplot(df_timings_shap, x  = "Method", y="Timing", color=PRIMARYCOLOR)
plt.xlabel("")
plt.ylabel("Time (in ms)")
plt.savefig("Documents/Images/MOOC/Timings_approx_method.png", bbox_inches='tight')