In [8]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset, AddRWStructEncoding
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model import Model
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd

In [37]:
CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

graph_model_name = 'base'

gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
walk_length = 20
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, graph_transform=AddRWStructEncoding(walk_length))
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer, graph_transform=AddRWStructEncoding(walk_length))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size_train = 16
batch_size_test = 16
learning_rate = 2e-5
graph_config = {}
graph_config['graph_layer'] = 3
graph_config['n_head'] = 4
graph_config['n_feedforward'] = 128
graph_config['input_dropout'] = 0.1
graph_config['dropout'] = 0.0
graph_config['attention_dropout'] = 0.25
graph_config['conv_type'] = 'Gated'
graph_config['walk_length'] = 20
graph_config['dim_se'] = 28

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

model = Model(model_name=model_name, graph_model_name=graph_model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300, graph_config=graph_config) # nout = bert model hidden dim
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

In [38]:
total_params = sum(p.numel() for p in model.parameters())
graph_params = sum(p.numel() for p in model.graph_encoder.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())

print(f'Total number of parameters: {total_params:,}')
print(f'    Graph encoder: {graph_params:,} parameters')
print(f'    Text encoder: {text_params:,} parameters')

Total number of parameters: 66,956,784
    Graph encoder: 593,904 parameters
    Text encoder: 66,362,880 parameters


### Train

In [39]:
graph_hidden_channels = 64
graph_layer = 10

model_save_name = f'{model_name}__{graph_model_name}_{graph_layer}_{graph_hidden_channels}_{graph_params//1000}m__base_'
model_save_name

'distilbert-base-uncased__base_10_64_593m__base_'

In [5]:
for i in range(nb_epochs):
    print('-----EPOCH {}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 

    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        with torch.no_grad():
            x_graph, x_text = model(graph_batch.to(device), 
                                    input_ids.to(device), 
                                    attention_mask.to(device))
            current_loss = contrastive_loss(x_graph, x_text)   
            val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improved saving checkpoint...')
        save_path = os.path.join('./checkpoints', model_save_name+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))

-----EPOCH 1-----
Iteration: 50, Time: 57.2708 s, training loss: 6.1487
Iteration: 100, Time: 86.3908 s, training loss: 4.9513
Iteration: 150, Time: 115.5573 s, training loss: 4.2516
Iteration: 200, Time: 145.7405 s, training loss: 3.7652
Iteration: 250, Time: 176.1655 s, training loss: 3.4676
Iteration: 300, Time: 205.8538 s, training loss: 2.9618
Iteration: 350, Time: 235.9578 s, training loss: 2.6791
Iteration: 400, Time: 265.8205 s, training loss: 2.5051
Iteration: 450, Time: 295.6485 s, training loss: 2.4506
Iteration: 500, Time: 325.2785 s, training loss: 2.2394
Iteration: 550, Time: 355.1813 s, training loss: 2.3337
Iteration: 600, Time: 385.0626 s, training loss: 1.9586
Iteration: 650, Time: 415.1148 s, training loss: 1.8303
Iteration: 700, Time: 445.0983 s, training loss: 1.8304
Iteration: 750, Time: 474.8970 s, training loss: 1.7427
Iteration: 800, Time: 504.6938 s, training loss: 1.6192
Iteration: 850, Time: 534.5232 s, training loss: 1.4412
Iteration: 900, Time: 564.4980 s,

  return torch._native_multi_head_attention(


-----EPOCH 1----- done.  Validation loss:  0.46229666404450903
validation loss improved saving checkpoint...
checkpoint saved to: ./checkpoints/distilbert-base-uncased__gps_10_64_563m__base_0.pt
-----EPOCH 2-----
Iteration: 1700, Time: 1100.5266 s, training loss: 0.8687
Iteration: 1750, Time: 1131.7850 s, training loss: 0.8996
Iteration: 1800, Time: 1161.5148 s, training loss: 0.8910
Iteration: 1850, Time: 1191.4225 s, training loss: 0.8263
Iteration: 1900, Time: 1221.3418 s, training loss: 0.8947
Iteration: 1950, Time: 1252.5566 s, training loss: 0.7691
Iteration: 2000, Time: 1282.4204 s, training loss: 0.7597
Iteration: 2050, Time: 1312.0023 s, training loss: 0.7370
Iteration: 2100, Time: 1348.1531 s, training loss: 0.7069
Iteration: 2150, Time: 1382.8558 s, training loss: 0.8399
Iteration: 2200, Time: 1414.8659 s, training loss: 0.8374
Iteration: 2250, Time: 1447.6664 s, training loss: 0.7106
Iteration: 2300, Time: 1480.3001 s, training loss: 0.7179
Iteration: 2350, Time: 1510.9931 

KeyboardInterrupt: 

### Submission

In [4]:
#save_path = os.path.join('./checkpoints', 'model'+str(4)+'.pt')

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size_test, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size_test, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

loading best model...


### Eval

In [43]:
model_save_name = 'model'

In [44]:
for i in range(nb_epochs):
    save_path = os.path.join('./checkpoints', model_save_name+str(i)+'.pt')
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        with torch.no_grad():
            x_graph, x_text = model(graph_batch.to(device), 
                                    input_ids.to(device), 
                                    attention_mask.to(device))
            current_loss = contrastive_loss(x_graph, x_text)   
            val_loss += current_loss.item()

    print('-----EPOCH '+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )

RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "graph_encoder.conv_layers.0.bias", "graph_encoder.conv_layers.0.lin.weight", "graph_encoder.conv_layers.1.bias", "graph_encoder.conv_layers.1.lin.weight", "graph_encoder.conv_layers.2.bias", "graph_encoder.conv_layers.2.lin.weight". 
	Unexpected key(s) in state_dict: "graph_encoder.conv1.bias", "graph_encoder.conv1.lin.weight", "graph_encoder.conv2.bias", "graph_encoder.conv2.lin.weight", "graph_encoder.conv3.bias", "graph_encoder.conv3.lin.weight". 

In [19]:
input_ids = batch.input_ids
batch.pop('input_ids')
attention_mask = batch.attention_mask
batch.pop('attention_mask')
graph_batch = batch
with torch.no_grad():
    x_graph, x_text = model(graph_batch.to(device), 
                                    input_ids.to(device), 
                                    attention_mask.to(device))
    print(x_graph.shape)
    print(x_text.shape)
    current_loss = contrastive_loss(x_graph, x_text)   
    val_loss += current_loss.item()

torch.Size([8, 768])
torch.Size([8, 768])
torch.Size([8, 8])
torch.Size([8])
