In [3]:
import json, collections, os, random, glob, math, string, re, torch, pickle
from nltk import word_tokenize
# from tqdm import trange, tqdm_notebook as tqdm 
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import WEIGHTS_NAME, BertConfig, BertForQuestionAnswering, BertTokenizerFast, BasicTokenizer, AdamW, get_linear_schedule_with_warmup
from transformers.models.bert.tokenization_bert import whitespace_tokenize
from transformers.models.roberta.modeling_roberta import *
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
from transformers import RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer
from transformers import XLMRobertaConfig, XLMRobertaForQuestionAnswering, XLMRobertaTokenizer
from transformers import AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer

In [4]:
from transformers import AutoTokenizer
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Trand\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:

def tokenize_function(example, tokenizer):
    question_word = word_tokenize(example["question"])
    context_word = word_tokenize(example["context"])

    question_sub_words_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)) for w in question_word]
    context_sub_words_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)) for w in context_word]
    valid = True
    if len([j for i in question_sub_words_ids + context_sub_words_ids for j in
            i]) > tokenizer.max_len_single_sentence - 1:
        valid = False

    question_sub_words_ids = [[tokenizer.bos_token_id]] + question_sub_words_ids + [[tokenizer.eos_token_id]]
    context_sub_words_ids = context_sub_words_ids + [[tokenizer.eos_token_id]]

    input_ids = [j for i in question_sub_words_ids + context_sub_words_ids for j in i]
    if len(input_ids) > tokenizer.max_len_single_sentence + 2:
        valid = False

    words_lengths = [len(item) for item in question_sub_words_ids + context_sub_words_ids]

    return {
        "input_ids": input_ids,
        "words_lengths": words_lengths,
        "valid": valid
    }


def data_collator(samples, tokenizer, device='cpu'):
    if len(samples) == 0:
        return {}

    def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
        """Convert a list of 1d tensors into a padded 2d tensor."""
        size = max(v.size(0) for v in values)
        res = values[0].new(len(values), size).fill_(pad_idx)

        def copy_tensor(src, dst):
            assert dst.numel() == src.numel()
            if move_eos_to_beginning:
                assert src[-1] == eos_idx
                dst[0] = eos_idx
                dst[1:] = src[:-1]
            else:
                dst.copy_(src)

        for i, v in enumerate(values):
            copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
        return res

    input_ids = collate_tokens([torch.tensor(item['input_ids'], device=device) for item in samples], pad_idx=tokenizer.pad_token_id)
    attention_mask = torch.zeros_like(input_ids, device=device)
    for i in range(len(samples)):
        attention_mask[i][:len(samples[i]['input_ids'])] = 1
    words_lengths = collate_tokens([torch.tensor(item['words_lengths'], device=device) for item in samples], pad_idx=0)

    batch_samples = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'words_lengths': words_lengths,
    }

    return batch_samples


