In [1]:
from torch_geometric.loader import DataLoader
import torch_geometric
from torch.utils.data import DataLoader as TorchDataLoader
from Model import GraphCL
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.functional as F
from torch import optim
import time
import os
import pandas as pd
import networkx as nx

from dataloader import GraphDatasetPretrain, AddRWStructEncoding
from dataloader import drop_node_augment, edge_pert_augment, attr_mask_augment, subgraph_augment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
walk_length = 20
val_dataset = GraphDatasetPretrain(root='./data/', gt=gt, split='val', 
                                   graph_augment1=attr_mask_augment, graph_augment2=subgraph_augment, 
                                   aug_p=0.2, graph_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphDatasetPretrain(root='./data/', gt=gt, split='train',
                                     graph_augment1=attr_mask_augment, graph_augment2=subgraph_augment,
                                     aug_p=0.2, graph_transform=AddRWStructEncoding(walk_length))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size_train = 32
batch_size_test = 32
learning_rate = 0.001

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

In [12]:
graph_config = {}
graph_config['graph_model_name'] = 'gps'
graph_config['num_node_features'] = 300
graph_config['graph_hidden_channels'] = 64
graph_config['graph_layer'] = 3
graph_config['n_head'] = 4
graph_config['n_feedforward'] = 128
graph_config['input_dropout'] = 0.1
graph_config['dropout'] = 0.0
graph_config['attention_dropout'] = 0.25
graph_config['conv_type'] = 'Gated'
graph_config['walk_length'] = 20
graph_config['dim_se'] = 28

model = GraphCL(graph_config, 64, 128)
model.to(device)

CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

In [13]:
for i in range(nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        aug_batch1, aug_batch2 = batch
        aug_batch1 = aug_batch1.to(device)
        aug_batch2 = aug_batch2.to(device)

        x_1 = model(aug_batch1)
        x_2 = model(aug_batch2)
        
        current_loss = contrastive_loss(x_1, x_2)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        aug_batch1, aug_batch2 = batch
        aug_batch1 = aug_batch1.to(device)
        aug_batch2 = aug_batch2.to(device)

        with torch.no_grad():
            x_1 = model(aug_batch1)
            x_2 = model(aug_batch2)

            current_loss = contrastive_loss(x_1, x_2)   
            val_loss += current_loss.item()

    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improved saving checkpoint...')
        """save_path = os.path.join('./checkpoints', model_save_name+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))"""

-----EPOCH 1-----
Iteration: 50, Time: 26.7592 s, training loss: 2.7703
Iteration: 100, Time: 53.4157 s, training loss: 0.9629
Iteration: 150, Time: 79.4561 s, training loss: 0.6625
Iteration: 200, Time: 108.7342 s, training loss: 0.5122
Iteration: 250, Time: 135.5258 s, training loss: 0.4553
Iteration: 300, Time: 162.2332 s, training loss: 0.3549
Iteration: 350, Time: 189.7228 s, training loss: 0.3015
Iteration: 400, Time: 215.9638 s, training loss: 0.3520
Iteration: 450, Time: 246.3472 s, training loss: 0.2635
Iteration: 500, Time: 275.3385 s, training loss: 0.2887
Iteration: 550, Time: 303.9189 s, training loss: 0.2533
Iteration: 600, Time: 333.6415 s, training loss: 0.1809
Iteration: 650, Time: 363.1292 s, training loss: 0.2398
Iteration: 700, Time: 390.6797 s, training loss: 0.2583
Iteration: 750, Time: 419.9492 s, training loss: 0.2012
Iteration: 800, Time: 448.8914 s, training loss: 0.2089


  return torch._native_multi_head_attention(


-----EPOCH 1----- done.  Validation loss:  0.1722753430949524
validation loss improved saving checkpoint...
-----EPOCH 2-----
Iteration: 850, Time: 509.9113 s, training loss: 0.2164
Iteration: 900, Time: 537.5832 s, training loss: 0.1786
Iteration: 950, Time: 566.2843 s, training loss: 0.1811
Iteration: 1000, Time: 593.9891 s, training loss: 0.2036
Iteration: 1050, Time: 622.9134 s, training loss: 0.1738
Iteration: 1100, Time: 652.1620 s, training loss: 0.1581
Iteration: 1150, Time: 680.9639 s, training loss: 0.1978
Iteration: 1200, Time: 708.9790 s, training loss: 0.1712
Iteration: 1250, Time: 738.2047 s, training loss: 0.1596
Iteration: 1300, Time: 767.9549 s, training loss: 0.1694
Iteration: 1350, Time: 794.8194 s, training loss: 0.1891
Iteration: 1400, Time: 822.4557 s, training loss: 0.1443
Iteration: 1450, Time: 851.0774 s, training loss: 0.1437
Iteration: 1500, Time: 878.8630 s, training loss: 0.1892
Iteration: 1550, Time: 908.5150 s, training loss: 0.2465
Iteration: 1600, Time: