In [5]:
import torch
import numpy as np
import pandas as pd

import fasttext as ft
from transformers import BertTokenizer, BertModel

In [56]:
ft_model = ft.load_model('path/to/model/cc.en.300.bin')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



# Read data

In [9]:
TEXT_REL = pd.read_csv('../data/text_rel.csv')
rel_vectors = pd.read_csv('../data/rel_vectors_trex.csv')

rel_vectors = rel_vectors.drop(['Unnamed: 0'], axis=1)
keys = rel_vectors['index'].values.tolist()
values = rel_vectors.iloc[:,1:].astype(float).values.tolist()
corpus_rel_vectors = {k:v for k, v in zip(keys, values)}




# Define model and tokenizer

In [51]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model = bert_model.to(device)

# Define funcs

In [52]:
def get_bert_vector(text, rel=None):
    encoded_input = bert_tokenizer(text, return_tensors='pt').to(device)  
    output = bert_model(**encoded_input)
    logits = output[0].squeeze()[1:-1]
    
    default_return = torch.mean(logits, axis=0).detach().cpu().numpy()
    if not rel:
        return default_return
    
    encoded_rel = bert_tokenizer(rel, return_tensors='pt')

    tokens = bert_tokenizer.convert_ids_to_tokens(encoded_input['input_ids'].squeeze())[1:-1]
    
    rel_tokens = bert_tokenizer.convert_ids_to_tokens(encoded_rel['input_ids'].squeeze()[1:-1])
    indices = []
    for rel_tok in rel_tokens:
        try:
            r_t = tokens.index(rel_tok)
        except ValueError:
            try:
                r_t = tokens.index('##' + rel_tok)
            except:
                continue 
        indices.append(r_t)
    if len(indices):
        return torch.mean(logits[indices,:], axis=0).detach().cpu().numpy()

    return default_return

In [53]:
def vectorize_rel_via_corpus(relation):
    sub_df = TEXT_REL[TEXT_REL.rel_id == relation]
    
    corpus_vecs = []
    for row in sub_df.itertuples():
        try:
            vec = get_bert_vector(row.text, row.rel)
            corpus_vecs.append(vec)
        except:
            continue

    corpus_vecs = np.array(corpus_vecs)   
    corpus_vec = np.mean(corpus_vecs, axis=0)
    
    definition_vec = get_bert_vector(get_description(relation)) 
    fasttext_vec = ft_model.get_sentence_vector(get_title(relation))
    
    return np.concatenate((fasttext_vec, corpus_vec, definition_vec))

In [54]:
def vectorize_pred_rel(text, rel_pred):
    pred_vector = np.concatenate((ft_model.get_sentence_vector(rel_pred), get_bert_vector(text, rel_pred), get_bert_vector(text, rel_pred)))
    return pred_vector