<center><h1><b><u>PGExplainer with Custom BA2MOTIF Dataset</u></b></h1></center>

## __Importing Libraries, Classes, and Functions__

### __Defining the Class__

In [2]:
#BA2MOTIFS and Custom BA2MOTIFS
import torch
from torch_geometric.datasets import BA2MotifDataset, ExplainerDataset
from torch_geometric.data import Data, Dataset
from torch.utils.data import ConcatDataset, Subset
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.utils import is_undirected, degree, to_networkx, from_networkx
import matplotlib.pyplot as plt

#Custom BA2MOTIFS
from torch_geometric.datasets.graph_generator import BAGraph
from torch_geometric.datasets.motif_generator import HouseMotif, CycleMotif

#BA2MOTIFS graph from scratch
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx #focuses on network analysis and graph theory; can create/visualize graphs
import inspect #for viewing source code
import pprint
from tqdm import tqdm
import pickle
import random
from dataclasses import dataclass

#Classifier
import torch.nn as nn
from torch.nn import Linear, Sequential
import torch_geometric
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv, global_max_pool, global_mean_pool

#Fine tuning
import copy

## __BA2Motif from Scratch__

In [43]:
device = torch.device('cpu')

In [3]:
'''Common in Python to group related functions in a module not encapsulated in a class. NodeView node order is correct since dictionary order is 
by insertion but don't rely on this in an algorithm.'''

def create_house_motif():
    house = nx.Graph()
    house.add_edges_from([ #isomorphically a house, but may not always look like house; 4 nodes in cycle with 2 connected to 5th node
        (0, 1), (1, 2), (2, 3), (3, 0),  # square base
        (1, 4), (2, 4)  # roof
    ])
    return house

def create_cycle_motif(size):
    cycle = nx.Graph()

    cycle.add_nodes_from(range(size))
    
    for i in range(size - 1):
        cycle.add_edge(i, i+1)
    cycle.add_edge(size - 1, 0) #connect last to first to make cycle
    
    return cycle


#Give all nodes and edges same features, otherwise it's "cheating" by labeling the answer; adding features here instead of randomly in class so it is more flexible
def add_node_features(G, motif_nodes):
    for node in G.nodes():
        G.nodes[node]['motif'] = 1

def add_edge_features(G, motif_edges):
    for u, v in G.edges():
        G[u][v]['motif'] = 1


def attach_motif(G, motif):
    orig_num_nodes = G.number_of_nodes() #len(graph._node)
    orig_nodes = list(G.nodes())

    motif_node_mapping = {old: orig_num_nodes + map for map, old in enumerate(motif.nodes())} #shifts motif indices by n so it doesn't override current edges
    mapped_motif_nodes = motif_node_mapping.values()
    
    #Create a copy of the motif and add it to existing graph data (still no edge connecting motif to original graph)
    G.add_nodes_from(mapped_motif_nodes)

    #Map and add edges
    motif_edge_mapping = {(u,v): (u + orig_num_nodes, v + orig_num_nodes) for (u, v) in motif.edges()}
    mapped_motif_edges = motif_edge_mapping.values()
    G.add_edges_from(mapped_motif_edges) 
    
    #Choose attachment point and connect motif
    attachment_point = random.choice(orig_nodes) #chooses random element of the list of node id's as attachment point
    G.add_edge(attachment_point, orig_num_nodes) #n is first node of motif

    #Add features
    add_node_features(G, mapped_motif_nodes)
    add_edge_features(G, mapped_motif_edges)

    return G

def generate_ba2motif_dataset(num_graphs, n, m, cycle_size, motif_prob):
    dataset = [] #list of graphs
    labels = []
    
    for i in range(num_graphs):
        ba_graph = nx.barabasi_albert_graph(n, m)
        
        if random.random() < motif_prob: #random number from 0-1; works by Law of Large Numbers
            motif = create_cycle_motif(cycle_size)
            label = 0
        else:
            motif = create_house_motif()
            label = 1  # House motif label is 1
        
        ba_motif_graph = attach_motif(ba_graph, motif)
        
        dataset.append(ba_motif_graph)
        labels.append(label)
    
    return dataset, labels

### __Creating the Dataset - Function Approach to be Compatible with PGExplainer Code__

In [38]:
num_graphs = 1000
n = 20
m = 1
cycle_size = 5

In [39]:
nx_graphs, labels = generate_ba2motif_dataset(num_graphs, n, m, cycle_size, motif_prob = 0.5) #creates nx graphs

