In [1]:
import sys
sys.path.append('../src')

import pickle

import torch
import torch.nn.functional as F
import random

from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
import torch_geometric
from torch_geometric.explain import Explainer
from torch_geometric.nn import SAGEConv, GATConv, GINConv, GIN, GCNConv
from cf_explainer import C2Explainer
from torch_geometric.nn import global_mean_pool,  global_max_pool
from cf_explainer.utils import seed_everything

import numpy as np

# import networkx as nx
# from pyvis.network import Network

from tqdm.auto import tqdm

import pickle

'''Config parameters'''
use_cuda_if_available = True
device = torch.device('cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')

# seed_everything(42, deterministic=True)
# if error when setting use_deterministic_algorithms(True)
# try this:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

print("PyTorch version:", torch.__version__)
print("PyTorch device:", device)

env: CUBLAS_WORKSPACE_CONFIG=:4096:8
PyTorch version: 2.0.1
PyTorch device: cuda


In [2]:
def results(num_perturbs, prop_perturbs):
    print("######")
    if len(num_perturbs) != 0:
        size = sum(num_perturbs)/(2*len(num_perturbs))
        prop = sum(prop_perturbs)/len(prop_perturbs)
    else:
        size = "N/A"
        prop = "N/A"
    print(f"size: {size}, num_success: {len(num_perturbs)}, prop_perturbs: {prop}")
    print("finished")
    return size, prop


def explain(model, dataset, explainer, seed=42):
    seed_everything(seed, deterministic=True)
    result = []
    
    explainer = Explainer(
        model=model,
        algorithm=explainer,
        explanation_type='model',
        node_mask_type=None,
        edge_mask_type='object',
        model_config=dict(
            mode='multiclass_classification',
            task_level='graph',
            return_type='raw',
        ), 
    )
    
    # cfs = []
    num_perturbs = []
    prop_perturbs = []

    for data in tqdm(test_dataset):
        explanation = explainer(data.x, data.edge_index, batch=None)
        
        if hasattr(explanation, "perturbs"):
            if explanation.perturbs < 20:
                # cfs.append(explanation.cf)
                num_perturbs.append(explanation.perturbs)
                prop_perturbs.append(explanation.prop_perturbs)
    
    size, prop = results(num_perturbs, prop_perturbs)
    
    result.append([len(test_dataset), size, len(num_perturbs), prop])
    print(f"Fedility: {len(num_perturbs)/len(test_dataset)}, Num_perturbs: {size}, Similarity: {1-prop}")
    return result

# BA2Motifs

In [3]:
class GCN(torch.nn.Module):
    def __init__(self, nhid, nout, dropout):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(10, nhid, normalize=False)
        self.conv2 = GCNConv(nhid, nhid, normalize=False)
        self.conv3 = GCNConv(nhid, nout, normalize=False)
        self.lin = Linear(nout, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_max_pool(x, batch)  # [batch_size, nhid]

        # 3. Apply a final classifier
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(nhid=20, nout=20, dropout=0).to(device)
model.load_state_dict(torch.load("../models/GCN_BA2Motifs_sd.pt", weights_only=True))

with open("../data/BA2motifs.pickle", "rb") as f:
	dataset = pickle.load(f)

torch.manual_seed(42)
dataset = dataset.shuffle()

train_dataset = dataset[:0.8]
test_dataset = dataset[0.8:]

  return self.fget.__get__(instance, owner)()


In [4]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, subgraph_mode=True, silent_mode=True, undirected=True)

result1 = explain(model, test_dataset, explainer, seed=42)

  0%|          | 0/200 [00:00<?, ?it/s]

######
size: 1.803191489361702, num_success: 188, prop_perturbs: 0.07015821058374232
finished
Fedility: 0.94, Num_perturbs: 1.803191489361702, Similarity: 0.9298417894162577
CPU times: user 8min 31s, sys: 1.11 s, total: 8min 32s
Wall time: 8min 42s


In [5]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, silent_mode=True, undirected=True)

result1 = explain(model, test_dataset, explainer, seed=42)

  0%|          | 0/200 [00:00<?, ?it/s]

######
size: 1.7842105263157895, num_success: 190, prop_perturbs: 0.06954183535762462
finished
Fedility: 0.95, Num_perturbs: 1.7842105263157895, Similarity: 0.9304581646423754
CPU times: user 8min 55s, sys: 1.21 s, total: 8min 57s
Wall time: 9min 8s


# MUTAG

In [6]:
class GCN(torch.nn.Module):
    def __init__(self, nhid, nout, dropout):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(7, nhid, normalize=False)
        self.conv2 = GCNConv(nhid, nhid, normalize=False)
        self.conv3 = GCNConv(nhid, nout, normalize=False)
        self.lin = Linear(nout, 2)
        self.dropout = dropout

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_max_pool(x, batch)  # [batch_size, nhid]

        # 3. Apply a final classifier
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(nhid=20, nout=20, dropout=0).to(device)
model.load_state_dict(torch.load("../models/GCN_MUTAG_sd.pt", weights_only=True))

with open("../data/MUTAG.pickle", "rb") as f:
	dataset = pickle.load(f)

torch.manual_seed(42)
dataset = dataset.shuffle()

train_dataset = dataset[:0.8]
test_dataset = dataset[0.8:]

In [7]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, subgraph_mode=True, silent_mode=True, undirected=True)

result2 = explain(model, test_dataset, explainer, seed=42)

  0%|          | 0/38 [00:00<?, ?it/s]

######
size: 2.121212121212121, num_success: 33, prop_perturbs: 0.12041421982233662
finished
Fedility: 0.868421052631579, Num_perturbs: 2.121212121212121, Similarity: 0.8795857801776634
CPU times: user 1min 36s, sys: 188 ms, total: 1min 37s
Wall time: 1min 38s


In [8]:
%%time
explainer = C2Explainer(epochs=1000, lr=0.1, silent_mode=True, undirected=True)

result1 = explain(model, test_dataset, explainer, seed=42)

  0%|          | 0/38 [00:00<?, ?it/s]

######
size: 1.6842105263157894, num_success: 38, prop_perturbs: 0.10486203322270968
finished
Fedility: 1.0, Num_perturbs: 1.6842105263157894, Similarity: 0.8951379667772903
CPU times: user 1min 40s, sys: 245 ms, total: 1min 40s
Wall time: 1min 42s
