In [38]:
from transformers import (AdamW, get_linear_schedule_with_warmup, AutoModelForMaskedLM, AutoConfig, AutoTokenizer)
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import json

In [2]:
# load pretrained-model
model_name = "bert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# tokenizer test
input_text = "[CLS] She goes school by [HL] bus [/HL] [SEP]"
tokenized_text = tokenizer.tokenize(input_text,add_special_tokens=False)
print(tokenized_text)

['[CLS]', 'she', 'goes', 'school', 'by', '[', 'h', '##l', ']', 'bus', '[', '/', 'h', '##l', ']', '[SEP]']


In [5]:
# add [HL], [/HL] token 
added_token_num = tokenizer.add_special_tokens({"additional_special_tokens":["[HL]","[/HL]"]})
tokenized_text = tokenizer.tokenize(input_text,add_special_tokens=False)
print(tokenized_text)

# add token number
print(model.get_input_embeddings())
model.resize_token_embeddings(tokenizer.vocab_size + added_token_num)
print(model.get_input_embeddings()) # 2개 증가 된 것을 확인 가능함

['[CLS]', 'she', 'goes', 'school', 'by', '[HL]', 'bus', '[/HL]', '[SEP]']
Embedding(30522, 768, padding_idx=0)
Embedding(30524, 768)


In [6]:
# load data
file_path = "./data/squad_nqg/test.json"
with open(file_path, 'r') as file:
    data = json.load(file)

In [41]:
# sampling data
print(data[0])

input_text = data[0]['context']
target_text = data[0]['question']

# input_text = "[CLS] Jane's favorite food is [HL] chicken [/HL] [SEP]"
# target_text = "What is Jane's favorite food"

{'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': [{'answer_start': 515, 'text': 'Saint Bernadette Soubirous'}]}


In [56]:
# add [HL], [/HL]

cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
mask_token = tokenizer.mask_token


answer_start = data[0]['answers'][0]['answer_start']
answer = data[0]['answers'][0]['text']

add_hl = f"{cls_token} {data[0]['context'][:answer_start]} [HL] {data[0]['context'][answer_start:answer_start+len(answer)]} [/HL] {data[0]['context'][answer_start+len(answer):]} {sep_token}"
print(add_hl)

[CLS] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to  [HL] Saint Bernadette Soubirous [/HL]  in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. [SEP]


In [34]:
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
mask_token = tokenizer.mask_token

example_pair = dict()


for d in data:
  input_text = d['context']
  target_text = d['question']
  
  tokenized_target = tokenizer.tokenize(target_text) # tokenize question
for i in range(0,len(tokenized_target)+1):
      # tokenized
      tokenized_text = tokenizer.tokenize(input_text)
      tokenized_text.extend(tokenized_target[:i]) # tokenized_context + tokenized_question[:i]
      tokenized_text.append('[MASK]') # tokenized_context + tokenized_question[:i] + [MASK]
      indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
      tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    
      # loss
      loss_ids = [-1] * (len(tokenizer.convert_tokens_to_ids(tokenized_text))-1)

      if i == len(tokenized_target):
          loss_ids.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('[SEP]'))[0])
      else:
          loss_ids.append(tokenizer.convert_tokens_to_ids(tokenized_target[i]))

      loss_tensors = torch.tensor([loss_ids]).to(device)
  
      example_pair[tokens_tensor] = loss_tensors
      # print(tokenized_text,loss_ids,loss_ids[-1])
      # print(len(indexed_tokens),len(loss_tensors))

['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]', '[SEP]', '[MASK]'] [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2054] 2054
12 1
['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]', '[SEP]', 'what', '[MASK]'] [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2003] 2003
13 1
['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]', '[SEP]', 'what', 'is', '[MASK]'] [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4869] 4869
14 1
['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]', '[SEP]', 'what', 'is', 'jane', '[MASK]'] [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1005] 1005
15 1
['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]', '[SEP]', 'what', 'is', 'jane', "'", '[MASK]'] [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1055] 1055
16 1
['[CLS]', 'jane', "'", 's', 'favorite', 'food', 'is', '[HL]', 'chicken', '[/HL]