# **Homework 7 - Bert (Question Answering)**

If you have any questions, feel free to email us at mlta-2022-spring@googlegroups.com



Slide:    [Link](https://docs.google.com/presentation/d/1H5ZONrb2LMOCixLY7D5_5-7LkIaXO6AGEaV2mRdTOMY/edit?usp=sharing)　Kaggle: [Link](https://www.kaggle.com/c/ml2022spring-hw7)　Data: [Link](https://drive.google.com/uc?id=1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb)




## Task description
- Chinese Extractive Question Answering
  - Input: Paragraph + Question
  - Output: Answer

- Objective: Learn how to fine tune a pretrained model on downstream task using transformers

- Todo
    - Fine tune a pretrained chinese BERT model
    - Change hyperparameters (e.g. doc_stride)
    - Apply linear learning rate decay
    - Try other pretrained models
    - Improve preprocessing
    - Improve postprocessing
- Training tips
    - Automatic mixed precision
    - Gradient accumulation
    - Ensemble

- Estimated training time (tesla t4 with automatic mixed precision enabled)
    - Simple: 8mins
    - Medium: 8mins
    - Strong: 25mins
    - Boss: 2.5hrs
  

## Download Dataset

In [1]:
# import gdown
# # Download link 1
# !gdown --id '1AVgZvy3VFeg0fX-6WQJMHPVrx3A-M1kb' --output hw7_data.zip

# # Download Link 2 (if the above link fails) 
# # !gdown --id '1qwjbRjq481lHsnTrrF4OjKQnxzgoLEFR' --output hw7_data.zip

# # Download Link 3 (if the above link fails) 
# # !gdown --id '1QXuWjNRZH6DscSd6QcRER0cnxmpZvijn' --output hw7_data.zip

# !unzip -o hw7_data.zip

# # For this HW, K80 < P4 < T4 < P100 <= T4(fp16) < V100
# !nvidia-smi

## Install transformers

Documentation for the toolkit:　https://huggingface.co/transformers/

In [2]:
# # You are allowed to change version of transformers or use other toolkits
# !pip install transformers==4.18.0

## Import Packages

In [3]:
import json
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, Dataset 
from transformers import AdamW, BertForQuestionAnswering, BertTokenizerFast, get_linear_schedule_with_warmup
import os

from tqdm.auto import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Fix random seed for reproducibility
def same_seeds(seed):
	  torch.manual_seed(seed)
	  if torch.cuda.is_available():
		    torch.cuda.manual_seed(seed)
		    torch.cuda.manual_seed_all(seed)
	  np.random.seed(seed)
	  random.seed(seed)
	  torch.backends.cudnn.benchmark = False
	  torch.backends.cudnn.deterministic = True
same_seeds(777777)

In [4]:
# Change "fp16_training" to True to support automatic mixed precision training (fp16)	
fp16_training = True

if fp16_training:
    # !pip install accelerate==0.2.0
    from accelerate import Accelerator
    accelerator = Accelerator(fp16=True)
    device = accelerator.device

# Documentation for the toolkit:  https://huggingface.co/docs/accelerate/

## Load Model and Tokenizer




 

In [5]:
# model = BertForQuestionAnswering.from_pretrained("bert-base-chinese").to(device)
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
pre_model = "luhua/chinese_pretrain_mrc_macbert_large"
tokenizer = BertTokenizerFast.from_pretrained(pre_model)
model = BertForQuestionAnswering.from_pretrained(pre_model).to(device)

# tokenizer = RobertaTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')
# model = RobertaModel.from_pretrained('hfl/chinese-roberta-wwm-ext')

# You can safely ignore the warning message (it pops up because new prediction heads for QA are initialized randomly)

## Read Data

- Training set: 31690 QA pairs
- Dev set: 4131  QA pairs
- Test set: 4957  QA pairs

- {train/dev/test}_questions:	
  - List of dicts with the following keys:
   - id (int)
   - paragraph_id (int)
   - question_text (string)
   - answer_text (string)
   - answer_start (int)
   - answer_end (int)
- {train/dev/test}_paragraphs: 
  - List of strings
  - paragraph_ids in questions correspond to indexs in paragraphs
  - A paragraph may be used by several questions 

In [6]:
def read_data(file):
    with open(file, 'r', encoding="utf-8") as reader:
        data = json.load(reader)
    return data["questions"], data["paragraphs"]

train_questions, train_paragraphs = read_data("hw7_train.json")
dev_questions, dev_paragraphs = read_data("hw7_dev.json")
test_questions, test_paragraphs = read_data("hw7_test.json")

## Tokenize Data

In [7]:
# Tokenize questions and paragraphs separately
# 「add_special_tokens」 is set to False since special tokens will be added when tokenized questions and paragraphs are combined in datset __getitem__ 

train_questions_tokenized = tokenizer([train_question["question_text"] for train_question in train_questions], add_special_tokens=False)
dev_questions_tokenized = tokenizer([dev_question["question_text"] for dev_question in dev_questions], add_special_tokens=False)
test_questions_tokenized = tokenizer([test_question["question_text"] for test_question in test_questions], add_special_tokens=False) 

train_paragraphs_tokenized = tokenizer(train_paragraphs, add_special_tokens=False)
dev_paragraphs_tokenized = tokenizer(dev_paragraphs, add_special_tokens=False)
test_paragraphs_tokenized = tokenizer(test_paragraphs, add_special_tokens=False)

# You can safely ignore the warning message as tokenized sequences will be futher processed in datset __getitem__ before passing to model

## Dataset and Dataloader

In [8]:
class QA_Dataset(Dataset):
    def __init__(self, split, questions, tokenized_questions, tokenized_paragraphs):
        self.split = split
        self.questions = questions
        self.tokenized_questions = tokenized_questions
        self.tokenized_paragraphs = tokenized_paragraphs
        self.max_question_len = 50
        self.max_paragraph_len = 350
        
        ##### TODO: Change value of doc_stride #####
        self.doc_stride = 310

        # Input sequence length = [CLS] + question + [SEP] + paragraph + [SEP]
        self.max_seq_len = 1 + self.max_question_len + 1 + self.max_paragraph_len + 1

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        tokenized_question = self.tokenized_questions[idx]
        tokenized_paragraph = self.tokenized_paragraphs[question["paragraph_id"]]

        ##### TODO: Preprocessing #####
        # Hint: How to prevent model from learning something it should not learn
        exceed = True if len(tokenized_paragraph) > self.max_paragraph_len else False

        if self.split == "train":
            # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph  
            answer_start_token = tokenized_paragraph.char_to_token(question["answer_start"])
            answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"])

            # A single window is obtained by slicing the portion of paragraph containing the answer
            if exceed:
                mid = (answer_start_token + answer_end_token) // 2
                paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))
            else:
                rand_start = random.randint(0, answer_start_token)
                paragraph_start = rand_start
            paragraph_end = paragraph_start + self.max_paragraph_len

            # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)
            input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102] 
            input_ids_paragraph = tokenized_paragraph.ids[paragraph_start : paragraph_end] + [102]		
            
            # Convert answer's start/end positions in tokenized_paragraph to start/end positions in the window
            answer_start_token += len(input_ids_question) - paragraph_start
            answer_end_token += len(input_ids_question) - paragraph_start
            
            # Pad sequence and obtain inputs to model 
            input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)
            return torch.tensor(input_ids), torch.tensor(token_type_ids), torch.tensor(attention_mask), answer_start_token, answer_end_token

        # Validation/Testing
        else:
            input_ids_list, token_type_ids_list, attention_mask_list = [], [], []
            qa_offset_list = []
            p_offset_list = []
            # Paragraph is split into several windows, each with start positions separated by step "doc_stride"
            for i in range(0, len(tokenized_paragraph), self.doc_stride):
                
                # Slice question/paragraph and add special tokens (101: CLS, 102: SEP)
                input_ids_question = [101] + tokenized_question.ids[:self.max_question_len] + [102]
                input_ids_paragraph = tokenized_paragraph.ids[i : i + self.max_paragraph_len] + [102]
                
                # calculate qa offset
                qa_offset = len(input_ids_question)
                qa_offset_list.append(qa_offset)
                p_offset = tokenized_paragraph.offsets[i : i + self.max_paragraph_len]
                p_offset_list.append(p_offset)
                # print(f"qa: {len(input_ids_question)})")
                # print(f"para: {len(input_ids_paragraph)})")
                # Pad sequence and obtain inputs to model
                input_ids, token_type_ids, attention_mask = self.padding(input_ids_question, input_ids_paragraph)
                
                input_ids_list.append(input_ids)
                token_type_ids_list.append(token_type_ids)
                attention_mask_list.append(attention_mask)
            
            return torch.tensor(input_ids_list), torch.tensor(token_type_ids_list), torch.tensor(attention_mask_list), torch.tensor(qa_offset_list), p_offset_list

    def padding(self, input_ids_question, input_ids_paragraph):
        # Pad zeros if sequence length is shorter than max_seq_len
        padding_len = self.max_seq_len - len(input_ids_question) - len(input_ids_paragraph)
        # Indices of input sequence tokens in the vocabulary
        input_ids = input_ids_question + input_ids_paragraph + [0] * padding_len
        # Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0, 1]
        token_type_ids = [0] * len(input_ids_question) + [1] * len(input_ids_paragraph) + [0] * padding_len
        # Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]
        attention_mask = [1] * (len(input_ids_question) + len(input_ids_paragraph)) + [0] * padding_len
        
        return input_ids, token_type_ids, attention_mask

