# Task Description

## Goal
Explain the predictions of a graph neural network

## Tasks
1. Pick a dataset
2. Analyze the dataset
3. Train a graph neural network on the data and evaluate it
4. Explain the graph neural network

## Datasets
Suggested homogeneous datasets (one edge type)  
▪ BA-Shapes: https://docs.dgl.ai/generated/dgl.data.BAShapeDataset.html  
▪ BA-Community: https://docs.dgl.ai/generated/dgl.data.BACommunityDataset.html  
▪ Tree-Cycles: https://docs.dgl.ai/generated/dgl.data.TreeCycleDataset.html  
▪ Tree-Grid: https://docs.dgl.ai/generated/dgl.data.TreeGridDataset.html  
▪ Cora: https://docs.dgl.ai/generated/dgl.data.CoraGraphDataset.html  

Suggested heterogeneous datasets (multiple edge types)  
▪ AIFB: https://data.dgl.ai/dataset/rdf/aifb-hetero.zip  
▪ MUTAG: https://data.dgl.ai/dataset/rdf/am-hetero.zip  
▪ BGS: https://data.dgl.ai/dataset/rdf/bgs-hetero.zip  
▪ AM: https://data.dgl.ai/dataset/rdf/am-hetero.zip  

## Converting graph data to tabular data  

▪ Potential libraries to convert graph data to tabular data (not tested)  
&emsp;&emsp;• https://derwen.ai/docs/kgl/ref/#build_df-method  
&emsp;&emsp;• https://github.com/cadmiumkitty/rdfpandas  


▪ Implement conversion yourself  
&emsp;&emsp;• RDF library: https://rdflib.readthedocs.io/en/stable/intro_to_parsing.html  
&emsp;&emsp;• “Table” library: https://pandas.pydata.org/docs/  
&emsp;&emsp;• Given: set of triples: (subject, predicate, object)  
&emsp;&emsp;• Each subject becomes a row in the table  
&emsp;&emsp;• Each predicate becomes a column in the table  
&emsp;&emsp;• Each object is inserted into a cell in the table  
&emsp;&emsp;&emsp;&emsp;o If there are multiple objects for one subject and predicate, a cell contains multiple values  
&emsp;&emsp;&emsp;&emsp;o If a subject does not have a certain predicate, the cell contains “NA” indicating a missing value  
&emsp;&emsp;&emsp;&emsp;o The best way to deal with multiple values and NA values depends on dataset and use case  
&emsp;&emsp;• Dealing with multiple values and NA values  
&emsp;&emsp;&emsp;&emsp;o Simple approach: One-hot-encoding: for each combination of predicate and object, create a  
&emsp;&emsp;&emsp;&emsp;column with a binary feature (0 or 1) indicating whether the subject is connected to the object via the predicate  
&emsp;&emsp;• It is okay to convert only part of a dataset to a table (but try to make sure to achieve an appropriate fidelity,   
&emsp;&emsp;&emsp;&emsp;i.e., the model trained on tabular data approximates the graph neural network well)

## Explaining the predictions

Strategy 1: Explain surrogate model on tabular data  
▪ Train a surrogate model that works on tabular data and approximates the graph neural network  
    &emsp;&emsp;• Convert graph data to tabular data  
    &emsp;&emsp;• Compute fidelity: measure how well the surrogate model approximates the graph neural network  

▪ If surrogate model is interpretable: it can be explained directly  
    &emsp;&emsp;• Weights of logistic regression  
    &emsp;&emsp;• Decision Tree  

▪ If surrogate model is not interpretable: it can be explained with an explainer  
• Example models  
    &emsp;&emsp;o random forest  
    &emsp;&emsp;o neural network on tabular data  
• Example explainers  
    &emsp;&emsp;o LIME  
    &emsp;&emsp;o ANCHOR  
    &emsp;&emsp;o SHAP  

# Implementation

## Imports

In [None]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid, TUDataset, KarateClub, GNNBenchmarkDataset, CitationFull
import pandas as pd
import os
import shutil
import pickle
from torch_geometric.explain.algorithm import GNNExplainer
from torch_geometric.explain import Explainer

pd.set_option('display.max_rows', 50000)

## Training

## Data

