# 4-1 Generation based MRC


https://github.com/huggingface/transformers/tree/master/examples/research_projects/longform-qa 참고

In [1]:
from transformers import BartConfig, BartTokenizer, BartForQuestionAnswering
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup

In [2]:
def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) # BartForConditionalGeneration
    if from_file is not None:
        param_dict = torch.load(from_file)  # has model weights, optimizer, and scheduler states
        model.load_state_dict(param_dict["model"])
    return tokenizer, model

In [3]:
from transformers.data.metrics.squad_metrics import (
    compute_predictions_logits,
    squad_evaluate,
)

from transformers.data.processors.squad import SquadResult, SquadProcessor, squad_convert_examples_to_features

In [4]:
import config as cfg
from utils import load_and_cache_examples, set_seed

In [5]:
import functools
import os
from time import time
from tqdm.auto import tqdm, trange

import numpy as np
import math 
from random import choice, randint

import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

import apex
from apex import amp

# Preparation

## squad

In [6]:
import datasets
squad = datasets.load_dataset("squad")
squad_train = squad["train"]
squad_valid = squad["validation"]

Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/4c81550d83a2ac7c7ce23783bd8ff36642800e6633c1f18417fb58c3ff50cdd7)


In [7]:
squad_train[46]

{'answers': {'answer_start': [565], 'text': ['1st overall']},
 'context': 'In 2015-2016, Notre Dame ranked 18th overall among "national universities" in the United States in U.S. News & World Report\'s Best Colleges 2016. In 2014, USA Today ranked Notre Dame 10th overall for American universities based on data from College Factual. Forbes.com\'s America\'s Best Colleges ranks Notre Dame 13th among colleges in the United States in 2015, 8th among Research Universities, and 1st in the Midwest. U.S. News & World Report also lists Notre Dame Law School as 22nd overall. BusinessWeek ranks Mendoza College of Business undergraduate school as 1st overall. It ranks the MBA program as 20th overall. The Philosophical Gourmet Report ranks Notre Dame\'s graduate philosophy program as 15th nationally, while ARCHITECT Magazine ranked the undergraduate architecture program as 12th nationally. Additionally, the study abroad program ranks sixth in highest participation percentage in the nation, with 57.

In [8]:
len(squad_train)

87599

In [9]:
# eli5 = datasets.load_dataset("eli5", name="LFQA_reddit")
# eli5_train = eli5["train_eli5"]
# eli5_train[1]

## load model, dataset

In [10]:
tokenizer, model = make_qa_s2s_model(model_name = "facebook/bart-large")
# "Primer/bart-squad2"


In [11]:
class SquadDatasetS2S(Dataset):
    def __init__(
        self, examples_array, tokenizer
    ):
        self.data = examples_array
        self.tokenizer = tokenizer
        # 모든 데이터가 답이 하나 뿐이라 qa_id_list 필요없음

    def __len__(self):
        return len(self.data) 

    def make_example(self, idx):
        example = self.data[idx]
        question = example["question"]
#         answer = example["answers"]["text"][0] + " " + self.tokenizer.eos_token
        answer = example["answers"]["text"][0]
        q_id = example["id"]

        document = example["context"]
        in_st = "question: {} context: {}".format(
            question.lower().replace(" --t--", "").strip(), document.lower().strip(),
        )
        out_st = answer
        return (in_st, out_st)

    def __getitem__(self, idx):
        return self.make_example(idx)

In [12]:
def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
    q_ls = [q for q, a in qa_list]
    a_ls = [a for q, a in qa_list]
    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, padding="max_length", truncation=True)
    q_ids, q_mask = (
        torch.LongTensor(q_toks["input_ids"]).to(device),
        torch.LongTensor(q_toks["attention_mask"]).to(device),
    )
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), padding="max_length", truncation=True)
    a_ids, a_mask = (
        torch.LongTensor(a_toks["input_ids"]).to(device),
        torch.LongTensor(a_toks["attention_mask"]).to(device),
    )
    lm_labels = a_ids[:, 1:].contiguous().clone()
    lm_labels[a_mask[:, 1:].contiguous() == 0] = -100
    model_inputs = {
        "input_ids": q_ids,
        "attention_mask": q_mask,
        "decoder_input_ids": a_ids[:, :-1].contiguous(),
        "labels": lm_labels,
    }
    return model_inputs

In [13]:
# mode = 'train'
# train_dataset = load_and_cache_examples(cfg, tokenizer, mode_or_filename=mode, output_examples=False)

In [14]:
train_dset = SquadDatasetS2S(squad_train, tokenizer)
valid_dset = SquadDatasetS2S(squad_valid, tokenizer)

