In [None]:
pip install torch_geometric

[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
from os import WIFCONTINUED
import numpy as np
import os.path as osp
import time
import sklearn
from sklearn.model_selection import train_test_split
import torch
import torch_geometric
from torch import nn
from torch_geometric.data import Data, DataLoader, DataListLoader
from torch_geometric.utils import degree
import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear, GRUCell
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import PNAConv, BatchNorm, global_mean_pool, DataParallel
import argparse


In [None]:
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('-b','--batch_size', default=2, type=int,
                    help='batch size')
parser.add_argument('-d','--data_path', default='/DATA/graphspiking/data/', type=str,
                    help='data path')
parser.add_argument('-i','--input_dim', default=2, type=int,
                    help='the dimension of coordinates (2D or 3D)')
parser.add_argument('-n','--num_data', default=2000, type=int,
                    help='the number of all data')
parser.add_argument('-l','--num_layer', default=14, type=int,
                    help='the number of PNAConv layers')
parser.add_argument('-v','--hidden_dim', default=50, type=int,
                    help='the hidden dimension of PNANet')
parser.add_argument('-m','--max_degree', default=4, type=int,
                    help='maximum degree of all nodes')
parser.add_argument('-e','--epoch', default=500, type=int,
                    help='number of epoch')
parser.add_argument('-s','--scale_factor', default=1e-6, type=float,
                    help='scale factor for node labels')
parser.add_argument('-f')
args = parser.parse_args()

if args.hidden_dim % 5 != 0:
    raise Exception("Sorry, not available hidden dimension, need to be multiple of 5")
if args.num_layer < 1:
    raise Exception("Sorry, the number of layer is not enough")

In [None]:
class PNANet(torch.nn.Module):
    def __init__(self):
        super(PNANet, self).__init__()


        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms  = ModuleList()
        self.grus = ModuleList()

        num_layer = args.num_layer
        input_dim = args.input_dim
        hidden_dim = args.hidden_dim

        for i in range(num_layer):
            if i == 0:
                conv = PNAConv(in_channels=input_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=1, pre_layers=1, post_layers=1, divide_input=False)
                self.convs.append(conv)
                self.grus.append(nn.GRUCell(input_dim, hidden_dim))
                self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            else:
                conv = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=5, pre_layers=1, post_layers=1, divide_input=False)
                self.convs.append(conv)
                self.grus.append(nn.GRUCell(hidden_dim, hidden_dim))
                self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.readout = PNAConv(in_channels=hidden_dim, out_channels=1, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=1, pre_layers=1, post_layers=1, divide_input=False)


    def forward(self, data):

        x, edge_index = data.x, data.edge_index
        for conv, gru, batch_norm in zip(self.convs, self.grus, self.batch_norms):
            y = conv(x, edge_index)
            x = gru(x, y)
            x = F.relu(batch_norm(x))
        x = self.readout(x, edge_index)

        return x

In [None]:
#Train function
def train(model, dataloader, optimizer, device):
    batch_loss = []
    model.train()

    for batch in dataloader:
        label = torch.cat([data.y for data in batch]).to(device)
        # pred = model(batch)       # commenting as code was changed
        pred_list=[]
        for data in batch:
          pred = model(data.to(device))
          pred_list.append(pred)

        pred_batch = torch.cat(pred_list)

        loss = F.mse_loss(pred_batch.squeeze(), label.squeeze())

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_loss.append(loss.item())

    return np.mean(np.array(batch_loss))

In [None]:
# Validation function
def validate(model, dataloader, device):
    val_loss = []
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            label = torch.cat([data.y for data in batch]).to(device)

            # pred = model(batch)   # changed
            pred_list=[]
            for data in batch:
              pred = model(data.to(device))
              pred_list.append(pred)

            pred_batch = torch.cat(pred_list)

            loss = F.mse_loss(pred_batch.squeeze(), label.squeeze())
            val_loss.append(loss.item())
    return np.mean(np.array(val_loss))