In [None]:
dataset = Planetoid(root='data/Cora', name='Cora')

data = dataset[0]

## Model

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, p):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=p, training=self.training) # p=0.2
        x = self.conv2(x, edge_index)

        return x


def train(p):
    model.train()
    optimizer.zero_grad()
    # getting the indices for node features and egdes
    x = dataset[0].x
    edge_index = dataset[0].edge_index
    # getting the labels for node§
    y = dataset[0].y

    try:
        # training mask
        train_mask = dataset[0].train_mask
    except:
        train_mask, val_mask, test_mask = get_mask(len(y))

    # forward pass for gnn training
    out = model(x, edge_index, p)
    # training nodes loss
    loss = F.nll_loss(out[train_mask], y[train_mask])
    # backward pass for gnn training
    loss.backward()
    optimizer.step()
    return loss

def get_mask(i):

    trn_len  = int(i*0.60)
    val_len  = int(i*0.20)
    test_len = int(i*0.20)

    train_mask=  torch.tensor(([True]  * trn_len) + ([False] * val_len) + ([False] * test_len))
    val_mask  =  torch.tensor(([False] * trn_len) + ([True]  * val_len) + ([False] * test_len))
    test_mask =  torch.tensor(([False] * trn_len) + ([False] * val_len) + ([True]  * test_len))

    return train_mask, val_mask, test_mask
    

