## WSD using BERT Masked Language Model
This notebook explores the a part of the idea proposed by Ajit Rakasekharan in his blog post 
[Examining BERT raw embeddings.](https://towardsdatascience.com/examining-berts-raw-embeddings-fd905cb22df7) 

The idea is that examining the predictions of a masked language model for a masked ambiguous word can yield insights into the semantic meaning of the ambiguous word.

We use the HuggingFace BERT for Masked LM with weights from a bert-base-cased pre-trained model for our experiment.

We mask the ambiguous word (here we have used bank for our test) in sentences, and then send them through a BERT MLM model. Output is an array of logits for each position of the input sequence. So assuming a sentence with T tokens and a vocabulary size of V, the predictions of the MLM is (1, T, V) where 1 is the batch size (1 input sentence at a time in our experiment).

In order to find the top k predictions, the logits for the masked position is softmaxed and the top k values chosen.



## Prepare your environment

As always, we highly recommend that you install all packages with a virtual environment manager, like [venv](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) or [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html), to prevent version conflicts of different packages.  

### Masked LM Model and Tokenizer 
[tutorial](https://huggingface.co/docs/transformers/tasks/language_modeling)  
Task is to predict words that are masked using BERT, so we will use BERTMaskedLM model and BERTTokenizer and use the pre-trained bert-base-uncased model.

In [237]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [238]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForMaskedLM.from_pretrained('bert-base-cased', return_dict=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


We are going to use the pre-trained BERT language model in inference mode only.

The tokenizer tokenizes the input sequence and pads it with the [CLS] and [SEP] tokens.

The output produced by the model has two components, loss and logits. The logits component has shape (1, number_of_tokens, vocab_size) where the leading 1 represents the single input sentence.

We will identify the logits corresponding to the position of our masked token, identify the top 5 vocabulary words predicted for that position, and return the softmax probabilities for each of the top 5 predicted words.

In [239]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
outputs = model(**inputs)

In [240]:
tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

['[CLS]', 'The', 'capital', 'of', 'France', 'is', '[MASK]', '.', '[SEP]']

In [241]:
outputs.logits.shape

torch.Size([1, 9, 28996])

In [242]:
outputs.logits

tensor([[[ -7.1545,  -6.9931,  -7.1826,  ...,  -5.9124,  -5.6733,  -5.9854],
         [ -8.0190,  -8.1319,  -8.0509,  ...,  -6.5679,  -6.4058,  -6.8998],
         [ -4.9772,  -6.1781,  -6.0669,  ...,  -5.6362,  -4.6603,  -5.1241],
         ...,
         [ -3.4420,  -3.2557,  -3.5733,  ...,  -2.4606,  -2.6495,  -3.1952],
         [-10.5890, -10.4621, -11.7181,  ...,  -7.4646,  -9.9543,  -8.3927],
         [-14.8900, -14.8873, -14.4569,  ..., -11.6588, -13.0151, -11.6073]]],
       grad_fn=<ViewBackward0>)

In [243]:
outputs.logits[0][0][100]

tensor(0.5711, grad_fn=<SelectBackward0>)

In [244]:
inputs.input_ids

tensor([[ 101, 1109, 2364, 1104, 1699, 1110,  103,  119,  102]])

In [245]:
def get_mask_index(input_ids, tokenizer):
    x = input_ids[0]
    is_masked = torch.where(x == tokenizer.mask_token_id, x, 0)
    mask_idx = torch.nonzero(is_masked)
    return mask_idx.item()

mask_idx = get_mask_index(inputs.input_ids, tokenizer)
mask_idx

6

In [246]:
def get_top_k_predictions(pred_logits, mask_idx, top_k):
    probs = torch.nn.functional.softmax(pred_logits[0, mask_idx, :], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    top_k_pct_weights = [100 * x.item() for x in top_k_weights]
    top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
    return list(zip(top_k_tokens, top_k_pct_weights))


get_top_k_predictions(outputs.logits, mask_idx, 5)

[('Paris', 44.46821212768555),
 ('Lyon', 9.396013617515564),
 ('Toulouse', 8.234520256519318),
 ('Lille', 7.515154033899307),
 ('Marseille', 5.692289397120476)]

### WSD Test Sentences
We take our pair of sentences for disambiguating the word bank and mask them, and extract the top 20 predictions from the pre-trained BERT MLM model.

As expected, the first set of predictions predominantly point to some sort of financial institution, whereas the second set of predictions predominantly point to some geographical formation around bodies of water.

In [247]:
sentences = [
  "Go to the [MASK] and deposit your pay check.",
  "Jim and Janet went down to the river [MASK] to admire the swans."
]

In [248]:
def get_predictions(sentence, tokenizer, model):
    inputs = tokenizer(sentence, return_tensors="pt")
    outputs = model(**inputs)
    mask_idx = get_mask_index(inputs.input_ids, tokenizer)
    top_preds = get_top_k_predictions(outputs.logits, mask_idx, 20)
    return top_preds

In [249]:
get_predictions(sentences[0], tokenizer, model)

[('bank', 70.31388878822327),
 ('office', 10.280612856149673),
 ('register', 1.745203323662281),
 ('store', 1.6284776851534843),
 ('bathroom', 0.9394790045917034),
 ('library', 0.893486849963665),
 ('desk', 0.8724372833967209),
 ('counter', 0.7977361790835857),
 ('hotel', 0.5163752939552069),
 ('lobby', 0.4956980235874653),
 ('kitchen', 0.3637084970250726),
 ('garage', 0.34799231216311455),
 ('door', 0.3412739373743534),
 ('car', 0.33113795798271894),
 ('house', 0.26490604504942894),
 ('airport', 0.25470347609370947),
 ('elevator', 0.24911430664360523),
 ('back', 0.24807732552289963),
 ('computer', 0.24019635748118162),
 ('banks', 0.23491408210247755)]

In [250]:
get_predictions(sentences[1], tokenizer, model)

[('##bank', 32.60203301906586),
 ('below', 13.032017648220062),
 ('bank', 11.940890550613403),
 (',', 5.626513808965683),
 ('##boat', 3.1638912856578827),
 ('##front', 2.733219973742962),
 ('basin', 1.6210557892918587),
 ('##bed', 1.2178429402410984),
 ('together', 1.1841758154332638),
 ('bed', 0.9657194837927818),
 ('again', 0.8369872346520424),
 ('deck', 0.8356154896318913),
 ('valley', 0.7271421141922474),
 ('mouth', 0.7227553520351648),
 ('boat', 0.7151072844862938),
 ('pier', 0.649328576400876),
 ('house', 0.6301597692072392),
 ('banks', 0.5700565874576569),
 ('pool', 0.5345731042325497),
 ('Thames', 0.4995553754270077)]

## Assignment
In this week's assignment, you are tasked with processing SemCor data and feed the data into BERT masked-LM. After that, use the predictions to find the most likely sense of the target word using WordNet similarity.

### Data Preprocessing 
You can find a sample of SemCor dataset [here](https://drive.google.com/file/d/1inmv3rUcGrtiS4VQwTMsT9HF-iL8jc5V/view?usp=sharing) and load the data using the following methods.

In [251]:
import json
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet as wn
sents = []
tokens = []
wn_id = []
lemmatizer = WordNetLemmatizer()

with open('semcor.sample.jsonl') as f:
    for line in f:
        data = json.loads(line)
        sents.append(data['sent'])
        tokens.append(data['tokens'])
        wn_id.append(data['wnid'])


In [252]:
import nltk
nltk.download('omw-1.4')

[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /home/nlplab/yhc/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


True

In [253]:
for index in range(len(sents)):
    if len(wn_id[index]) !=  len(tokens[index]): 
        print(index, len(wn_id[index]) , len(tokens[index]))

In [254]:
# The WordNet ID can be converted to NLTK Lemma using the following function
wn.lemma_from_key('implementation%1:04:01::')

Lemma('execution.n.06.implementation')

### TODO 
Please implement a method to convert the data to BERT Masked-LM format and keep track of the headword. Store the data into the following lists

word[i] = 'implementation'  
ground_truth[i] = 'implementation%1:04:01::'  
sent[i] = "[MASK] of georgia 's automobile title law was also recommended by the outgoing jury ."  



In [257]:
word = []
ground_truth = []
sent = []

for index in range(len(sents)):
    for index1 in range(len(tokens[index])):
        if wn_id[index][index1]!=0:
            #print(f'index: {index}, index1: {index1}')
            #sentence = sents[index]
            #tmp_tokens = sentence.split()
            #print(tmp_tokens[index1])
            ground_truth.append(wn_id[index][index1])
            word.append(tokens[index][index1])
            
            tmp_tokens =  tokens[index].copy()
            tmp_tokens[index1] = '[MASK]'
            modify_str = ' '.join(str(e) for e in tmp_tokens)
            sent.append(modify_str)
            

In [258]:
sent[200]

"the jury also commented on the fulton ordinary 's court which has been under fire for its [MASK] in the appointment of appraisers , guardians and administrators and the awarding of fees and compensation ."

In [259]:
ground_truth[0]

'say%2:32:00::'

In [260]:
print(len(sent))
print(len(word))
print(len(ground_truth))

1042
1042
1042


#### Identify the top 5 predictions other than the headword using Masked-LM 
1. Use get_predictions to get the predicted words
2. Use lemmatizer to lemmatize the prediction
3. Remove headword
4. Keep top 5 unique predictions

In [261]:
get_predictions(sent[4], tokenizer, model)

[('2010', 13.912227749824524),
 ('2008', 7.134079933166504),
 ('2012', 4.418175294995308),
 ('2006', 4.327632486820221),
 ('2009', 3.2670795917510986),
 ('2004', 3.139011934399605),
 ('2007', 3.1135501340031624),
 ('2011', 2.980724535882473),
 ('2005', 2.7421828359365463),
 ('2016', 2.4403834715485573),
 ('2002', 2.3336390033364296),
 ('2013', 2.2248294204473495),
 ('2014', 2.1104879677295685),
 ('first', 1.676163636147976),
 ('November', 1.5382764860987663),
 ('2003', 1.4332271181046963),
 ('May', 1.2353728525340557),
 ('Democratic', 1.1448323726654053),
 ('2000', 1.0835643857717514),
 ('Republican', 1.0767881758511066)]

In [263]:
import re 
print(re.search(r'\d', 'dadacc3sd') )

<re.Match object; span=(6, 7), match='3'>


In [265]:
candidate_lemmas = []

dump_punctuation = ['-', ';', '\'', '.', ',', '\\', ':']

for index in range(len(sent)):
    #print(sent[index])
    prediction = get_predictions(sent[index], tokenizer, model)
    
    individual_lemmas = []
    for index1 in range(len(prediction)):
        err = 0
        # ex: ('-', 1.1574707925319672) 取tuple前面第一個
        tmp_tokens = prediction[index1][0]
        if re.search(r'\d', tmp_tokens):
            continue
        for i in tmp_tokens:
            # checking whether the char is punctuation.
            if i in string.punctuation:
                err = 1
        if err == 1:
            continue
        if any(c.isdigit() for c in tmp_tokens):
            continue
            
        if tmp_tokens not in dump_punctuation:
            if lemmatizer.lemmatize(tmp_tokens) != lemmatizer.lemmatize(word[index]):
                individual_lemmas.append(lemmatizer.lemmatize(tmp_tokens))
            if len(individual_lemmas) == 5:
                break
    candidate_lemmas.append(individual_lemmas)
#get_predictions(sent[0], tokenizer, model)
#lemmatizer(get_predictions(sent[0], tokenizer, model))

# use nltk.wordnet.similarity 

In [266]:
# candidate_lemmas =[ [], [], [], ... ], i = 50th
# candidate_lemmas中一個list of list 對應到sent[i], i = 50th
candidate_lemmas

[['found', 'and', 'in', 'concluded', 'told'],
 ['that', 'after', 'in', 'during', 'of'],
 ['analysis', 'examination', 'audit', 'evaluation', 'inspection'],
 ['California', 'Virginia', 'Florida', 'Arizona', 'Alabama'],
 ['first', 'November', 'May', 'Democratic', 'Republican'],
 ['municipal', 'general', 'recall', 'presidential', 'gubernatorial'],
 ['found', 'showed', 'uncovered', 'revealed', 'discovered'],
 ['indication', 'proof', 'sign', 'information', 'confirmation'],
 ['violence', 'election', 'voting', 'fraud', 'election'],
 ['take', 'taking', 'in', 'taken', 'did'],
 ['mayor', 'governor', 'president', 'council', 'city'],
 ['stated', 'declared', 'noted', 'concluded', 'reported'],
 ['the', 'it', 'their', 'his', 'her'],
 ['of', 'three', 'two', 'length', 'court'],
 ['celebration', 'term', 'celebration', 'remark', 'speech'],
 ['took', 'ha', 'take', 'hold', 'held'],
 ['knowledge', 'control', 'coverage', 'approval', 'account'],
 ['contest', 'matter', 'event', 'vote', 'city'],
 ['received', 'd

example:  
candidate_lemmas = ['office', 'register', 'store', 'bathroom', 'library']


Identify the most similar sense of headword with relation to the 5 unique candidates

In [88]:
print(word[0])
print(ground_truth[0])
print(wn.lemma_from_key(ground_truth[0]) )
# Lemma('state.v.01.say')
# lemma's key = say%2:32:00::

said
say%2:32:00::
Lemma('state.v.01.say')


In [60]:
wn.synsets(word[0])

[Synset('state.v.01'),
 Synset('allege.v.01'),
 Synset('suppose.v.01'),
 Synset('read.v.02'),
 Synset('order.v.01'),
 Synset('pronounce.v.01'),
 Synset('say.v.07'),
 Synset('say.v.08'),
 Synset('say.v.09'),
 Synset('say.v.10'),
 Synset('say.v.11'),
 Synset('aforesaid.s.01')]

In [87]:
print(f'word:  {word[0]}, ground_truth:  {ground_truth[0]}')
# print(wn.synsets(word[0]))

print(f'ground_truth to lemma:  {wn.lemma_from_key(ground_truth[0])}')
print(f'ground_truth to sensenet:  {wn.lemma_from_key(ground_truth[0]).synset()}')

print(f'ground_truth def:  {wn.lemma_from_key(ground_truth[0]).synset().definition() }')
wn.synset('found.n.01').definition()
tt = wn.synsets(word[0])
# print(wn.synset(wn.synsets(word[0])).definition())

word:  said, ground_truth:  say%2:32:00::
ground_truth to lemma:  Lemma('state.v.01.say')
ground_truth to sensenet:  Synset('state.v.01')
ground_truth def:  express in words


In [64]:
for x in wn.synsets(word[0]):
    print(x.lemmas()[0].key())
    print(wn.lemma_from_key(x.lemmas()[0].key()) )

state%2:32:00::
Lemma('state.v.01.state')
allege%2:32:00::
Lemma('allege.v.01.allege')
suppose%2:32:00::
Lemma('suppose.v.01.suppose')
read%2:42:00::
Lemma('read.v.02.read')
order%2:32:01::
Lemma('order.v.01.order')
pronounce%2:32:01::
Lemma('pronounce.v.01.pronounce')
say%2:32:07::
Lemma('say.v.07.say')
say%2:32:15::
Lemma('say.v.08.say')
say%2:32:13::
Lemma('say.v.09.say')
say%2:32:08::
Lemma('say.v.10.say')
say%2:32:09::
Lemma('say.v.11.say')
aforesaid%5:00:00:same:02
Lemma('aforesaid.s.01.aforesaid')


In [34]:
wn.synset('found.n.01').definition()

'food and lodging provided in addition to money'

In [39]:
a = wn.lemma_from_key('say%2:32:00::')
print(a)
print(wn.synset('say.v.01').definition())

Lemma('state.v.01.say')
express in words


In [225]:
len(candidate_lemmas)

1042

In [226]:
correct = 0
ppredict = []
ground_err = 0

for index in range(len(candidate_lemmas)):
# for index in range(10):
    max_score = 0
    most_sim_word = None
    sub_can = candidate_lemmas[index]
    try:
        #對同一個word內所有synset做score紀錄
        mean_score_word = []
        for ii in range(len(wn.synsets(word[index]))):
            accu_score = 0
        # ground_truth_lemma = wn.lemma_from_key(ground_truth[index])
        # ground_truth_synset = ground_truth_lemma.synset()
        
            for i in range(len(candidate_lemmas[index])):
                w = wn.synsets(candidate_lemmas[index][i])
                
                if w:
                    for index1 in range(len(w)):
                        sim_score = wn.path_similarity(w[index1], wn.synsets(word[index])[ii])
                        accu_score += sim_score
                        # if sim_score > max_score:
                        #     most_sim_word = w[index1]
                        #     max_score = sim_score
                        #     ppredict_cur = most_sim_word
            mean_score_word.append(accu_score)
            
        # if not most_sim_word:
        #     print(index, candidate_lemmas[index])
        #     print('Bomb!!!')
        #     ppredict_cur = wn.synsets(word[index])[0]
        #     ppredict_cur = None
        for k in range(len(mean_score_word)):
            if mean_score_word[k] >= max_score:
                max_score = mean_score_word[k]
                max_index = k
                most_sim_word = wn.synsets(word[index])[k]
        ppredict.append(most_sim_word)
    except Exception as e:
        print("捕捉錯誤資訊: "+ str(e))
    except:
        ground_err +=1
        print('ground_truth 轉不過去')
        ppredict.append(None)
        ppredict_cur = None
    

In [229]:
for x in range(len(word)):
    if not wn.synsets(word[x]):
        print(x, word[x], wn.synsets(word[x]))

247 jan. []
314 aug. []
365 jan. []
369 sept. []
404 sept. []
463 per []
581 and []
975 aug. []


In [230]:
for x in range(len(ppredict)):
    if not ppredict[x] and wn.synsets(word[x]):
        ppredict[x] =  wn.synsets(word[x])[0]

In [231]:
ground_err =0
correct = 0
total = len(ppredict)
for index in range(len(ppredict)):
    try:
        ground_truth_lemma = wn.lemma_from_key(ground_truth[index])
        ground_truth_synset = ground_truth_lemma.synset()
        
        if ppredict[index] == ground_truth_synset:
             correct +=1
    except Exception as e:
        # print("捕捉錯誤資訊: "+ str(e))
        ground_err +=1

In [232]:
print(f'correct: {correct}, total: {total}')
correct/total

correct: 305, total: 1042


0.2927063339731286

In [234]:
candidate_lemmas

[['found', 'and', 'in', 'concluded', 'told'],
 ['that', 'after', 'in', 'during', 'of'],
 ['analysis', 'examination', 'audit', 'evaluation', 'inspection'],
 ['California', 'Virginia', 'Florida', 'Arizona', 'Alabama'],
 ['first', 'November', 'May', 'Democratic', 'Republican'],
 ['municipal', 'general', 'recall', 'presidential', 'gubernatorial'],
 ['found', 'showed', 'uncovered', 'revealed', 'discovered'],
 ['indication', 'proof', 'sign', 'information', 'confirmation'],
 ['violence', 'election', 'voting', 'fraud', 'election'],
 ['take', 'taking', 'in', 'taken', 'did'],
 ['mayor', 'governor', 'president', 'council', 'city'],
 ['stated', 'declared', 'noted', 'concluded', 'reported'],
 ['the', 'it', 'their', 'his', 'her'],
 ['of', 'three', 'two', 'length', 'court'],
 ['celebration', 'term', 'celebration', 'remark', 'speech'],
 ['took', 'ha', 'take', 'hold', 'held'],
 ['knowledge', 'control', 'coverage', 'approval', 'account'],
 ['contest', 'matter', 'event', 'vote', 'city'],
 ['received', 'd

In [168]:
for i in range(len(ppredict)):
    try:
        ground_truth_lemma = wn.lemma_from_key(ground_truth[i])
        ground_truth_synset = ground_truth_lemma.synset()
        print(ppredict[i].definition(),'|||', ground_truth_synset.definition(), '\n')
    except:
        # ground_err +=1
        print('ground_truth 轉不過去')
        

being the one previously mentioned or spoken of ||| express in words 

the sixth day of the week; the fifth working day ||| the sixth day of the week; the fifth working day 

the work of inquiring into something thoroughly and systematically ||| an inquiry into unfamiliar or questionable activities 

a siege in which Federal troops under Sherman cut off the railroads supplying the city and then burned it; 1864 ||| state capital and largest city of Georgia; chief commercial center of the southeastern United States; was plundered and burned by Sherman's army during the American Civil War 

ground_truth 轉不過去
of primary importance ||| a preliminary election where delegates or nominees are chosen 

come to have or undergo a change of (physical features and attributes) ||| bring out for display 

give evidence ||| your basis for belief or disbelief; knowledge on which to base belief 

irregular and infrequent or difficult evacuation of the bowels; can be a symptom of intestinal obstruction o

In [113]:

print(sent[3])
print(word[3])
print(ground_truth[3])
print(candidate_lemmas[3])

the fulton county grand jury said friday an investigation of [MASK] 's recent primary election produced " no evidence " that any irregularities took place .
atlanta
atlanta%1:15:00::
['California', 'Virginia', 'Florida', 'Arizona', 'Alabama']


In [109]:
ground_truth[4]

'recent%3:00:00:past:00'

For evaluation purpose, for i = 50, please run the process and print out the following:  
1. word[50]
2. ground_truth[50] (in synset or lemma)
3. sent[50]
4. candidate_lemmas
5. predicted_sense (in synset or lemma)    

Also, please print out the accuracy of the process over our dataset

In [90]:
hit = wn.synset('hit.v.01')
slap = wn.synset('slap.v.01')
print(type(hit))
wn.path_similarity(hit, slap)

<class 'nltk.corpus.reader.wordnet.Synset'>


0.14285714285714285

In [281]:
print(f'correct: {correct}, total: {total}')

score = correct/(total-ground_err)
print(f'acc: {score}')

correct: 305, total: 1042
acc: 0.31671858774662515


## TA's Note

Congratulations, you made it to the end of the tutorial! Make sure you make an appointment to show your work and turn in your finished assignment before next week's lesson. We will ask you to run your code, so double check that everything is working and that your model is saved. Don't worry if you didn't pass the evaluation requirements, you'll still get partial points for trying.

In [277]:
# word[50]
# ground_truth[50] (in synset or lemma)
# sent[50]
# candidate_lemmas
# predicted_sense (in synset or lemma)    

def eva(i):
    print(f'+ word[{i}]: {word[i]}\n')
    print(f'+ ground_truth[{i}]: {wn.lemma_from_key(ground_truth[i]).synset()}\n')
    print(f'+ sent[{i}]: {sent[i]}\n')
    print(f'+ candidate_lemmas[{i}]: {candidate_lemmas[i]}\n')
    print(f'+ predicted_sense[{i}]: {ppredict[i]}\n')

In [278]:
eva(50)

+ word[50]: size

+ ground_truth[50]: Synset('size.n.01')

+ sent[50]: " only a relative handful of such reports was received " , the jury said , " considering the widespread interest in the election , the number of voters and the [MASK] of this city " .

+ candidate_lemmas[50]: ['population', 'reputation', 'status', 'character', 'history']

+ predicted_sense[50]: Synset('size.n.02')

