# This model using modified similarity (similarity2)

In [2]:
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../src')

%load_ext autoreload
%autoreload 2

import torch
import torch_geometric
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt


from torch_geometric.datasets import TUDataset
from preprocessing import data_transformation
from similarity import calculate_similarity_matrix

from model import GCN
import copy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [174]:
dataset = TUDataset(root='datasets/', name='Mutagenicity')
torch.manual_seed(1234)
dataset = dataset.shuffle()

## Preprocessing

#### Split: Train test validation

```train_dataset```: for training model<br/>
```val_dataset```: evaluate model for hyperparameter tunning<br/>
```test_dataset```: testing model after complete training<br/>

In [175]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

In [176]:
tr, ts, vl = 0.8, 0.1, 0.1
dslen = len(dataset)
tri = round(tr*dslen)
tsi = round((tr+ts)*dslen)
train_dataset = dataset[:tri]
test_dataset = dataset[tri:tsi]
val_dataset = dataset[tsi:]

In [177]:
dataset[0]

Data(edge_index=[2, 24], x=[12, 14], edge_attr=[24, 3], y=[1])

In [178]:
dataset.y

tensor([1, 1, 0,  ..., 0, 1, 0])

In [179]:
# data = dataset[0]

# Function to reindex nodes
def reindex_nodes(data):
    unique_nodes = torch.unique(data.edge_index)
    old_to_new = {old.item(): new for new, old in enumerate(unique_nodes)}

    # Reindex edge_index
    for i, edge in enumerate(data.edge_index.t()):
        data.edge_index[0][i] = old_to_new[edge[0].item()]
        data.edge_index[1][i] = old_to_new[edge[1].item()]

    # Reorder x according to new indices
    data.x = data.x[unique_nodes]

    return data

for i in range(len(dataset)):
    data = reindex_nodes(dataset[i])
    dataset[i].edge_index = data.edge_index
    dataset[i].x = data.x
    dataset[i].edge_attr = data.edge_attr

In [180]:
dataset

Mutagenicity(4337)

In [182]:
print(len(train_dataset))
train_dataset.y

3470


tensor([1, 1, 0,  ..., 0, 1, 1])

In [183]:
len(test_dataset)
test_dataset.y