In [None]:
if __name__ == "__main__":
    # Read relevant data files
    f1 = open(args.data_path + "edge.txt", "r")
    f2 = open(args.data_path + "node_features.txt", "r")
    f3 = open(args.data_path + "node_labels_sxx.txt", "r")
    lines1 = f1.readlines()
    lines2 = f2.readlines()
    lines3 = f3.readlines()

    # Data preprocessing
    num_data = args.num_data
    data_list = []
    t0 = time.time()
    print("Number of data processed\ttime")
    ave = []
    for i in range(num_data):
        if i % 200 == 0:
            print(i, time.time() - t0)
        node1 = [int(idx) for idx in lines1[2 * i].split()[1:]]
        node2 = [int(idx) for idx in lines1[2 * i + 1].split()[1:]]
        edge_index = torch.tensor([node1, node2], dtype=torch.long)
        if args.input_dim == 1:
          xs = [float(idx) for idx in lines2[i].split()[1:]]
          node_feature = [[xs[j]] for j in range(len(xs))]
        elif args.input_dim == 2:
          xs = [float(idx) for idx in lines2[2 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[2 * i + 1].split()[1:]]
          node_feature = [[xs[j], ys[j]] for j in range(len(xs))]
        elif args.input_dim == 3:
          xs = [float(idx) for idx in lines2[3 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[3 * i + 1].split()[1:]]
          zs = [float(idx) for idx in lines2[3 * i + 2].split()[1:]]
          node_feature = [[xs[j], ys[j], zs[j]] for j in range(len(xs))]
        elif args.input_dim == 4:
          xs = [float(idx) for idx in lines2[4 * i].split()[1:]]
          ys = [float(idx) for idx in lines2[4 * i + 1].split()[1:]]
          zs = [float(idx) for idx in lines2[4 * i + 2].split()[1:]]
          ls = [float(idx) for idx in lines2[4 * i + 3].split()[1:]]
          node_feature = [[xs[j], ys[j], zs[j], ls[j]] for j in range(len(xs))]
        else:
           raise Exception("Sorry, not available input dimension")

        x = torch.tensor(node_feature, dtype=torch.float)
        node_label = [float(idx) * args.scale_factor for idx in lines3[i].split()[1:]]
        y = torch.tensor(node_label, dtype=torch.float)
        data = Data(x=x, edge_index=edge_index, y=y)
        # m = data.x
        # print(m)
        data_list.append(data)

    mean_value = np.mean(np.array(ave))
    Train_data, Test_data = train_test_split(data_list, test_size = 0.275, random_state=42)
    Train_data, Val_data = train_test_split(Train_data, test_size = 0.035, random_state=42)

    batch_size = args.batch_size
    train_loader = DataListLoader(Train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataListLoader(Test_data, batch_size=batch_size, shuffle=True)
    val_loader = DataListLoader(Val_data, batch_size=batch_size, shuffle=True)

    deg = torch.zeros(args.max_degree, dtype=torch.long)
    for data in Train_data:
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        deg += torch.bincount(d, minlength=deg.numel())

    device = "cuda:0"
    torch.cuda.empty_cache()
    model = PNANet().to(device)
    # model = DataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=20, min_lr=-1e-5, verbose=True)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Model architecture:")
    print(model)
    print("The number of trainable parameters is:{}".format(params))


    path = '/DATA/graphspiking/ckpt_orig_2/'
    # Training
    print("epoch", "train loss", "validation loss")

    val_loss_curve = []
    train_loss_curve = []

    for epoch in range(args.epoch):

        # Compute train your model on training data
        epoch_loss = train(model, train_loader, optimizer,  device=0)

        # Validate your on validation data
        val_loss = validate(model, val_loader, device=0)


        # Record train and loss performance
        train_loss_curve.append(epoch_loss)
        val_loss_curve.append(val_loss)

        # The learning rate scheduler record the validation loss
        scheduler.step(val_loss)

        if (epoch + 1) % 20 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict':model.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'loss':epoch_loss,

            },
            path + str(epoch+1) + ".pt")
        print(epoch, epoch_loss, val_loss)

Number of data processed	time
0 0.00010752677917480469
200 4.213025093078613
400 7.891521453857422
600 11.482861995697021
800 15.20004940032959
1000 18.805384159088135
1200 22.25380039215088
1400 25.793635606765747
1600 29.23866558074951
1800 32.68548798561096
Model architecture:
PNANet(
  (convs): ModuleList(
    (0): PNAConv(2, 50, towers=1, edge_dim=None)
    (1-13): 13 x PNAConv(50, 50, towers=5, edge_dim=None)
  )
  (batch_norms): ModuleList(
    (0-13): 14 x BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (grus): ModuleList(
    (0): GRUCell(2, 50)
    (1-13): 13 x GRUCell(50, 50)
  )
  (readout): PNAConv(50, 1, towers=1, edge_dim=None)
)
The number of trainable parameters is:1002563
epoch train loss validation loss




0 0.9755017411176647 1.4325430619716644
1 0.5270898751701627 0.728704474568367
2 0.31815855404894267 0.6462131260335445
3 0.31317959370623744 0.6801067620515824
5 0.27319307948063526 0.8169910082221031
6 0.27085092765944346 0.6197202572226524
7 0.25572684415882185 0.715461283326149
8 0.2509342587925494 0.9330663299560547
9 0.2466496991698763 0.6403337754309177
10 0.24081455906187849 0.7266723623871804
11 0.23832622642229712 0.6782606676220894
12 0.2343905958506678 0.5696723632514477
13 0.23796351459941695 0.7173037105798721
14 0.2294299913197756 0.5772407113015652
15 0.22585447035995979 0.6905628705024719
16 0.225772169978757 0.7081995479762554
17 0.2170141119456717 0.5864555698633194
18 0.21900726879813842 0.6493200007081031
19 0.2317973183068846 0.6260060124099255
20 0.21738554454009448 0.5727256074547767
21 0.21935982600918838 0.6336311806738376
22 0.21384887091682425 0.6120408488810063
23 0.21358979263742056 0.5995454309880733
24 0.2171960343101195 0.5739412993192673
25 0.211086363

In [None]:
torch.save(model,'/DATA/graphspiking/ckpt_orig_2/453.pt')

In [None]:
model=torch.load('/DATA/graphspiking/ckpt_orig_2/453.pt')

In [None]:
epoch_loss

0.1376843332635638

In [None]:
torch.save({
    'epoch': 453,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, '/DATA/graphspiking/ckpt_orig_2/453.pt')

In [None]:
# Load the model checkpoint at epoch 453
checkpoint = torch.load('/DATA/graphspiking/ckpt_orig_2/453.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch

In [None]:
  path = '/DATA/graphspiking/ckpt_orig_2/'

  print("epoch", "train loss", "validation loss")
  val_loss_curve = []
  train_loss_curve = []

  # print(epoch)

  for ep in range(start_epoch,500):

    # Compute train your model on training data
    epoch_loss = train(model, train_loader, optimizer, device)

    # Validate your on validation data
    val_loss = validate(model, val_loader, device)

    # Record train and loss performance
    train_loss_curve.append(epoch_loss)
    val_loss_curve.append(val_loss)

    # The learning rate scheduler record the validation loss
    scheduler.step(val_loss)

    if (ep + 1) % 20 == 0:
      torch.save({
          'epoch': epoch,
          'model_state_dict':model.state_dict(),
          'optimizer_state_dict':optimizer.state_dict(),
          'loss':epoch_loss,

      },
      path + str(ep+1) + ".pt")

    print(ep, epoch_loss, val_loss)

epoch train loss validation loss
454 0.137649029729655 0.5549473875015974
455 0.1377469482352691 0.5513487920165062
456 0.1376733301420297 0.5938590614497662
457 0.13769823523504393 0.5645996236801147
458 0.13777859871009632 0.5448531959950924
459 0.13762232600977378 0.5570385427027941
460 0.13760328427655621 0.5548129042983055
461 0.13763012084527873 0.5482439592480659
462 0.1378346759716182 0.5602750903367997
463 0.13774404184975927 0.5751591975986957
464 0.13766275910427794 0.5440929931402206
465 0.1376417409069836 0.5723896725475788
466 0.13768315743788012 0.5468115486204624
467 0.1376068274527123 0.5660721746087074
468 0.13761295611537727 0.5533799427747726
469 0.1377065071776243 0.5499404111504554
470 0.13767208398585873 0.5813687826693058
471 0.1375731175565826 0.5919675751030445
472 0.13770013745874166 0.5684518429636956
473 0.1376341019517609 0.6084320424497127
474 0.13771640879176889 0.5524491652846336
475 0.13758622834924608 0.5999023343622685
476 0.13763504669070245 0.56764

In [None]:
model_new = PNANet()

In [None]:
ckpt = torch.load('pretrained_porous_graphene.pt')
# model_new.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

#     pred_test = []
#     label_test = []
#     x_test = []
#     for batch in test_loader:
#        #node, edge, label, batch = data.x, data.edge_index, data.y, data.batch
#        #node = node.to(device)
#        #edge = edge.to(device)
#        #label = label.to(device)

#        # Test the model on each batch
#        with torch.no_grad():
#          pred = model(batch)
#        node = torch.cat([data.x for data in batch]).to(device)
#        label = torch.cat([data.y for data in batch]).to(device)
#        split_size = [data.x.size()[0] for data in batch]
#        pred_split = torch.split(pred, split_size)
#        label_split = torch.split(label, split_size)
#        x_split = torch.split(node, split_size, dim=0)
#        num_graphs = len(pred_split)
#        for i in range(num_graphs):
#           pred_test.append(pred_split[i].cpu().detach().numpy().squeeze().tolist())
#           label_test.append(label_split[i].cpu().detach().numpy().tolist())
#           x_test.append(x_split[i].cpu().detach().numpy().tolist())
#        torch.cuda.empty_cache()

#     write_data(pred_test, label_test, x_test)

In [None]:
model_new.load_state_dict(ckpt['model_state_dict'])

RuntimeError: Error(s) in loading state_dict for PNANet:
	Missing key(s) in state_dict: "convs.0.aggr_module.avg_deg_lin", "convs.0.aggr_module.avg_deg_log", "convs.1.aggr_module.avg_deg_lin", "convs.1.aggr_module.avg_deg_log", "convs.2.aggr_module.avg_deg_lin", "convs.2.aggr_module.avg_deg_log", "convs.3.aggr_module.avg_deg_lin", "convs.3.aggr_module.avg_deg_log", "convs.4.aggr_module.avg_deg_lin", "convs.4.aggr_module.avg_deg_log", "convs.5.aggr_module.avg_deg_lin", "convs.5.aggr_module.avg_deg_log", "convs.6.aggr_module.avg_deg_lin", "convs.6.aggr_module.avg_deg_log", "convs.7.aggr_module.avg_deg_lin", "convs.7.aggr_module.avg_deg_log", "convs.8.aggr_module.avg_deg_lin", "convs.8.aggr_module.avg_deg_log", "convs.9.aggr_module.avg_deg_lin", "convs.9.aggr_module.avg_deg_log", "convs.10.aggr_module.avg_deg_lin", "convs.10.aggr_module.avg_deg_log", "convs.11.aggr_module.avg_deg_lin", "convs.11.aggr_module.avg_deg_log", "convs.12.aggr_module.avg_deg_lin", "convs.12.aggr_module.avg_deg_log", "convs.13.aggr_module.avg_deg_lin", "convs.13.aggr_module.avg_deg_log", "readout.aggr_module.avg_deg_lin", "readout.aggr_module.avg_deg_log". 

In [None]:
class PNANet(torch.nn.Module):
    def __init__(self):
        super(PNANet, self).__init__()


        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms  = ModuleList()
        self.grus = ModuleList()

        num_layer = args.num_layer
        input_dim = args.input_dim
        hidden_dim = args.hidden_dim

        for i in range(num_layer):
            if i == 0:
                conv = PNAConv(in_channels=input_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=1, pre_layers=1, post_layers=1, divide_input=False)
                self.convs.append(conv)
                self.grus.append(GRUCell(input_dim, hidden_dim))
                self.batch_norms.append(BatchNorm(hidden_dim))
            else:
                conv = PNAConv(in_channels=hidden_dim, out_channels=hidden_dim, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=5, pre_layers=1, post_layers=1, divide_input=False)
                self.convs.append(conv)
                self.grus.append(GRUCell(hidden_dim, hidden_dim))
                self.batch_norms.append(BatchNorm(hidden_dim))

        self.readout = PNAConv(in_channels=hidden_dim, out_channels=1, aggregators=aggregators, scalers=scalers, deg=deg,
                          towers=1, pre_layers=1, post_layers=1, divide_input=False)


    def forward(self, data):

        x, edge_index = data.x, data.edge_index
        for conv, gru, batch_norm in zip(self.convs, self.grus, self.batch_norms):
            y = conv(x, edge_index)
            x = gru(x, y)
            x = F.relu(batch_norm(x))
        x = self.readout(x, edge_index)

        return x

In [None]:
model_new.state_dict

<bound method Module.state_dict of PNANet(
  (convs): ModuleList(
    (0): PNAConv(2, 50, towers=1, edge_dim=None)
    (1-13): 13 x PNAConv(50, 50, towers=5, edge_dim=None)
  )
  (batch_norms): ModuleList(
    (0-13): 14 x BatchNorm(50)
  )
  (grus): ModuleList(
    (0): GRUCell(2, 50)
    (1-13): 13 x GRUCell(50, 50)
  )
  (readout): PNAConv(50, 1, towers=1, edge_dim=None)
)>