train_set = QA_Dataset("train", train_questions, train_questions_tokenized, train_paragraphs_tokenized)
dev_set = QA_Dataset("dev", dev_questions, dev_questions_tokenized, dev_paragraphs_tokenized)
test_set = QA_Dataset("test", test_questions, test_questions_tokenized, test_paragraphs_tokenized)

train_batch_size = 4

# Note: Do NOT change batch size of dev_loader / test_loader !
# Although batch size=1, it is actually a batch consisting of several windows from the same QA pair
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, pin_memory=True)
dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=True)

## Training

In [9]:
num_epoch = 3
validation = False
logging_step = 100
learning_rate = 1e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)

# batch accumulation parameter
accum_iter = 2

#set up scheduler
len_dataset = len(train_set)
print(f"length of train_set: {len_dataset}")
total_steps = (len_dataset // accum_iter) * num_epoch
warm_up_ratio = 0
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)



if fp16_training:
    model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) 

model.train()

print("Start Training ...")

best_acc = 0

for epoch in range(num_epoch):
    step = 1
    train_loss = train_acc = 0
    
    for data in tqdm(train_loader):	
        # Load all data into GPU
        data = [i.to(device) for i in data]
        
        # Model inputs: input_ids, token_type_ids, attention_mask, start_positions, end_positions (Note: only "input_ids" is mandatory)
        # Model outputs: start_logits, end_logits, loss (return when start_positions/end_positions are provided)  
        output = model(input_ids=data[0], token_type_ids=data[1], attention_mask=data[2], start_positions=data[3], end_positions=data[4])

        # Choose the most probable start position / end position
        start_index = torch.argmax(output.start_logits, dim=1)
        end_index = torch.argmax(output.end_logits, dim=1)
        
        # Prediction is correct only if both start_index and end_index are correct
        train_acc += ((start_index == data[3]) & (end_index == data[4])).float().mean()

        # normalize loss to account for batch accumulation
        train_loss += output.loss / accum_iter 
        
        if fp16_training:
            accelerator.backward(output.loss)
        else:
            output.loss.backward()
        
        # weights update
        if (step % accum_iter == 0) or (step == len(train_loader)):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
        # optimizer.step()
        # optimizer.zero_grad()
        step += 1

        # # ##### TODO: Apply linear learning rate decay #####
        # scheduler.step()
        
        # Print training loss and accuracy over past logging step
        if step % logging_step == 0:
            print(f"Epoch {epoch + 1} | Step {step} | loss = {train_loss.item() / logging_step:.3f}, acc = {train_acc / logging_step:.3f}")
            train_loss = train_acc = 0



    if validation:
        print("Evaluating Dev Set ...")
        model.eval()
        with torch.no_grad():
            dev_acc = 0
            for i, data in enumerate(tqdm(dev_loader)):
                output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
                       attention_mask=data[2].squeeze(dim=0).to(device))
                # prediction is correct only if answer text exactly matches
                dev_acc += evaluate(data, output) == dev_questions[i]["answer_text"]
            print(f"Validation | Epoch {epoch + 1} | acc = {dev_acc / len(dev_loader):.3f}")

            if dev_acc > best_acc:
                best_acc = dev_acc
                print("Saving Model ...")
                model_save_dir = "saved_model_last" 
                model.save_pretrained(model_save_dir)
        model.train()

# Save a model and its configuration file to the directory 「saved_model」 
# i.e. there are two files under the direcory 「saved_model」: 「pytorch_model.bin」 and 「config.json」
# Saved model can be re-loaded using 「model = BertForQuestionAnswering.from_pretrained("saved_model")」

print("Saving Model ...")
model_save_dir = "saved_model_last" 
model.save_pretrained(model_save_dir)



length of train_set: 31690
Start Training ...


  1%|▏         | 100/7923 [00:14<19:04,  6.84it/s]

Epoch 1 | Step 100 | loss = 0.496, acc = 0.625


  3%|▎         | 200/7923 [00:27<19:05,  6.74it/s]

Epoch 1 | Step 200 | loss = 0.351, acc = 0.717


  4%|▍         | 300/7923 [00:41<18:47,  6.76it/s]

Epoch 1 | Step 300 | loss = 0.275, acc = 0.750


  5%|▌         | 400/7923 [00:55<18:40,  6.72it/s]

Epoch 1 | Step 400 | loss = 0.340, acc = 0.710


  6%|▋         | 500/7923 [01:09<18:27,  6.70it/s]

Epoch 1 | Step 500 | loss = 0.254, acc = 0.792


  8%|▊         | 600/7923 [01:22<18:03,  6.76it/s]

Epoch 1 | Step 600 | loss = 0.297, acc = 0.750


  9%|▉         | 700/7923 [01:36<17:46,  6.77it/s]

Epoch 1 | Step 700 | loss = 0.247, acc = 0.782


 10%|█         | 800/7923 [01:50<17:44,  6.69it/s]

Epoch 1 | Step 800 | loss = 0.298, acc = 0.765


 11%|█▏        | 900/7923 [02:04<17:29,  6.69it/s]

Epoch 1 | Step 900 | loss = 0.268, acc = 0.785


 13%|█▎        | 1000/7923 [02:18<17:17,  6.67it/s]

Epoch 1 | Step 1000 | loss = 0.287, acc = 0.770


 14%|█▍        | 1100/7923 [02:32<17:04,  6.66it/s]

Epoch 1 | Step 1100 | loss = 0.273, acc = 0.757


 15%|█▌        | 1200/7923 [02:46<16:44,  6.69it/s]

Epoch 1 | Step 1200 | loss = 0.262, acc = 0.775


 16%|█▋        | 1300/7923 [02:59<16:18,  6.77it/s]

Epoch 1 | Step 1300 | loss = 0.223, acc = 0.812


 18%|█▊        | 1400/7923 [03:13<16:14,  6.69it/s]

Epoch 1 | Step 1400 | loss = 0.265, acc = 0.787


 19%|█▉        | 1500/7923 [03:27<16:04,  6.66it/s]

Epoch 1 | Step 1500 | loss = 0.220, acc = 0.803


 20%|██        | 1600/7923 [03:41<15:38,  6.74it/s]

Epoch 1 | Step 1600 | loss = 0.281, acc = 0.757


 21%|██▏       | 1700/7923 [03:55<15:32,  6.67it/s]

Epoch 1 | Step 1700 | loss = 0.212, acc = 0.795


 23%|██▎       | 1800/7923 [04:09<15:21,  6.65it/s]

Epoch 1 | Step 1800 | loss = 0.278, acc = 0.787


 24%|██▍       | 1900/7923 [04:23<15:05,  6.65it/s]

Epoch 1 | Step 1900 | loss = 0.259, acc = 0.785


 25%|██▌       | 2000/7923 [04:37<14:50,  6.65it/s]

Epoch 1 | Step 2000 | loss = 0.238, acc = 0.795


 27%|██▋       | 2100/7923 [04:51<14:40,  6.61it/s]

Epoch 1 | Step 2100 | loss = 0.259, acc = 0.790


 28%|██▊       | 2200/7923 [05:05<14:23,  6.63it/s]

Epoch 1 | Step 2200 | loss = 0.240, acc = 0.815


 29%|██▉       | 2300/7923 [05:19<13:56,  6.72it/s]

Epoch 1 | Step 2300 | loss = 0.206, acc = 0.808


 30%|███       | 2400/7923 [05:33<13:55,  6.61it/s]

Epoch 1 | Step 2400 | loss = 0.227, acc = 0.805


 32%|███▏      | 2500/7923 [05:47<13:37,  6.63it/s]

Epoch 1 | Step 2500 | loss = 0.238, acc = 0.810


 33%|███▎      | 2600/7923 [06:01<13:24,  6.62it/s]

Epoch 1 | Step 2600 | loss = 0.211, acc = 0.830


 34%|███▍      | 2700/7923 [06:15<12:59,  6.70it/s]

Epoch 1 | Step 2700 | loss = 0.219, acc = 0.822


 35%|███▌      | 2800/7923 [06:29<12:46,  6.68it/s]

Epoch 1 | Step 2800 | loss = 0.315, acc = 0.757


 37%|███▋      | 2900/7923 [06:43<12:39,  6.61it/s]

Epoch 1 | Step 2900 | loss = 0.231, acc = 0.820


 38%|███▊      | 3000/7923 [06:57<12:22,  6.63it/s]

Epoch 1 | Step 3000 | loss = 0.243, acc = 0.827


 39%|███▉      | 3100/7923 [07:11<12:10,  6.60it/s]

Epoch 1 | Step 3100 | loss = 0.246, acc = 0.800


 40%|████      | 3200/7923 [07:25<11:54,  6.61it/s]

Epoch 1 | Step 3200 | loss = 0.193, acc = 0.827


 42%|████▏     | 3300/7923 [07:39<11:32,  6.68it/s]

Epoch 1 | Step 3300 | loss = 0.203, acc = 0.847


 43%|████▎     | 3400/7923 [07:53<11:28,  6.56it/s]

Epoch 1 | Step 3400 | loss = 0.274, acc = 0.777


 44%|████▍     | 3500/7923 [08:07<11:13,  6.57it/s]

Epoch 1 | Step 3500 | loss = 0.192, acc = 0.822


 45%|████▌     | 3600/7923 [08:21<10:56,  6.59it/s]

Epoch 1 | Step 3600 | loss = 0.207, acc = 0.822


 47%|████▋     | 3700/7923 [08:35<10:44,  6.55it/s]

Epoch 1 | Step 3700 | loss = 0.231, acc = 0.772


 48%|████▊     | 3800/7923 [08:49<10:18,  6.66it/s]

Epoch 1 | Step 3800 | loss = 0.199, acc = 0.825


 49%|████▉     | 3900/7923 [09:04<10:12,  6.57it/s]

Epoch 1 | Step 3900 | loss = 0.254, acc = 0.803


 50%|█████     | 4000/7923 [09:18<09:54,  6.60it/s]

Epoch 1 | Step 4000 | loss = 0.248, acc = 0.825


 52%|█████▏    | 4100/7923 [09:32<09:40,  6.58it/s]

Epoch 1 | Step 4100 | loss = 0.243, acc = 0.770


 53%|█████▎    | 4200/7923 [09:46<09:24,  6.59it/s]

Epoch 1 | Step 4200 | loss = 0.211, acc = 0.822


 54%|█████▍    | 4300/7923 [10:00<09:04,  6.65it/s]

Epoch 1 | Step 4300 | loss = 0.194, acc = 0.832


 56%|█████▌    | 4400/7923 [10:14<08:54,  6.60it/s]

Epoch 1 | Step 4400 | loss = 0.235, acc = 0.803


 57%|█████▋    | 4500/7923 [10:28<08:39,  6.59it/s]

Epoch 1 | Step 4500 | loss = 0.225, acc = 0.808


 58%|█████▊    | 4600/7923 [10:42<08:26,  6.56it/s]

Epoch 1 | Step 4600 | loss = 0.221, acc = 0.850


 59%|█████▉    | 4700/7923 [10:57<08:04,  6.65it/s]

Epoch 1 | Step 4700 | loss = 0.212, acc = 0.837


 61%|██████    | 4800/7923 [11:11<07:53,  6.60it/s]

Epoch 1 | Step 4800 | loss = 0.220, acc = 0.805


 62%|██████▏   | 4900/7923 [11:25<07:40,  6.56it/s]

Epoch 1 | Step 4900 | loss = 0.258, acc = 0.757


 63%|██████▎   | 5000/7923 [11:39<07:20,  6.63it/s]

Epoch 1 | Step 5000 | loss = 0.209, acc = 0.830


 64%|██████▍   | 5100/7923 [11:53<07:09,  6.57it/s]

Epoch 1 | Step 5100 | loss = 0.184, acc = 0.837


 66%|██████▌   | 5200/7923 [12:07<06:55,  6.55it/s]

Epoch 1 | Step 5200 | loss = 0.251, acc = 0.808


 67%|██████▋   | 5300/7923 [12:21<06:39,  6.57it/s]

Epoch 1 | Step 5300 | loss = 0.201, acc = 0.847


 68%|██████▊   | 5400/7923 [12:35<06:27,  6.52it/s]

Epoch 1 | Step 5400 | loss = 0.185, acc = 0.842


 69%|██████▉   | 5500/7923 [12:50<06:08,  6.58it/s]

Epoch 1 | Step 5500 | loss = 0.186, acc = 0.855


 71%|███████   | 5600/7923 [13:04<05:49,  6.64it/s]

Epoch 1 | Step 5600 | loss = 0.177, acc = 0.815


 72%|███████▏  | 5700/7923 [13:18<05:37,  6.58it/s]

Epoch 1 | Step 5700 | loss = 0.231, acc = 0.827


 73%|███████▎  | 5800/7923 [13:32<05:22,  6.59it/s]

Epoch 1 | Step 5800 | loss = 0.227, acc = 0.837


 74%|███████▍  | 5900/7923 [13:46<05:07,  6.57it/s]

Epoch 1 | Step 5900 | loss = 0.225, acc = 0.800


 76%|███████▌  | 6000/7923 [14:00<04:53,  6.56it/s]

Epoch 1 | Step 6000 | loss = 0.200, acc = 0.810


 77%|███████▋  | 6100/7923 [14:14<04:35,  6.63it/s]

Epoch 1 | Step 6100 | loss = 0.203, acc = 0.840


 78%|███████▊  | 6200/7923 [14:29<04:24,  6.52it/s]

Epoch 1 | Step 6200 | loss = 0.189, acc = 0.820


 80%|███████▉  | 6300/7923 [14:43<04:08,  6.54it/s]

Epoch 1 | Step 6300 | loss = 0.171, acc = 0.835


 81%|████████  | 6400/7923 [14:57<03:52,  6.55it/s]

Epoch 1 | Step 6400 | loss = 0.207, acc = 0.840


 82%|████████▏ | 6500/7923 [15:11<03:34,  6.62it/s]

Epoch 1 | Step 6500 | loss = 0.189, acc = 0.817


 83%|████████▎ | 6600/7923 [15:25<03:21,  6.57it/s]

Epoch 1 | Step 6600 | loss = 0.215, acc = 0.840


 85%|████████▍ | 6700/7923 [15:39<03:06,  6.57it/s]

Epoch 1 | Step 6700 | loss = 0.258, acc = 0.792


 86%|████████▌ | 6800/7923 [15:54<02:51,  6.55it/s]

Epoch 1 | Step 6800 | loss = 0.193, acc = 0.837


 87%|████████▋ | 6900/7923 [16:08<02:35,  6.58it/s]

Epoch 1 | Step 6900 | loss = 0.245, acc = 0.810


 88%|████████▊ | 7000/7923 [16:22<02:20,  6.57it/s]

Epoch 1 | Step 7000 | loss = 0.165, acc = 0.837


 90%|████████▉ | 7100/7923 [16:36<02:04,  6.60it/s]

Epoch 1 | Step 7100 | loss = 0.212, acc = 0.827


 91%|█████████ | 7200/7923 [16:50<01:50,  6.55it/s]

Epoch 1 | Step 7200 | loss = 0.210, acc = 0.837


 92%|█████████▏| 7300/7923 [17:04<01:35,  6.54it/s]

Epoch 1 | Step 7300 | loss = 0.238, acc = 0.808


 93%|█████████▎| 7400/7923 [17:18<01:18,  6.66it/s]

Epoch 1 | Step 7400 | loss = 0.185, acc = 0.845


 95%|█████████▍| 7500/7923 [17:33<01:04,  6.57it/s]

Epoch 1 | Step 7500 | loss = 0.166, acc = 0.860


 96%|█████████▌| 7600/7923 [17:47<00:49,  6.54it/s]

Epoch 1 | Step 7600 | loss = 0.171, acc = 0.825


 97%|█████████▋| 7700/7923 [18:01<00:33,  6.57it/s]

Epoch 1 | Step 7700 | loss = 0.220, acc = 0.810


 98%|█████████▊| 7800/7923 [18:15<00:18,  6.54it/s]

Epoch 1 | Step 7800 | loss = 0.199, acc = 0.820


100%|█████████▉| 7900/7923 [18:29<00:03,  6.63it/s]

Epoch 1 | Step 7900 | loss = 0.224, acc = 0.792


100%|██████████| 7923/7923 [18:32<00:00,  7.12it/s]
  1%|▏         | 100/7923 [00:14<19:52,  6.56it/s]

Epoch 2 | Step 100 | loss = 0.099, acc = 0.890


  3%|▎         | 200/7923 [00:28<19:44,  6.52it/s]

Epoch 2 | Step 200 | loss = 0.088, acc = 0.910


  4%|▍         | 300/7923 [00:42<19:29,  6.52it/s]

Epoch 2 | Step 300 | loss = 0.119, acc = 0.885


  5%|▌         | 400/7923 [00:56<18:57,  6.62it/s]

Epoch 2 | Step 400 | loss = 0.080, acc = 0.922


  6%|▋         | 500/7923 [01:10<18:51,  6.56it/s]

Epoch 2 | Step 500 | loss = 0.080, acc = 0.902


  8%|▊         | 600/7923 [01:24<18:23,  6.64it/s]

Epoch 2 | Step 600 | loss = 0.095, acc = 0.907


  9%|▉         | 700/7923 [01:39<18:29,  6.51it/s]

Epoch 2 | Step 700 | loss = 0.081, acc = 0.915


 10%|█         | 800/7923 [01:53<18:11,  6.53it/s]

Epoch 2 | Step 800 | loss = 0.100, acc = 0.902


 11%|█▏        | 900/7923 [02:07<18:01,  6.50it/s]

Epoch 2 | Step 900 | loss = 0.101, acc = 0.890


 13%|█▎        | 1000/7923 [02:21<17:37,  6.55it/s]

Epoch 2 | Step 1000 | loss = 0.086, acc = 0.905


 14%|█▍        | 1100/7923 [02:35<17:13,  6.60it/s]

Epoch 2 | Step 1100 | loss = 0.114, acc = 0.875


 15%|█▌        | 1200/7923 [02:49<17:11,  6.52it/s]

Epoch 2 | Step 1200 | loss = 0.116, acc = 0.890


 16%|█▋        | 1300/7923 [03:04<16:48,  6.56it/s]

Epoch 2 | Step 1300 | loss = 0.072, acc = 0.922


 18%|█▊        | 1400/7923 [03:18<16:47,  6.47it/s]

Epoch 2 | Step 1400 | loss = 0.104, acc = 0.890


 19%|█▉        | 1500/7923 [03:32<16:09,  6.63it/s]

Epoch 2 | Step 1500 | loss = 0.096, acc = 0.897


 20%|██        | 1600/7923 [03:46<16:06,  6.54it/s]

Epoch 2 | Step 1600 | loss = 0.124, acc = 0.900


 21%|██▏       | 1700/7923 [04:00<15:46,  6.57it/s]

Epoch 2 | Step 1700 | loss = 0.114, acc = 0.902


 23%|██▎       | 1800/7923 [04:14<15:23,  6.63it/s]

Epoch 2 | Step 1800 | loss = 0.119, acc = 0.887


 24%|██▍       | 1900/7923 [04:29<15:23,  6.52it/s]

Epoch 2 | Step 1900 | loss = 0.107, acc = 0.885


 25%|██▌       | 2000/7923 [04:43<15:06,  6.53it/s]

Epoch 2 | Step 2000 | loss = 0.103, acc = 0.907


 27%|██▋       | 2100/7923 [04:57<14:55,  6.50it/s]

Epoch 2 | Step 2100 | loss = 0.080, acc = 0.920


 28%|██▊       | 2200/7923 [05:11<14:28,  6.59it/s]

Epoch 2 | Step 2200 | loss = 0.095, acc = 0.882


 29%|██▉       | 2300/7923 [05:25<14:24,  6.50it/s]

Epoch 2 | Step 2300 | loss = 0.106, acc = 0.880


 30%|███       | 2400/7923 [05:39<14:06,  6.52it/s]

Epoch 2 | Step 2400 | loss = 0.101, acc = 0.917


 32%|███▏      | 2500/7923 [05:54<13:44,  6.58it/s]

Epoch 2 | Step 2500 | loss = 0.099, acc = 0.895


 33%|███▎      | 2600/7923 [06:08<13:33,  6.55it/s]

Epoch 2 | Step 2600 | loss = 0.090, acc = 0.912


 34%|███▍      | 2700/7923 [06:22<13:14,  6.58it/s]

Epoch 2 | Step 2700 | loss = 0.111, acc = 0.873


 35%|███▌      | 2800/7923 [06:36<13:04,  6.53it/s]

Epoch 2 | Step 2800 | loss = 0.116, acc = 0.907


 37%|███▋      | 2900/7923 [06:50<12:46,  6.55it/s]

Epoch 2 | Step 2900 | loss = 0.135, acc = 0.887


 38%|███▊      | 3000/7923 [07:04<12:32,  6.54it/s]

Epoch 2 | Step 3000 | loss = 0.106, acc = 0.897


 39%|███▉      | 3100/7923 [07:19<12:18,  6.53it/s]

