<a href="https://colab.research.google.com/github/huang624/NaturalLanguageUnderstanding-Fine_tuning_BERT_for_QuestionAnswering_on_SQuAD_Dataset/blob/main/BERT_for_QuestionAnswering_SQuAD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **目標一**:
### 使用 SQuAD 資料集和 BERT-base-uncased 訓練 Question-Answering 模型


# 安裝套件

In [None]:
pip install transformers datasets accelerate

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 5.2 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.0.0-py3-none-any.whl (325 kB)
[K     |████████████████████████████████| 325 kB 47.9 MB/s 
[?25hCollecting accelerate
  Downloading accelerate-0.6.2-py3-none-any.whl (65 kB)
[K     |████████████████████████████████| 65 kB 4.0 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 41.8 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 34.9 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77

# 確認 GPU 分配

In [None]:
!nvidia-smi

Wed Apr 13 16:11:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8    26W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Mount 雲端硬碟

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


## cd 到自己的雲端硬碟中的colab

In [None]:
%cd /content/gdrive/"BERT_SQuAD"  

# 資料前處理

## 下載資料集

In [None]:
!cp /content/gdrive/"SQuAD 1.1"/train-v1.1.json .
!cp /content/gdrive/"SQuAD 1.1"/dev-v1.1.json .

In [None]:
!ls

### SQuAD 資料格式
Stanford 大學所整理的閱讀理解資料集 Stanford Question Answering Dataset (SQuAD) 
內容從維基百科中收集超過 10 萬筆的 CQA pair


For more information please refer to Paper: https://arxiv.org/abs/1606.05250

### Data format 資料格式

- version : <String> 資料集版本
- data : <Array>
  - title : <String> : 文章標題
  - id : <String> : 文章編號
  - paragraphs : <Array>
    - id : <String> : 文章編號_段落編號
    - context : <String> : 段落內容
    - qas : <Array>
      - question : <String> : 問題內容
      - id :<String> : 文章編號_段落編號_問題編號
      - answers : <Arrays>
        - answer_start : <int> text在文中位置
        - id : <String> : "1"表示為人工標註的答案，"2"以上為人工答題的答案
        - text : <string> : 答案內容


In [None]:
import json
from pprint import pprint
with open('dev-v1.1.json') as file:
  train_data = json.load(file)

for ele in train_data['data']:
  pprint(ele['paragraphs'][0])
  break

{'context': 'Super Bowl 50 was an American football game to determine the '
            'champion of the National Football League (NFL) for the 2015 '
            'season. The American Football Conference (AFC) champion Denver '
            'Broncos defeated the National Football Conference (NFC) champion '
            'Carolina Panthers 24–10 to earn their third Super Bowl title. The '
            "game was played on February 7, 2016, at Levi's Stadium in the San "
            'Francisco Bay Area at Santa Clara, California. As this was the '
            '50th Super Bowl, the league emphasized the "golden anniversary" '
            'with various gold-themed initiatives, as well as temporarily '
            'suspending the tradition of naming each Super Bowl game with '
            'Roman numerals (under which the game would have been known as '
            '"Super Bowl L"), so that the logo could prominently feature the '
            'Arabic numerals 50.',
 'qas': [{'answers': [{'answe

## 讀取資料

In [None]:
from pathlib import Path
def read_data(path, limit=None):
    path = Path(path)
    with open(path, 'rb') as f:
        data_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in data_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
                if limit != None and len(contexts) > limit:
                  return contexts, questions, answers
                  
    return contexts, questions, answers

In [None]:
train_contexts, train_questions, train_answers = read_data('train-v1.1.json', 8000)
eval_contexts, eval_questions, eval_answers = read_data('dev-v1.1.json',1500)

In [None]:
print(len(train_contexts))
print(len(eval_contexts))

8001
1502


In [None]:
print(train_contexts[0])
print(train_questions[0])
print(train_answers[0])

Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
{'answer_start': 515, 'text': 'Saint Bernadette Soubirous'}


## 新增答案結束位置

In [None]:
def add_end_idx(answers):
    for answer in answers:
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text) # Find end character index of answer in context
        answer.update({'answer_end': end_idx})

In [None]:
add_end_idx(train_answers)
add_end_idx(eval_answers)

In [None]:
print(train_answers[0])

{'answer_start': 515, 'text': 'Saint Bernadette Soubirous', 'answer_end': 541}


# 將資料進行 Tokenize
## 將 input 資料轉換成 token id 、tpye_id 與 attention_mask

In [None]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
eval_encodings = tokenizer(eval_contexts, eval_questions, truncation=True, padding=True)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

## 檢查轉換是否正確

In [None]:
train_encodings.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [None]:
print(train_encodings['input_ids'][0])
print(tokenizer.convert_ids_to_tokens(train_encodings['input_ids'][0]))
print(tokenizer.decode(train_encodings['input_ids'][0]))

print(train_encodings['token_type_ids'][0])
print(train_encodings['attention_mask'][0])

[101, 6549, 2135, 1010, 1996, 2082, 2038, 1037, 3234, 2839, 1012, 10234, 1996, 2364, 2311, 1005, 1055, 2751, 8514, 2003, 1037, 3585, 6231, 1997, 1996, 6261, 2984, 1012, 3202, 1999, 2392, 1997, 1996, 2364, 2311, 1998, 5307, 2009, 1010, 2003, 1037, 6967, 6231, 1997, 4828, 2007, 2608, 2039, 14995, 6924, 2007, 1996, 5722, 1000, 2310, 3490, 2618, 4748, 2033, 18168, 5267, 1000, 1012, 2279, 2000, 1996, 2364, 2311, 2003, 1996, 13546, 1997, 1996, 6730, 2540, 1012, 3202, 2369, 1996, 13546, 2003, 1996, 24665, 23052, 1010, 1037, 14042, 2173, 1997, 7083, 1998, 9185, 1012, 2009, 2003, 1037, 15059, 1997, 1996, 24665, 23052, 2012, 10223, 26371, 1010, 2605, 2073, 1996, 6261, 2984, 22353, 2135, 2596, 2000, 3002, 16595, 9648, 4674, 2061, 12083, 9711, 2271, 1999, 8517, 1012, 2012, 1996, 2203, 1997, 1996, 2364, 3298, 1006, 1998, 1999, 1037, 3622, 2240, 2008, 8539, 2083, 1017, 11342, 1998, 1996, 2751, 8514, 1007, 1010, 2003, 1037, 3722, 1010, 2715, 2962, 6231, 1997, 2984, 1012, 102, 2000, 3183, 2106, 1996, 

### 經過 Tokenize 轉換後，重新計算 context 中 answer 的開始與結束位置

In [None]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
        # if None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

In [None]:
add_token_positions(train_encodings, train_answers)
add_token_positions(eval_encodings, eval_answers)

In [None]:
print(train_encodings['start_positions'][0])
print(train_encodings['end_positions'][0])

114
121


# 定義 Dataset，並轉換成 tensor 格式

In [None]:
import torch
class Dataset(torch.utils.data.Dataset):
  def __init__(self, encodings):
    self.encodings = encodings

  def __getitem__(self, idx):
    return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

  def __len__(self):
    return len(self.encodings.input_ids)

In [None]:
train_dataset = Dataset(train_encodings)
eval_dataset = Dataset(eval_encodings)

In [None]:
train_dataset[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0

# 載入模型架構( QuestionAnswering )

In [None]:
from transformers import BertConfig, BertForQuestionAnswering
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased', config=config)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

## 查看模型架構

In [None]:
print(model)

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

# 訓練模型

In [None]:
import logging
import datasets
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
import math

import transformers
from accelerate import Accelerator
from transformers import (
    AdamW,
    AutoConfig,
    default_data_collator,
    get_scheduler
)

## 設定 epoch 與 batch size

In [None]:
train_batch_size = 4      # 設定 training batch size
eval_batch_size = 4    # 設定 eval batch size
num_train_epochs = 3      # 設定 epoch

## 將資料丟入 DataLoader


In [None]:
data_collator = default_data_collator
train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=eval_batch_size)

## Optimizer 、Learning rate 、Scheduler 設定

In [None]:
learning_rate=3e-5          # 設定 learning_rate
gradient_accumulation_steps = 1   # 設定 幾步後進行反向傳播

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },                                
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch
print('max_train_steps', max_train_steps)

# scheduler
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

max_train_steps 6003




## 將資料、參數丟入 Accelerator



In [None]:
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

## 設定 metric 方法

In [None]:
# Get the metric function

# metric = load_metric("accuracy")

## 開始訓練

In [None]:
# Train!
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)
output_dir = 'model/'


total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num Epochs = {num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")


completed_steps = 0
best_epoch = {"epoch:": 0, "acc": 0 }

for epoch in trange(num_train_epochs, desc="Epoch"):
  model.train()
  for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
    outputs = model(**batch)
    loss = outputs.loss
    loss = loss / gradient_accumulation_steps
    accelerator.backward(loss)
    if step % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()
      completed_steps += 1

    if step % 50 == 0:
      print({'epoch': epoch, 'step': step, 'loss': loss.item()})

    if completed_steps >= max_train_steps:
      break
      
  # logger.info("***** Running eval *****")
  # model.eval()
  # for step, batch in enumerate(tqdm(eval_dataloader, desc="Eval Iteration")):
  #   outputs = model(**batch)
  #   predictions = outputs.logits.argmax(dim=-1)
  #   metric.add_batch(
  #       predictions=accelerator.gather(predictions),
  #       references=accelerator.gather(batch["labels"]),
  #   )

  # eval_metric = metric.compute()
  # logger.info(f"epoch {epoch}: {eval_metric}")
  # if eval_metric > best_epoch['acc']:
  #   best_epoch['epoch'] = num_train_epochs
  #   best_epoch['acc'] = eval_metric


  if output_dir is not None:
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir + 'epoch_' + str(epoch) + '/', save_function=accelerator.save)


# 分析模型 (計算 exact match, F1-score )

In [None]:
from transformers import BertTokenizerFast, BertConfig, BertForQuestionAnswering, default_data_collator
from torch.utils.data import DataLoader
from accelerate import Accelerator
from tqdm.auto import tqdm
import json

In [None]:
%cd /content/gdrive/"BERT_SQuAD"  
!ls

## 載入模型與測試資料

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("./model/epoch_1/config.json") 
model = BertForQuestionAnswering.from_pretrained("./model/epoch_1/pytorch_model.bin", config=config)

In [None]:
eval_contexts, eval_questions, eval_answers = read_data('dev-v1.1.json',1500)
add_end_idx(eval_answers)
eval_encodings = tokenizer(eval_contexts, eval_questions, truncation=True, padding=True)
add_token_positions(eval_encodings, eval_answers)
eval_dataset = Dataset(eval_encodings)

In [None]:
eval_batch_size = 5      # 設定 batch size
data_collator = default_data_collator

eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=eval_batch_size)

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()

# Prepare everything with our `accelerator`.
model, eval_dataloader = accelerator.prepare(
    model, eval_dataloader
)

## 預測資料

In [None]:
print("***** Running eval *****")
model.eval()
ref = []
pre = []
start_predictions = []
end_predictions = []
index = 0
for step, batch in enumerate(tqdm(eval_dataloader, desc="Eval Iteration")):
  outputs = model(**batch)

  start_predicted = outputs.start_logits.argmax(dim=-1)
  end_predicted = outputs.end_logits.argmax(dim=-1)

  for input_id, s_lable, e_lable, s_pre, e_pre in zip(batch["input_ids"].tolist() ,batch["start_positions"].tolist(), batch["end_positions"].tolist(), start_predicted.tolist(), end_predicted.tolist()):

    ref.append({'id':index, 'input_ids':input_id, 'start':s_lable, 'end':e_lable})
    pre.append({'id':index, 'input_ids':input_id, 'start':s_pre, 'end':e_pre})
    index+=1


***** Running eval *****


Eval Iteration:   0%|          | 0/301 [00:00<?, ?it/s]

## 將start, end 位置反轉回文字並計算 EM、F1

In [None]:
predictions = []
references = []

for r, p in zip(ref, pre):

  if r['id'] == p['id']:
    context_token = r['input_ids']

    label_answer = tokenizer.decode(context_token[r['start']:r['end']+1])
    references.append({'answers': {'answer_start': [r['start']], 'text': [label_answer]}, 'id': str(r['id'])})

    prediction_answer = tokenizer.decode(context_token[p['start']:p['end']+1])
    predictions.append({'prediction_text': prediction_answer, 'id': str(p['id'])})


In [None]:
print(references[0])
print(predictions[0])

{'answers': {'answer_start': [34], 'text': ['denver broncos']}, 'id': '0'}
{'prediction_text': 'denver broncos', 'id': '0'}


In [None]:
import datasets

squad_metric = datasets.load_metric("squad")
results = squad_metric.compute(predictions=predictions, references=references)
print(results)

Downloading builder script:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

{'exact_match': 61.98402130492676, 'f1': 71.79389441893014}


# Inference

In [None]:
# **撰寫預測程式**
def QA_model(model, context, question):

  input_encodings = tokenizer([context], [question], truncation=True, padding=True)
  input_dataset = Dataset(input_encodings)

  data_collator = default_data_collator
  input_dataloader = DataLoader(input_dataset, collate_fn=data_collator, batch_size=1)  

  accelerator = Accelerator()
  model, input_dataloader = accelerator.prepare(model, input_dataloader)
  for batch in input_dataloader:
    outputs = model(**batch)

    start_predicted = outputs.start_logits.argmax(dim=-1)
    end_predicted = outputs.end_logits.argmax(dim=-1)
    
    input_ids = batch['input_ids'][0]
    answer = tokenizer.decode(input_ids[start_predicted:end_predicted+1])

  return answer

In [None]:
context = '''Harry Potter is a series of seven fantasy novels written by British author J. K. Rowling. The novels chronicle the lives of a young wizard, Harry Potter, and his friends Hermione Granger and Ron Weasley, all of whom are students at Hogwarts School of Witchcraft and Wizardry. The main story arc concerns Harry's struggle against Lord Voldemort, a dark wizard who intends to become immortal, overthrow the wizard governing body known as the Ministry of Magic and subjugate all wizards and Muggles (non-magical people).'''
question = 'Who is the author of harry potter?'

answer = QA_model(model, context, question)  
print(answer)

j. k. rowling
