In [1]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset, AddRWStructEncoding
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model import Model
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

graph_model_name = 'gin'

gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
walk_length = 20
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, pre_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer, pre_transform=AddRWStructEncoding(walk_length))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size_train = 16
batch_size_test = 16
learning_rate = 2e-5

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

model = Model(model_name=model_name, graph_model_name=graph_model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300, graph_layers=3) # nout = bert model hidden dim
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

In [12]:
total_params = sum(p.numel() for p in model.parameters())
graph_params = sum(p.numel() for p in model.graph_encoder.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())

print(f'Total number of parameters: {total_params:,}')
print(f'    Graph encoder: {graph_params:,} parameters')
print(f'    Text encoder: {text_params:,} parameters')

Total number of parameters: 67,226,148
    Graph encoder: 863,268 parameters
    Text encoder: 66,362,880 parameters


### Train

In [13]:
model_save_name = f'{model_name}__{graph_model_name}_{3}_{300}_{graph_params//1000}m__base_'
model_save_name

'distilbert-base-uncased__gin_3_300_863m__base_'

In [14]:
for i in range(nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improved saving checkpoint...')
        save_path = os.path.join('./checkpoints', model_save_name+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))

-----EPOCH 1-----
Iteration: 50, Time: 34.3453 s, training loss: 5.2524
Iteration: 100, Time: 60.0932 s, training loss: 4.2078
Iteration: 150, Time: 86.1017 s, training loss: 3.7228
Iteration: 200, Time: 112.5527 s, training loss: 3.3896
Iteration: 250, Time: 138.4410 s, training loss: 3.0198
Iteration: 300, Time: 164.3894 s, training loss: 2.8887
Iteration: 350, Time: 190.8463 s, training loss: 2.6419
Iteration: 400, Time: 217.2572 s, training loss: 2.4117
Iteration: 450, Time: 243.2813 s, training loss: 2.2711
Iteration: 500, Time: 269.5218 s, training loss: 2.1473
Iteration: 550, Time: 295.9180 s, training loss: 2.0163
Iteration: 600, Time: 322.2865 s, training loss: 1.8776
Iteration: 650, Time: 348.5860 s, training loss: 1.8327
Iteration: 700, Time: 375.1021 s, training loss: 1.8261
Iteration: 750, Time: 401.5453 s, training loss: 1.7298
Iteration: 800, Time: 427.9165 s, training loss: 1.6739
Iteration: 850, Time: 454.1497 s, training loss: 1.5123
Iteration: 900, Time: 480.2525 s, 

### Submission

In [4]:
#save_path = os.path.join('./checkpoints', 'model'+str(4)+'.pt')

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size_test, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size_test, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

loading best model...


### Eval

In [5]:
for i in range(nb_epochs):
    save_path = os.path.join('./checkpoints', 'model'+str(i)+'.pt')
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        val_loss += current_loss.item()

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )

-----EPOCH 1----- done.  Validation loss:  0.9059125973262648
-----EPOCH 2----- done.  Validation loss:  0.5076332101664112
-----EPOCH 3----- done.  Validation loss:  0.3639905517425946
-----EPOCH 4----- done.  Validation loss:  0.32216463984858584
-----EPOCH 5----- done.  Validation loss:  0.29098984794151306
