In [1]:
from torchtext import datasets, data
import matplotlib.pyplot as plt
import numpy as np
import os, sys
from time import time

from neural_interaction_detection import *
from sampling_and_inference import *
from utils.general_utils import *
from utils.graph_utils import *

%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0")

## Load Model

In [2]:
model_folder = "utils/pretrained"

model, n_nodes, n_hops, test_idxs = get_graph_model(model_folder)
model = model.to(device)

## Classify Graph

In [3]:
data_folder = "utils/data/cora"

node_feats, adj_mat, labels = load_cora(data_folder, device)

In [4]:
target_idx = test_idxs[0]

preds = model(node_feats, convert_adj_to_da(adj_mat))
classification = torch.argmax(preds, 1).cpu().numpy()[target_idx] 
print("target node classification:", classification)

target node classification: 6


## Run MADEX

In [5]:
data_inst = {"nodes": node_feats, "edges": adj_mat, "test_idxs": test_idxs}
Xs, Ys = generate_perturbation_dataset_graph(data_inst, model, target_idx, n_hops+1, device, seed=42, std_scale=False)

100%|██████████| 6000/6000 [01:40<00:00, 59.72it/s]


In [6]:
t0 = time()
interactions, mlp_loss = detect_interactions(Xs, Ys, weight_samples=True, seed=42, verbose=False)
print("{} test loss, {} seconds elapsed".format(round(mlp_loss, 4), round(time() - t0, 1)))

19.4754 test loss, 94.2 seconds elapsed


## Show Main Effects and Interaction Interpretations

In [7]:
node_to_hop = get_hops_to_target(target_idx, adj_mat, n_hops)
local_map = data_inst["local_idx_map"]

print("legend: (hops from target node, node idx). All hops should be within n_hops:", n_hops)

print("\ntarget", (0, target_idx))
print("\nmain effects")
for uni, att in get_lime_attributions(Xs, Ys)[:5]:
    if att > 0:
        print((node_to_hop[local_map[uni]],local_map[uni]))
print("\ninteractions")
for i, inter in enumerate(interactions[:5]):
    print(len(inter[0]))
    print("inter {}:".format(i), tuple((node_to_hop[local_map[n]],local_map[n]) for n in inter[0]))


legend: (hops from target node, node idx). All hops should be within n_hops: 3

target (0, 1808)

main effects
(2, 722)
(2, 2465)
(2, 264)
(2, 1189)
(2, 2146)

interactions
2
inter 0: ((1, 638), (2, 722))
4
inter 1: ((2, 264), (1, 638), (2, 722), (2, 2465))
5
inter 2: ((2, 264), (1, 638), (2, 722), (2, 1189), (2, 2465))
6
inter 3: ((2, 264), (1, 638), (2, 722), (2, 1189), (2, 2146), (2, 2465))
9
inter 4: ((2, 264), (2, 294), (2, 296), (1, 638), (2, 722), (2, 1189), (2, 1327), (2, 2146), (2, 2465))