tensor([1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,
        0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0,

In [184]:
len(val_dataset)
val_dataset.y

tensor([0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,
        1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0,
        1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1,
        1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0,
        1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0,
        0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1,
        1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1,
        1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,

#### Batching

In [185]:
# paper 128
batch_size = 2

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Building Model

In [186]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphSAGE
from torch_geometric.nn import GraphConv
from torch_geometric.nn import GINConv
# from torch_geometric.nn import GINConv
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d

from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_add_pool
import torch.nn.functional as F

In [187]:
class Base(torch.nn.Module):
    # merging type: o --> complement only, s --> substraction, c --> concatenation
    def __init__(self, dataset, hidden_channels):
        super(Base, self).__init__()
        
        # weight seed
        torch.manual_seed(42)
        nn1 = Sequential(Linear(dataset.num_node_features, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm1d(hidden_channels)
        
        nn2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv2 = GCNConv(nn2)
        self.bn2 = BatchNorm1d(hidden_channels)
        
        
        # classification layer        
        self.lin = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Embed original
        embedding = self.conv1(x, edge_index)
        embedding = embedding.relu()
        embedding = self.conv2(embedding, edge_index)
        embedding = embedding.relu()
        embedding = self.conv3(embedding, edge_index)
        embedding = embedding.relu()
        # subgraph_embedding = subgraph_embedding.relu()
        
        embedding = global_mean_pool(embedding, batch)
        h = self.lin(embedding)
        h = F.relu(h)
        h = F.dropout(h, p=0.3, training=self.training)
        h = self.lin2(h)
        
        return embedding, h

#### Train

### Experiment Model

In [188]:
from sklearn.cluster import AffinityPropagation

#### Model modification

In [189]:
# paper 128
batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# batch1 = next(iter(train_loader))

In [190]:


from similarity2 import calculate_similarity_matrix, testt


# AP Clustering
from sklearn.cluster import AffinityPropagation

from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import GINConv, global_add_pool
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d

import torch.nn.functional as F

class Experiment(torch.nn.Module):
    # merging type: o --> complement only, s --> substraction, c --> concatenation
    def __init__(self, dataset, hidden_channels, k = 1):
        super(Experiment, self).__init__()
        
        # save number of subgraphs, default 1
        self.k_subgraph = k
        
        # weight seed
        torch.manual_seed(42)
        nn1 = Sequential(Linear(dataset.num_node_features, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm1d(hidden_channels)
        
        nn2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv2 = GINConv(nn2)
        self.bn2 = BatchNorm1d(hidden_channels)
        
        # embeddings for subgraph
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.conv5 = GCNConv(hidden_channels, hidden_channels)
        
        # attention layer
        self.query_layer = Linear(hidden_channels,hidden_channels)
        self.key_layer = Linear(hidden_channels,hidden_channels)
        self.value_layer = Linear(hidden_channels,hidden_channels)
        
        # classification layer
        self.lin = Linear(hidden_channels*2, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch, ptr):
        # Embed original
        # embedding = self.conv1(x, edge_index)
        # embedding = embedding.relu()
        # embedding = self.conv2(embedding, edge_index)
                
        embedding = F.relu(self.conv1(x, edge_index))
        embedding = self.bn1(embedding)
        embedding = F.relu(self.conv2(embedding, edge_index))
        embedding = self.bn2(embedding)
        
        # generate subgraph based on embeddings
        feature_emb = embedding.detach()
        
        subgraph_edge_index, communities, S, batch_communities = self.subgraph_generator(feature_emb, edge_index, batch, ptr)
        
        
        subgraph_embedding = self.conv4(embedding, subgraph_edge_index)
        subgraph_embedding = subgraph_embedding.relu()
        subgraph_embedding = self.conv5(subgraph_embedding, subgraph_edge_index)
        
        # apply readout layer/pooling for each subgraphs
        subgraph_pool_embedding = self.subgraph_pooling(subgraph_embedding, communities, batch, ptr, batch_communities)
        # print(len(subgraph_pool_embedding))
        # apply selective (top k) attention
        topk_subgraph_embedding = self.selectk_subgraph(embedding, subgraph_pool_embedding, self.k_subgraph)
        
        # readout layer for original embedding
        embedding = global_mean_pool(embedding, batch)
                
        combined_embeddings = torch.cat((embedding, topk_subgraph_embedding.view(len(embedding), -1)), 1)
        
        
        # h = F.dropout(combined_embeddings, p=0.3, training=self.training)
        h = self.lin(combined_embeddings)
        h = F.relu(h)
        h = F.dropout(h, p=0.3, training=self.training)
        h = self.lin2(h)
        
        return embedding, h, S, communities, topk_subgraph_embedding.view(len(embedding), -1)
    
    # checked
    def selectk_subgraph(self, embs, sub_embs, k = 1):
        # calculate attention and select top k subgraph
        
        topk_subgraphs_all = []

        for i, (emb, sub_emb) in enumerate(zip(embs, sub_embs)):
            sub = torch.tensor(sub_emb)
            sub = sub.to(torch.float32)

            # transform
            query = self.query_layer(emb)
            key = self.key_layer(sub)
            value = self.value_layer(sub)

            # att score
            attention_score = torch.matmul(query, key.transpose(0,1))
            attention_weight = F.softmax(attention_score, dim=0)
            
            # select topk
            topk_subgraph_embeddings = None
            
            if (k <= len(sub)):
                topk_scores, topk_indices = torch.topk(attention_weight, k)
                topk_subgraph_embeddings = sub[topk_indices]
            else:
                print('too big')
                
            topk_subgraphs_all.append(topk_subgraph_embeddings)
        
        return torch.stack(topk_subgraphs_all)
    
    
    def subgraph_generator(self, embeddings, batch_edge_index, batch, ptr):
        '''
        Return subgraph_edge_index (edge_index of created subgraph)
        '''
        graph_counter = 0
        edge_index = [[],[]]
        subgraph_edge_index = [[],[]]
        # Gs = []
        sub_created = False
        graph_bound = {}
        all_communities = []
        batch_communities = {}
        S = []

        for i in range(len(ptr)-1):
            graph_bound[i] = [ptr[i].item(), ptr[i+1].item()]
        
        for i, (src, dst) in enumerate(zip(batch_edge_index[0], batch_edge_index[1])):
            lower_bound = graph_bound[graph_counter][0]
            upper_bound = graph_bound[graph_counter][1]
            if ((src >= lower_bound and src < upper_bound) or
                (dst >= lower_bound and dst < upper_bound)):
                
                edge_index[0].append(src - lower_bound)
                edge_index[1].append(dst - lower_bound)
            else:
                sub_created = True
                
            if (i == len(batch_edge_index[0]) - 1) or sub_created:
                sub_created = False
                
                embs = []
                # make new graph
                for i, (b, emb) in enumerate(zip(batch, embeddings)):
                    if (b == graph_counter):
                        embs.append(emb)
                
                G = data_transformation(edge_index, embs)
                # dont need this at the moment
                # Gs.append(G)
                
                # Calculate similarity matrix
                S = calculate_similarity_matrix(G)
                
                # AP Clustering        
                clustering = AffinityPropagation(affinity='precomputed', damping=0.8, random_state=42, convergence_iter=15, max_iter=1000)
                clustering.fit(S)
                
                
                # Get community
                communities = {}
                for lab in clustering.labels_:
                    communities[lab] = []
                    all_communities.append(lab)
                for nd, clust in enumerate(clustering.labels_):
                    communities[clust].append(nd)
                
                edge_index = [[],[]]
                batch_communities[graph_counter] = communities
                
                graph_counter+=1
                
                # Make subgraph edge_index
                for c in communities:
                    w = G.subgraph(communities[c])
                    for sub in w.edges:
                        subgraph_edge_index[0].append(sub[0] + lower_bound)
                        subgraph_edge_index[1].append(sub[1] + lower_bound)
                
        
        # print('batch communities', batch_communities)
        return torch.tensor(subgraph_edge_index), all_communities, S, batch_communities
    
        
    # check autograd (done)
    def subgraph_pooling(self, embeddings, communities, batch, ptr, batch_communities, pool_type = 'mean'):
        # batch communities: batch (or graph in this batch) -> communities -> member        
        all_emb_pool = []
        
        # LOOP THROUGH BATCH
        for b in batch_communities:
            
            # initialize array
            emb_pool = [None] * len(batch_communities[b])
            for comm in batch_communities[b]:
                emb_temp = []

                for member in batch_communities[b][comm]:
                    index_used = member + ptr[b].item()
                    emb_temp.append(embeddings[index_used])

                # Pooling per sub graph using PyTorch
                emb_temp_tensor = torch.stack(emb_temp)
                if pool_type == 'mean': # mean pool
                    emb_pool[comm] = torch.mean(emb_temp_tensor, dim=0)
                elif pool_type == 'add': # add pool
                    emb_pool[comm] = torch.sum(emb_temp_tensor, dim=0)
                else:
                    print('TODO: fill later')
                    
            all_emb_pool.append(torch.stack(emb_pool))
        return all_emb_pool

# experiment = Experiment(dataset, 64)
# emb, h, S, communities, sub_emb = experiment(batch1.x, batch1.edge_index, batch1.batch, batch1.ptr)

In [191]:
def train_base(model, loader, experiment_mode=False):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    model.train()
    total_loss = 0
    
    for data in loader:
        optimizer.zero_grad()
        if experiment_mode:
            emb, h, S, communities, sub_emb = model(data.x, data.edge_index, data.batch, data.ptr)
        else:
            emb, h = model(data.x, data.edge_index, data.batch)
        loss = criterion(h, data.y)
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()

    return loss / len(loader)

@torch.no_grad()
def test_base(model, loader, experiment_mode=False):
    model.eval()
    correct = 0
    for data in loader:
        if experiment_mode:
            emb, h, S, communities, sub_emb = model(data.x, data.edge_index, data.batch, data.ptr)
        else:
            emb, h = model(data.x, data.edge_index, data.batch)
        pred = h.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct/len(loader.dataset)

In [192]:
def expTrain(train_loader, val_loader, test_loader, epoch = 10, fold=0):
    import warnings
    warnings.filterwarnings("ignore", category=UserWarning) 

    num_hidden_layer = 128
    experiment = Experiment(dataset, num_hidden_layer)
    loss_history = []
    train_acc_history = []
    val_acc_history = []
    test_acc_history = []
    
    # early stop
    early_stopping_patience = 20
    best_val_score = -float("inf")
    epochs_without_improvement = 0
    best_state = None
    
    # Train
    print('process training')
    for _ in range(epoch):
        loss = round(train_base(experiment, train_loader, True).item(), 5)
        train_acc = round(test_base(experiment, train_loader, True), 5)
        val_acc = round(test_base(experiment, val_loader, True), 5)
        test_acc = test_base(experiment, test_loader, True)
        
        loss_history.append(loss)
        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)
        test_acc_history.append(test_acc)
        
        print(f'epoch {_+1}; loss: {loss}; train_acc: {train_acc}; val_acc: {val_acc}; test_acc: {test_acc}')
        
        torch.save(experiment.state_dict(), "model-history/GIN-Mutagenicity/"+str(fold)+"-e"+str(_)+".experiment_best_model-gin_data-mutag.pth")
        
        # early stop
        if (val_acc > best_val_score):
            best_val_score = val_acc
            epochs_without_improvement = 0
            
            print('best found, save model')
            # save model
            torch.save(experiment.state_dict(), "model-history/GIN-Mutagenicity/"+str(fold)+".experiment_best_model-gin_data-mutag.pth")
            best_state = copy.deepcopy(experiment.state_dict())
        # else:
        #     epochs_without_improvement += 1
        #     if (epochs_without_improvement >= early_stopping_patience):
        #         print('early stop triggered')
        #         break
                
            

    # Test    
    # Create a new instance of the model for testing
    best_model = Experiment(dataset, num_hidden_layer)
    best_model.load_state_dict(best_state)

    # Test
    test = test_base(best_model, test_loader, True)
    print(f'Accuracy: {test}')

    
    return [loss_history, train_acc_history, val_acc_history, test_acc_history]

# expTrain(train_loader, val_loader, test_loader, epoch = 100)

In [193]:
def baseTrain(train_loader, val_loader, test_loader, epoch = 10, fold=0):
    num_hidden_layer = 128
    base = Base(dataset, num_hidden_layer)
    early_stopping_patience = 20
    best_val_score = -float("inf")
    epochs_without_improvement = 0
    best_state = None
    
    
    
    # Train
    for _ in range(epoch):
        
        loss = round(train_base(base, train_loader).item(), 5)
        train_acc = round(test_base(base, train_loader), 5)
        val_acc = round(test_base(base, val_loader), 5)
        
        
        print(f'epoch {_}; loss: {loss}; train_acc: {train_acc}; val_acc: {val_acc}; test: {round(test_base(base, test_loader), 2)}')
        
        torch.save(base.state_dict(), "model-history/GIN-Mutagenicity/"+str(fold)+"-e"+str(_)+".base_best_model-gin_data-mutag.pth")
        # best_state = copy.deepcopy(base.state_dict())
        
        if (val_acc > best_val_score):
            best_val_score = val_acc
            epochs_without_improvement = 0
            
            print('best found, save model')
            # save model
            torch.save(base.state_dict(), "model-history/GIN-Mutagenicity/"+str(fold)+".base_best_model-gin_data-mutag.pth")
            best_state = copy.deepcopy(base.state_dict())
        # else:
        #     epochs_without_improvement += 1
        #     if (epochs_without_improvement >= early_stopping_patience):
        #         print('early stop triggered')
        #         break
                
            
    # Test
    # test = test_base(best, test_loader)
    # print(f'Accuracy: {test}')
    
    # Create a new instance of the model for testing
    best_model = Base(dataset, num_hidden_layer)
    best_model.load_state_dict(best_state)

    # Test
    test = test_base(best_model, test_loader)
    print(f'Accuracy: {test}')


#### Cross validation 10

In [194]:
from sklearn.model_selection import KFold

In [195]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphSAGE
from torch_geometric.nn import GraphConv
from torch_geometric.nn import GINConv
# from torch_geometric.nn import GINConv
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d

from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_add_pool
import torch.nn.functional as F

In [196]:
class Base(torch.nn.Module):
    # merging type: o --> complement only, s --> substraction, c --> concatenation
    def __init__(self, dataset, hidden_channels):
        super(Base, self).__init__()
        
        # weight seed
        torch.manual_seed(42)
        nn1 = Sequential(Linear(dataset.num_node_features, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv1 = GINConv(nn1)
        self.bn1 = BatchNorm1d(hidden_channels)
        
        nn2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels,hidden_channels))
        self.conv2 = GINConv(nn2)
        self.bn2 = BatchNorm1d(hidden_channels)
        
        
        # classification layer        
        self.lin = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Embed original
        embedding = F.relu(self.conv1(x, edge_index))
        embedding = self.bn1(embedding)
        embedding = F.relu(self.conv2(embedding, edge_index))
        embedding = self.bn2(embedding)
        
        embedding = global_add_pool(embedding, batch)
        h = self.lin(embedding)
        h = F.relu(h)
        h = F.dropout(h, p=0.3, training=self.training)
        h = self.lin2(h)
        
        return embedding, h

In [197]:
train_dataset = dataset[:round(len(dataset) * 0.8)]
test_dataset = dataset[round(len(dataset) * 0.8):]
print(train_dataset.y)
print(len(train_dataset.y))
print(test_dataset.y)

tensor([1, 1, 0,  ..., 0, 1, 1])
3470
tensor([1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0,
        0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1,
        0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
  

In [199]:
# 
train_dataset
test_dataset
k = 10
batch_size = 128

splits = KFold(n_splits=k,shuffle=True,random_state=42)
k_counter = 0
fold_logs = {}

for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(train_dataset)))):
    print(f'Fold {fold}/{k}')
    
    fold_train = []
    for key in train_idx:
        fold_train.append(train_dataset[key])

    fold_val = [] 
    for key in val_idx:
        fold_val.append(train_dataset[key])
        
    print('ftrain', train_idx[0:20])
    tr = DataLoader(fold_train, batch_size=batch_size, shuffle=True)
    vd = DataLoader(fold_val, batch_size=batch_size, shuffle=True)
    ts = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    print("=== Base model vs Experiment ===")
    # print("=== Base model ===")
    # fold_logs[fold] = baseTrain(tr, vd, ts, 50, fold=fold)
    print("=== Experiment model ===")
    fold_logs[fold] = expTrain(tr, vd, ts, 50, fold=fold)
    # break
    k_counter += 1
    
    # break

Fold 0/10
ftrain [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 15 16 18 19 20 21]
=== Base model vs Experiment ===
=== Experiment model ===
process training


In [26]:
print(fold_logs)

{0: None, 1: None, 2: None, 3: None, 4: None, 5: None, 6: None, 7: None, 8: None, 9: None}


In [80]:
e = Experiment(dataset=dataset, hidden_channels=128)
e.load_state_dict(torch.load("model-history/GIN-MUTAG/6.experiment_best_model-gin_data-mutag.pth"))

<All keys matched successfully>

In [51]:
nodes = [0, 2, 3, 13, 1, 4, 5, 6, 7, 8, 14, 9, 20, 10, 11, 15, 16, 17, 18, 19]

In [81]:
list(set([*list(set(train_dataset[11].edge_index[0].tolist())), *list(set(train_dataset[11].edge_index[1].tolist()))]))


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

In [93]:
nodes2 = [[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9, 10, 11, 11, 13, 14, 15, 16, 17, 18, 19, 20], 
[2, 3, 13, 0, 4, 5, 0, 6, 0, 7, 8, 14, 1, 6, 9, 1, 20, 2, 4, 10, 3, 11, 15, 16, 3, 17, 4, 18, 6, 7, 19, 0, 3, 7, 7, 8, 9, 11, 5]]

In [100]:
list(set([*list(set(nodes2[0])), *list(set(nodes2[1]))]))

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20]

In [114]:
data = {0: {0: 0.0, 1: 0.0, 2: 5.9, 3: 14.3, 4: 5.9, 5: 14.3, 6: 11.7, 7: 14.3, 8: 16.4, 20: 20.2, 9: 11.7, 10: 14.3, 11: 16.4, 21: 20.2, 12: 20.1, 22: 20.2, 23: 20.2, 24: 20.2, 13: 20.1, 25: 20.2, 26: 20.2, 27: 20.2, 14: 23.8, 15: 23.8, 16: 20.2, 17: 20.2, 18: 20.2, 19: 20.2}, 1: {0: 0.0, 1: 0.0, 2: 5.9, 3: 14.3, 4: 5.9, 5: 14.3, 6: 11.7, 7: 14.3, 8: 16.4, 20: 20.2, 9: 11.7, 10: 14.3, 11: 16.4, 21: 20.2, 12: 20.1, 22: 20.2, 23: 20.2, 24: 20.2, 13: 20.1, 25: 20.2, 26: 20.2, 27: 20.2, 14: 23.8, 15: 23.8, 16: 20.2, 17: 20.2, 18: 20.2, 19: 20.2}, 2: {0: 5.9, 1: 5.9, 2: 0.0, 3: 10.2, 4: 0.0, 5: 10.2, 6: 9.1, 7: 10.2, 8: 12.4, 20: 16.7, 9: 9.1, 10: 10.2, 11: 12.4, 21: 16.7, 12: 17.8, 22: 16.7, 23: 16.7, 24: 16.7, 13: 17.8, 25: 16.7, 26: 16.7, 27: 16.7, 14: 23.2, 15: 23.2, 16: 18.2, 17: 18.2, 18: 18.2, 19: 18.2}, 3: {0: 14.3, 1: 14.3, 2: 10.2, 3: 0.0, 4: 10.2, 5: 0.0, 6: 8.8, 7: 0.0, 8: 7.6, 20: 10.2, 9: 8.8, 10: 0.0, 11: 7.6, 21: 10.2, 12: 16.1, 22: 10.2, 23: 10.2, 24: 10.2, 13: 16.1, 25: 10.2, 26: 10.2, 27: 10.2, 14: 22.6, 15: 22.6, 16: 15.7, 17: 15.7, 18: 15.7, 19: 15.7}, 4: {0: 5.9, 1: 5.9, 2: 0.0, 3: 10.2, 4: 0.0, 5: 10.2, 6: 9.1, 7: 10.2, 8: 12.4, 20: 16.7, 9: 9.1, 10: 10.2, 11: 12.4, 21: 16.7, 12: 17.8, 22: 16.7, 23: 16.7, 24: 16.7, 13: 17.8, 25: 16.7, 26: 16.7, 27: 16.7, 14: 23.2, 15: 23.2, 16: 18.2, 17: 18.2, 18: 18.2, 19: 18.2}, 5: {0: 14.3, 1: 14.3, 2: 10.2, 3: 0.0, 4: 10.2, 5: 0.0, 6: 8.8, 7: 0.0, 8: 7.6, 20: 10.2, 9: 8.8, 10: 0.0, 11: 7.6, 21: 10.2, 12: 16.1, 22: 10.2, 23: 10.2, 24: 10.2, 13: 16.1, 25: 10.2, 26: 10.2, 27: 10.2, 14: 22.6, 15: 22.6, 16: 15.7, 17: 15.7, 18: 15.7, 19: 15.7}, 6: {0: 11.7, 1: 11.7, 2: 9.1, 3: 8.8, 4: 9.1, 5: 8.8, 6: 0.0, 7: 8.8, 8: 7.8, 20: 10.6, 9: 0.0, 10: 8.8, 11: 7.8, 21: 10.6, 12: 15.4, 22: 10.6, 23: 10.6, 24: 10.6, 13: 15.4, 25: 10.6, 26: 10.6, 27: 10.6, 14: 21.2, 15: 21.2, 16: 14.4, 17: 14.4, 18: 14.4, 19: 14.4}, 7: {0: 14.3, 1: 14.3, 2: 10.2, 3: 0.0, 4: 10.2, 5: 0.0, 6: 8.8, 7: 0.0, 8: 7.6, 20: 10.2, 9: 8.8, 10: 0.0, 11: 7.6, 21: 10.2, 12: 16.1, 22: 10.2, 23: 10.2, 24: 10.2, 13: 16.1, 25: 10.2, 26: 10.2, 27: 10.2, 14: 22.6, 15: 22.6, 16: 15.7, 17: 15.7, 18: 15.7, 19: 15.7}, 8: {0: 16.4, 1: 16.4, 2: 12.4, 3: 7.6, 4: 12.4, 5: 7.6, 6: 7.8, 7: 7.6, 8: 0.0, 20: 8.9, 9: 7.8, 10: 7.6, 11: 0.0, 21: 8.9, 12: 16.7, 22: 8.9, 23: 8.9, 24: 8.9, 13: 16.7, 25: 8.9, 26: 8.9, 27: 8.9, 14: 22.6, 15: 22.6, 16: 15.2, 17: 15.2, 18: 15.2, 19: 15.2}, 20: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 9: {0: 11.7, 1: 11.7, 2: 9.1, 3: 8.8, 4: 9.1, 5: 8.8, 6: 0.0, 7: 8.8, 8: 7.8, 20: 10.6, 9: 0.0, 10: 8.8, 11: 7.8, 21: 10.6, 12: 15.4, 22: 10.6, 23: 10.6, 24: 10.6, 13: 15.4, 25: 10.6, 26: 10.6, 27: 10.6, 14: 21.2, 15: 21.2, 16: 14.4, 17: 14.4, 18: 14.4, 19: 14.4}, 10: {0: 14.3, 1: 14.3, 2: 10.2, 3: 0.0, 4: 10.2, 5: 0.0, 6: 8.8, 7: 0.0, 8: 7.6, 20: 10.2, 9: 8.8, 10: 0.0, 11: 7.6, 21: 10.2, 12: 16.1, 22: 10.2, 23: 10.2, 24: 10.2, 13: 16.1, 25: 10.2, 26: 10.2, 27: 10.2, 14: 22.6, 15: 22.6, 16: 15.7, 17: 15.7, 18: 15.7, 19: 15.7}, 11: {0: 16.4, 1: 16.4, 2: 12.4, 3: 7.6, 4: 12.4, 5: 7.6, 6: 7.8, 7: 7.6, 8: 0.0, 20: 8.9, 9: 7.8, 10: 7.6, 11: 0.0, 21: 8.9, 12: 16.7, 22: 8.9, 23: 8.9, 24: 8.9, 13: 16.7, 25: 8.9, 26: 8.9, 27: 8.9, 14: 22.6, 15: 22.6, 16: 15.2, 17: 15.2, 18: 15.2, 19: 15.2}, 21: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 12: {0: 20.1, 1: 20.1, 2: 17.8, 3: 16.1, 4: 17.8, 5: 16.1, 6: 15.4, 7: 16.1, 8: 16.7, 20: 16.5, 9: 15.4, 10: 16.1, 11: 16.7, 21: 16.5, 12: 0.0, 22: 16.5, 23: 16.5, 24: 16.5, 13: 0.0, 25: 16.5, 26: 16.5, 27: 16.5, 14: 22.7, 15: 22.7, 16: 17.2, 17: 17.2, 18: 17.2, 19: 17.2}, 22: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 23: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 24: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 13: {0: 20.1, 1: 20.1, 2: 17.8, 3: 16.1, 4: 17.8, 5: 16.1, 6: 15.4, 7: 16.1, 8: 16.7, 20: 16.5, 9: 15.4, 10: 16.1, 11: 16.7, 21: 16.5, 12: 0.0, 22: 16.5, 23: 16.5, 24: 16.5, 13: 0.0, 25: 16.5, 26: 16.5, 27: 16.5, 14: 22.7, 15: 22.7, 16: 17.2, 17: 17.2, 18: 17.2, 19: 17.2}, 25: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 26: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 27: {0: 20.2, 1: 20.2, 2: 16.7, 3: 10.2, 4: 16.7, 5: 10.2, 6: 10.6, 7: 10.2, 8: 8.9, 20: 0.0, 9: 10.6, 10: 10.2, 11: 8.9, 21: 0.0, 12: 16.5, 22: 0.0, 23: 0.0, 24: 0.0, 13: 16.5, 25: 0.0, 26: 0.0, 27: 0.0, 14: 22.7, 15: 22.7, 16: 13.9, 17: 13.9, 18: 13.9, 19: 13.9}, 14: {0: 23.8, 1: 23.8, 2: 23.2, 3: 22.6, 4: 23.2, 5: 22.6, 6: 21.2, 7: 22.6, 8: 22.6, 20: 22.7, 9: 21.2, 10: 22.6, 11: 22.6, 21: 22.7, 12: 22.7, 22: 22.7, 23: 22.7, 24: 22.7, 13: 22.7, 25: 22.7, 26: 22.7, 27: 22.7, 14: 0.0, 15: 0.0, 16: 11.3, 17: 11.3, 18: 11.3, 19: 11.3}, 15: {0: 23.8, 1: 23.8, 2: 23.2, 3: 22.6, 4: 23.2, 5: 22.6, 6: 21.2, 7: 22.6, 8: 22.6, 20: 22.7, 9: 21.2, 10: 22.6, 11: 22.6, 21: 22.7, 12: 22.7, 22: 22.7, 23: 22.7, 24: 22.7, 13: 22.7, 25: 22.7, 26: 22.7, 27: 22.7, 14: 0.0, 15: 0.0, 16: 11.3, 17: 11.3, 18: 11.3, 19: 11.3}, 16: {0: 20.2, 1: 20.2, 2: 18.2, 3: 15.7, 4: 18.2, 5: 15.7, 6: 14.4, 7: 15.7, 8: 15.2, 20: 13.9, 9: 14.4, 10: 15.7, 11: 15.2, 21: 13.9, 12: 17.2, 22: 13.9, 23: 13.9, 24: 13.9, 13: 17.2, 25: 13.9, 26: 13.9, 27: 13.9, 14: 11.3, 15: 11.3, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0}, 17: {0: 20.2, 1: 20.2, 2: 18.2, 3: 15.7, 4: 18.2, 5: 15.7, 6: 14.4, 7: 15.7, 8: 15.2, 20: 13.9, 9: 14.4, 10: 15.7, 11: 15.2, 21: 13.9, 12: 17.2, 22: 13.9, 23: 13.9, 24: 13.9, 13: 17.2, 25: 13.9, 26: 13.9, 27: 13.9, 14: 11.3, 15: 11.3, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0}, 18: {0: 20.2, 1: 20.2, 2: 18.2, 3: 15.7, 4: 18.2, 5: 15.7, 6: 14.4, 7: 15.7, 8: 15.2, 20: 13.9, 9: 14.4, 10: 15.7, 11: 15.2, 21: 13.9, 12: 17.2, 22: 13.9, 23: 13.9, 24: 13.9, 13: 17.2, 25: 13.9, 26: 13.9, 27: 13.9, 14: 11.3, 15: 11.3, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0}, 19: {0: 20.2, 1: 20.2, 2: 18.2, 3: 15.7, 4: 18.2, 5: 15.7, 6: 14.4, 7: 15.7, 8: 15.2, 20: 13.9, 9: 14.4, 10: 15.7, 11: 15.2, 21: 13.9, 12: 17.2, 22: 13.9, 23: 13.9, 24: 13.9, 13: 17.2, 25: 13.9, 26: 13.9, 27: 13.9, 14: 11.3, 15: 11.3, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0}}

In [116]:
data[0][1]

0.0

In [130]:
i = 0
for dt in train_dataset: 
    if i == 20:
        break
    i+= 1
    edge = list(set([*list(set(dt.edge_index[0].tolist())), *list(set(dt.edge_index[1].tolist()))]))
    print(edge)
    # break

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 