Epoch 2 | Step 3100 | loss = 0.100, acc = 0.885


 40%|████      | 3200/7923 [07:33<12:00,  6.55it/s]

Epoch 2 | Step 3200 | loss = 0.092, acc = 0.915


 42%|████▏     | 3300/7923 [07:47<11:36,  6.63it/s]

Epoch 2 | Step 3300 | loss = 0.078, acc = 0.902


 43%|████▎     | 3400/7923 [08:01<11:30,  6.56it/s]

Epoch 2 | Step 3400 | loss = 0.121, acc = 0.892


 44%|████▍     | 3500/7923 [08:15<11:12,  6.57it/s]

Epoch 2 | Step 3500 | loss = 0.111, acc = 0.902


 45%|████▌     | 3600/7923 [08:29<11:02,  6.52it/s]

Epoch 2 | Step 3600 | loss = 0.079, acc = 0.912


 47%|████▋     | 3700/7923 [08:44<10:43,  6.57it/s]

Epoch 2 | Step 3700 | loss = 0.084, acc = 0.900


 48%|████▊     | 3800/7923 [08:58<10:25,  6.59it/s]

Epoch 2 | Step 3800 | loss = 0.090, acc = 0.910


 49%|████▉     | 3900/7923 [09:12<10:12,  6.57it/s]

Epoch 2 | Step 3900 | loss = 0.092, acc = 0.907


 50%|█████     | 4000/7923 [09:26<09:54,  6.60it/s]

Epoch 2 | Step 4000 | loss = 0.101, acc = 0.900


 52%|█████▏    | 4100/7923 [09:40<09:37,  6.62it/s]

Epoch 2 | Step 4100 | loss = 0.111, acc = 0.902


 53%|█████▎    | 4200/7923 [09:54<09:33,  6.49it/s]

Epoch 2 | Step 4200 | loss = 0.100, acc = 0.895


 54%|█████▍    | 4300/7923 [10:09<09:12,  6.56it/s]

Epoch 2 | Step 4300 | loss = 0.096, acc = 0.900


 56%|█████▌    | 4400/7923 [10:23<08:58,  6.54it/s]

Epoch 2 | Step 4400 | loss = 0.116, acc = 0.880


 57%|█████▋    | 4500/7923 [10:37<08:44,  6.53it/s]

Epoch 2 | Step 4500 | loss = 0.116, acc = 0.882


 58%|█████▊    | 4600/7923 [10:51<08:23,  6.60it/s]

Epoch 2 | Step 4600 | loss = 0.101, acc = 0.910


 59%|█████▉    | 4700/7923 [11:05<08:12,  6.55it/s]

Epoch 2 | Step 4700 | loss = 0.107, acc = 0.897


 61%|██████    | 4800/7923 [11:19<07:55,  6.57it/s]

Epoch 2 | Step 4800 | loss = 0.098, acc = 0.897


 62%|██████▏   | 4900/7923 [11:34<07:43,  6.52it/s]

Epoch 2 | Step 4900 | loss = 0.096, acc = 0.910


 63%|██████▎   | 5000/7923 [11:48<07:21,  6.62it/s]

Epoch 2 | Step 5000 | loss = 0.096, acc = 0.910


 64%|██████▍   | 5100/7923 [12:02<07:11,  6.55it/s]

Epoch 2 | Step 5100 | loss = 0.093, acc = 0.897


 66%|██████▌   | 5200/7923 [12:16<06:59,  6.49it/s]

Epoch 2 | Step 5200 | loss = 0.123, acc = 0.880


 67%|██████▋   | 5300/7923 [12:30<06:39,  6.57it/s]

Epoch 2 | Step 5300 | loss = 0.092, acc = 0.915


 68%|██████▊   | 5400/7923 [12:44<06:26,  6.53it/s]

Epoch 2 | Step 5400 | loss = 0.103, acc = 0.895


 69%|██████▉   | 5500/7923 [12:59<06:10,  6.55it/s]

Epoch 2 | Step 5500 | loss = 0.114, acc = 0.885


 71%|███████   | 5600/7923 [13:13<05:50,  6.63it/s]

Epoch 2 | Step 5600 | loss = 0.098, acc = 0.902


 72%|███████▏  | 5700/7923 [13:27<05:40,  6.53it/s]

Epoch 2 | Step 5700 | loss = 0.087, acc = 0.897


 73%|███████▎  | 5800/7923 [13:41<05:23,  6.56it/s]

Epoch 2 | Step 5800 | loss = 0.113, acc = 0.900


 74%|███████▍  | 5900/7923 [13:55<05:05,  6.61it/s]

Epoch 2 | Step 5900 | loss = 0.083, acc = 0.910


 76%|███████▌  | 6000/7923 [14:09<04:52,  6.57it/s]

Epoch 2 | Step 6000 | loss = 0.094, acc = 0.915


 77%|███████▋  | 6100/7923 [14:24<04:37,  6.58it/s]

Epoch 2 | Step 6100 | loss = 0.161, acc = 0.880


 78%|███████▊  | 6200/7923 [14:38<04:20,  6.63it/s]

Epoch 2 | Step 6200 | loss = 0.092, acc = 0.917


 80%|███████▉  | 6300/7923 [14:52<04:08,  6.53it/s]

Epoch 2 | Step 6300 | loss = 0.082, acc = 0.927


 81%|████████  | 6400/7923 [15:06<03:54,  6.50it/s]

Epoch 2 | Step 6400 | loss = 0.095, acc = 0.922


 82%|████████▏ | 6500/7923 [15:20<03:36,  6.57it/s]

Epoch 2 | Step 6500 | loss = 0.087, acc = 0.907


 83%|████████▎ | 6600/7923 [15:34<03:22,  6.52it/s]

Epoch 2 | Step 6600 | loss = 0.134, acc = 0.875


 85%|████████▍ | 6700/7923 [15:49<03:07,  6.52it/s]

Epoch 2 | Step 6700 | loss = 0.087, acc = 0.907


 86%|████████▌ | 6800/7923 [16:03<02:51,  6.55it/s]

Epoch 2 | Step 6800 | loss = 0.096, acc = 0.917


 87%|████████▋ | 6900/7923 [16:17<02:37,  6.51it/s]

Epoch 2 | Step 6900 | loss = 0.091, acc = 0.910


 88%|████████▊ | 7000/7923 [16:31<02:21,  6.53it/s]

Epoch 2 | Step 7000 | loss = 0.110, acc = 0.890


 90%|████████▉ | 7100/7923 [16:45<02:04,  6.60it/s]

Epoch 2 | Step 7100 | loss = 0.087, acc = 0.920


 91%|█████████ | 7200/7923 [16:59<01:50,  6.55it/s]

Epoch 2 | Step 7200 | loss = 0.079, acc = 0.925


 92%|█████████▏| 7300/7923 [17:14<01:34,  6.57it/s]

Epoch 2 | Step 7300 | loss = 0.093, acc = 0.892


 93%|█████████▎| 7400/7923 [17:28<01:19,  6.58it/s]

Epoch 2 | Step 7400 | loss = 0.094, acc = 0.882


 95%|█████████▍| 7500/7923 [17:42<01:05,  6.49it/s]

Epoch 2 | Step 7500 | loss = 0.097, acc = 0.905


 96%|█████████▌| 7600/7923 [17:56<00:49,  6.55it/s]

Epoch 2 | Step 7600 | loss = 0.099, acc = 0.910


 97%|█████████▋| 7700/7923 [18:10<00:33,  6.60it/s]

Epoch 2 | Step 7700 | loss = 0.094, acc = 0.905


 98%|█████████▊| 7800/7923 [18:24<00:18,  6.57it/s]

