In [1]:
import numpy as np
import torch
from torch.nn.functional import cosine_similarity
import matplotlib.pyplot as plt
from tqdm import tqdm

node_directory = f'../../../data/jesse/additional_exp/'
dataset = 'Cora'
gnn_type = 'SGC'
min_layers = 1
max_layers = 5
torch.manual_seed(0)

<torch._C.Generator at 0x7f2a84ba8690>

In [2]:
# Input-level similarities to 1{Grad > 0}
means = []
for layer in range(min_layers, max_layers + 1):
    gi_expls = torch.vstack(torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_gi_preds.pt',
                                    map_location = 'cpu'))
    occ_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_occ_preds.pt',
                                    map_location = 'cpu')
    gnn_expls = torch.vstack(torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_gnnexplainer_preds.pt', 
                                    map_location = 'cpu'))
    rand_expls = torch.vstack(torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_rand_preds.pt',
                                        map_location = 'cpu'))
    # full_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_full_preds.pt',
    #                                     map_location = 'cpu')
    inds_to_compare = torch.any(gi_expls > 0, axis = 1)
    gi_pos = (gi_expls > 0)[inds_to_compare]
    means.append([
        cosine_similarity(gnn_expls[inds_to_compare], gi_pos).mean(),
        cosine_similarity(gi_expls[inds_to_compare], gi_pos).mean(),
        cosine_similarity(occ_expls[inds_to_compare], gi_pos).mean(),
        cosine_similarity(rand_expls[inds_to_compare], gi_pos).mean(),
        # cosine_similarity(full_expls[inds_to_compare], gi_pos).mean(),
    ])
means = np.array(means)
print(means)

[[0.965144   0.8906654  0.8906654  0.01520702]
 [0.892093   0.5467398  0.5467399  0.04867632]
 [0.851444   0.34041625 0.34545377 0.10266495]
 [0.81744975 0.22560844 0.23079652 0.18101051]
 [0.77889395 0.15996967 0.16458367 0.28231594]]


In [3]:
means = []
for layer in range(min_layers, max_layers + 1):
    l_gi_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_layerwise_grad_preds.pt', map_location = 'cpu')
    l_occ_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_layerwise_occ_preds.pt', map_location = 'cpu')
    l_gnn_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_gnnexplainer_layerwise_preds.pt', map_location = 'cpu')
    l_full_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{layer}_layers_full_preds.pt', map_location = 'cpu')
    sims = []
    for gi_expl, occ_expl, gnn_expl, full_expl in zip(l_gi_expls, l_occ_expls, l_gnn_expls, l_full_expls):
        rand_expl = torch.rand_like(full_expl)
        if torch.any(occ_expl != 0):
            sims.append([
                cosine_similarity(gnn_expl.flatten(), occ_expl.flatten(), 0),
                cosine_similarity(gi_expl.flatten(), occ_expl.flatten(), 0),
                cosine_similarity(rand_expl.flatten(), occ_expl.flatten(), 0),
                cosine_similarity(full_expl.flatten(), occ_expl.flatten(), 0),
            ])
    sims = np.array(sims)
    means.append(sims.mean(0))
means = np.array(means)
print(means)

[[0.7591908  1.         0.01197875 0.47585166]
 [0.5412214  1.         0.01759271 0.28255373]
 [0.39997467 1.         0.02078626 0.19767343]
 [0.44111547 1.         0.02337404 0.13589779]
 [0.45087516 1.         0.0261202  0.09893397]]


In [None]:
means = []
for num_layers in range(min_layers, max_layers + 1):
    # gi_expls = torch.load(node_directory + dataset + f'/{gnn_type}_{num_layers}_layers_gi_preds.pt', map_location = 'cpu')
    flips = torch.load(node_directory + dataset + f'/{gnn_type}_{num_layers}_grad_flips.pt', map_location = 'cpu')
    pos_to_neg = torch.load(node_directory + dataset + f'/{gnn_type}_{num_layers}_pos_to_neg_flips.pt', map_location = 'cpu')
    neg_to_pos = torch.load(node_directory + dataset + f'/{gnn_type}_{num_layers}_neg_to_pos_flips.pt', map_location = 'cpu')
    mask = torch.where(gi_expls != 0.)[0]
    means.append([float((flips.sum(1) / (gi_expls != 0.).sum(1))[mask].mean()),
                float((pos_to_neg.sum(1) / (gi_expls != 0.).sum(1))[mask].mean()),
                float((neg_to_pos.sum(1) / (gi_expls != 0.).sum(1))[mask].mean())
                ])
print(means)

[[0.0, 0.0, 0.0], [1.4870525774313137e-05, 1.2246315236552618e-05, 2.6242100830131676e-06], [0.0003175294550601393, 0.00018806841399054974, 0.00012946105562150478], [0.001366338925436139, 0.0008152547525241971, 0.000551084172911942], [0.005478476174175739, 0.0031385556794703007, 0.002339920960366726]]