def save_pickle(data, fname):
    with open(fname, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return True


def load_pickle(fname):
    with open(fname, 'rb') as handle:
        data = pickle.load(handle)
        return data
    
@torch.no_grad()
def test(p):
    model.eval()
    x = dataset[0].x
    edge_index = dataset[0].edge_index
    y = dataset[0].y
    
    try:
        val_mask = dataset[0].val_mask
        test_mask = dataset[0].test_mask
        # print(val_mask)
    except:
        # val_mask, test_mask not present
        train_mask, val_mask, test_mask = get_mask(len(y))


    out = model(x, edge_index, p)
    pred = out.argmax(dim=1)
    # print('>> ', pred)
    val_acc = pred[val_mask].eq(y[val_mask]).sum().item() / val_mask.sum().item()
    test_acc = pred[test_mask].eq(y[test_mask]).sum().item() / test_mask.sum().item()
    return val_acc, test_acc



In [None]:
EPOCHS=10

ps  = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
lrs = [0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.45, 0.05, 0.055, 0.06, 0.065, 0.07, 0.075, 0.08, 0.085, 0.09, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

finaldf=pd.DataFrame()

for i, p in enumerate(ps):
    print(f'{i+1}/{len(ps)}', end='\r')
    
    for lr in lrs:

        # Create the model and optimizer
        model = GNN(dataset.num_node_features, 16, dataset.num_classes)
        # optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

        val_accs, test_accs=[], []
        # Run for 200 epochs
        for epoch in range(1, EPOCHS):
            loss = train(p)
            val_acc, test_acc = test(p)
            val_accs.append(val_acc)
            test_accs.append(test_acc)

            # print(f'Epoch: {epoch}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

        # print(max(val_accs), max(test_accs))

        tmpdf=pd.DataFrame({
                                'p':[p], 
                                'lr': [lr], 
                                'val_accs': [max(val_accs)], 
                                'test_accs': [max(test_accs)], 
                                'val_accs_iter': [val_accs.index(max(val_accs))], 
                                'test_accs_iter': [test_accs.index(max(test_accs))], 
                            })

        finaldf = pd.concat([finaldf, tmpdf])

        # torch.save(model, f'./models/p-{p}_lr-{lr}_val_accs-{max(val_accs)}_test_accs-{max(test_accs)}_{dataset_name}')
        # save_pickle(model,  f'./models/p-{p}_lr-{lr}_val_accs-{max(val_accs)}_test_accs-{max(test_accs)}_{dataset_name}')

finaldf.reset_index(drop=True).sort_values('test_accs', ascending=False).head(10)


In [None]:
li=os.listdir('./models/')
resultsdf=pd.DataFrame()
for x in li:
    if 'val_accs' in x:
        items=x.split('_')
        tmpdf=pd.DataFrame([items], columns=['p','lr','val','val_acc','test','test_acc','data1','data2'])
        resultsdf = pd.concat([resultsdf, tmpdf])

resultsdf.sort_values('test_acc', ascending=False).reset_index(drop=True).head()

In [None]:
winner_model='_'.join(resultsdf.iloc[0].tolist())
print(winner_model)

shutil.copyfile(f'./models/{winner_model}', f'./best_model/{winner_model}')

## EDA 

### EdgeDF

In [None]:
x = data.x # a tensor of shape [2708, 1433], where 2708 is the number of nodes and 1433 is the number of features
y = data.y # a tensor of shape [2708], where each element is an integer representing the class label of the node
edge_index = data.edge_index # a tensor of shape [2, 10556], where each column is a pair of node indices representing an edge

# converting the node features and labels to np arrays
x = x.numpy()
y = y.numpy()
edge_index = edge_index.numpy()

# create a list of triples from the edge index
triples = []
for i in range(edge_index.shape[1]):
  # getting the source and target node indices
  source = edge_index[0,i]
  target = edge_index[1,i]
  # creating a triple with the predicate 'cites'
  triple = (source, 'cites', target)
  # append the triple to the list
  triples.append(triple)

# creating a list of triples from the node features and labels
for i in range(x.shape[0]):
  node = i
  features = x[i,:]
  label = y[i]
  # creating a triple with the predicate 'rdf:type' and the object 'Paper'
  triple = (node, 'rdf:type', 'Paper')
  # appending the triple to the list
  triples.append(triple)
  # creating a triple with the predicate 'label' and the object as the label value
  triple = (node, 'label', label)
  # appending the triple to the list
  triples.append(triple)
  # looping over the features and create triples with the predicate as 'word_i' and the object as the feature value
  for j in range(features.shape[0]):
    feature = features[j]
    predicate = f'word_{j}'
    triple = (node, predicate, feature)
    # appending the triple to the list
    triples.append(triple)

edge_df = pd.DataFrame(triples, columns=['subject', 'predicate', 'object'])

# Print the first five rows of the dataframe
print(edge_df.shape)
edge_df.head()


In [None]:
edge_df.subject.value_counts().to_frame().reset_index(drop=False).head()

In [None]:
edge_df.predicate.value_counts().to_frame().reset_index(drop=False).head()


In [None]:
edge_df.object.value_counts().to_frame().reset_index(drop=False).head()


### NodeDF

In [None]:
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')

# Get the node features, labels and edge indices
x = data.x # a tensor of shape [2708, 1433], where 2708 is the number of nodes and 1433 is the number of features
y = data.y # a tensor of shape [2708], where each element is an integer representing the class label of the node
edge_index = data.edge_index # a tensor of shape [2, 10556], where each column is a pair of node indices representing an edge

# Convert the node features and labels to numpy arrays
x = x.numpy()
y = y.numpy()

# Create a pandas dataframe from the node features and labels
node_df = pd.DataFrame(x)
node_df['label'] = y

# Print the first five rows of the dataframe
print(node_df.shape)
node_df.head()


In [None]:
node_df.isnull().sum()

In [None]:
node_df.describe()


## Explainer

In [None]:
try:
    print(model)
except:
    model=load_pickle(os.listdir('./best_model/')[0])

# create an Explainer object using the trained GNN model and the GNNExplainer algorithm
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=20),
    explanation_type="model",
    node_mask_type="attributes",
    edge_mask_type="object",
    model_config=dict(
        mode="multiclass_classification",
        task_level="node",
        return_type="log_probs", # Model returns log probabilities
    ),
)

# explain the node prediction for the node at index 0 using 5 features
explanation = explainer(data.x, data.edge_index, p=0.50, index=152)#, num_features=5)


In [None]:
explanation.node_attrs()

In [None]:
explanation.get_explanation_subgraph()

# node_mask=[43, 2879], : 43 nodes, 2879 features
# edge_mask=[59],       : 59 edges
# prediction=[43, 7], 
# target=[43], 
# index=23,             : node index
# x=[43, 2879],         : 43 nodes, 2879 features
# edge_index=[2, 59],   : 
# p=0.5


In [None]:
explanation.visualize_graph()


In [None]:
explanation.visualize_feature_importance(top_k=10)