In [1]:
import sys
sys.path.append('../src')

import pickle

import torch
import torch.nn.functional as F
import random

from cf_explainer.gcn_conv import GCNConv
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
import torch_geometric
from torch_geometric.explain import Explainer
from torch_geometric.nn import SAGEConv, GATConv, GINConv, GIN
from cf_explainer import C2Explainer
from cf_explainer.utils import seed_everything

import numpy as np

import networkx as nx
from pyvis.network import Network

from tqdm.auto import tqdm

import pickle

'''Config parameters'''
use_cuda_if_available = True
device = torch.device('cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')

# seed_everything(42, deterministic=True)
# if error when setting use_deterministic_algorithms(True)
# try this:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

print("PyTorch version:", torch.__version__)
print("PyTorch device:", device)

env: CUBLAS_WORKSPACE_CONFIG=:4096:8
PyTorch version: 2.0.1
PyTorch device: cuda


In [2]:
def results(num_perturbs, prop_perturbs):
    print("######")
    if len(num_perturbs) != 0:
        size = sum(num_perturbs)/(2*len(num_perturbs))
        prop = sum(prop_perturbs)/len(prop_perturbs)
    else:
        size = "N/A"
        prop = "N/A"
    print(f"size: {size}, num_success: {len(num_perturbs)}, prop_perturbs: {prop}")
    print("finished")
    return size, prop


def explain(model, data, explainer, seed=42):
    seed_everything(seed, deterministic=True)
    result = []
    
    condition = (data.test_mask.cpu() | data.val_mask.cpu())
    df_cf = np.where(condition)[0].tolist()

    print(len(df_cf))
    
    explainer = Explainer(
        model=model,
        algorithm=explainer,
        explanation_type='model',
        node_mask_type=None,
        edge_mask_type='object',
        model_config=dict(
            mode='multiclass_classification',
            task_level='node',
            return_type='raw',
        ), 
    )
    
    cfs = []
    num_perturbs = []
    prop_perturbs = []
    indices = []

    for i in tqdm(df_cf):
        explanation = explainer(data.x, data.edge_index, index=i)
        
        if hasattr(explanation, "perturbs"):
            if explanation.perturbs < 20:
                cfs.append(explanation.cf)
                num_perturbs.append(explanation.perturbs)
                prop_perturbs.append(explanation.prop_perturbs)
                indices.append(i)
    
    
    size, prop = results(num_perturbs, prop_perturbs)
    
    result.append([len(df_cf), size, len(num_perturbs), prop])
    print(f"Fedility: {len(num_perturbs)/len(df_cf)}, Num_perturbs: {size}, Similarity: {1-prop}")
    return result, cfs, num_perturbs, indices

# LoanDecision

In [3]:
class GCN(torch.nn.Module):
    def __init__(self, nhid, nout, dropout):
        super().__init__()
        self.conv1 = GCNConv(2, nhid, normalize=True)
        self.conv2 = GCNConv(nhid, nhid, normalize=True)
        self.conv3 = GCNConv(nhid, nout, normalize=True)
        self.lin1 = Linear(nout, nout)
        self.lin2 = Linear(nout,2)
        self.dropout = dropout

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        return x
    
model = GCN(nhid=100, nout=20, dropout=0).to(device)
model.load_state_dict(torch.load("../models/GCN_LoanDecision_sd.pt", weights_only=True))

with open("../data/LoanDecision.pickle", "rb") as f:
	data = pickle.load(f)
    
data

  return self.fget.__get__(instance, owner)()


Data(edge_index=[2, 3950], num_nodes=1000, x=[1000, 2], y=[1000], train_mask=[1000], val_mask=[1000], test_mask=[1000])

In [4]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, silent_mode=True, undirected=True, AR_mode=True, FPM=True)

result1, cfs1, num_perturbs1, indices1 = explain(model, data, explainer, seed=42)

200


  0%|          | 0/200 [00:00<?, ?it/s]

######
size: 0.9169603586196899, num_success: 198, prop_perturbs: 0.005527470260858536
finished
Fedility: 0.99, Num_perturbs: 0.9169603586196899, Similarity: 0.9944725036621094
CPU times: user 20min 31s, sys: 1.24 s, total: 20min 32s
Wall time: 20min 46s


In [5]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, silent_mode=True, undirected=True, AR_mode=True)

result2, cfs2, num_perturbs2, indices2 = explain(model, data, explainer, seed=42)

200


  0%|          | 0/200 [00:00<?, ?it/s]

######
size: 1.425414364640884, num_success: 181, prop_perturbs: 0.007844492793083191
finished
Fedility: 0.905, Num_perturbs: 1.425414364640884, Similarity: 0.9921554923057556
CPU times: user 19min 55s, sys: 1.25 s, total: 19min 57s
Wall time: 20min 10s


In [6]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, silent_mode=True, undirected=True, AR_mode=False)