def extract_answer(inputs, outputs, tokenizer):
    plain_result = []
    for sample_input, start_logit, end_logit in zip(inputs, outputs.start_logits, outputs.end_logits):
        sample_words_length = sample_input['words_lengths']
        input_ids = sample_input['input_ids']
        # Get the most likely beginning of answer with the argmax of the score
        answer_start = sum(sample_words_length[:torch.argmax(start_logit)])
        # Get the most likely end of answer with the argmax of the score
        answer_end = sum(sample_words_length[:torch.argmax(end_logit) + 1])

        if answer_start <= answer_end:
            answer = tokenizer.convert_tokens_to_string(
                tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
            if answer == tokenizer.bos_token:
                answer = ''
        else:
            answer = ''

        score_start = torch.max(torch.softmax(start_logit, dim=-1)).cpu().detach().numpy().tolist()
        score_end = torch.max(torch.softmax(end_logit, dim=-1)).cpu().detach().numpy().tolist()
        plain_result.append({
            "answer": answer,
            "score_start": score_start,
            "score_end": score_end
        })
    return plain_result



In [5]:
class MRCQuestionAnswering(RobertaPreTrainedModel):
    config_class = RobertaConfig

    def _reorder_cache(self, past, beam_idx):
        pass

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
            self,
            input_ids=None,
            words_lengths=None,
            start_idx=None,
            end_idx=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            start_positions=None,
            end_positions=None,
            span_answer_ids=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        context_embedding = sequence_output

        # Compute align word sub_word matrix
        batch_size = input_ids.shape[0]
        max_sub_word = input_ids.shape[1]
        max_word = words_lengths.shape[1]
        align_matrix = torch.zeros((batch_size, max_word, max_sub_word))

        for i, sample_length in enumerate(words_lengths):
            for j in range(len(sample_length)):
                start_idx = torch.sum(sample_length[:j])
                align_matrix[i][j][start_idx: start_idx + sample_length[j]] = 1 if sample_length[j] > 0 else 0

        align_matrix = align_matrix.to(context_embedding.device)
        # Combine sub_word features to make word feature
        context_embedding_align = torch.bmm(align_matrix, context_embedding)

        logits = self.qa_outputs(context_embedding_align)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
model_checkpoint = "nguyenvulebinh/vi-mrc-large"
#model_checkpoint = "nguyenvulebinh/vi-mrc-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [7]:
# !git clone https://github.com/nguyenvulebinh/extractive-qa-mrc

* Train:

    Before processing

    - <b>Invalid</b> (lack of answer and not match in context): 80641

    - <b>Valid</b> (answer macth in context): 49678

    After processing

    - <b>Invalid</b> (lack of answer and not match in context): 27286

    - <b>Valid</b> (answer macth in context): 103033

* Test:

    Before processing

    - <b>Invalid</b> (lack of answer and not match in context): 8380

    - <b>Valid</b> (answer macth in context): 3493

    After processing

    - <b>Invalid</b> (lack of answer and not match in context): 2262
    
    - <b>Valid</b> (answer macth in context): 9611

In [112]:
input_file = r'D:\NLP_project\data\SQuA2.0\extractive-qa-mrc\test-preprocessed.json'
with open(input_file, "r", encoding='utf-8') as reader:
    source = json.load(reader)

In [113]:
context, question, answer = source.values()
invalid = 0
valid = 0
for txt, qus, ans in zip(context, question, answer):
    if ans['answer_start'][0] == -1:
        invalid += 1
    else:
        valid += 1
print(valid, invalid, len(context))

9611 2262 11873


In [103]:
11873-8380

3493

In [109]:
miss = 0
non_match = 0
#non_match_ans = {'context':[], 'question': [], 'answer':[]}
inputs = []
pbar = tqdm(source, total=len(source))
for entry in pbar:
    if entry[2] == '':
        miss += 1
    elif entry[0].lower().find(entry[2].lower()) == -1:
        non_match += 1
print(miss, non_match)

100%|██████████| 11873/11873 [00:00<00:00, 288631.13it/s]

4054 1183





In [102]:
miss + non_match

8380

In [11]:
'''with open("non_match_answer.json", "w") as outfile:
    json.dump(non_match_ans, outfile)
with open("inputs.pickle", "wb") as outfile:
    pickle.dump(inputs, outfile)
print(miss)
print(non_match)'''

'with open("non_match_answer.json", "w") as outfile:\n    json.dump(non_match_ans, outfile)\nwith open("inputs.pickle", "wb") as outfile:\n    pickle.dump(inputs, outfile)\nprint(miss)\nprint(non_match)'

In [12]:
'''with open("inputs.pickle", "rb") as file_to_read:
    inputs = pickle.load(file_to_read)
with open("non_match_answer.json", "r", encoding='utf-8') as reader:
    non_match_ans = json.load(reader)
print(len(inputs))'''

'with open("inputs.pickle", "rb") as file_to_read:\n    inputs = pickle.load(file_to_read)\nwith open("non_match_answer.json", "r", encoding=\'utf-8\') as reader:\n    non_match_ans = json.load(reader)\nprint(len(inputs))'

In [72]:
model = MRCQuestionAnswering.from_pretrained(model_checkpoint)

In [73]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model = model.eval()


In [15]:
# model_checkpoint = "nguyenvulebinh/vi-mrc-large"
# tokenizer_large = AutoTokenizer.from_pretrained(model_checkpoint)
# model_large = MRCQuestionAnswering.from_pretrained(model_checkpoint)

# model_checkpoint = "nguyenvulebinh/vi-mrc-base"
# tokenizer_base = AutoTokenizer.from_pretrained(model_checkpoint, token='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')
# model_base = MRCQuestionAnswering.from_pretrained(model_checkpoint, token='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')

In [16]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_large = model_large.to(device)
# model_large = model_large.eval()

# model_base = model_base.to(device)
# model_base = model_base.eval()

In [74]:
non_answer_elements = 0
good_answer = 0
#context, question, answer = non_match_ans.values()
#min_sent = len(context[0].split('. '))
pbar = tqdm(range(len(source)), total=len(source))
# err_lst = {'context': [], 'question': [], 'answer': []}
for i in pbar:
    if source[i][2] != '' and source[i][0].lower().find(source[i][2].lower()) != -1:
        continue
    QA_input = {
        'context': source[i][0],
        'question': source[i][1]
    }
    plain_result = None
    # try:
    #     inputs = [tokenize_function(QA_input, tokenizer)]
    #     inputs_ids = data_collator(inputs, tokenizer, device)
    #     outputs = model(**inputs_ids)
    #     plain_result = extract_answer(inputs, outputs, tokenizer)[0]
    # except:
    sentences = source[i][0].split('. ')
    for sent in sentences:
        QA_input['context'] = sent
        try:
            inputs = [tokenize_function(QA_input, tokenizer)]
            inputs_ids = data_collator(inputs, tokenizer, device)
            outputs = model(**inputs_ids)
        except:
            continue
        plain_result = extract_answer(inputs, outputs, tokenizer)[0]
        if plain_result['answer'] != '':
            break
                
    if plain_result == None or plain_result['answer'] == '':
        non_answer_elements += 1
    else:
        good_answer += 1
        source[i][2] = plain_result['answer']
    mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
    s = ('Memory: %5s  ---  Good answers: %8d  ---  None answer: %8d') %(mem, good_answer, non_answer_elements)
    pbar.set_description(s)


Memory:  5.5G  ---  Good answers:     3672  ---  None answer:     4708: 100%|██████████| 11873/11873 [16:56<00:00, 11.68it/s] 


In [75]:
miss = 0
non_match = 0
pbar = tqdm(enumerate(source), total=len(source))
non_match_lst = []
for i, entry in pbar:
    if entry[2] == '':
        miss += 1
    elif entry[0].lower().find(entry[2].lower()) == -1:
        non_match += 1
        non_match_lst.append(i)
print(miss, non_match)

100%|██████████| 11873/11873 [00:00<00:00, 391926.55it/s]

4054 1183





In [76]:
n = 0
not_good_ans = []
for i in non_match_lst:
    txt = source[i][0]
    qus = source[i][1]
    ans = source[i][2]
    tmp = ans.split('</s>')[-1]
    tmp = ans.split('.')
    for i, term in enumerate(tmp):
        term = term.split(',')
        tmp[i] = ', '.join([t.strip() for t in term])
    process_ans = ''
    if len(tmp) == 1 and tmp[0][-1] == '%':
        process_ans = ''.join([t.strip() for t in tmp[0].split()])
    elif len(tmp) == 2 and tmp[-1] != '' and tmp[-1][-1] == '%':
        process_ans = tmp[0] + '.' + ''.join([t.strip() for t in tmp[-1].split()])
    else:
        process_ans = '. '.join([sent.strip() for sent in tmp])
        process_ans = ' (' .join([t.strip() for t in process_ans.split('(')])
        process_ans = ') ' .join([t.strip() for t in process_ans.split(')')])
    if source[i][0].lower().find(process_ans.lower().strip()) == -1:
        n += 1
        not_good_ans.append(i)
    else:
        source[i][2] = process_ans.lower().strip()
print(n)

1183


In [63]:
miss = 0
non_match = 0
pbar = tqdm(enumerate(source), total=len(source))
non_match_lst = []
for i, entry in pbar:
    if entry[2] == '':
        miss += 1
    elif entry[0].lower().find(entry[2].lower().strip()) == -1:
        non_match += 1
        non_match_lst.append(i)
print(miss, non_match)

100%|██████████| 130319/130319 [00:00<00:00, 443338.98it/s]

32114 13309





In [110]:
data = {'context':[], 'question':[], 'answers':[]}
for txt, qus, ans in source:
    idx = txt.find(ans)
    data['context'].append(txt)
    data['question'].append(qus)
    data['answers'].append({'text':[ans if idx != -1 else ''], 'answer_start': [idx]})

In [90]:
data

{'context': ['Beyoncé Giselle Knowles-Carter (/ b i gì ɒ n s eɪ / bee-YON-say) (sinh ngày 04 tháng 9 1981) là một ca sĩ, nhạc sĩ, nhà sản xuất thu âm và nữ diễn viên người Mỹ. Sinh ra và lớn lên ở Houston, Texas, cô đã biểu diễn trong các cuộc thi ca hát và nhảy múa khác nhau khi còn nhỏ, và nổi tiếng vào cuối những năm 1990 với tư cách là ca sĩ chính của nhóm nhạc nữ R & B Destiny\'s Child. Được quản lý bởi cha cô, Mathew Knowles, nhóm đã trở thành một trong những nhóm nhạc nữ bán chạy nhất thế giới mọi thời đại. Sự gián đoạn của họ đã chứng kiến việc phát hành album đầu tay của Beyoncé, Dangerously in Love (2003), giúp cô trở thành một nghệ sĩ solo trên toàn thế giới, giành được năm giải Grammy và có đĩa đơn quán quân Billboard Hot 100 "Crazy in Love" và "Baby Boy".',
  'Beyoncé Giselle Knowles-Carter (/ b i gì ɒ n s eɪ / bee-YON-say) (sinh ngày 04 tháng 9 1981) là một ca sĩ, nhạc sĩ, nhà sản xuất thu âm và nữ diễn viên người Mỹ. Sinh ra và lớn lên ở Houston, Texas, cô đã biểu diễn t

In [111]:
with open("test-preprocessed.json", "w") as outfile:
    json.dump(data, outfile)

In [12]:
QA_input = {
  'context': 'Vào ngày 6 tháng 2 năm 2016, một ngày trước buổi biểu diễn của cô tại Super Bowl, Beyoncé đã phát hành một đĩa đơn mới độc quyền trên dịch vụ nghe nhạc trực tuyến Tidal có tên "Formation".',
  'question': 'Beyonce phát hành bài hát "Formation" trên dịch vụ âm nhạc trực tuyến nào?'
}
try:
    inputs = [tokenize_function(QA_input, tokenizer)]
    inputs_ids = data_collator(inputs, tokenizer, device)
    outputs = model(**inputs_ids)
    ans = extract_answer(inputs, outputs, tokenizer)[0]
except:
    print('split')
    qlen = len(QA_input['context'])
    qus = QA_input['context'].split('. ')
    qlen = len(qus)
    qus = '. '.join(qus[:qlen//2])
    QA_input['context'] = qus
    inputs = [tokenize_function(QA_input, tokenizer)]
    inputs_ids = data_collator(inputs, tokenizer, device)
    outputs = model(**inputs_ids)
    ans = extract_answer(inputs, outputs, tokenizer)[0]
print(ans)

{'answer': 'Tidal', 'score_start': 0.9999996423721313, 'score_end': 0.9999996423721313}
