In [1]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset, AddRWStructEncoding
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model import Model
import numpy as np
from transformers import AutoTokenizer
from transformers import get_scheduler
import gensim
from nltk import word_tokenize
import torch
from torch import optim
import time
import os
import pandas as pd
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open('config.json') as f:
    config = json.load(f)

with open('graph_config.json') as f:
    graph_config = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = config['model_name']
model_type = config['model_type']
pooling_type = config['pooling_type']
nout = config['nout']
nhid = config['nhid']
nb_epochs = config['nb_epochs']
batch_size_train = config['batch_size_train']
batch_size_test = config['batch_size_test']
learning_rate = config['learning_rate']
load_graph_pretrained = config['load_graph_pretrained']

walk_length = graph_config['walk_length']

In [3]:
if model_type=='text':
    tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
    tokenizer = None
if model_type=='w2v':
    model_w2v = gensim.models.KeyedVectors.load_word2vec_format(model_name + '.txt')
    w2v_embeddings = np.zeros((len(model_w2v.vectors)+1, model_w2v.vectors.shape[1]), dtype=np.float32)
    w2v_embeddings[1:] = model_w2v.vectors
    nltk_tokenizer = word_tokenize
    word2idx = model_w2v.key_to_index
else:
    nltk_tokenizer = None
    word2idx = None
    w2v_embeddings = None
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, 
                               nltk_tokenizer=nltk_tokenizer, word2idx=word2idx, 
                               graph_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer, 
                                 nltk_tokenizer=nltk_tokenizer, word2idx=word2idx, 
                                 graph_transform=AddRWStructEncoding(walk_length))

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

In [4]:
CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  batch_size = v1.shape[0]
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return ((CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels))/2)*(16/batch_size)

model = Model(model_name, nout, nhid, graph_config, load_graph_pretrained=load_graph_pretrained, 
              model_type=model_type, pooling_type=pooling_type, w2v_embeddings=w2v_embeddings)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=config['weight_decay'], 
                                eps=1e-08)

lr_scheduler = get_scheduler('cosine', optimizer=optimizer, num_warmup_steps=config['num_warmup_steps'], 
                             num_training_steps=len(train_loader)*nb_epochs*config['scheduler_steps_factor'])

In [5]:
total_params = sum(p.numel() for p in model.parameters())
graph_params = sum(p.numel() for p in model.graph_encoder.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())

print(f'Total number of parameters: {total_params:,}')
print(f'    Graph encoder: {graph_params:,} parameters')
print(f'    Text encoder: {text_params:,} parameters')

Total number of parameters: 82,882,932
    Graph encoder: 764,532 parameters
    Text encoder: 82,118,400 parameters


### Train

In [6]:
g_m_n = graph_config['graph_model_name']
g_l = graph_config['graph_layers']
g_h_l = graph_config['graph_hidden_channels']
pretrained = ''
if len(load_graph_pretrained)>0:
    pretrained = 'pretrained'

s_name = model_name.replace('/', '-')
model_save_name = f'{model_type}_{s_name}__{g_m_n}_{g_l}_{g_h_l}_{graph_params//1000}m_{pretrained}__base2_'
model_save_name

'text_sentence-transformers-all-distilroberta-v1__gps_10_64_764m___base2_'

In [7]:
def train_one_epoch(model, train_loader, criterion, optimizer, losses, device, count_iter, printEvery, time1):
    loss = 0
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
            
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids=input_ids.to(device), 
                                attention_mask=attention_mask.to(device))
            

        current_loss = criterion(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    return losses, count_iter


def eval(model, val_loader, criterion, device):
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        with torch.no_grad():
            x_graph, x_text = model(graph_batch.to(device), 
                                    input_ids=input_ids.to(device), 
                                    attention_mask=attention_mask.to(device))
            current_loss = criterion(x_graph, x_text)   
            val_loss += current_loss.item()
            

    return val_loss

In [8]:
# Load model
"""save_path = os.path.join('./checkpoints', 'ep'+str(8)+model_save_name+'.pt')

checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])"""

"save_path = os.path.join('./checkpoints', 'ep'+str(8)+model_save_name+'.pt')\n\ncheckpoint = torch.load(save_path)\nmodel.load_state_dict(checkpoint['model_state_dict'])\noptimizer.load_state_dict(checkpoint['optimizer_state_dict'])\nlr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])"

In [8]:
epoch = 0

losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000


for i in range(epoch, nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    losses, count_iter = train_one_epoch(model, train_loader, contrastive_loss, optimizer, losses, device, count_iter, printEvery, time1)

    val_loss = eval(model, val_loader, contrastive_loss, device)
    
    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    save_path = os.path.join('./checkpoints', 'ep' + str(i) + model_save_name+'.pt')
    torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
        'validation_accuracy': val_loss,
        'loss': losses[-1],
        }, save_path)
    print('checkpoint saved to: {}'.format(save_path))

-----EPOCH 1-----
Iteration: 50, Time: 50.1327 s, training loss: 2.0906
Iteration: 100, Time: 99.1638 s, training loss: 1.8220
Iteration: 150, Time: 148.1938 s, training loss: 1.4758
Iteration: 200, Time: 197.4885 s, training loss: 1.2807
Iteration: 250, Time: 246.2370 s, training loss: 1.1419
Iteration: 300, Time: 289.0047 s, training loss: 0.9977
Iteration: 350, Time: 330.6434 s, training loss: 0.8866
Iteration: 400, Time: 373.4468 s, training loss: 0.8167
Iteration: 450, Time: 415.7072 s, training loss: 0.8110
Iteration: 500, Time: 458.3068 s, training loss: 0.6704
Iteration: 550, Time: 500.8691 s, training loss: 0.7080
Iteration: 600, Time: 543.1417 s, training loss: 0.6862
Iteration: 650, Time: 586.4333 s, training loss: 0.5860
Iteration: 700, Time: 629.7994 s, training loss: 0.6038
Iteration: 750, Time: 672.4020 s, training loss: 0.5333
Iteration: 800, Time: 717.4621 s, training loss: 0.5464
Iteration: 850, Time: 765.1645 s, training loss: 0.5289
Iteration: 900, Time: 809.1361 s,