Epoch 2 | Step 7800 | loss = 0.117, acc = 0.907


100%|█████████▉| 7900/7923 [18:39<00:03,  6.54it/s]

Epoch 2 | Step 7900 | loss = 0.112, acc = 0.895


100%|██████████| 7923/7923 [18:42<00:00,  7.06it/s]
  1%|▏         | 100/7923 [00:14<19:57,  6.53it/s]

Epoch 3 | Step 100 | loss = 0.050, acc = 0.938


  3%|▎         | 200/7923 [00:28<19:37,  6.56it/s]

Epoch 3 | Step 200 | loss = 0.041, acc = 0.957


  4%|▍         | 300/7923 [00:42<19:21,  6.56it/s]

Epoch 3 | Step 300 | loss = 0.041, acc = 0.957


  5%|▌         | 400/7923 [00:56<19:11,  6.54it/s]

Epoch 3 | Step 400 | loss = 0.054, acc = 0.940


  6%|▋         | 500/7923 [01:10<18:38,  6.63it/s]

Epoch 3 | Step 500 | loss = 0.056, acc = 0.955


  8%|▊         | 600/7923 [01:24<18:36,  6.56it/s]

Epoch 3 | Step 600 | loss = 0.067, acc = 0.940


  9%|▉         | 700/7923 [01:39<18:24,  6.54it/s]

Epoch 3 | Step 700 | loss = 0.069, acc = 0.940


 10%|█         | 800/7923 [01:53<18:09,  6.54it/s]

Epoch 3 | Step 800 | loss = 0.049, acc = 0.962


 11%|█▏        | 900/7923 [02:07<17:47,  6.58it/s]

Epoch 3 | Step 900 | loss = 0.037, acc = 0.950


 13%|█▎        | 1000/7923 [02:21<17:40,  6.53it/s]

Epoch 3 | Step 1000 | loss = 0.075, acc = 0.940


 14%|█▍        | 1100/7923 [02:35<17:13,  6.60it/s]

Epoch 3 | Step 1100 | loss = 0.034, acc = 0.947


 15%|█▌        | 1200/7923 [02:49<17:03,  6.57it/s]

Epoch 3 | Step 1200 | loss = 0.047, acc = 0.942


 16%|█▋        | 1300/7923 [03:04<16:46,  6.58it/s]

Epoch 3 | Step 1300 | loss = 0.053, acc = 0.950


 18%|█▊        | 1400/7923 [03:18<16:39,  6.53it/s]

Epoch 3 | Step 1400 | loss = 0.043, acc = 0.955


 19%|█▉        | 1500/7923 [03:32<16:24,  6.52it/s]

Epoch 3 | Step 1500 | loss = 0.055, acc = 0.940


 20%|██        | 1600/7923 [03:46<16:04,  6.55it/s]

Epoch 3 | Step 1600 | loss = 0.052, acc = 0.952


 21%|██▏       | 1700/7923 [04:00<15:37,  6.64it/s]

Epoch 3 | Step 1700 | loss = 0.058, acc = 0.950


 23%|██▎       | 1800/7923 [04:14<15:36,  6.54it/s]

Epoch 3 | Step 1800 | loss = 0.049, acc = 0.952


 24%|██▍       | 1900/7923 [04:29<15:23,  6.52it/s]

Epoch 3 | Step 1900 | loss = 0.037, acc = 0.955


 25%|██▌       | 2000/7923 [04:43<14:54,  6.62it/s]

Epoch 3 | Step 2000 | loss = 0.048, acc = 0.935


 27%|██▋       | 2100/7923 [04:57<14:54,  6.51it/s]

Epoch 3 | Step 2100 | loss = 0.028, acc = 0.975


 28%|██▊       | 2200/7923 [05:11<14:31,  6.57it/s]

Epoch 3 | Step 2200 | loss = 0.037, acc = 0.962


 29%|██▉       | 2300/7923 [05:25<14:12,  6.59it/s]

Epoch 3 | Step 2300 | loss = 0.057, acc = 0.942


 30%|███       | 2400/7923 [05:39<14:02,  6.55it/s]

Epoch 3 | Step 2400 | loss = 0.049, acc = 0.955


 32%|███▏      | 2499/7923 [05:53<13:16,  6.81it/s]

Epoch 3 | Step 2500 | loss = 0.035, acc = 0.960


 33%|███▎      | 2600/7923 [06:08<13:38,  6.50it/s]

Epoch 3 | Step 2600 | loss = 0.044, acc = 0.957


 34%|███▍      | 2700/7923 [06:22<13:20,  6.53it/s]

Epoch 3 | Step 2700 | loss = 0.035, acc = 0.945


 35%|███▌      | 2800/7923 [06:36<13:05,  6.52it/s]

Epoch 3 | Step 2800 | loss = 0.051, acc = 0.947


 37%|███▋      | 2900/7923 [06:50<12:37,  6.63it/s]

Epoch 3 | Step 2900 | loss = 0.047, acc = 0.942


 38%|███▊      | 3000/7923 [07:04<12:34,  6.53it/s]

Epoch 3 | Step 3000 | loss = 0.050, acc = 0.942


 39%|███▉      | 3100/7923 [07:19<12:19,  6.52it/s]

Epoch 3 | Step 3100 | loss = 0.057, acc = 0.938


 40%|████      | 3200/7923 [07:33<12:05,  6.51it/s]

Epoch 3 | Step 3200 | loss = 0.051, acc = 0.947


 42%|████▏     | 3300/7923 [07:47<11:39,  6.61it/s]

Epoch 3 | Step 3300 | loss = 0.046, acc = 0.935


 43%|████▎     | 3400/7923 [08:01<11:29,  6.56it/s]

Epoch 3 | Step 3400 | loss = 0.077, acc = 0.938


 44%|████▍     | 3500/7923 [08:15<11:05,  6.64it/s]

Epoch 3 | Step 3500 | loss = 0.049, acc = 0.960


 45%|████▌     | 3600/7923 [08:29<11:00,  6.54it/s]

Epoch 3 | Step 3600 | loss = 0.062, acc = 0.952


 47%|████▋     | 3700/7923 [08:44<10:45,  6.54it/s]

Epoch 3 | Step 3700 | loss = 0.062, acc = 0.942


 48%|████▊     | 3800/7923 [08:58<10:19,  6.65it/s]

Epoch 3 | Step 3800 | loss = 0.055, acc = 0.952


 49%|████▉     | 3900/7923 [09:12<10:13,  6.56it/s]

Epoch 3 | Step 3900 | loss = 0.066, acc = 0.930


 50%|█████     | 4000/7923 [09:26<09:56,  6.58it/s]

Epoch 3 | Step 4000 | loss = 0.061, acc = 0.942


 52%|█████▏    | 4100/7923 [09:40<09:40,  6.58it/s]

Epoch 3 | Step 4100 | loss = 0.057, acc = 0.940


 53%|█████▎    | 4200/7923 [09:55<09:21,  6.62it/s]

Epoch 3 | Step 4200 | loss = 0.046, acc = 0.942


 54%|█████▍    | 4300/7923 [10:09<09:16,  6.51it/s]

Epoch 3 | Step 4300 | loss = 0.080, acc = 0.940


 56%|█████▌    | 4399/7923 [10:23<08:40,  6.77it/s]

Epoch 3 | Step 4400 | loss = 0.056, acc = 0.925


 57%|█████▋    | 4500/7923 [10:37<08:36,  6.62it/s]

Epoch 3 | Step 4500 | loss = 0.062, acc = 0.935


 58%|█████▊    | 4600/7923 [10:51<08:25,  6.57it/s]

Epoch 3 | Step 4600 | loss = 0.044, acc = 0.947


 59%|█████▉    | 4700/7923 [11:05<08:13,  6.54it/s]

