In [6]:
from torch_geometric.loader import DataLoader
from Model import GraphCL
import numpy as np

import torch
from torch import optim
import time
import os
import json

from dataloader import GraphDatasetPretrain, AddRWStructEncoding
from dataloader import drop_node_augment, edge_pert_augment, attr_mask_augment, subgraph_augment

from transformers import get_scheduler

In [7]:
with open('graph_config.json') as f:
    graph_config = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

walk_length = graph_config['walk_length']
nb_epochs = graph_config['pretraining_epochs']
batch_size = graph_config['pretraining_batch_size']
learning_rate = graph_config['pretraining_lr']
weight_decay = graph_config['pretraining_weight_decay']
pretraining_scheduler_steps_factor = graph_config['pretraining_scheduler_steps_factor']

In [8]:
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

val_dataset = GraphDatasetPretrain(root='./data/', gt=gt, split='val', 
                                   graph_augment1=attr_mask_augment, graph_augment2=subgraph_augment, 
                                   aug_p=0.2, graph_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphDatasetPretrain(root='./data/', gt=gt, split='train',
                                     graph_augment1=attr_mask_augment, graph_augment2=subgraph_augment,
                                     aug_p=0.2, graph_transform=AddRWStructEncoding(walk_length))

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [9]:
model = GraphCL(graph_config, 64, 128)
model.to(device)

CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return (CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels))/2

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=weight_decay, 
                                eps=1e-08)

lr_scheduler = get_scheduler('cosine', optimizer=optimizer, num_warmup_steps=250, 
                             num_training_steps=len(train_loader)*nb_epochs*pretraining_scheduler_steps_factor)

In [10]:
params = sum(p.numel() for p in model.parameters())
print(f'Number of parameters: {params:,}')

Number of parameters: 526,344


In [12]:
def train_one_epoch(model, train_loader, criterion, optimizer, losses, device, count_iter, printEvery, time1):
    loss = 0
    model.train()
    for batch in train_loader:
        aug_batch1, aug_batch2 = batch
        x_1 = model(aug_batch1.to(device))
        x_2 = model(aug_batch2.to(device))
        
        current_loss = criterion(x_1, x_2)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    return losses, count_iter


def eval(model, val_loader, criterion, device):
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        aug_batch1, aug_batch2 = batch
        with torch.no_grad():
            x_1 = model(aug_batch1.to(device))
            x_2 = model(aug_batch2.to(device))

            current_loss = criterion(x_1, x_2)   
            val_loss += current_loss.item()

    return val_loss

In [15]:
epoch = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

for i in range(nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    losses, count_iter = train_one_epoch(model, train_loader, contrastive_loss, optimizer, losses, device, count_iter, printEvery, time1)

    val_loss = eval(model, val_loader, contrastive_loss, device)

    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    save_name = 'ep' + str(i) + '_' + graph_config['graph_model_name'] + '_' + str(graph_config['graph_layers']) + '_' + \
        str(graph_config['graph_hidden_channels']) + '_' + graph_config['conv_type']
    save_path = os.path.join('./graph_checkpoints', save_name +'.pt')
    torch.save({
            'epoch': i,
            'model_state_dict': model.graph_base.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'validation_accuracy': val_loss,
            'loss': losses[-1],
            }, save_path)
    print('checkpoint saved to: {}'.format(save_path))

-----EPOCH 1-----
Iteration: 50, Time: 24.0918 s, training loss: 4.5393
Iteration: 100, Time: 44.8924 s, training loss: 1.9357
Iteration: 150, Time: 70.4103 s, training loss: 1.0075
Iteration: 200, Time: 96.9311 s, training loss: 0.8094
Iteration: 250, Time: 122.9633 s, training loss: 0.7841
Iteration: 300, Time: 149.2615 s, training loss: 0.6821
Iteration: 350, Time: 173.2303 s, training loss: 0.6124
Iteration: 400, Time: 195.5359 s, training loss: 0.4933
Iteration: 450, Time: 220.7424 s, training loss: 0.4206
Iteration: 500, Time: 251.1807 s, training loss: 0.4535
Iteration: 550, Time: 274.6487 s, training loss: 0.4817
Iteration: 600, Time: 301.2038 s, training loss: 0.4472
Iteration: 650, Time: 324.1943 s, training loss: 0.4142
Iteration: 700, Time: 348.5735 s, training loss: 0.3738
Iteration: 750, Time: 374.1029 s, training loss: 0.3215
Iteration: 800, Time: 400.0661 s, training loss: 0.3409
-----EPOCH 1----- done.  Validation loss:  0.31544835880503813
checkpoint saved to: ./graph

In [20]:
save_name = 'ep' + str(i) + '_' + graph_config['graph_model_name'] + '_' + str(graph_config['graph_layers']) + \
        '_' + str(graph_config['graph_hidden_channels']) + '_' + graph_config['conv_type'] + '_' + graph_config['agg_type']

save_path = os.path.join('./graph_checkpoints', save_name +'.pt')
torch.save({
        'epoch': i,
        'model_state_dict': model.graph_base.state_dict(),
        'validation_accuracy': val_loss,
        'loss': losses[-1],
        }, save_path)
print('checkpoint saved to: {}'.format(save_path))

checkpoint saved to: ./graph_checkpoints/9_gps_10_64_Gated.pt