In [15]:
len_tr = 10000
len_vl = 2000
train_dset = torch.utils.data.random_split(train_dset, [len_tr, len(train_dset) - len_tr])[0]
valid_dset = torch.utils.data.random_split(valid_dset, [len_vl, len(valid_dset) - len_vl])[0]

In [16]:
train_dset[1]

('question: what story was written by child in 1842? context: the figure of the "tragic octoroon" was a stock character of abolitionist literature: a mixed-race woman raised as if a white woman in her white father\'s household, until his bankruptcy or death has her reduced to a menial position she may even be unaware of her status before being reduced to victimization. the first character of this type was the heroine of lydia maria child\'s "the quadroons" (1842), a short story. this character allowed abolitionists to draw attention to the sexual exploitation in slavery and, unlike portrayals of the suffering of the field hands, did not allow slaveholders to retort that the sufferings of northern mill hands were no easier. the northern mill owner would not sell his own children into slavery.',
 '"The Quadroons"')

## test

In [17]:
# from transformers import BartModel
# test_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
# test_model = BartModel.from_pretrained("facebook/bart-base")

# device = "cuda:0"
# qa_list = [train_dset[1], train_dset[2]]
# q_ls = [q for q, a in qa_list]
# a_ls = [a for q, a in qa_list]
# q_toks = test_tokenizer.batch_encode_plus(q_ls, max_length=64, padding="max_length", truncation=True)
# q_ids, q_mask = (
#     torch.LongTensor(q_toks["input_ids"]),
#     torch.LongTensor(q_toks["attention_mask"]),
# )

# a_toks = test_tokenizer.batch_encode_plus(a_ls, max_length=min(64, 360), padding="max_length", truncation=True)
# a_ids, a_mask = (
#     torch.LongTensor(a_toks["input_ids"]),
#     torch.LongTensor(a_toks["attention_mask"]),
# )

# # q_ids.shape

# # a_ids.shape # 0~64

# lm_labels = a_ids[:, 1:].contiguous().clone()

# # lm_labels.shape # 1번째부터 64번째까지

# # lm_labels

# lm_labels[a_mask[:, 1:].contiguous() == 0] = -100

# # lm_labels

# model_inputs = {
#     "input_ids": q_ids,
#     "attention_mask": q_mask,
#     "decoder_input_ids": a_ids[:, :-1].contiguous(),
#     "return_dict": True,
#     #     "labels": lm_labels,
# }

# outputs = test_model(**model_inputs)

# # 'last_hidden_state', 'past_key_values', 'decoder_hidden_states', 
# # 'decoder_attentions', 'cross_attentions', 'encoder_last_hidden_state',
# # 'encoder_hidden_states', 'encoder_attentions'

# # 3개만 나올땐 last_hidden_state past_key_values encoder_last_hidden_state 이렇게 3개인듯

# outputs.__dict__.keys()

# print(outputs[0].shape) # 출력 모양, (bs, seq_len, hidden_dim)
# print(test_model.shared.num_embeddings) # 단어 수 

# import torch.nn as nn
# lm_head = nn.Linear(outputs[0].shape[-1], test_model.shared.num_embeddings, bias=False)

# lm_logits = lm_head(outputs[0])
# print(lm_logits.shape) # 출력 시퀀스 내 각 포지션 별 로짓

# print(lm_labels.shape) # (bs, max_seq_len - 1)
# print(lm_labels.view(-1).shape) # (bs x max_seq_len -1)
# print(lm_logits.view(-1, 50265).shape) # (bs x max_seq_len -1, vocab_size)

# loss_fct = nn.CrossEntropyLoss()
# masked_lm_loss = loss_fct(lm_logits.view(-1, 50265), lm_labels.view(-1))
# print(masked_lm_loss)

# Training

## 함수꼴

In [18]:
def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
    model.train()
    # make iterator
    if curriculum:
        train_sampler = SequentialSampler(dataset)
    else:
        train_sampler = RandomSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, batch_inputs in enumerate(epoch_iterator):
        loss = model(**batch_inputs)[0]
#         loss = pre_loss.sum() / pre_loss.shape[0] # 배치단위로 평균 내 줄 수 있으나, bartcondgen 은 필요없음
        
        # amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0 or step == 1:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0

In [19]:
def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
    model.eval()
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    with torch.no_grad():
        for step, batch_inputs in enumerate(epoch_iterator):
            loss = model(**batch_inputs)[0]