Epoch 3 | Step 4700 | loss = 0.032, acc = 0.957


 61%|██████    | 4800/7923 [11:20<07:51,  6.63it/s]

Epoch 3 | Step 4800 | loss = 0.063, acc = 0.922


 62%|██████▏   | 4900/7923 [11:34<07:43,  6.52it/s]

Epoch 3 | Step 4900 | loss = 0.055, acc = 0.940


 63%|██████▎   | 5000/7923 [11:48<07:29,  6.50it/s]

Epoch 3 | Step 5000 | loss = 0.044, acc = 0.945


 64%|██████▍   | 5100/7923 [12:02<07:09,  6.57it/s]

Epoch 3 | Step 5100 | loss = 0.049, acc = 0.952


 66%|██████▌   | 5200/7923 [12:16<06:54,  6.57it/s]

Epoch 3 | Step 5200 | loss = 0.045, acc = 0.950


 67%|██████▋   | 5300/7923 [12:30<06:37,  6.59it/s]

Epoch 3 | Step 5300 | loss = 0.055, acc = 0.938


 68%|██████▊   | 5400/7923 [12:45<06:24,  6.55it/s]

Epoch 3 | Step 5400 | loss = 0.044, acc = 0.960


 69%|██████▉   | 5500/7923 [12:59<06:08,  6.58it/s]

Epoch 3 | Step 5500 | loss = 0.056, acc = 0.933


 71%|███████   | 5600/7923 [13:13<05:54,  6.56it/s]

Epoch 3 | Step 5600 | loss = 0.078, acc = 0.950


 72%|███████▏  | 5700/7923 [13:27<05:37,  6.59it/s]

Epoch 3 | Step 5700 | loss = 0.064, acc = 0.942


 73%|███████▎  | 5800/7923 [13:41<05:23,  6.57it/s]

Epoch 3 | Step 5800 | loss = 0.084, acc = 0.920


 74%|███████▍  | 5900/7923 [13:55<05:09,  6.55it/s]

Epoch 3 | Step 5900 | loss = 0.067, acc = 0.927


 76%|███████▌  | 6000/7923 [14:10<04:53,  6.55it/s]

Epoch 3 | Step 6000 | loss = 0.048, acc = 0.942


 77%|███████▋  | 6100/7923 [14:24<04:39,  6.53it/s]

Epoch 3 | Step 6100 | loss = 0.058, acc = 0.938


 78%|███████▊  | 6200/7923 [14:38<04:22,  6.57it/s]

Epoch 3 | Step 6200 | loss = 0.056, acc = 0.942


 80%|███████▉  | 6300/7923 [14:52<04:03,  6.65it/s]

Epoch 3 | Step 6300 | loss = 0.064, acc = 0.935


 81%|████████  | 6400/7923 [15:06<03:53,  6.52it/s]

Epoch 3 | Step 6400 | loss = 0.072, acc = 0.940


 82%|████████▏ | 6500/7923 [15:20<03:37,  6.54it/s]

Epoch 3 | Step 6500 | loss = 0.042, acc = 0.945


 83%|████████▎ | 6600/7923 [15:35<03:20,  6.61it/s]

Epoch 3 | Step 6600 | loss = 0.044, acc = 0.930


 85%|████████▍ | 6700/7923 [15:49<03:06,  6.56it/s]

Epoch 3 | Step 6700 | loss = 0.042, acc = 0.950


 86%|████████▌ | 6800/7923 [16:03<02:51,  6.55it/s]

Epoch 3 | Step 6800 | loss = 0.059, acc = 0.942


 87%|████████▋ | 6900/7923 [16:17<02:35,  6.57it/s]

Epoch 3 | Step 6900 | loss = 0.048, acc = 0.955


 88%|████████▊ | 7000/7923 [16:31<02:21,  6.54it/s]

Epoch 3 | Step 7000 | loss = 0.066, acc = 0.940


 90%|████████▉ | 7100/7923 [16:45<02:05,  6.55it/s]

Epoch 3 | Step 7100 | loss = 0.058, acc = 0.935


 91%|█████████ | 7200/7923 [17:00<01:48,  6.63it/s]

Epoch 3 | Step 7200 | loss = 0.033, acc = 0.970


 92%|█████████▏| 7300/7923 [17:14<01:35,  6.56it/s]

Epoch 3 | Step 7300 | loss = 0.057, acc = 0.930


 93%|█████████▎| 7400/7923 [17:28<01:20,  6.52it/s]

Epoch 3 | Step 7400 | loss = 0.052, acc = 0.940


 95%|█████████▍| 7500/7923 [17:42<01:03,  6.64it/s]

Epoch 3 | Step 7500 | loss = 0.046, acc = 0.940


 96%|█████████▌| 7600/7923 [17:56<00:49,  6.49it/s]

Epoch 3 | Step 7600 | loss = 0.052, acc = 0.955


 97%|█████████▋| 7700/7923 [18:10<00:34,  6.50it/s]

Epoch 3 | Step 7700 | loss = 0.053, acc = 0.947


 98%|█████████▊| 7800/7923 [18:25<00:18,  6.61it/s]

Epoch 3 | Step 7800 | loss = 0.036, acc = 0.967


100%|█████████▉| 7900/7923 [18:39<00:03,  6.56it/s]

Epoch 3 | Step 7900 | loss = 0.063, acc = 0.927


100%|██████████| 7923/7923 [18:42<00:00,  7.06it/s]


Saving Model ...


In [10]:
max_answer_length = 30

def evaluate(data, output1, paragraph):

    p_offsets = data[4]
    question_offset = data[3].squeeze(dim=0)
    ##### TODO: Postprocessing #####
    # There is a bug and room for improvement in postprocessing 
    # Hint: Open your prediction file to see what is wrong 
    answer = ''
    max_prob = float('-inf')
    num_of_windows = data[0].shape[1]
    n_best = 60

    
    ans_start, ans_end, ans_k  = 0,0,0
    for k in range(num_of_windows):
        output_start = output1.start_logits[k].cpu().numpy()
        output_end = output1.end_logits[k].cpu().numpy()

        start_indexes = np.argsort(output_start)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(output_end)[-1 : -n_best - 1 : -1].tolist()

        for start_index in start_indexes:
            for end_index in end_indexes:
                if start_index > end_index or end_index - start_index + 1 > max_answer_length:
                    continue
                
                if start_index < question_offset[k].item():
                    # print(f"answer is in question: {start_index},{end_index}")
                    continue
                
                start_prob= output1.start_logits[k][start_index]
                end_prob = output1.end_logits[k][end_index]
                
                prob = start_prob + end_prob
                if prob > max_prob:
                    max_prob = prob
                    # Convert tokens to chars (e.g. [1920, 7032] --> "大 金")
                    answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])
                    ans_start, ans_end, ans_k = start_index , end_index, k


    # print(ans_start, ans_end, ans_k)
    # print(paragraph)
    
    res = ''
    prefix = question_offset[ans_k].item() # char occupied by question

    ans_p_start = ans_start  -  prefix #ans_start in origin paragraph token
    ans_p_end = ans_end -  prefix


    # tks = ids[ans_k][ans_start:ans_end+1].tolist()
    # print(tks)
    st, ed = 0,0
    p_offset = [(a.item(),b.item()) for a,b in p_offsets[ans_k]]
    # print(len(p_offset), ans_p_start, ans_p_end, ans_k)
    for i in range(ans_p_start, ans_p_end+1):
        st, ed = p_offset[i]
        res += paragraph[st : ed]
    # print(f"k: {ans_k}, res: {res}")
    # raise Exception
    # if '[UNK]' in answer:
    #     print('found [UNK] in prediction, using original text')
    #     print('original prediction', answer)
    #     print('final prediction',res)
        
    # res = res.replace(' ','')
    # print(res)
    return res

