In [1]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset, AddRWStructEncoding
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model import Model
import numpy as np
from transformers import AutoTokenizer
from transformers import get_scheduler
import torch
from torch import optim
import time
import os
import pandas as pd
import json

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
with open('config.json') as f:
    config = json.load(f)

with open('graph_config.json') as f:
    graph_config = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = config['model_name']
model_type = config['model_type']
nout = config['nout']
nhid = config['nhid']
nb_epochs = config['nb_epochs']
batch_size_train = config['batch_size_train']
batch_size_test = config['batch_size_test']
learning_rate = config['learning_rate']
load_graph_pretrained = config['load_graph_pretrained']

walk_length = graph_config['walk_length']

In [11]:
if model_type=='text':
    tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
    tokenizer = None
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, graph_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer, graph_transform=AddRWStructEncoding(walk_length))

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

In [12]:
CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

model = Model(model_name, nout, nhid, graph_config, load_graph_pretrained=load_graph_pretrained, model_type=model_type)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

lr_scheduler = get_scheduler('cosine', optimizer=optimizer, num_warmup_steps=500, num_training_steps=len(train_loader)*nb_epochs)

In [13]:
total_params = sum(p.numel() for p in model.parameters())
graph_params = sum(p.numel() for p in model.graph_encoder.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())

print(f'Total number of parameters: {total_params:,}')
print(f'    Graph encoder: {graph_params:,} parameters')
print(f'    Text encoder: {text_params:,} parameters')

Total number of parameters: 82,493,812
    Graph encoder: 375,412 parameters
    Text encoder: 82,118,400 parameters


### Train

In [14]:
g_m_n = graph_config['graph_model_name']
g_l = graph_config['graph_layers']
g_h_l = graph_config['graph_hidden_channels']
pretrained = ''
if load_graph_pretrained:
    pretrained = 'pretrained'

model_save_name = f'{model_type}_{model_name}__{g_m_n}_{g_l}_{g_h_l}_{graph_params//1000}m_{pretrained}__base_'
model_save_name

'text_distilroberta-base__gps_3_64_375m_pretrained__base_'

In [15]:
def train_one_epoch(model, train_loader, criterion, optimizer, losses, device, count_iter, printEvery, time1):
    loss = 0
    model.train()
    for batch in train_loader:
        if model_type == 'text':
            input_ids = batch.input_ids
            batch.pop('input_ids')
            attention_mask = batch.attention_mask
            batch.pop('attention_mask')
            graph_batch = batch
            
            x_graph, x_text = model(graph_batch.to(device), 
                                    input_ids=input_ids.to(device), 
                                    attention_mask=attention_mask.to(device))
        else:
            sentences = batch.text
            batch.pop('text')
            graph_batch = batch
            
            x_graph, x_text = model(graph_batch.to(device), 
                                    sentences=sentences.to(device))

        current_loss = criterion(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    return losses, count_iter


def eval(model, val_loader, criterion, device):
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        if model_type == 'text':
            input_ids = batch.input_ids
            batch.pop('input_ids')
            attention_mask = batch.attention_mask
            batch.pop('attention_mask')
            graph_batch = batch
        
            with torch.no_grad():
                x_graph, x_text = model(graph_batch.to(device), 
                                        input_ids=input_ids.to(device), 
                                        attention_mask=attention_mask.to(device))
                current_loss = criterion(x_graph, x_text)   
                val_loss += current_loss.item()

        else:
            sentences = batch.text
            batch.pop('text')
            graph_batch = batch

            with torch.no_grad():
                x_graph, x_text = model(graph_batch.to(device), 
                                        sentences = sentences.to(device))
                current_loss = criterion(x_graph, x_text)   
                val_loss += current_loss.item()

    return val_loss

In [16]:
epoch = 0

losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000


for i in range(nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    losses, count_iter = train_one_epoch(model, train_loader, contrastive_loss, optimizer, losses, device, count_iter, printEvery, time1)

    val_loss = eval(model, val_loader, contrastive_loss, device)
    
    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improved saving checkpoint...')
        save_path = os.path.join('./checkpoints', model_save_name+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': losses[-1],
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))

-----EPOCH 1-----
Iteration: 50, Time: 45.8356 s, training loss: 6.7651
Iteration: 100, Time: 91.0632 s, training loss: 6.5115
Iteration: 150, Time: 135.8789 s, training loss: 6.0522
Iteration: 200, Time: 180.8622 s, training loss: 5.0110
Iteration: 250, Time: 226.0309 s, training loss: 4.2482
Iteration: 300, Time: 271.1350 s, training loss: 3.7098
Iteration: 350, Time: 316.7400 s, training loss: 3.3752
Iteration: 400, Time: 361.6860 s, training loss: 3.0922
Iteration: 450, Time: 407.6856 s, training loss: 2.7404
Iteration: 500, Time: 453.9784 s, training loss: 2.5600
Iteration: 550, Time: 500.0239 s, training loss: 2.4964
Iteration: 600, Time: 546.3326 s, training loss: 2.4903
Iteration: 650, Time: 592.8481 s, training loss: 2.2617
Iteration: 700, Time: 639.2184 s, training loss: 2.1177
Iteration: 750, Time: 684.9796 s, training loss: 1.9738
Iteration: 800, Time: 729.2038 s, training loss: 1.9128
Iteration: 850, Time: 776.1986 s, training loss: 1.8294
Iteration: 900, Time: 821.1359 s,

KeyboardInterrupt: 

### Submission

In [4]:
#save_path = os.path.join('./checkpoints', 'model'+str(4)+'.pt')

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size_test, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size_test, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

loading best model...