#             loss = pre_loss.sum() / pre_loss.shape[0]
            loc_loss += loss.item()
            loc_steps += 1
            if step % args.print_freq == 0:
                print(
                    "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                        step,
                        len(dataset) // args.batch_size,
                        loc_loss / loc_steps,
                        time() - st_time,
                    )
                )
    print(
        "Total \t L: {:.3f} \t -- {:.3f}".format(
            loc_loss / loc_steps,
            time() - st_time,
        )
    )

In [20]:
def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
#     s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
    s2s_optimizer = apex.optimizers.FusedLAMB(qa_s2s_model.parameters(),
                                    lr = s2s_args.learning_rate,
                                    eps=1e-8,
                                    weight_decay=0.0,
                                    max_grad_norm=1.0)
    qa_s2s_model, s2s_optimizer = amp.initialize(qa_s2s_model, s2s_optimizer, opt_level="O1")
    s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
    )
    for e in range(s2s_args.num_epochs):
        train_qa_s2s_epoch(
            qa_s2s_model,
            s2s_train_dset,
            qa_s2s_tokenizer,
            s2s_optimizer,
            s2s_scheduler,
            s2s_args,
            e,
            curriculum=(e == 0),
        )
        m_save_dict = {
            "model": qa_s2s_model.state_dict(),
            "optimizer": s2s_optimizer.state_dict(),
            "scheduler": s2s_scheduler.state_dict(),
        }
        print("Saving model {}".format(s2s_args.model_save_name))
        eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
        torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))

## 훈련

In [21]:
# training loop proper
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 4
        self.backward_freq = 16
        self.max_length = 512
        self.print_freq = 100
        self.model_save_name = "seq2seq_models/squad_bart_model1"
        self.learning_rate = 1e-4
        self.num_epochs = 2

s2s_args = ArgumentsS2S()

In [None]:
train_qa_s2s(model, tokenizer, train_dset, valid_dset, s2s_args)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
 0     0 of  2500 	 L: 9.341 	 -- 0.735




 0     1 of  2500 	 L: 11.416 	 -- 0.969
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0
 0   100 of  2500 	 L: 9.972 	 -- 24.370
 0   200 of  2500 	 L: 8.324 	 -- 48.321
 0   300 of  2500 	 L: 5.689 	 -- 72.754
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32.0
 0   400 of  2500 	 L: 4.499 	 -- 97.859
 0   500 of  2500 	 L: 3.821 	 -- 122.450
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8.0
 0   600 of  2500 	 L: 3.201 	 -- 147.500
 0   700 of  2500 	 L: 2.778 	 -- 172.067
 0   800 of  2500 	 L: 2.214 	 -- 196.247
 0   900 of  2500 	 L: 1.943 	 -- 220.488
 0  1000 of  2500 	 L: 1.635 	 -- 245.155
 0  1100 of  2500 	 L: 1.539 	 -- 269.701
 0  1200 of  2500 	 L: 1.387 	 -- 294.863
 0  1300 of  2500 	 L: 1.409 	 -- 318.182
 0  1400 of  2500 	 L: 1.165 	 -- 344.592


In [None]:

# generate answer from input "question: ... context: <p> ..."
def qa_s2s_generate(
    question_doc,
    qa_s2s_model,
    qa_s2s_tokenizer,
    num_answers=1,
    num_beams=None,
    min_len=64,
    max_len=256,
    do_sample=False,
    temp=1.0,
    top_p=None,
    top_k=None,
    max_input_length=512,
    device="cuda:0",
):
    model_inputs = make_qa_s2s_batch(
        [(question_doc, "A")],
        qa_s2s_tokenizer,
        max_input_length,
        device=device,
    )
    n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
    generated_ids = qa_s2s_model.generate(
        input_ids=model_inputs["input_ids"],
        attention_mask=model_inputs["attention_mask"],
        min_length=min_len,
        max_length=max_len,
        do_sample=do_sample,
        early_stopping=True,
        num_beams=1 if do_sample else n_beams,
        temperature=temp,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=qa_s2s_tokenizer.eos_token_id,
        no_repeat_ngram_size=3,
        num_return_sequences=num_answers,
        decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
    )
    return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]

In [None]:
# tokenizer, model = make_qa_s2s_model(from_file="seq2seq_models/squad_bart_model_.pth")

In [None]:

example = squad_valid[1634]

print(f"question = {example['question']}")
print(f"original answer = {example['answers']['text'][0]}")
question_document = "question: {} context: {}".format(example['question'], example['context'])
answer = qa_s2s_generate(question_document, model, tokenizer,
                         max_len = 20, top_p=0.95, top_k=30,
                         device="cuda:0"
                        )

print("="*50)

print(f"generated answer = {answer}")

In [None]:
pri
answer