def ensemble(data, output1, output2, output5, paragraph):

    p_offsets = data[4]
    question_offset = data[3].squeeze(dim=0)
    ##### TODO: Postprocessing #####
    # There is a bug and room for improvement in postprocessing 
    # Hint: Open your prediction file to see what is wrong 
    answer = ''
    max_prob = float('-inf')
    num_of_windows = data[0].shape[1]
    n_best = 60

    
    ans_start, ans_end, ans_k  = 0,0,0
    for k in range(num_of_windows):
        # output_start = output1.start_logits[k].cpu().numpy() + output2.start_logits[k].cpu().numpy() + output5.start_logits[k].cpu().numpy()
        # output_end = output1.end_logits[k].cpu().numpy() + output2.end_logits[k].cpu().numpy() + output5.end_logits[k].cpu().numpy()
        output_start = output1.start_logits[k].cpu().numpy() + output2.start_logits[k].cpu().numpy()
        output_end = output1.end_logits[k].cpu().numpy() + output2.end_logits[k].cpu().numpy()

        start_indexes = np.argsort(output_start)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(output_end)[-1 : -n_best - 1 : -1].tolist()

        for i, start_index in enumerate(start_indexes):
            for j, end_index in enumerate(end_indexes):
                if start_index > end_index or end_index - start_index + 1 > max_answer_length:
                    continue

                if start_index < question_offset[k].item():
                    # print(f"answer is in question: {start_index},{end_index}")
                    continue
                
                # start_prob= output1.start_logits[k][start_index] + output2.start_logits[k][start_index] + output5.start_logits[k][start_index]
                # end_prob = output1.end_logits[k][end_index]+ output2.end_logits[k][end_index] + output5.end_logits[k][end_index]
                start_prob= output1.start_logits[k][start_index] + output2.start_logits[k][start_index]
                end_prob = output1.end_logits[k][end_index]+ output2.end_logits[k][end_index]
                
                prob = start_prob + end_prob
                if prob > max_prob:
                    max_prob = prob
                    # Convert tokens to chars (e.g. [1920, 7032] --> "大 金")
                    answer = tokenizer.decode(data[0][0][k][start_index : end_index + 1])
                    ans_start, ans_end, ans_k = start_index , end_index, k


    # print(ans_start, ans_end, ans_k)
    # print(paragraph)
    
    res = ''
    prefix = question_offset[ans_k].item() # char occupied by question

    ans_p_start = ans_start  -  prefix #ans_start in origin paragraph token
    ans_p_end = ans_end -  prefix


    # tks = ids[ans_k][ans_start:ans_end+1].tolist()
    # print(tks)
    st, ed = 0,0
    p_offset = [(a.item(),b.item()) for a,b in p_offsets[ans_k]]
    # print(p_offset)
    for i in range(ans_p_start, ans_p_end+1):
        st, ed = p_offset[i]
        res += paragraph[st : ed]
    # print(f"k: {ans_k}, res: {res}")
    # raise Exception
    # if '[UNK]' in answer:
    #     print('found [UNK] in prediction, using original text')
    #     print('original prediction', answer)
    #     print('final prediction',res)
        
    # res = res.replace(' ','')
    # print(res)
    return res

In [11]:
# model2 = BertForQuestionAnswering.from_pretrained("saved_model_cml5_2").to(device)

In [12]:
# print("Evaluating Dev Set ...")
# model.eval()
# with torch.no_grad():
#     dev_acc = 0
#     for i, data in enumerate(tqdm(dev_loader)):
#         output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
#                 attention_mask=data[2].squeeze(dim=0).to(device))
#         output2 = model2(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
#                 attention_mask=data[2].squeeze(dim=0).to(device))
#         # prediction is correct only if answer text exactly matches
#         dev_acc += ensemble(data, output,output2, None, dev_paragraphs[dev_questions[i]['paragraph_id']]) == dev_questions[i]["answer_text"]
#     print(f"Validation | acc = {dev_acc / len(dev_loader):.3f}")

In [13]:
# result = []

# with torch.no_grad():
#     for i, data in enumerate(tqdm(test_loader)):
#         output1 = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
#                        attention_mask=data[2].squeeze(dim=0).to(device))
        
#         output2 = model2(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
#                        attention_mask=data[2].squeeze(dim=0).to(device))

#         output5 = None
        
#         result.append(ensemble(data, output1, output2, output5, test_paragraphs[test_questions[i]['paragraph_id']]))


# result_file = "result_ensemble_last1_6.csv"
# with open(result_file, 'w') as f:	
# 	  f.write("ID,Answer\n")
# 	  for i, test_question in enumerate(test_questions):
#         # Replace commas in answers with empty strings (since csv is separated by comma)
#         # Answers in kaggle are processed in the same way
# 		    f.write(f"{test_question['id']},{result[i].replace(',','')}\n")

# print(f"Completed! Result is in {result_file}")

## Testing

In [14]:
# newresult = []
# for res in result:
#     new_res = ''
#     fixed = False
#     if res[0] == '《':
#         closed = False
#         for i in range(1,len(res)):
#             if res[i] == '》':
#                 closed = True
#                 break
#         if not closed:
#             new_res = res + '》'
#             fixed = True
#             print(f"fixed: {new_res}")
        
    
#     elif res[-1] == '》':
#         closed = False
#         for i in range(len(res)-1, 0, -1):
#             if res[i] == '《':
#                 closed = True
#                 break
#         if not closed:
#             new_res = '《' + res
#             fixed = True
#             print(f"fixed: {new_res}")
#     if res[0] == '「':
#         closed = False
#         for i in range(1,len(res)):
#             if res[i] == '」':
#                 closed = True
#                 break
#         if not closed:
#             new_res = res + '」'
#             fixed = True
#             print(f"fixed: {new_res}")
    
#     elif res[-1] == '」':
#         closed = False
#         for i in range(len(res)-1, 0, -1):
#             if res[i] == '「':
#                 closed = True
#                 break
#         if not closed:
#             new_res = '「' + res
#             fixed = True
#             print(f"fixed: {new_res}")
    
#     if fixed:
#         newresult.append(new_res)
#     else:
#         newresult.append(res)

# newresult_file = "newresult_2.csv"
# with open(newresult_file, 'w') as f:	
# 	  f.write("ID,Answer\n")
# 	  for i, test_question in enumerate(test_questions):
#         # Replace commas in answers with empty strings (since csv is separated by comma)
#         # Answers in kaggle are processed in the same way
# 		    f.write(f"{test_question['id']},{newresult[i].replace(',','')}\n")


In [15]:
# p = test_paragraphs[test_questions[2]['paragraph_id']]
# print(p)
# comp = r"" + "置縣" + '\s*(?:\S\s*){0,20}' + "六年"
# print(re.search(comp,p))

In [16]:

# print("Evaluating Test Set ...")

# result = []

# model.eval()
# with torch.no_grad():
#     for i, data in enumerate(tqdm(test_loader)):
#         output = model(input_ids=data[0].squeeze(dim=0).to(device), token_type_ids=data[1].squeeze(dim=0).to(device),
#                        attention_mask=data[2].squeeze(dim=0).to(device))
#         result.append(evaluate(data, output))

# result_file = "result.csv"
# with open(result_file, 'w') as f:	
# 	  f.write("ID,Answer\n")
# 	  for i, test_question in enumerate(test_questions):
#         # Replace commas in answers with empty strings (since csv is separated by comma)
#         # Answers in kaggle are processed in the same way
# 		    f.write(f"{test_question['id']},{result[i].replace(',','')}\n")

# print(f"Completed! Result is in {result_file}")