result3, cfs3, num_perturbs3, indices3 = explain(model, data, explainer, seed=42)

200


  0%|          | 0/200 [00:00<?, ?it/s]

######
size: 1.2043010752688172, num_success: 186, prop_perturbs: 0.006940612103790045
finished
Fedility: 0.93, Num_perturbs: 1.2043010752688172, Similarity: 0.9930593967437744
CPU times: user 46min 25s, sys: 1.43 s, total: 46min 26s
Wall time: 46min 24s


In [7]:
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.utils import to_undirected, coalesce

def remove_edges(edge_index, edge_index_to_remove):
        r"""
        remove edges in edge_index that are also in edge_index_to_remove.
        """
        # Trick from https://github.com/pyg-team/pytorch_geometric/discussions/9440
        all_edge_index = torch.cat([edge_index,
                                    edge_index_to_remove], dim=1)

        # mark removed edges as 1 and 0 otherwise
        all_edge_weights = torch.cat([torch.zeros(edge_index.size(1)),
                                      torch.ones(edge_index_to_remove.size(1))]
                                     ).to(all_edge_index.device)

        all_edge_index, all_edge_weights = coalesce(
            all_edge_index, all_edge_weights)

        # remove edges indicated by 1
        edge_index = all_edge_index[:, all_edge_weights == 0]
        return edge_index

def isAR(edge_index1, edge_index2, num_perturbs, index):
    subset, _, _, hard_edge_mask1 = k_hop_subgraph(
                index,
                num_hops=3,
                edge_index=edge_index1,
                relabel_nodes=False)
    
    edge_index1 = edge_index1[:, hard_edge_mask1] # edges
    
    a = subset
    b = torch.tensor([index]).to(subset.device)
    AR_edge_index1 = torch.cartesian_prod(a, b).T.to(int)
    AR_edge_index1 = to_undirected(AR_edge_index1)
    AR_edge_index1 = remove_edges(AR_edge_index1, edge_index2)
    
    edges1 = edge_index1.t()
    AR_edges1 = AR_edge_index1.t()
    
    # Find unique edges in edge_index1
    unique_edges1 = torch.empty((0, 2), dtype=torch.long).to(edge_index1.device)
    for edge in edges1:
        if any((edge == AR_edges1).all(dim=1)):
            unique_edges1 = torch.cat((unique_edges1, edge.unsqueeze(0)), dim=0)
            
            
    #======#
    
    _, _, _, hard_edge_mask1 = k_hop_subgraph(
                index,
                num_hops=1,
                edge_index=edge_index1,
                relabel_nodes=False)

    _, _, _, hard_edge_mask2 = k_hop_subgraph(
                    index,
                    num_hops=1,
                    edge_index=edge_index2,
                    relabel_nodes=False)
    
    edge_index1 = edge_index1[:, hard_edge_mask1] 
    edge_index2 = edge_index2[:, hard_edge_mask2]    
    
    edges1 = edge_index1.t()
    edges2 = edge_index2.t()

    # Find unique edges in edge_index2
    unique_edges2 = torch.empty((0, 2), dtype=torch.long).to(edge_index2.device)
    for edge in edges2:
        if not any((edge == edges1).all(dim=1)):
            unique_edges2 = torch.cat((unique_edges2, edge.unsqueeze(0)), dim=0)

    # Transpose to get back to edge_index format
    unique_edge_index1 = unique_edges1.t()
    unique_edge_index2 = unique_edges2.t()
    
    if unique_edge_index1.size(1)==0 & unique_edge_index2.size(1)==0:
        return 0
    elif unique_edge_index1.size(1)+unique_edge_index2.size(1)==num_perturbs[i]: # all edges perturbed are in hop-1
        return 1
    else:
        return 0

#     print("Edges in added:")
#     print(unique_edge_index1)

#     print("Edges in deleted:")
#     print(unique_edge_index2)
    

In [8]:
# num_AR=0
# for i, index in enumerate(indices1):
#     a = isAR(cfs1[i], data.edge_index, num_perturbs1, index) # isAR ii not suitable for counterfactuals with feature perturbations
#     # print(a)
#     num_AR+=a
#     # print("======")
# print("num_AR", num_AR)

## isAR ii not suitable for counterfactuals with feature perturbations, we calculate the AR_val manually for it.

In [9]:
num_AR=0
for i, index in enumerate(indices2):
    a = isAR(cfs2[i], data.edge_index, num_perturbs2, index)
    # print(a)
    num_AR+=a
    # print("======")
print("num_AR", num_AR)
print("AR_val", num_AR/200)

num_AR 181
AR_val 0.905


In [10]:
num_AR=0
for i, index in enumerate(indices3):
    a = isAR(cfs3[i], data.edge_index, num_perturbs3, index)
    # print(a)
    num_AR+=a
    # print("======")
print("num_AR", num_AR/200)

num_AR 0.73
