<a href="https://colab.research.google.com/github/gaurav8901/MCQ-Generation-Project-1/blob/main/Sentence2MCQ_using_BERT_Word_Sense_Disambiguation_and_T5_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Author: **Ramsri Goutham Golla**  [Linkedin](https://www.linkedin.com/in/ramsrig/)   [Twitter](https://twitter.com/ramsri_goutham/)





BERT Word Sense Disambiguation is adapted from the awesome repo here. [BERT WSD](https://github.com/BPYap/BERT-WSD) 

## Installation and mount Google Drive

In [None]:
!pip install --quiet transformers==2.9.0
!pip install --quiet nltk==3.4.5

In [None]:
# connect your personal google drive to store the trained model
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## Generate distractors (wrong choices) for MCQ options

In [None]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

sentence1 = "Srivatsan loves to watch cricket during his free time"
sentence2 = "Srivatsan is annoyed by a cricket in his room"

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
# An example of a word with two different senses
original_word = "cricket"

syns = wn.synsets(original_word,'n')

for syn in syns:
  print (syn, ": ",syn.definition(),"\n" )

Synset('cricket.n.01') :  leaping insect; male makes chirping noises by rubbing the forewings together 

Synset('cricket.n.02') :  a game played with a ball and bat by two teams of 11 players; teams take turns trying to score runs 



In [None]:
# Distractors from Wordnet
def get_distractors_wordnet(syn,word):
    distractors=[]
    word= word.lower()
    orig_word = word
    if len(word.split())>0:
        word = word.replace(" ","_")
    hypernym = syn.hypernyms()
    if len(hypernym) == 0: 
        return distractors
    for item in hypernym[0].hyponyms():
        name = item.lemmas()[0].name()
        #print ("name ",name, " word",orig_word)
        if name == orig_word:
            continue
        name = name.replace("_"," ")
        name = " ".join(w.capitalize() for w in name.split())
        if name is not None and name not in distractors:
            distractors.append(name)
    return distractors


synset_to_use = wn.synsets(original_word,'n')[0]
distractors_calculated = get_distractors_wordnet(synset_to_use,original_word)

print ("\noriginal word: ",original_word.capitalize())
print (distractors_calculated)


original_word = "cricket"
synset_to_use = wn.synsets(original_word,'n')[1]
distractors_calculated = get_distractors_wordnet(synset_to_use,original_word)

print ("\noriginal word: ",original_word.capitalize())
print (distractors_calculated)


original word:  Cricket
['Grasshopper']

original word:  Cricket
['Ball Game', 'Field Hockey', 'Football', 'Hurling', 'Lacrosse', 'Polo', 'Pushball', 'Ultimate Frisbee']


## Download pretrained BERT WSD Model and extract

Download pre-trained BERT WSD from [here](https://entuedu-my.sharepoint.com/:f:/g/personal/boonpeng001_e_ntu_edu_sg/EiWzblOyyOBDtuO3klUbXoAB3THFzke-2MLWguIXrDopWg?e=08umXD)

Click the download button at the top left of the link to download a file named "bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6.zip"

Place the zip file in your Google drive home folder

In [None]:
import os
import zipfile

bert_wsd_pytorch = "/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6.zip"
extract_directory = "/content/gdrive/My Drive"

extracted_folder = bert_wsd_pytorch.replace(".zip","")

#  If unzipped folder exists don't unzip again.
if not os.path.isdir(extracted_folder):
  with zipfile.ZipFile(bert_wsd_pytorch, 'r') as zip_ref:
      zip_ref.extractall(extract_directory)
else:
  print (extracted_folder," is extracted already")

/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6  is extracted already


## Find the correct sense (contextual meaning) of a given word in a sentence

In [None]:
import torch
import math
from transformers import BertModel, BertConfig, BertPreTrainedModel, BertTokenizer

class BertWSD(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.ranking_linear = torch.nn.Linear(config.hidden_size, 1)

        self.init_weights()


# def _forward(args, model, batch):
#     batch = tuple(t.to(args.device) for t in batch)
#     outputs = model.bert(input_ids=batch[0], attention_mask=batch[1], token_type_ids=batch[2])

#     return model.dropout(outputs[1])
    

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dir = "/content/gdrive/My Drive/t5/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6"


model = BertWSD.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
# add new special token
if '[TGT]' not in tokenizer.additional_special_tokens:
    tokenizer.add_special_tokens({'additional_special_tokens': ['[TGT]']})
    assert '[TGT]' in tokenizer.additional_special_tokens
    model.resize_token_embeddings(len(tokenizer))
    
model.to(DEVICE)
model.eval()

In [None]:
import csv
import os
from collections import namedtuple

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

import torch
from tqdm import tqdm

GlossSelectionRecord = namedtuple("GlossSelectionRecord", ["guid", "sentence", "sense_keys", "glosses", "targets"])
BertInput = namedtuple("BertInput", ["input_ids", "input_mask", "segment_ids", "label_id"])



def _create_features_from_records(records, max_seq_length, tokenizer, cls_token_at_end=False, pad_on_left=False,
                                  cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
                                  sequence_a_segment_id=0, sequence_b_segment_id=1,
                                  cls_token_segment_id=1, pad_token_segment_id=0,
                                  mask_padding_with_zero=True, disable_progress_bar=False):
    """ Convert records to list of features. Each feature is a list of sub-features where the first element is
        always the feature created from context-gloss pair while the rest of the elements are features created from
        context-example pairs (if available)
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """
    features = []
    for record in tqdm(records, disable=disable_progress_bar):
        tokens_a = tokenizer.tokenize(record.sentence)

        sequences = [(gloss, 1 if i in record.targets else 0) for i, gloss in enumerate(record.glosses)]

        pairs = []
        for seq, label in sequences:
            tokens_b = tokenizer.tokenize(seq)

            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

            # The convention in BERT is:
            # (a) For sequence pairs:
            #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
            #
            # Where "type_ids" are used to indicate whether this is the first
            # sequence or the second sequence. The embedding vectors for `type=0` and
            # `type=1` were learned during pre-training and are added to the wordpiece
            # embedding vector (and position vector). This is not *strictly* necessary
            # since the [SEP] token unambiguously separates the sequences, but it makes
            # it easier for the model to learn the concept of sequences.
            #
            # For classification tasks, the first vector (corresponding to [CLS]) is
            # used as as the "sentence vector". Note that this only makes sense because
            # the entire model is fine-tuned.
            tokens = tokens_a + [sep_token]
            segment_ids = [sequence_a_segment_id] * len(tokens)

            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

            if cls_token_at_end:
                tokens = tokens + [cls_token]
                segment_ids = segment_ids + [cls_token_segment_id]
            else:
                tokens = [cls_token] + tokens
                segment_ids = [cls_token_segment_id] + segment_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_seq_length - len(input_ids)
            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
                segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            pairs.append(
                BertInput(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label)
            )

        features.append(pairs)

    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


![picture](https://drive.google.com/uc?export=view&id=1rVyHMMl0YoQrQO8aLD54prOIKIlmzwT0)


In [None]:
import re
import torch
from tabulate import tabulate
from torch.nn.functional import softmax
from tqdm import tqdm
from transformers import BertTokenizer
import time


MAX_SEQ_LENGTH = 128

def get_sense(sent):
  re_result = re.search(r"\[TGT\](.*)\[TGT\]", sent)
  if re_result is None:
      print("\nIncorrect input format. Please try again.")

  ambiguous_word = re_result.group(1).strip()

  results = dict()

  wn_pos = wn.NOUN
  for i, synset in enumerate(set(wn.synsets(ambiguous_word, pos=wn_pos))):
      results[synset] =  synset.definition()

  if len(results) ==0:
    return (None,None,ambiguous_word)

  # print (results)
  sense_keys=[]
  definitions=[]
  for sense_key, definition in results.items():
      sense_keys.append(sense_key)
      definitions.append(definition)


  record = GlossSelectionRecord("test", sent, sense_keys, definitions, [-1])

  features = _create_features_from_records([record], MAX_SEQ_LENGTH, tokenizer,
                                            cls_token=tokenizer.cls_token,
                                            sep_token=tokenizer.sep_token,
                                            cls_token_segment_id=1,
                                            pad_token_segment_id=0,
                                            disable_progress_bar=True)[0]

  with torch.no_grad():
      logits = torch.zeros(len(definitions), dtype=torch.double).to(DEVICE)
      # for i, bert_input in tqdm(list(enumerate(features)), desc="Progress"):
      for i, bert_input in list(enumerate(features)):
          logits[i] = model.ranking_linear(
              model.bert(
                  input_ids=torch.tensor(bert_input.input_ids, dtype=torch.long).unsqueeze(0).to(DEVICE),
                  attention_mask=torch.tensor(bert_input.input_mask, dtype=torch.long).unsqueeze(0).to(DEVICE),
                  token_type_ids=torch.tensor(bert_input.segment_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
              )[1]
          )
      scores = softmax(logits, dim=0)

      preds = (sorted(zip(sense_keys, definitions, scores), key=lambda x: x[-1], reverse=True))


  # print (preds)
  sense = preds[0][0]
  meaning = preds[0][1]
  return (sense,meaning,ambiguous_word)


sentence1 = "Srivatsan loves to watch **cricket** during his free time"


sentence_for_bert = sentence1.replace("**"," [TGT] ")
sentence_for_bert = " ".join(sentence_for_bert.split())
sense,meaning,answer = get_sense(sentence_for_bert)

print (sentence1)
print (sense)
print (meaning)

sentence2 = "Srivatsan is annoyed by a **cricket** in his room"
sentence_for_bert = sentence2.replace("**"," [TGT] ")
sentence_for_bert = " ".join(sentence_for_bert.split())
sense,meaning,answer = get_sense(sentence_for_bert)

print ("\n-------------------------------")
print (sentence2)
print (sense)
print (meaning)



Srivatsan loves to watch **cricket** during his free time
Synset('cricket.n.02')
a game played with a ball and bat by two teams of 11 players; teams take turns trying to score runs

-------------------------------
Srivatsan is annoyed by a **cricket** in his room
Synset('cricket.n.01')
leaping insect; male makes chirping noises by rubbing the forewings together


## Generate a question using context and answer with T5

![picture](https://drive.google.com/uc?export=view&id=1Dc6W3F__okw1q6GxhKs46lvgeeBsP0iG)


In [None]:
from transformers import T5ForConditionalGeneration,T5Tokenizer

question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
question_tokenizer = T5Tokenizer.from_pretrained('t5-base')

def get_question(sentence,answer):
  text = "context: {} answer: {} </s>".format(sentence,answer)
  print (text)
  max_len = 256
  encoding = question_tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=True, return_tensors="pt")

  input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

  outs = question_model.generate(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  early_stopping=True,
                                  num_beams=5,
                                  num_return_sequences=1,
                                  no_repeat_ngram_size=2,
                                  max_length=200)


  dec = [question_tokenizer.decode(ids) for ids in outs]


  Question = dec[0].replace("question:","")
  Question= Question.strip()
  return Question


sentence1 = "Srivatsan loves to watch **cricket** during his free time"
sentence2 = "Srivatsan is annoyed by a **cricket** in his room"


answer = "cricket"

sentence_for_T5 = sentence1.replace("**"," ")
sentence_for_T5 = " ".join(sentence_for_T5.split()) 
ques = get_question(sentence_for_T5,answer)
print (ques)


print ("\n**************************************\n")
sentence_for_T5 = sentence2.replace("**"," ")
sentence_for_T5 = " ".join(sentence_for_T5.split()) 
ques = get_question(sentence_for_T5,answer)
print (ques)


Downloading:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

context: Srivatsan loves to watch cricket during his free time answer: cricket </s>


  beam_id = beam_token_id // vocab_size


What sport does Srivatsan enjoy watching?

**************************************

context: Srivatsan is annoyed by a cricket in his room answer: cricket </s>
What insect is in Srivatsan's room?


## Putting it all together

In [None]:
def getMCQs(sent):
  sentence_for_bert = sent.replace("**"," [TGT] ")
  sentence_for_bert = " ".join(sentence_for_bert.split())
  # try:
  sense,meaning,answer = get_sense(sentence_for_bert)
  if sense is not None:
    distractors = get_distractors_wordnet(sense,answer)
  else: 
    distractors = ["Word not found in Wordnet. So unable to extract distractors."]
  sentence_for_T5 = sent.replace("**"," ")
  sentence_for_T5 = " ".join(sentence_for_T5.split()) 
  ques = get_question(sentence_for_T5,answer)
  return ques,answer,distractors,meaning



print ("\n")
question,answer,distractors,meaning = getMCQs(sentence1)
print (question)
print (answer)
print (distractors)
print (meaning)

print ("\n")
question,answer,distractors,meaning = getMCQs(sentence2)
print (question)
print (answer)
print (distractors)
print (meaning)



context: Srivatsan loves to watch cricket during his free time answer: cricket </s>


  beam_id = beam_token_id // vocab_size


What sport does Srivatsan enjoy watching?
cricket
['Ball Game', 'Field Hockey', 'Football', 'Hurling', 'Lacrosse', 'Polo', 'Pushball', 'Ultimate Frisbee']
a game played with a ball and bat by two teams of 11 players; teams take turns trying to score runs


context: Srivatsan is annoyed by a cricket in his room answer: cricket </s>
What insect is in Srivatsan's room?
cricket
['Grasshopper']
leaping insect; male makes chirping noises by rubbing the forewings together


**Few more examples with disambiguation words (word with contextual meanings)**

In [None]:
# More examples

sentence = "John went to river **bank** to cry"
# sentence = "John went to deposit money in the **bank**"

# sentence = "John bought a **mouse** for his computer."
# sentence = "John saw a **mouse** under his bed."


print ("\n")
question,answer,distractors,meaning = getMCQs(sentence)
print (question)
print (answer)
print (distractors)
print (meaning)



context: John went to river bank to cry answer: bank </s>


  beam_id = beam_token_id // vocab_size


Where did John go to cry?
bank
['Ascent', 'Canyonside', 'Coast', 'Descent', 'Escarpment', 'Hillside', 'Mountainside', 'Piedmont', 'Ski Slope']
sloping land (especially the slope beside a body of water)
