In [None]:
import faiss
import numpy as np
import torch

In [None]:
from transformers import BertTokenizer
tz = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
#load model weights
method='max_vocab'
model_sd = torch.load('pretrained_berts_{}/mf+mlm/best/model.pt'.format(method), map_location='cpu')
word_emb = model_sd['sd']['bert.embeddings.word_embeddings.weight'].cpu().numpy()


#load facet's vector
view_1=np.load('embeds/{}/view_1.npy'.format(method))
view_2=np.load('embeds/{}/view_2.npy'.format(method))
view_3=np.load('embeds/{}/view_3.npy'.format(method))

#load raw text (validation set)
import pickle
with open('embeds/raw_val.pkl', 'rb') as handle:
    raw_text = pickle.load(handle)

total_examples=len(raw_text)

### For each facet, find the nearset token among all the bert vocab

In [None]:
#Use faiss lib to do the nearset neighbor search
#see here (https://github.com/facebookresearch/faiss/issues/95#issuecomment-714562162)

#build required index 
index = faiss.index_factory(768, "Flat", faiss.METRIC_INNER_PRODUCT)
faiss.normalize_L2(word_emb)
index.add(word_emb)

faiss.normalize_L2(view_1)
faiss.normalize_L2(view_2)
faiss.normalize_L2(view_3)

In [None]:
def get_neighbor(query,text_id,n):
    q = np.expand_dims(query[text_id],axis=0)
    '''
    D: distance
    I: neighbor index, here is token index
    '''
    D, I =index.search(q, n)
    words=[]
    for i in I[0]:
        #token index -> word
        words+=tz.convert_ids_to_tokens([i])
    
    return words

In [None]:
#Find the top-n nearset neighbors for each facet

import random
n=5
choose_id = random.randint(0,total_examples-1)
print('Query:', raw_text[choose_id])
print('\n')

print('Facet 1')
print(get_neighbor(view_1,choose_id,n))
print('\n')

print('Facet 2')
print(get_neighbor(view_2,choose_id,n))
print('\n')

print('Facet 3')
print(get_neighbor(view_3,choose_id,n))
print('\n')