#Convert to Data objects
dataset = []
for nx_graph, label in tqdm(zip(nx_graphs, labels), total = len(labels), desc="Creating Dataset", unit = " Graphs", colour="green", ncols=100):
    data = from_networkx(nx_graph)  # converts networkx to torch Data instance; used debugging to see name and dimensions of data object
    
    data.x = data.motif.unsqueeze(1).float()  # change shape from 1D tensor to 2D: [num_nodes, 1]
    #data.x = torch.ones((data.num_nodes, 1), dtype=torch.float) #Other option if features not created when generating dataset
    del data.motif #removes this attribute from namespace, not necessarily memory

    #Most GNN don't use edge features
    data.edge_attr = data.edge_motif.unsqueeze(1).float()
    del data.edge_motif
    data.y = torch.tensor([label], dtype=torch.long)
    
    dataset.append(data)

ba2_explain = dataset

Creating Dataset: 100%|[32m██████████████████████████████████[0m| 1000/1000 [00:00<00:00, 2854.59 Graphs/s][0m


### __Pickling the Dataset__

In [40]:
'''Pickle saves data to disk in byte stream (binary data more compact than text data); like MNIST from scratch. No whitespaces or other transformations necessary
to decode/deserialize the data so it is also faster.'''

#with open('../data/ba2-explain.pkl', 'wb') as f:
    #pickle.dump('ba2-explain', f)

In [26]:
with open('../data/ba2-explain.pkl', 'rb') as f:
        ba2_explain = pickle.load(f)

### __Loading Pretrained Model__

In [42]:
class GIN(torch.nn.Module):
    def __init__(self, num_features, dropout_rate = 0.2, num_classes = 2, pretrained_model = False):
        super(GIN, self).__init__()

        self.nn1 = Sequential(
            Linear(num_features, 32),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            Linear(32, 16),
            nn.ReLU())

        self.conv1 = GINConv(self.nn1)
        self.bn1 = nn.BatchNorm1d(16)

        self.nn2 = Sequential(
            Linear(16, 16),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            Linear(16, 16),
            nn.ReLU())
        
        self.conv2 = GINConv(self.nn2)
        self.bn2 = nn.BatchNorm1d(16)

        self.fc1 = Linear(32, num_classes) #2 classes

        #Ensure original model isn't altered during fine-tuning
        if pretrained_model:
            temp = torch.load(f'../models/BA2-Scratch/{pretrained_model}')
            cloned_state_dict = copy.deepcopy(temp)
            self.load_state_dict(cloned_state_dict)
            
            
            self.freeze_layers()
            self.fc1 = Linear(32, 2) #Common to freeze earlier layers; replace desired layer for fine tuning

    def freeze_layers(self):
        for name, param in self.named_parameters():
                param.requires_grad = True if 'fc1' else False # Don't freeze the final layer

    def node_embedding(self, x, edge_index):
        # Process through GIN layers
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        return x
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch #extracts feature matrix from graph, edge info, batch indices
        node_embeddings = self.node_embedding(x, edge_index) #reduces dimension of nodes and edges

        x_max = global_max_pool(node_embeddings, batch)
        x_mean = global_mean_pool(node_embeddings, batch)
        x = torch.cat([x_max, x_mean], dim=1) #combines these two poolings into single tensor

        return self.fc1(x)

In [44]:
def load_model(filename):
    model = GIN(num_features=1)
    model.load_state_dict(torch.load(f'../models/BA2-Scratch/{filename}', weights_only=True))
    model.eval()
    model = model.to(device)
    return model

In [45]:
aug_filename = "GIN-aug-fine-tuned-100.pth"
model = load_model(aug_filename)

In [41]:
len(ba2_explain)

1000

In [36]:
def evaluate(model, dataset):
    model.eval()

    
    for data in dataset:
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1)
            correct = int((pred == data.y).sum())
    
    total = len(dataset)
    return correct / total

# Usage
model = model.to(device)  # Ensure your model is on the correct device
accuracy = evaluate(model, ba2_explain, device)
print(f"Test Accuracy: {accuracy:.4f}")

AttributeError: 'strBatch' object has no attribute 'stores_as'

In [29]:
def evaluate(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
            pred = out.argmax(dim=1) #dim1 is for certain batch; index of max val will be predicted class (0 == cycle, 1 == house)
            correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

test_loader = DataLoader(ba2_explain, batch_size=32, shuffle = False)
evaluate(model, test_loader)

In [30]:
test_loader = DataLoader(ba2_explain, batch_size=32, shuffle = False)
evaluate(model, test_loader)

AttributeError: 'list' object has no attribute 'to'