In [None]:
from os import WIFCONTINUED
import numpy as np
import os.path as osp
import time
import sklearn
from sklearn.model_selection import train_test_split
import torch
import torch_geometric
from torch import nn
from torch_geometric.data import Data, DataLoader, DataListLoader
from torch_geometric.utils import degree
import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear, GRUCell
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import PNAConv, BatchNorm, global_mean_pool, DataParallel
import snntorch as snn

In [None]:
# Set the parameters here
batch_size = 3
data_path = '/DATA/graphspiking/graph_spiking/Porous/'
input_dim = 2
num_data = 2000
num_layer = 14
hidden_dim = 50
max_degree = 4
epoch = 250    # changed from 500
scale_factor = 1e-6
timesteps = 1


if hidden_dim % 5 != 0:
    raise Exception("Sorry, not available hidden dimension, need to be multiple of 5")
if num_layer < 1:
    raise Exception("Sorry, the number of layer is not enough")

In [None]:
class PNANet(torch.nn.Module):
    def __init__(self):
        super(PNANet, self).__init__()


        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms  = ModuleList()
        self.grus = ModuleList()


        # Defining the layers

        # Layer 1
        beta1_1 = torch.rand(hidden_dim)
        beta1_2 = torch.rand(input_dim)
        beta1_3 = torch.rand(hidden_dim)
        thr1_1 = torch.rand(hidden_dim)
        thr1_2 = torch.rand(input_dim)
        thr1_3 = torch.rand(hidden_dim)
        self.lif1_1 = snn.Leaky(beta = beta1_1,threshold = thr1_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif1_2 = snn.Leaky(beta = beta1_2,threshold = thr1_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif1_3 = snn.Leaky(beta = beta1_3,threshold = thr1_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv1 = PNAConv(in_channels=input_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=1, pre_layers=1, post_layers=1, divide_input=False)
        self.gru1 = GRUCell(input_dim, hidden_dim)
        self.batch_norm1 = BatchNorm(hidden_dim)

        # Layer 2
        beta2_1 = torch.rand(hidden_dim)
        beta2_2 = torch.rand(hidden_dim)
        beta2_3 = torch.rand(hidden_dim)
        thr2_1 = torch.rand(hidden_dim)
        thr2_2 = torch.rand(hidden_dim)
        thr2_3 = torch.rand(hidden_dim)
        self.lif2_1 = snn.Leaky(beta = beta2_1, threshold = thr2_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif2_2 = snn.Leaky(beta = beta2_2, threshold = thr2_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif2_3 = snn.Leaky(beta = beta2_3, threshold = thr2_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv2 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru2 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm2 = BatchNorm(hidden_dim)

        #Layer 3
        beta3_1 = torch.rand(hidden_dim)
        beta3_2 = torch.rand(hidden_dim)
        beta3_3 = torch.rand(hidden_dim)
        thr3_1 = torch.rand(hidden_dim)
        thr3_2 = torch.rand(hidden_dim)
        thr3_3 = torch.rand(hidden_dim)
        self.lif3_1 = snn.Leaky(beta = beta3_1, threshold = thr3_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif3_2 = snn.Leaky(beta = beta3_2, threshold = thr3_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif3_3 = snn.Leaky(beta = beta3_3, threshold = thr3_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv3 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru3 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm3 = BatchNorm(hidden_dim)

        #Layer 4
        beta4_1 = torch.rand(hidden_dim)
        beta4_2 = torch.rand(hidden_dim)
        beta4_3 = torch.rand(hidden_dim)
        thr4_1 = torch.rand(hidden_dim)
        thr4_2 = torch.rand(hidden_dim)
        thr4_3 = torch.rand(hidden_dim)
        self.lif4_1 = snn.Leaky(beta = beta4_1, threshold = thr4_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif4_2 = snn.Leaky(beta = beta4_2, threshold = thr4_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif4_3 = snn.Leaky(beta = beta4_3, threshold = thr4_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv4 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru4 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm4 = BatchNorm(hidden_dim)

        #Layer 5
        beta5_1 = torch.rand(hidden_dim)
        beta5_2 = torch.rand(hidden_dim)
        beta5_3 = torch.rand(hidden_dim)
        thr5_1 = torch.rand(hidden_dim)
        thr5_2 = torch.rand(hidden_dim)
        thr5_3 = torch.rand(hidden_dim)
        self.lif5_1 = snn.Leaky(beta = beta5_1, threshold = thr5_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif5_2 = snn.Leaky(beta = beta5_2, threshold = thr5_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif5_3 = snn.Leaky(beta = beta5_3, threshold = thr5_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv5 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru5 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm5 = BatchNorm(hidden_dim)

        #Layer 6
        beta6_1 = torch.rand(hidden_dim)
        beta6_2 = torch.rand(hidden_dim)
        beta6_3 = torch.rand(hidden_dim)
        thr6_1 = torch.rand(hidden_dim)
        thr6_2 = torch.rand(hidden_dim)
        thr6_3 = torch.rand(hidden_dim)
        self.lif6_1 = snn.Leaky(beta = beta6_1, threshold = thr6_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif6_2 = snn.Leaky(beta = beta6_2, threshold = thr6_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif6_3 = snn.Leaky(beta = beta6_3, threshold = thr6_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv6 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru6 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm6 = BatchNorm(hidden_dim)

        #Layer 7
        beta7_1 = torch.rand(hidden_dim)
        beta7_2 = torch.rand(hidden_dim)
        beta7_3 = torch.rand(hidden_dim)
        thr7_1 = torch.rand(hidden_dim)
        thr7_2 = torch.rand(hidden_dim)
        thr7_3 = torch.rand(hidden_dim)
        self.lif7_1 = snn.Leaky(beta = beta7_1, threshold = thr7_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif7_2 = snn.Leaky(beta = beta7_2, threshold = thr7_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif7_3 = snn.Leaky(beta = beta7_3, threshold = thr7_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv7 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru7 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm7 = BatchNorm(hidden_dim)

        #Layer 8
        beta8_1 = torch.rand(hidden_dim)
        beta8_2 = torch.rand(hidden_dim)
        beta8_3 = torch.rand(hidden_dim)
        thr8_1 = torch.rand(hidden_dim)
        thr8_2 = torch.rand(hidden_dim)
        thr8_3 = torch.rand(hidden_dim)
        self.lif8_1 = snn.Leaky(beta = beta8_1, threshold = thr8_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif8_2 = snn.Leaky(beta = beta8_2, threshold = thr8_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif8_3 = snn.Leaky(beta = beta8_3, threshold = thr8_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv8 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru8 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm8 = BatchNorm(hidden_dim)

        #Layer 9
        beta9_1 = torch.rand(hidden_dim)
        beta9_2 = torch.rand(hidden_dim)
        beta9_3 = torch.rand(hidden_dim)
        thr9_1 = torch.rand(hidden_dim)
        thr9_2 = torch.rand(hidden_dim)
        thr9_3 = torch.rand(hidden_dim)
        self.lif9_1 = snn.Leaky(beta = beta9_1, threshold = thr9_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif9_2 = snn.Leaky(beta = beta9_2, threshold = thr9_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif9_3 = snn.Leaky(beta = beta9_3, threshold = thr9_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv9 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru9 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm9 = BatchNorm(hidden_dim)

        #Layer 10
        beta10_1 = torch.rand(hidden_dim)
        beta10_2 = torch.rand(hidden_dim)
        beta10_3 = torch.rand(hidden_dim)
        thr10_1 = torch.rand(hidden_dim)
        thr10_2 = torch.rand(hidden_dim)
        thr10_3 = torch.rand(hidden_dim)
        self.lif10_1 = snn.Leaky(beta = beta10_1, threshold = thr10_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif10_2 = snn.Leaky(beta = beta10_2, threshold = thr10_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif10_3 = snn.Leaky(beta = beta10_3, threshold = thr10_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv10 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru10 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm10 = BatchNorm(hidden_dim)

        #Layer 11
        beta11_1 = torch.rand(hidden_dim)
        beta11_2 = torch.rand(hidden_dim)
        beta11_3 = torch.rand(hidden_dim)
        thr11_1 = torch.rand(hidden_dim)
        thr11_2 = torch.rand(hidden_dim)
        thr11_3 = torch.rand(hidden_dim)
        self.lif11_1 = snn.Leaky(beta = beta11_1, threshold = thr11_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif11_2 = snn.Leaky(beta = beta11_2, threshold = thr11_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif11_3 = snn.Leaky(beta = beta11_3, threshold = thr11_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv11 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru11 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm11 = BatchNorm(hidden_dim)

        #Layer 12
        beta12_1 = torch.rand(hidden_dim)
        beta12_2 = torch.rand(hidden_dim)
        beta12_3 = torch.rand(hidden_dim)
        thr12_1 = torch.rand(hidden_dim)
        thr12_2 = torch.rand(hidden_dim)
        thr12_3 = torch.rand(hidden_dim)
        self.lif12_1 = snn.Leaky(beta = beta12_1, threshold = thr12_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif12_2 = snn.Leaky(beta = beta12_2, threshold = thr12_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif12_3 = snn.Leaky(beta = beta12_3, threshold = thr12_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv12 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru12 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm12 = BatchNorm(hidden_dim)

        #Layer 13
        beta13_1 = torch.rand(hidden_dim)
        beta13_2 = torch.rand(hidden_dim)
        beta13_3 = torch.rand(hidden_dim)
        thr13_1 = torch.rand(hidden_dim)
        thr13_2 = torch.rand(hidden_dim)
        thr13_3 = torch.rand(hidden_dim)
        self.lif13_1 = snn.Leaky(beta = beta13_1, threshold = thr13_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif13_2 = snn.Leaky(beta = beta13_2, threshold = thr13_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif13_3 = snn.Leaky(beta = beta13_3, threshold = thr13_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv13 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru13 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm13 = BatchNorm(hidden_dim)

        #Layer 14
        beta14_1 = torch.rand(hidden_dim)
        beta14_2 = torch.rand(hidden_dim)
        beta14_3 = torch.rand(hidden_dim)
        thr14_1 = torch.rand(hidden_dim)
        thr14_2 = torch.rand(hidden_dim)
        thr14_3 = torch.rand(hidden_dim)
        self.lif14_1 = snn.Leaky(beta = beta14_1, threshold = thr14_1, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif14_2 = snn.Leaky(beta = beta14_2, threshold = thr14_2, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.lif14_3 = snn.Leaky(beta = beta14_3, threshold = thr14_3, learn_beta = True, learn_threshold=True, reset_mechanism='zero')
        self.conv14 = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg, towers=5, pre_layers=1, post_layers=1, divide_input=False)
        self.gru14 = GRUCell(hidden_dim, hidden_dim)
        self.batch_norm14 = BatchNorm(hidden_dim)

        self.readout = PNAConv(in_channels=hidden_dim, out_channels=1, aggregators=aggregators, scalers=scalers, deg=deg, towers=1, pre_layers=1, post_layers=1, divide_input=False)


    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        mem1_1 = self.lif1_1.init_leaky()
        mem1_2 = self.lif1_2.init_leaky()
        mem1_3 = self.lif1_3.init_leaky()
        
        mem2_1 = self.lif2_1.init_leaky()
        mem2_2 = self.lif2_2.init_leaky()
        mem2_3 = self.lif2_3.init_leaky()
        
        mem3_1 = self.lif3_1.init_leaky()
        mem3_2 = self.lif3_2.init_leaky()
        mem3_3 = self.lif3_3.init_leaky()
        
        mem4_1 = self.lif4_1.init_leaky()
        mem4_2 = self.lif4_2.init_leaky()
        mem4_3 = self.lif4_3.init_leaky()
        
        mem5_1 = self.lif5_1.init_leaky()
        mem5_2 = self.lif5_2.init_leaky()
        mem5_3 = self.lif5_3.init_leaky()
        
        mem6_1 = self.lif6_1.init_leaky()
        mem6_2 = self.lif6_2.init_leaky()
        mem6_3 = self.lif6_3.init_leaky()
        
        mem7_1 = self.lif7_1.init_leaky()
        mem7_2 = self.lif7_2.init_leaky()
        mem7_3 = self.lif7_3.init_leaky()
        
        mem8_1 = self.lif8_1.init_leaky()
        mem8_2 = self.lif8_2.init_leaky()
        mem8_3 = self.lif8_3.init_leaky()
        
        mem9_1 = self.lif9_1.init_leaky()
        mem9_2 = self.lif9_2.init_leaky()
        mem9_3 = self.lif9_3.init_leaky()
        
        mem10_1 = self.lif10_1.init_leaky()
        mem10_2 = self.lif10_2.init_leaky()
        mem10_3 = self.lif10_2.init_leaky()
        
        mem11_1 = self.lif11_1.init_leaky()
        mem11_2 = self.lif11_2.init_leaky()
        mem11_3 = self.lif11_3.init_leaky()
        
        mem12_1 = self.lif12_1.init_leaky()
        mem12_2 = self.lif12_2.init_leaky()
        mem12_3 = self.lif12_3.init_leaky()
        
        mem13_1 = self.lif13_1.init_leaky()
        mem13_2 = self.lif13_2.init_leaky()
        mem13_3 = self.lif13_3.init_leaky()
        
        mem14_1 = self.lif14_1.init_leaky()
        mem14_2 = self.lif14_2.init_leaky()
        mem14_3 = self.lif14_3.init_leaky()
        
        s1_sum = torch.zeros([14]).to(device)
        s2_sum = torch.zeros([14]).to(device)
        s3_sum = torch.zeros([14]).to(device)

        for i in range(timesteps):
            # For Layer 1---------------------------------
            y = self.conv1(x, edge_index)
            
            # spiking on y
            spk_in1, mem1_1 = self.lif1_1(y, mem1_1)
            s1_sum[0] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem1_2 = self.lif1_2(x, mem1_2)
            s2_sum[0] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru1(x, y)
            
            #spiking on relu
            z1 = self.batch_norm1(x)
            spk_in3, mem1_3 = self.lif1_3(z1, mem1_3)
            s3_sum[0] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z1
            
            # x = F.relu(self.batch_norm1(x))
            
            # For Layer 2----------------------------------
            y = self.conv2(x, edge_index)
            
            # spiking on y
            spk_in1, mem2_1 = self.lif2_1(y, mem2_1)
            s1_sum[1] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem2_2 = self.lif2_2(x, mem2_2)
            s2_sum[1] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru2(x, y)
            #spiking on relu
            z2 = self.batch_norm2(x)
            spk_in3, mem2_3 = self.lif2_3(z2, mem2_3)
            s3_sum[1] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z2

            # For Layer 3-----------------------------------
            y = self.conv3(x, edge_index)
            spk_in1, mem3_1 = self.lif3_1(y, mem3_1)
            s1_sum[2] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem3_2 = self.lif3_2(x, mem3_2)
            s2_sum[2] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru3(x, y)
            #spiking on relu
            z3 = self.batch_norm3(x)
            spk_in3, mem3_3 = self.lif3_3(z3, mem3_3)
            s3_sum[2] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z3

            # For Layer 4-----------------------------------
            y = self.conv4(x, edge_index)
            spk_in1, mem4_1 = self.lif4_1(y, mem4_1)
            s1_sum[3] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem4_2 = self.lif4_2(x, mem4_2)
            s2_sum[3] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru4(x, y)
            #spiking on relu
            z4 = self.batch_norm4(x)
            spk_in3, mem4_3 = self.lif4_3(z4, mem4_3)
            s3_sum[3] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z4

            # For Layer 5-----------------------------------
            y = self.conv5(x, edge_index)
            spk_in1, mem5_1 = self.lif5_1(y, mem5_1)
            s1_sum[4] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem5_2 = self.lif5_2(x, mem5_2)
            s2_sum[4] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru5(x, y)
            #spiking on relu
            z5 = self.batch_norm5(x)
            spk_in3, mem5_3 = self.lif5_3(z5, mem5_3)
            s3_sum[4] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z5

            # For Layer 6-----------------------------------
            y = self.conv6(x, edge_index)
            spk_in1, mem6_1 = self.lif6_1(y, mem6_1)
            s1_sum[5] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem6_2 = self.lif6_2(x, mem6_2)
            s2_sum[5] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru6(x, y)
            #spiking on relu
            z6 = self.batch_norm6(x)
            spk_in3, mem6_3 = self.lif6_3(z6, mem6_3)
            s3_sum[5] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z6
            

            # For Layer 7-----------------------------------
            y = self.conv7(x, edge_index)
            spk_in1, mem7_1 = self.lif7_1(y, mem7_1)
            s1_sum[6] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem7_2 = self.lif7_2(x, mem7_2)
            s2_sum[6] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru7(x, y)
            #spiking on relu
            z7 = self.batch_norm7(x)
            spk_in3, mem7_3 = self.lif7_3(z7, mem7_3)
            s3_sum[6] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z7

            # For Layer 8-----------------------------------
            y = self.conv8(x, edge_index)
            spk_in1, mem8_1 = self.lif8_1(y, mem8_1)
            s1_sum[7] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem8_2 = self.lif8_2(x, mem8_2)
            s2_sum[7] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru8(x, y)
            #spiking on relu
            z8 = self.batch_norm8(x)
            spk_in3, mem8_3 = self.lif8_3(z8, mem8_3)
            s3_sum[7] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z8

            # For Layer 9-----------------------------------
            y = self.conv9(x, edge_index)
            spk_in1, mem9_1 = self.lif9_1(y, mem9_1)
            s1_sum[8] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem9_2 = self.lif9_2(x, mem9_2)
            s2_sum[8] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
             
            x = self.gru9(x, y)
            #spiking on relu
            z9 = self.batch_norm9(x)
            spk_in3, mem9_3 = self.lif9_3(z9, mem9_3)
            s3_sum[8] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z9

            # For Layer 10----------------------------------
            y = self.conv10(x, edge_index)
            spk_in1, mem10_1 = self.lif10_1(y, mem10_1)
            s1_sum[9] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem10_2 = self.lif10_2(x, mem10_2)
            s2_sum[9] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru10(x, y)
            #spiking on relu
            z10 = self.batch_norm10(x)
            spk_in3, mem10_3 = self.lif10_3(z10, mem10_3)
            s3_sum[9] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z10

            # For Layer 11-----------------------------------
            y = self.conv11(x, edge_index)
            spk_in1, mem11_1 = self.lif11_1(y, mem11_1)
            s1_sum[10] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem11_2 = self.lif11_2(x, mem11_2)
            s2_sum[10] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru11(x, y)
            #spiking on relu
            z11 = self.batch_norm11(x)
            spk_in3, mem11_3 = self.lif11_3(z11, mem11_3)
            s3_sum[10] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z11

            # For Layer 12-----------------------------------
            y = self.conv12(x, edge_index)
            spk_in1, mem12_1 = self.lif12_1(y, mem12_1)
            s1_sum[11] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem12_2 = self.lif12_2(x, mem12_2)
            s2_sum[11] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru12(x, y)
            #spiking on relu
            z12 = self.batch_norm12(x)
            spk_in3, mem12_3 = self.lif12_3(z12, mem12_3)
            s3_sum[11] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z12

            # For Layer 13-----------------------------------
            y = self.conv13(x, edge_index)
            spk_in1, mem13_1 = self.lif13_1(y, mem13_1)
            s1_sum[12] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem13_2 = self.lif13_2(x, mem13_2)
            s2_sum[12] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru13(x, y)
            #spiking on relu
            z13 = self.batch_norm13(x)
            spk_in3, mem13_3 = self.lif13_3(z13, mem13_3)
            s3_sum[12] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z13

            # For Layer 14---------------------------------
            y = self.conv14(x, edge_index)
            spk_in1, mem14_1 = self.lif14_1(y, mem14_1)
            s1_sum[13] += torch.sum(spk_in1)/spk_in1.numel()
            y = spk_in1*y
            # spiking on x
            spk_in2, mem14_2 = self.lif14_2(x, mem14_2)
            s2_sum[13] += torch.sum(spk_in2)/spk_in2.numel()
            x = spk_in2*x
            
            x = self.gru14(x, y)
            #spiking on relu
            z14 = self.batch_norm14(x)
            spk_in3, mem14_3 = self.lif14_3(z14, mem14_3)
            s3_sum[13] += torch.sum(spk_in3)/spk_in3.numel()
            x = spk_in3*z14


        x = self.readout(x, edge_index)

        return x, s1_sum, s2_sum, s3_sum

In [None]:
def train(model, dataloader, optimizer, device):
    batch_loss = []
    model.train()

    for batch in dataloader:
        label = torch.cat([data.y for data in batch]).to(device)
        # pred = model(batch)       # commenting as code was changed
        pred_list=[]
        for data in batch:
          pred,s1_t, s2_t, s3_t = model(data.to(device))
          pred_list.append(pred)

        pred_batch = torch.cat(pred_list)

        loss = F.mse_loss(pred_batch.squeeze(), label.squeeze())

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_loss.append(loss.item())

    return np.mean(np.array(batch_loss)),s1_t,s2_t,s3_t

In [None]:
def validate(model, dataloader, device):
    val_loss = []
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            label = torch.cat([data.y for data in batch]).to(device)

            # pred = model(batch)   # changed
            pred_list=[]
            for data in batch:
              pred,s1_v, s2_v, s3_v = model(data.to(device))
              pred_list.append(pred)

            pred_batch = torch.cat(pred_list)

            loss = F.mse_loss(pred_batch.squeeze(), label.squeeze())
            val_loss.append(loss.item())
    return np.mean(np.array(val_loss)), s1_v, s2_v, s3_v

In [None]:
if __name__ == "__main__":
    # Read relevant data files
    f1 = open(data_path + "edge.txt", "r")
    f2 = open(data_path + "node_features.txt", "r")
    f3 = open(data_path + "node_labels_sxx.txt", "r")
    lines1 = f1.readlines()
    lines2 = f2.readlines()
    lines3 = f3.readlines()

    # Data preprocessing
    num_data = num_data
    data_list = []
    t0 = time.time()
    print("Number of data processed\ttime")
    ave = []
    for i in range(num_data):
        if i % 200 == 0:
            print(i, time.time() - t0)
        # print(lines1[i])
        node1 = [int(idx) for idx in lines1[2 * i].split()[1:]]
        node2 = [int(idx) for idx in lines1[2 * i + 1].split()[1:]]
        edge_index = torch.tensor([node1, node2], dtype=torch.long)
        if input_dim == 1:
          xs = [float(idx) for idx in lines2[i].split()[1:]]
          node_feature = [[xs[j]] for j in range(len(xs))]
        elif input_dim == 2:
          xs = [float(idx) for idx in lines2[2 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[2 * i + 1].split()[1:]]
          node_feature = [[xs[j], ys[j]] for j in range(len(xs))]
        elif input_dim == 3:
          xs = [float(idx) for idx in lines2[3 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[3 * i + 1].split()[1:]]
          zs = [float(idx) for idx in lines2[3 * i + 2].split()[1:]]
          node_feature = [[xs[j], ys[j], zs[j]] for j in range(len(xs))]
        elif input_dim == 4:
          xs = [float(idx) for idx in lines2[4 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[4 * i + 1].split()[1:]]
          zs = [float(idx) for idx in lines2[4 * i + 2].split()[1:]]
          ls = [float(idx) for idx in lines2[4 * i + 3].split()[1:]]
          node_feature = [[xs[j], ys[j], zs[j], ls[j]] for j in range(len(xs))]
        else:
           raise Exception("Sorry, not available input dimension")

        x = torch.tensor(node_feature, dtype=torch.float)

        node_label = [float(idx) * scale_factor for idx in lines3[i].split()[1:]]
        y = torch.tensor(node_label, dtype=torch.float)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

    mean_value = np.mean(np.array(ave))

    batch_size = batch_size
    train_loader = DataListLoader(torch.load("train_dataset.pt",weights_only=False))
    test_loader = DataListLoader(torch.load("test_dataset.pt",weights_only=False))
    val_loader = DataListLoader(torch.load("val_dataset.pt",weights_only=False))

    deg = torch.zeros(max_degree, dtype=torch.long)
    for data in Train_data:
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        deg += torch.bincount(d, minlength=deg.numel())

    device = "cuda:0"
    torch.cuda.empty_cache()
    model = PNANet().to(device)
    # model = DataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20, min_lr=-1e-5, verbose=True)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Model architecture:")
    print(model)
    print("The number of trainable parameters is:{}".format(params))


    path = './out/'
    # Training
    print("epoch", "train loss", "validation loss")

    val_loss_curve = []
    train_loss_curve = []

    for epoch in range(epoch):

        # Compute train your model on training data
        epoch_loss,s1_t,s2_t, s3_t = train(model, train_loader, optimizer,  device)

        # Validate your on validation data
        val_loss,s1_v,s2_v, s3_v = validate(model, val_loader, device)


        # Record train and loss performance
        train_loss_curve.append(epoch_loss)
        val_loss_curve.append(val_loss)

        # The learning rate scheduler record the validation loss
        scheduler.step(val_loss)

        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'loss':epoch_loss,

            },
            path + str(epoch+1) + ".pt")
        print(epoch, epoch_loss, val_loss)
        print("s1_t : ",s1_t)
        print("s2_t : ",s2_t)
        print("s3_t : ",s3_t)
        print("s1_v : ",s1_v)
        print("s2_v : ",s2_v)
        print("s3_v : ",s3_v)