# BERT-based Query Expansion Playground

In [27]:
from transformers import BertForMaskedLM, BertTokenizer

The BERT large uncased, in its variant _whole word masking_, has been trained over BookCorpus and Wikipedia English with NSP - Next Sentenct Prediction - and MLM - Masked Language Modeling - objectives. Let's import it and its tokenizer:

In [28]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking', cache_dir = 'hf_cache')
unmasking_model = BertForMaskedLM.from_pretrained('bert-large-uncased-whole-word-masking', cache_dir = 'hf_cache')

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


The pipeline convenience object is created to interface with both the tokenizer and the model.

In [3]:
from transformers import FillMaskPipeline

unmasker = FillMaskPipeline(model = unmasking_model, tokenizer = tokenizer, tokenizer_kwargs = {"truncation": True})

With the following line we are getting all the candidates, to the masked word, proposed by BERT. Each substitute has a confidence level associated with the token.

In [13]:
from typing import List

def mask_token(tokens: List[str], idx: int) -> str:
    tokens = tokens[:]
    tokens[idx] = '[MASK]'
    return ' '.join(tokens)

original_sentence = 'modern shared room near Harvard.'
original_sentence_tokens = tokenizer.tokenize(original_sentence)

masked_sentence = mask_token(original_sentence_tokens, 2)
candidates = unmasker(masked_sentence, top_k = 50)

At this point, some words can be more suitable than others. We try to figure out the fitness level by reinserting the token into the sentence and by testing the similarity between the original sentence and the one with the mask replaced.

In [14]:
from transformers import BertModel
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')
encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')

def get_meaned_embeddings(sentence: str):
    tokens = tokenizer.tokenize(sentence)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    input_ids = torch.tensor(input_ids).unsqueeze(0)
    with torch.no_grad():
        outputs = encoder(input_ids)
        embedding = outputs.last_hidden_state[0]

    return embedding.mean(dim = 0)

cos_sim = torch.nn.CosineSimilarity(dim = 0)

from operator import itemgetter
import pydash

original_sentence_embedding = get_meaned_embeddings(original_sentence)


similarities = (
    pydash.chain(candidates)
        .map(itemgetter('sequence'))  # Get complete sentence
        .map(get_meaned_embeddings)  # Get context vectors
        .map(lambda x: cos_sim(x, original_sentence_embedding))  # Compute the similarity
        .value()
)

THRESHOLD = 0.8
SLICE = 5

expansions = (
    pydash.chain(candidates)
        .map(itemgetter('token_str'))
        .zip(similarities)
        .filter(lambda t: t[1] > THRESHOLD)
        .sort(key = itemgetter(1), reverse = True)
        .map(itemgetter(0))
        .take(SLICE)
        .value()
)

print(expansions)

['room', 'house', 'residence', 'land', 'hall']
