In [1]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset, AddRWStructEncoding
from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model import Model
import numpy as np
from transformers import AutoTokenizer
import gensim
from nltk import word_tokenize
import torch
import os
import pandas as pd
import json

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import label_ranking_average_precision_score

  from .autonotebook import tqdm as notebook_tqdm


# Model

In [2]:
with open('config.json') as f:
    config = json.load(f)

with open('graph_config.json') as f:
    graph_config = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = config['model_name']
model_type = config['model_type']
nout = config['nout']
nhid = config['nhid']
nb_epochs = config['nb_epochs']
batch_size_train = config['batch_size_train']
batch_size_test = config['batch_size_test']
learning_rate = config['learning_rate']
load_graph_pretrained = config['load_graph_pretrained']

walk_length = graph_config['walk_length']

if model_type=='text':
    tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
    tokenizer = None
if model_type=='w2v':
    model_w2v = gensim.models.KeyedVectors.load_word2vec_format(model_name + '.txt')
    w2v_embeddings = np.zeros((len(model_w2v.vectors)+1, model_w2v.vectors.shape[1]), dtype=np.float32)
    w2v_embeddings[1:] = model_w2v.vectors
    nltk_tokenizer = word_tokenize
    word2idx = model_w2v.key_to_index
else:
    nltk_tokenizer = None
    word2idx = None
    w2v_embeddings = None
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

In [3]:
model = Model(model_name, nout, nhid, graph_config, load_graph_pretrained=load_graph_pretrained, 
              model_type=model_type, w2v_embeddings=w2v_embeddings).to(device)

total_params = sum(p.numel() for p in model.parameters())
graph_params = sum(p.numel() for p in model.graph_encoder.parameters())
text_params = sum(p.numel() for p in model.text_encoder.parameters())

g_m_n = graph_config['graph_model_name']
g_l = graph_config['graph_layers']
g_h_l = graph_config['graph_hidden_channels']
pretrained = ''
if len(load_graph_pretrained)>0:
    pretrained = 'pretrained'

s_name = model_name.replace('/', '-')
model_save_name = f'{model_type}_{s_name}__{g_m_n}_{g_l}_{g_h_l}_{graph_params//1000}m_{pretrained}__base2_'
model_save_name

'text_sentence-transformers-all-distilroberta-v1__gps_10_64_764m___base2_'

# Evaluate model

In [4]:
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, 
                               nltk_tokenizer=nltk_tokenizer, word2idx=word2idx, 
                               graph_transform=AddRWStructEncoding(walk_length))

val_loader = DataLoader(val_dataset, batch_size=batch_size_test, shuffle=False)

In [6]:
save_path = os.path.join('./checkpoints', 'ep'+str(9)+model_save_name+'.pt')

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_embeddings = []
text_embeddings = []

for batch in val_loader:
    input_ids = batch.input_ids
    batch.pop('input_ids')
    attention_mask = batch.attention_mask
    batch.pop('attention_mask')
    graph_batch = batch
    with torch.no_grad():
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids=input_ids.to(device), 
                                attention_mask=attention_mask.to(device))
        
        for output in x_graph:
            graph_embeddings.append(output.tolist())
        for output in x_text:
            text_embeddings.append(output.tolist())

similarity = cosine_similarity(text_embeddings, graph_embeddings)
y_true = np.identity(len(val_dataset))
label_ranking_average_precision_score(y_true, similarity)

loading best model...


0.3653687012137659

# Submission

In [7]:
save_path = os.path.join('./checkpoints', 'ep'+str(9)+model_save_name+'.pt')

print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids', graph_transform=AddRWStructEncoding(walk_length))
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer, nltk_tokenizer=nltk_tokenizer, word2idx=word2idx)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size_test, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    with torch.no_grad():
        for output in graph_model(batch.to(device)):
            graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size_test, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    with torch.no_grad():
        for output in text_model(batch['input_ids'].to(device), 
                                attention_mask=batch['attention_mask'].to(device),
                                sentences=None):
            text_embeddings.append(output.tolist())


similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submissions/' + model_save_name + '_submissiontest.csv', index=False)

loading best model...


Processing...
  return torch.LongTensor(edge_index).T, torch.FloatTensor(x)
  return torch.LongTensor(edge_index).T, torch.FloatTensor(x)
Done!
