# 문서 검색 효율화를 위한 기계독해
- 1차 모의경진대회(22.11.14 ~ 22.11.25)
- 자연어 기계독해(Machine Reading Comprehension) 과제

## 데이터 구조

```
$ MRC/
├── DATA/
│   ├── train.json
│   ├── test.json
│   └── sample_submission.csv
├── prediction.csv (코드 실행 후 생성)
├── results/ (코드 실행 후 생성)
```

In [2]:
!nvidia-smi

Mon Nov 21 23:19:50 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.72       Driver Version: 410.72       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:3B:00.0 Off |                    0 |
| N/A   28C    P0    24W / 250W |      0MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

# 0. 사전 준비

## 0.1 구글 드라이브 마운트

In [3]:
# # 구글 Colaboratory 를 사용하기 위해 구글 계정으로 로그인합니다. 
# from google.colab import drive
# drive.mount('/content/drive')

## 0.2 라이브러리 설치

In [4]:
!pip install transformers



## 1. 라이브러리 불러오기

In [5]:
import os
import sys
import csv
import copy
import json
import random
import shutil
import numpy as np
import pandas as pd
from time import time
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from datetime import datetime, timezone, timedelta

from transformers import ElectraTokenizerFast, ElectraTokenizer
from transformers import ElectraForQuestionAnswering

###
from transformers import ElectraModel
from tokenizers import Encoding
###


import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

## 2. 하이퍼파라미터 및 기타 인자 설정

### 2.1 데이터 경로

In [6]:
PROJECT_DIR = './MRC'
DATA_DIR= './MRC/DATA'

### 2.2 시드 설정

In [7]:
# 난수 생성기가 항상 일정한 값을 출력하게 하기 위해 seed 고정
RANDOM_SEED = 42

torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

### 2.3 하이퍼파라미터 설정

In [8]:
LEARNING_RATE = 1e-5     # 학습률(learning rate)은 경사하강법(gradient descent)을 통해 내리막길을 내려갈 때의 보폭
BATCH_SIZE = 4     # 배치(batch)는 모델의 가중치(weights)를 업데이트하는 학습 데이터의 단위. 여기서는 16개를 학습할 때마다 모델의 가중치(weights)를 업데이트한다는 것
PIN_MEMORY = True
NUM_WORKERS = 0
EPOCHS = 6     # 에폭은 전체 학습 데이터를 학습에 사용하는 횟수. 주어진 학습 데이터를 여러번 학습할 수 있음
DROP_LAST = False
EARLY_STOPPING_MODE = min
EARLY_STOPPING_PATIENCE = 10
EARLY_STOPPING_TARGET = 'val_loss'     # validation set의 loss를 기준으로 early_stopping 여부를 결정할 것
LOGGING_INTERVAL = 200

### 2.4 디바이스 설정

In [9]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 3. Dataset 정의

In [10]:
class QADataset(Dataset):     # 데이터를 input으로 변환해주는 Dataset 클래스를 상속하여, QA(Question Answering) 과제에 맞게 커스터마이징한다
    
    def __init__ (self, data_dir: str, tokenizer, max_seq_len: int, mode = 'train'):     # Dataset 클래스는 기본적으로 __init__, __len__, __getitem__를 정의해 주어야 한다
        self.mode = mode
        self.data = json.load(open(data_dir, 'r', encoding='utf8'))
        
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        
        if mode == 'test':
            self.encodings, self.question_ids = self.preprocess()
        else:
            self.encodings, self.answers = self.preprocess()
        
    def __len__(self):     # index를 통해 input을 순차적으로 읽어오기 위해서는 데이터의 길이가 먼저 확인되어야 한다. __len__ 함수는 input의 길이를 반환해주는 함수
        return len(self.encodings.input_ids)

    def __getitem__(self, index: int):     # input의 길이가 확인되면 index를 통해 데이터를 불러올 수 있다. __getitem__ 함수는 index에 해당하는 input 데이터를 반환해주는 함수
        return {key: torch.tensor(val[index]) for key, val in self.encodings.items()}

    
    def preprocess(self):
        contexts, questions, answers, question_ids = self.read_squad()     # SQuAD(Stanford Question Answering Dataset) 형식의 데이터에서 contexts, questions, answers, question_ids를 읽어오는 함수
        if self.mode == 'test':
            encodings = self.tokenizer(contexts, questions, truncation=True, max_length = self.max_seq_len, padding=True)
            return encodings, question_ids
        else: # train or val
            self.add_end_idx(answers, contexts)     # train.json에는 질문에 대한 답이 context 내에서 시작되는 index인 'answer_srart'만 있기 때문에, 추가로 'answer_end'를 찾아주는 함수
            encodings = self.tokenizer(contexts, questions, truncation=True, max_length = self.max_seq_len, padding=True)
            self.add_token_positions(encodings, answers)
        
            return encodings, answers
        
    
    def read_squad(self):     # SQuAD(Stanford Question Answering Dataset) 형식의 데이터에서 contexts, questions, answers, question_ids를 읽어오는 함수
        contexts = []
        questions = []
        question_ids = []
        answers = []
        
        # train - val split
        if self.mode == 'train':
            self.data['data'] = self.data['data'][:-1*int(len(self.data['data'])*0.2)]
        elif self.mode == 'val':
            self.data['data'] = self.data['data'][-1*int(len(self.data['data'])*0.2):]
        
        
        till = len(self.data['data'])
        

        for group in self.data['data'][:till]:
            for passage in group['paragraphs']:
                context = passage['context']
                for qa in passage['qas']:
                    question = qa['question']
                    if self.mode == 'test':
                        contexts.append(context)
                        questions.append(question)
                        question_ids.append(qa['question_id'])
                    else: # train or val
                        for ans in qa['answers']:
                            contexts.append(context)
                            questions.append(question)

                            if qa['is_impossible']:
                                answers.append({'text':'','answer_start':-1})
                            else:
                                answers.append(ans)
                
        # return formatted data lists
        return contexts, questions, answers, question_ids
    
    
    def add_end_idx(self, answers, contexts):     # train.json에는 질문에 대한 답이 context 내에서 시작되는 index인 'answer_srart'만 있기 때문에, 추가로 'answer_end'를 찾아주는 함수
        for answer, context in zip(answers, contexts):
            gold_text = answer['text']
            start_idx = answer['answer_start']
            end_idx = start_idx + len(gold_text)

            # in case the indices are off 1-2 idxs
            if context[start_idx:end_idx] == gold_text:
                answer['answer_end'] = end_idx
            else:
                for n in [1, 2]:
                    if context[start_idx-n:end_idx-n] == gold_text:
                        answer['answer_start'] = start_idx - n
                        answer['answer_end'] = end_idx - n
                    elif context[start_idx+n:end_idx+n] == gold_text:
                        answer['answer_start'] = start_idx + n
                        answer['answer_end'] = end_idx + n
                        

    def add_token_positions(self, encodings, answers):
        # should use Fast tokenizer
        start_positions = []
        end_positions = []
        for i in range(len(answers)):
            if answers[i]['answer_start'] == -1:
                # set [CLS] token as answer if is_impossible
                start_positions.append(0)
                end_positions.append(1)
            else:
                start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))

                assert 'answer_end' in answers[i].keys(), f'no answer_end at {i}'
                end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))

            # answer passage truncated
            if start_positions[-1] is None:
                start_positions[-1] = tokenizer.model_max_length                
            # end position cannot be found, shift until found
            shift = 1
            while end_positions[-1] is None:
                end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] - shift)
                shift += 1
                
        # char-based -> token based
        encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

## 4. 모델 정의

In [11]:
class electra(nn.Module):     # pytorch의 모든 neural network 모델들은 torch.nn.Module 클래스를 상속해야 한다. 기본적으로 __init__()과 forward 함수가 override(재정의)되어야 하며, forward 함수는 모델의 계산을 실행하는 것을 뜻한다.

    def __init__(self, pretrained, **kwargs):
        super(electra, self).__init__()

        self.model = ElectraForQuestionAnswering.from_pretrained(pretrained)     # Hugging Face에서 pretrain된 모델을 가져와서 model 변수에 저장한다.


    def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None):
        
        outputs = self.model(input_ids = input_ids, 
                             attention_mask = attention_mask,
                             start_positions = start_positions,
                             end_positions = end_positions)
        
        return outputs

## 5. Utils 정의
### 5.1 EarlyStopper

In [12]:
class EarlyStopper():     # 일정 기간 모델 성능에 개선이 없으면, 학습을 중단하는 기능

    def __init__(self, patience: int, mode:str)-> None:
        self.patience = patience
        self.mode = mode

        # Initiate
        self.patience_counter = 0
        self.stop = False
        self.best_loss = np.inf

        print(f"Initiated early stopper, mode: {self.mode}, best score: {self.best_loss}, patience: {self.patience}")

        
    def check_early_stopping(self, loss: float)-> None:
        loss = -loss if self.mode == 'max' else loss  # get max value if mode set to max

        if loss > self.best_loss:
            # got worse score
            self.patience_counter += 1

            print(f"Early stopper, counter {self.patience_counter}/{self.patience}, best:{abs(self.best_loss)} -> now:{abs(loss)}")
            
            if self.patience_counter == self.patience:
                print(f"Early stopper, stop")
                self.stop = True  # end

        elif loss <= self.best_loss:
            # got better score
            self.patience_counter = 0
            
            print(f"Early stopper, counter {self.patience_counter}/{self.patience}, best:{abs(self.best_loss)} -> now:{abs(loss)}")
            print(f"Set counter as {self.patience_counter}")
            print(f"Update best score as {abs(loss)}")
            
            self.best_loss = loss
            
        else:
            print('debug')

### 5.2 Trainer

In [13]:
class Trainer():     # 학습을 위한 Trainer 클래스 정의

    def __init__(self,
                 model,
                 optimizer,
                 loss,
                 metrics,
                 device,
                 tokenizer,
                 interval=100):
        
        self.model = model
        self.optimizer = optimizer
        self.loss = loss
        self.metrics = metrics
        self.device = device
        self.interval = interval
        self.tokenizer = tokenizer

        # History
        self.loss_sum = 0  # Epoch loss sum
        self.loss_mean = 0 # Epoch loss mean
        self.y = list()
        self.y_preds = list()
        self.score_dict = dict()  # metric score
        self.elapsed_time = 0
        

    def train(self, mode, dataloader, tokenizer, epoch_index=0):
        
        start_timestamp = time()
        self.model.train() if mode == 'train' else self.model.eval()     # 모델을 train(eval) mode로 전환.  train(eval) mode에서는 dropout, batchnorm이 적용된다(적용되지 않는다)
 
        for batch_index, batch in enumerate(tqdm(dataloader, leave=True)):
            
            self.optimizer.zero_grad()     # 파라미터 업데이트는 batch 단위로 이루어지고, 매 batch마다 이전 스템에서 계산된 gradient를 초기화해주어야 함
            # pull all the tensor batches required for training
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            start_positions = batch['start_positions'].to(self.device)
            end_positions = batch['end_positions'].to(self.device)
            
            # train model on batch and return outputs (incl. loss)
            # Inference
            outputs = self.model(input_ids, attention_mask=attention_mask,
                            start_positions=start_positions,
                            end_positions=end_positions)
            
            loss = outputs.loss
            start_score = outputs.start_logits
            end_score = outputs.end_logits
            
            
            start_idx = torch.argmax(start_score, dim=1).cpu().tolist()
            end_idx = torch.argmax(end_score, dim=1).cpu().tolist()
            
            # Update
            if mode == 'train':
                loss.backward()     # backpropagation
                self.optimizer.step()     # 파라미터 업데이트
                
            elif mode in ['val', 'test']:
                pass
            
            # History
            self.loss_sum += loss.item()
            
            # create answer; list of strings
            for i in range(len(input_ids)):
                if start_idx[i] > end_idx[i]:
                    output = ''
                
                self.y_preds.append(self.tokenizer.decode(input_ids[i][start_idx[i]:end_idx[i]]))
                self.y.append(self.tokenizer.decode(input_ids[i][start_positions[i]:end_positions[i]]))


            # Logging
            if batch_index % self.interval == 0:
                print(f"batch: {batch_index}/{len(dataloader)} loss: {loss.item()}")
                
        # Epoch history
        self.loss_mean = self.loss_sum / len(dataloader)  # Epoch loss mean

        # Metric
        score = self.metrics(self.y, self.y_preds)
        self.score_dict['metric_name'] = score

        # Elapsed time
        end_timestamp = time()
        self.elapsed_time = end_timestamp - start_timestamp

    def clear_history(self):
        self.loss_sum = 0
        self.loss_mean = 0
        self.y_preds = list()
        self.y = list()
        self.score_dict = dict()
        self.elapsed_time = 0

### 5.3 Recorder

In [14]:
class Recorder():

    def __init__(self,
                 record_dir: str,
                 model: object,
                 optimizer: object):
        
        self.record_dir = record_dir
        self.record_filepath = os.path.join(self.record_dir, 'record.csv')
        self.weight_path = os.path.join(record_dir, 'model.pt')

        self.model = model
        self.optimizer = optimizer

        
    def set_model(self, model: 'model'):
        self.model = model


    def add_row(self, row_dict: dict):

        fieldnames = list(row_dict.keys())

        with open(self.record_filepath, newline='', mode='a') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)

            if f.tell() == 0:
                writer.writeheader()

            writer.writerow(row_dict)
            print(f"Write row {row_dict['epoch_index']}")

            
    def save_weight(self, epoch: int)-> None:
        check_point = {
            'epoch': epoch + 1,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        
        torch.save(check_point, self.weight_path)
        print(f"Recorder, epoch {epoch} Model saved: {self.weight_path}")

## 6. 모델 학습

### 6.1 모델과 기타 utils 설정

In [None]:
# Load model
model = electra(pretrained="monologg/koelectra-base-v3-discriminator").to(device)

# Set optimizer, loss function, metric function
optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
loss = F.cross_entropy
metrics = accuracy_score

# Set tokenizer
tokenizer = ElectraTokenizerFast.from_pretrained("monologg/koelectra-base-v3-discriminator")

# Set Trainer
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  device=device,
                  tokenizer=tokenizer,
                  interval=LOGGING_INTERVAL)

# Set earlystopper
early_stopper = EarlyStopper(patience=EARLY_STOPPING_PATIENCE,
                            mode=min)

# Set train serial
kst = timezone(timedelta(hours=9))
train_serial = datetime.now(tz=kst).strftime("%Y%m%d_%H%M%S")


# Set recorder 
RECORDER_DIR = os.path.join(PROJECT_DIR, 'results', 'train', train_serial)
os.makedirs(RECORDER_DIR, exist_ok=True)

recorder = Recorder(record_dir=RECORDER_DIR,
                    model=model,
                    optimizer=optimizer)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForQuestionAnswering: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForQuestionAnswering were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: ['qa_outputs.bias', 

Initiated early stopper, mode: <built-in function min>, best score: inf, patience: 10


### 6.2 Dataset & Dataloader 설정

In [None]:
# torch.utils.data.Dataset : 데이터를 input으로 변환
train_dataset = QADataset(data_dir=os.path.join(DATA_DIR, 'train.json'), tokenizer = tokenizer, max_seq_len = 512, mode = 'train')
val_dataset = QADataset(data_dir=os.path.join(DATA_DIR, 'train.json'), tokenizer = tokenizer, max_seq_len = 512, mode = 'val')

# torch.utils.data.DataLoader : input을 배치 단위로 리턴해주는 기능
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS, 
                              shuffle=True,
                              pin_memory=PIN_MEMORY,
                              drop_last=DROP_LAST)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            num_workers=NUM_WORKERS, 
                            shuffle=False,
                            pin_memory=PIN_MEMORY,
                            drop_last=DROP_LAST)

print(f"Load data, train:{len(train_dataset)} val:{len(val_dataset)}")

Load data, train:22387 val:5637


### 6.3 Epoch 단위 학습 진행

In [17]:
# Train
for epoch_index in range(EPOCHS):

    # Set Recorder row
    row_dict = dict()
    row_dict['epoch_index'] = epoch_index
    row_dict['train_serial'] = train_serial

    """
    Train
    """
    print(f"Train {epoch_index}/{EPOCHS}")
    print(f"--Train {epoch_index}/{EPOCHS}")
    trainer.train(dataloader=train_dataloader, epoch_index=epoch_index, tokenizer=tokenizer, mode='train')

    row_dict['train_loss'] = trainer.loss_mean
    row_dict['train_elapsed_time'] = trainer.elapsed_time 

    for metric_str, score in trainer.score_dict.items():
        row_dict[f"train_{metric_str}"] = score
    trainer.clear_history()

    """
    Validation
    """
    print(f"Val {epoch_index}/{EPOCHS}")
    print(f"--Val {epoch_index}/{EPOCHS}")
    trainer.train(dataloader=val_dataloader, epoch_index=epoch_index, tokenizer=tokenizer, mode='val')

    row_dict['val_loss'] = trainer.loss_mean
    row_dict['val_elapsed_time'] = trainer.elapsed_time 

    for metric_str, score in trainer.score_dict.items():
        row_dict[f"val_{metric_str}"] = score
    trainer.clear_history()

    """
    Record
    """
    recorder.add_row(row_dict)

    """
    Early stopper
    """
    early_stopping_target = EARLY_STOPPING_TARGET
    early_stopper.check_early_stopping(loss=row_dict[early_stopping_target])

    if early_stopper.patience_counter == 0:
        recorder.save_weight(epoch=epoch_index)
        best_row_dict = copy.deepcopy(row_dict)

    if early_stopper.stop == True:
        print(f"Early stopped, counter {early_stopper.patience_counter}/{EARLY_STOPPING_PATIENCE}")

        break

Train 0/6
--Train 0/6


  0%|          | 2/5597 [00:00<24:15,  3.85it/s]

batch: 0/5597 loss: 6.099382400512695


  4%|▎         | 202/5597 [00:27<12:07,  7.41it/s]

batch: 200/5597 loss: 2.3963546752929688


  7%|▋         | 402/5597 [00:54<11:40,  7.42it/s]

batch: 400/5597 loss: 1.1598236560821533


 11%|█         | 602/5597 [01:21<11:12,  7.42it/s]

batch: 600/5597 loss: 1.0227819681167603


 14%|█▍        | 802/5597 [01:48<10:47,  7.40it/s]

batch: 800/5597 loss: 1.4500176906585693


 18%|█▊        | 1002/5597 [02:15<10:18,  7.43it/s]

batch: 1000/5597 loss: 1.1187474727630615


 21%|██▏       | 1202/5597 [02:42<09:53,  7.41it/s]

batch: 1200/5597 loss: 2.6995301246643066


 25%|██▌       | 1402/5597 [03:09<09:23,  7.44it/s]

batch: 1400/5597 loss: 0.64496248960495


 29%|██▊       | 1602/5597 [03:36<09:04,  7.33it/s]

batch: 1600/5597 loss: 1.6376339197158813


 32%|███▏      | 1802/5597 [04:03<08:31,  7.42it/s]

batch: 1800/5597 loss: 1.7765495777130127


 36%|███▌      | 2002/5597 [04:30<08:04,  7.42it/s]

batch: 2000/5597 loss: 0.893886923789978


 39%|███▉      | 2202/5597 [04:57<07:38,  7.41it/s]

batch: 2200/5597 loss: 0.5850663781166077


 43%|████▎     | 2402/5597 [05:24<07:09,  7.44it/s]

batch: 2400/5597 loss: 1.3873870372772217


 46%|████▋     | 2602/5597 [05:51<06:43,  7.42it/s]

batch: 2600/5597 loss: 1.3489704132080078


 50%|█████     | 2802/5597 [06:18<06:17,  7.41it/s]

batch: 2800/5597 loss: 0.3923242688179016


 54%|█████▎    | 3002/5597 [06:45<05:49,  7.43it/s]

batch: 3000/5597 loss: 0.8188754916191101


 57%|█████▋    | 3202/5597 [07:12<05:22,  7.43it/s]

batch: 3200/5597 loss: 0.4350345730781555


 61%|██████    | 3402/5597 [07:39<04:55,  7.42it/s]

batch: 3400/5597 loss: 0.2894554138183594


 64%|██████▍   | 3602/5597 [08:06<04:29,  7.41it/s]

batch: 3600/5597 loss: 1.010141134262085


 68%|██████▊   | 3802/5597 [08:33<04:02,  7.39it/s]

batch: 3800/5597 loss: 0.33209922909736633


 72%|███████▏  | 4002/5597 [09:00<03:35,  7.41it/s]

batch: 4000/5597 loss: 0.77495276927948


 75%|███████▌  | 4202/5597 [09:27<03:07,  7.43it/s]

batch: 4200/5597 loss: 0.22682581841945648


 79%|███████▊  | 4402/5597 [09:54<02:40,  7.43it/s]

batch: 4400/5597 loss: 1.677100419998169


 82%|████████▏ | 4602/5597 [10:20<02:14,  7.41it/s]

batch: 4600/5597 loss: 0.3210865557193756


 86%|████████▌ | 4802/5597 [10:47<01:47,  7.41it/s]

batch: 4800/5597 loss: 1.4620288610458374


 89%|████████▉ | 5002/5597 [11:14<01:20,  7.43it/s]

batch: 5000/5597 loss: 0.902995228767395


 93%|█████████▎| 5202/5597 [11:41<00:53,  7.43it/s]

batch: 5200/5597 loss: 0.16519169509410858


 97%|█████████▋| 5402/5597 [12:08<00:26,  7.42it/s]

batch: 5400/5597 loss: 0.9855315089225769


100%|██████████| 5597/5597 [12:35<00:00,  7.41it/s]


Val 0/6
--Val 0/6


  0%|          | 3/1410 [00:00<01:02, 22.37it/s]

batch: 0/1410 loss: 0.030173731967806816


 14%|█▍        | 204/1410 [00:08<00:53, 22.74it/s]

batch: 200/1410 loss: 0.03921408951282501


 29%|██▊       | 405/1410 [00:17<00:44, 22.72it/s]

batch: 400/1410 loss: 0.941118061542511


 43%|████▎     | 603/1410 [00:26<00:35, 22.68it/s]

batch: 600/1410 loss: 0.30977416038513184


 57%|█████▋    | 804/1410 [00:35<00:26, 22.73it/s]

batch: 800/1410 loss: 1.2776985168457031


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.71it/s]

batch: 1000/1410 loss: 0.9613958597183228


 85%|████████▌ | 1203/1410 [00:53<00:09, 22.68it/s]

batch: 1200/1410 loss: 1.3843533992767334


100%|█████████▉| 1404/1410 [01:01<00:00, 22.76it/s]

batch: 1400/1410 loss: 1.3358421325683594


100%|██████████| 1410/1410 [01:02<00:00, 22.70it/s]


Write row 0
Early stopper, counter 0/10, best:inf -> now:0.5520814343547145
Set counter as 0
Update best score as 0.5520814343547145
Recorder, epoch 0 Model saved: /mnt/workspace/MyFiles/MNC_mock_1/MRC/results/train/20221121_232000/model.pt
Train 1/6
--Train 1/6


  0%|          | 1/5597 [00:00<14:17,  6.53it/s]

batch: 0/5597 loss: 0.6606749296188354


  4%|▎         | 202/5597 [00:27<12:06,  7.42it/s]

batch: 200/5597 loss: 0.12557905912399292


  7%|▋         | 402/5597 [00:54<11:39,  7.43it/s]

batch: 400/5597 loss: 0.14774075150489807


 11%|█         | 602/5597 [01:21<11:13,  7.42it/s]

batch: 600/5597 loss: 0.044682301580905914


 14%|█▍        | 802/5597 [01:48<10:46,  7.42it/s]

batch: 800/5597 loss: 0.18094158172607422


 18%|█▊        | 1002/5597 [02:14<10:19,  7.42it/s]

batch: 1000/5597 loss: 0.0399436391890049


 21%|██▏       | 1202/5597 [02:41<09:52,  7.42it/s]

batch: 1200/5597 loss: 0.019183646887540817


 25%|██▌       | 1402/5597 [03:08<09:26,  7.41it/s]

batch: 1400/5597 loss: 0.24356701970100403


 29%|██▊       | 1602/5597 [03:35<08:58,  7.42it/s]

batch: 1600/5597 loss: 0.5347874164581299


 32%|███▏      | 1802/5597 [04:02<08:32,  7.41it/s]

batch: 1800/5597 loss: 1.7758004665374756


 36%|███▌      | 2002/5597 [04:29<08:04,  7.41it/s]

batch: 2000/5597 loss: 0.07695814967155457


 39%|███▉      | 2202/5597 [04:56<07:37,  7.43it/s]

batch: 2200/5597 loss: 0.03789898753166199


 43%|████▎     | 2402/5597 [05:23<07:10,  7.42it/s]

batch: 2400/5597 loss: 0.49385398626327515


 46%|████▋     | 2602/5597 [05:50<06:43,  7.42it/s]

batch: 2600/5597 loss: 0.08574932813644409


 50%|█████     | 2802/5597 [06:17<06:16,  7.43it/s]

batch: 2800/5597 loss: 0.2706417739391327


 54%|█████▎    | 3002/5597 [06:44<05:49,  7.42it/s]

batch: 3000/5597 loss: 0.5588668584823608


 57%|█████▋    | 3202/5597 [07:11<05:22,  7.42it/s]

batch: 3200/5597 loss: 1.0734269618988037


 61%|██████    | 3402/5597 [07:38<04:55,  7.42it/s]

batch: 3400/5597 loss: 0.112637460231781


 64%|██████▍   | 3602/5597 [08:05<04:29,  7.41it/s]

batch: 3600/5597 loss: 0.6777994632720947


 68%|██████▊   | 3802/5597 [08:32<04:01,  7.43it/s]

batch: 3800/5597 loss: 0.2619664669036865


 72%|███████▏  | 4002/5597 [08:59<03:34,  7.44it/s]

batch: 4000/5597 loss: 0.8584862351417542


 75%|███████▌  | 4202/5597 [09:26<03:07,  7.43it/s]

batch: 4200/5597 loss: 0.6141237616539001


 79%|███████▊  | 4402/5597 [09:53<02:40,  7.44it/s]

batch: 4400/5597 loss: 0.21087639033794403


 82%|████████▏ | 4602/5597 [10:20<02:14,  7.42it/s]

batch: 4600/5597 loss: 0.6751890182495117


 86%|████████▌ | 4802/5597 [10:46<01:46,  7.43it/s]

batch: 4800/5597 loss: 0.6205365657806396


 89%|████████▉ | 5002/5597 [11:13<01:20,  7.42it/s]

batch: 5000/5597 loss: 0.37785813212394714


 93%|█████████▎| 5202/5597 [11:40<00:53,  7.43it/s]

batch: 5200/5597 loss: 0.06858767569065094


 97%|█████████▋| 5402/5597 [12:07<00:26,  7.42it/s]

batch: 5400/5597 loss: 0.1741677075624466


100%|██████████| 5597/5597 [12:34<00:00,  7.42it/s]


Val 1/6
--Val 1/6


  0%|          | 3/1410 [00:00<01:02, 22.47it/s]

batch: 0/1410 loss: 0.01928095333278179


 14%|█▍        | 204/1410 [00:08<00:53, 22.75it/s]

batch: 200/1410 loss: 0.0732903927564621


 29%|██▊       | 405/1410 [00:17<00:44, 22.76it/s]

batch: 400/1410 loss: 0.8134999871253967


 43%|████▎     | 603/1410 [00:26<00:35, 22.74it/s]

batch: 600/1410 loss: 0.34397825598716736


 57%|█████▋    | 804/1410 [00:35<00:27, 22.19it/s]

batch: 800/1410 loss: 1.2447720766067505


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.77it/s]

batch: 1000/1410 loss: 0.9514705538749695


 85%|████████▌ | 1203/1410 [00:52<00:09, 22.78it/s]

batch: 1200/1410 loss: 1.558793544769287


100%|█████████▉| 1404/1410 [01:01<00:00, 22.76it/s]

batch: 1400/1410 loss: 1.223635196685791


100%|██████████| 1410/1410 [01:01<00:00, 22.74it/s]


Write row 1
Early stopper, counter 0/10, best:0.5520814343547145 -> now:0.5513570183278755
Set counter as 0
Update best score as 0.5513570183278755
Recorder, epoch 1 Model saved: /mnt/workspace/MyFiles/MNC_mock_1/MRC/results/train/20221121_232000/model.pt
Train 2/6
--Train 2/6


  0%|          | 1/5597 [00:00<14:16,  6.54it/s]

batch: 0/5597 loss: 0.16415813565254211


  4%|▎         | 202/5597 [00:27<12:07,  7.41it/s]

batch: 200/5597 loss: 0.25571101903915405


  7%|▋         | 402/5597 [00:54<11:40,  7.41it/s]

batch: 400/5597 loss: 0.19540303945541382


 11%|█         | 602/5597 [01:21<11:12,  7.42it/s]

batch: 600/5597 loss: 0.025437548756599426


 14%|█▍        | 802/5597 [01:48<10:47,  7.41it/s]

batch: 800/5597 loss: 0.03639456629753113


 18%|█▊        | 1002/5597 [02:15<10:18,  7.43it/s]

batch: 1000/5597 loss: 0.04662257432937622


 21%|██▏       | 1202/5597 [02:41<09:52,  7.42it/s]

batch: 1200/5597 loss: 0.5285959243774414


 25%|██▌       | 1402/5597 [03:08<09:25,  7.42it/s]

batch: 1400/5597 loss: 0.16054858267307281


 29%|██▊       | 1602/5597 [03:35<08:58,  7.41it/s]

batch: 1600/5597 loss: 0.5972471237182617


 32%|███▏      | 1802/5597 [04:02<08:30,  7.44it/s]

batch: 1800/5597 loss: 0.06714729219675064


 36%|███▌      | 2002/5597 [04:29<08:03,  7.43it/s]

batch: 2000/5597 loss: 0.17344066500663757


 39%|███▉      | 2202/5597 [04:56<07:36,  7.43it/s]

batch: 2200/5597 loss: 0.11941918730735779


 43%|████▎     | 2402/5597 [05:23<07:09,  7.44it/s]

batch: 2400/5597 loss: 0.01480952836573124


 46%|████▋     | 2602/5597 [05:50<06:43,  7.42it/s]

batch: 2600/5597 loss: 0.03693780303001404


 50%|█████     | 2802/5597 [06:17<06:17,  7.41it/s]

batch: 2800/5597 loss: 0.0316411592066288


 54%|█████▎    | 3002/5597 [06:44<05:48,  7.44it/s]

batch: 3000/5597 loss: 0.22977706789970398


 57%|█████▋    | 3202/5597 [07:11<05:22,  7.44it/s]

batch: 3200/5597 loss: 0.1890258938074112


 61%|██████    | 3402/5597 [07:38<04:54,  7.45it/s]

batch: 3400/5597 loss: 0.053979720920324326


 64%|██████▍   | 3602/5597 [08:04<04:28,  7.43it/s]

batch: 3600/5597 loss: 0.043595943599939346


 68%|██████▊   | 3802/5597 [08:31<04:01,  7.43it/s]

batch: 3800/5597 loss: 0.4566221237182617


 72%|███████▏  | 4002/5597 [08:58<03:35,  7.41it/s]

batch: 4000/5597 loss: 0.0333266481757164


 75%|███████▌  | 4202/5597 [09:25<03:07,  7.43it/s]

batch: 4200/5597 loss: 0.9961682558059692


 79%|███████▊  | 4402/5597 [09:52<02:40,  7.44it/s]

batch: 4400/5597 loss: 0.430366188287735


 82%|████████▏ | 4602/5597 [10:19<02:14,  7.42it/s]

batch: 4600/5597 loss: 0.6579272747039795


 86%|████████▌ | 4802/5597 [10:46<01:46,  7.44it/s]

batch: 4800/5597 loss: 0.255258709192276


 89%|████████▉ | 5002/5597 [11:13<01:19,  7.44it/s]

batch: 5000/5597 loss: 0.1446191668510437


 93%|█████████▎| 5202/5597 [11:40<00:53,  7.44it/s]

batch: 5200/5597 loss: 0.32307955622673035


 97%|█████████▋| 5402/5597 [12:07<00:26,  7.41it/s]

batch: 5400/5597 loss: 2.087002754211426


100%|██████████| 5597/5597 [12:33<00:00,  7.43it/s]


Val 2/6
--Val 2/6


  0%|          | 3/1410 [00:00<01:02, 22.44it/s]

batch: 0/1410 loss: 0.04072363302111626


 14%|█▍        | 204/1410 [00:08<00:52, 22.76it/s]

batch: 200/1410 loss: 0.05190373584628105


 29%|██▊       | 405/1410 [00:17<00:44, 22.76it/s]

batch: 400/1410 loss: 0.7623479962348938


 43%|████▎     | 603/1410 [00:26<00:35, 22.63it/s]

batch: 600/1410 loss: 0.5124049782752991


 57%|█████▋    | 804/1410 [00:35<00:26, 22.74it/s]

batch: 800/1410 loss: 1.3397879600524902


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.87it/s]

batch: 1000/1410 loss: 0.3710503578186035


 85%|████████▌ | 1203/1410 [00:52<00:09, 22.77it/s]

batch: 1200/1410 loss: 2.204555034637451


100%|█████████▉| 1404/1410 [01:01<00:00, 22.77it/s]

batch: 1400/1410 loss: 1.4163835048675537


100%|██████████| 1410/1410 [01:01<00:00, 22.77it/s]


Write row 2
Early stopper, counter 1/10, best:0.5513570183278755 -> now:0.5908287504033173
Train 3/6
--Train 3/6


  0%|          | 1/5597 [00:00<13:18,  7.01it/s]

batch: 0/5597 loss: 0.22266031801700592


  4%|▎         | 202/5597 [00:27<12:05,  7.44it/s]

batch: 200/5597 loss: 0.008087143301963806


  7%|▋         | 402/5597 [00:54<11:39,  7.42it/s]

batch: 400/5597 loss: 0.02184927836060524


 11%|█         | 602/5597 [01:21<11:13,  7.42it/s]

batch: 600/5597 loss: 0.01968090794980526


 14%|█▍        | 802/5597 [01:47<10:46,  7.42it/s]

batch: 800/5597 loss: 0.06398677080869675


 18%|█▊        | 1002/5597 [02:14<10:17,  7.44it/s]

batch: 1000/5597 loss: 0.05801818147301674


 21%|██▏       | 1202/5597 [02:41<09:51,  7.42it/s]

batch: 1200/5597 loss: 0.005153208505362272


 25%|██▌       | 1402/5597 [03:08<09:25,  7.41it/s]

batch: 1400/5597 loss: 0.017806751653552055


 29%|██▊       | 1602/5597 [03:35<08:57,  7.43it/s]

batch: 1600/5597 loss: 0.24843212962150574


 32%|███▏      | 1802/5597 [04:02<08:31,  7.42it/s]

batch: 1800/5597 loss: 0.15924052894115448


 36%|███▌      | 2002/5597 [04:29<08:05,  7.41it/s]

batch: 2000/5597 loss: 0.03827677294611931


 39%|███▉      | 2202/5597 [04:56<07:36,  7.43it/s]

batch: 2200/5597 loss: 0.0017451568273827434


 43%|████▎     | 2402/5597 [05:23<07:10,  7.42it/s]

batch: 2400/5597 loss: 0.13704583048820496


 46%|████▋     | 2602/5597 [05:50<06:43,  7.42it/s]

batch: 2600/5597 loss: 0.049942854791879654


 50%|█████     | 2802/5597 [06:17<06:16,  7.43it/s]

batch: 2800/5597 loss: 0.06177420914173126


 54%|█████▎    | 3002/5597 [06:44<05:49,  7.42it/s]

batch: 3000/5597 loss: 0.10000923275947571


 57%|█████▋    | 3202/5597 [07:11<05:22,  7.43it/s]

batch: 3200/5597 loss: 0.014871936291456223


 61%|██████    | 3402/5597 [07:37<04:55,  7.43it/s]

batch: 3400/5597 loss: 0.21977342665195465


 64%|██████▍   | 3602/5597 [08:04<04:28,  7.44it/s]

batch: 3600/5597 loss: 0.009331803768873215


 68%|██████▊   | 3802/5597 [08:31<04:01,  7.42it/s]

batch: 3800/5597 loss: 0.19751819968223572


 72%|███████▏  | 4002/5597 [08:58<03:34,  7.43it/s]

batch: 4000/5597 loss: 0.9501153826713562


 75%|███████▌  | 4202/5597 [09:25<03:07,  7.43it/s]

batch: 4200/5597 loss: 0.01040003914386034


 79%|███████▊  | 4402/5597 [09:52<02:40,  7.44it/s]

batch: 4400/5597 loss: 2.817898988723755


 82%|████████▏ | 4602/5597 [10:19<02:14,  7.42it/s]

batch: 4600/5597 loss: 0.22259685397148132


 86%|████████▌ | 4802/5597 [10:46<01:46,  7.43it/s]

batch: 4800/5597 loss: 0.03551427274942398


 89%|████████▉ | 5002/5597 [11:13<01:19,  7.45it/s]

batch: 5000/5597 loss: 0.3166767358779907


 93%|█████████▎| 5202/5597 [11:40<00:53,  7.43it/s]

batch: 5200/5597 loss: 0.5629870891571045


 97%|█████████▋| 5402/5597 [12:07<00:26,  7.45it/s]

batch: 5400/5597 loss: 0.0673878937959671


100%|██████████| 5597/5597 [12:33<00:00,  7.43it/s]


Val 3/6
--Val 3/6


  0%|          | 3/1410 [00:00<01:02, 22.48it/s]

batch: 0/1410 loss: 0.04154177010059357


 14%|█▍        | 204/1410 [00:09<00:53, 22.74it/s]

batch: 200/1410 loss: 0.04320785775780678


 29%|██▊       | 405/1410 [00:17<00:44, 22.76it/s]

batch: 400/1410 loss: 0.41224366426467896


 43%|████▎     | 603/1410 [00:26<00:35, 22.74it/s]

batch: 600/1410 loss: 0.48059529066085815


 57%|█████▋    | 804/1410 [00:35<00:26, 22.75it/s]

batch: 800/1410 loss: 1.460432767868042


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.74it/s]

batch: 1000/1410 loss: 0.8606706857681274


 85%|████████▌ | 1203/1410 [00:52<00:09, 22.71it/s]

batch: 1200/1410 loss: 2.283640146255493


100%|█████████▉| 1404/1410 [01:01<00:00, 22.77it/s]

batch: 1400/1410 loss: 1.7065534591674805


100%|██████████| 1410/1410 [01:02<00:00, 22.73it/s]


Write row 3
Early stopper, counter 2/10, best:0.5513570183278755 -> now:0.7430412789309688
Train 4/6
--Train 4/6


  0%|          | 1/5597 [00:00<13:08,  7.09it/s]

batch: 0/5597 loss: 0.019549306482076645


  4%|▎         | 202/5597 [00:27<12:05,  7.44it/s]

batch: 200/5597 loss: 0.0782378613948822


  7%|▋         | 402/5597 [00:54<11:40,  7.41it/s]

batch: 400/5597 loss: 0.002892290009185672


 11%|█         | 602/5597 [01:20<11:12,  7.43it/s]

batch: 600/5597 loss: 0.02972836047410965


 14%|█▍        | 802/5597 [01:47<10:44,  7.44it/s]

batch: 800/5597 loss: 0.4186243414878845


 18%|█▊        | 1002/5597 [02:14<10:19,  7.42it/s]

batch: 1000/5597 loss: 0.01977718248963356


 21%|██▏       | 1202/5597 [02:41<09:51,  7.43it/s]

batch: 1200/5597 loss: 0.000887631787918508


 25%|██▌       | 1402/5597 [03:08<09:23,  7.44it/s]

batch: 1400/5597 loss: 0.036192815750837326


 29%|██▊       | 1602/5597 [03:35<08:56,  7.44it/s]

batch: 1600/5597 loss: 0.07057259231805801


 32%|███▏      | 1802/5597 [04:02<08:30,  7.44it/s]

batch: 1800/5597 loss: 0.007580433040857315


 36%|███▌      | 2002/5597 [04:29<08:03,  7.44it/s]

batch: 2000/5597 loss: 0.08660005033016205


 39%|███▉      | 2202/5597 [04:56<07:36,  7.43it/s]

batch: 2200/5597 loss: 0.0983383059501648


 43%|████▎     | 2402/5597 [05:23<07:10,  7.43it/s]

batch: 2400/5597 loss: 0.0022531135473400354


 46%|████▋     | 2602/5597 [05:50<06:43,  7.43it/s]

batch: 2600/5597 loss: 0.004814408253878355


 50%|█████     | 2802/5597 [06:16<06:17,  7.39it/s]

batch: 2800/5597 loss: 0.4307619631290436


 54%|█████▎    | 3002/5597 [06:43<05:49,  7.43it/s]

batch: 3000/5597 loss: 0.14180482923984528


 57%|█████▋    | 3202/5597 [07:10<05:22,  7.42it/s]

batch: 3200/5597 loss: 0.11608348041772842


 61%|██████    | 3402/5597 [07:37<04:55,  7.43it/s]

batch: 3400/5597 loss: 0.023807989433407784


 64%|██████▍   | 3602/5597 [08:04<04:28,  7.42it/s]

batch: 3600/5597 loss: 0.012799790129065514


 68%|██████▊   | 3802/5597 [08:31<04:02,  7.41it/s]

batch: 3800/5597 loss: 0.026168158277869225


 72%|███████▏  | 4002/5597 [08:58<03:35,  7.42it/s]

batch: 4000/5597 loss: 0.041159488260746


 75%|███████▌  | 4202/5597 [09:25<03:08,  7.42it/s]

batch: 4200/5597 loss: 0.005180174484848976


 79%|███████▊  | 4402/5597 [09:52<02:40,  7.43it/s]

batch: 4400/5597 loss: 0.03316684439778328


 82%|████████▏ | 4602/5597 [10:19<02:14,  7.42it/s]

batch: 4600/5597 loss: 0.08463942259550095


 86%|████████▌ | 4802/5597 [10:46<01:46,  7.47it/s]

batch: 4800/5597 loss: 0.04799894616007805


 89%|████████▉ | 5002/5597 [11:13<01:20,  7.43it/s]

batch: 5000/5597 loss: 0.0018439748091623187


 93%|█████████▎| 5202/5597 [11:40<00:53,  7.43it/s]

batch: 5200/5597 loss: 0.013001699000597


 97%|█████████▋| 5402/5597 [12:06<00:26,  7.41it/s]

batch: 5400/5597 loss: 0.08430808037519455


100%|██████████| 5597/5597 [12:33<00:00,  7.43it/s]


Val 4/6
--Val 4/6


  0%|          | 3/1410 [00:00<01:02, 22.47it/s]

batch: 0/1410 loss: 0.020247017964720726


 14%|█▍        | 204/1410 [00:08<00:52, 22.78it/s]

batch: 200/1410 loss: 0.01469478104263544


 29%|██▊       | 405/1410 [00:17<00:43, 22.84it/s]

batch: 400/1410 loss: 0.7775681018829346


 43%|████▎     | 603/1410 [00:26<00:35, 22.75it/s]

batch: 600/1410 loss: 0.34227824211120605


 57%|█████▋    | 804/1410 [00:35<00:27, 22.26it/s]

batch: 800/1410 loss: 1.6497870683670044


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.79it/s]

batch: 1000/1410 loss: 0.453810453414917


 85%|████████▌ | 1203/1410 [00:53<00:09, 22.21it/s]

batch: 1200/1410 loss: 2.6626596450805664


100%|█████████▉| 1404/1410 [01:01<00:00, 22.79it/s]

batch: 1400/1410 loss: 2.5939292907714844


100%|██████████| 1410/1410 [01:02<00:00, 22.69it/s]


Write row 4
Early stopper, counter 3/10, best:0.5513570183278755 -> now:0.8604683457891513
Train 5/6
--Train 5/6


  0%|          | 1/5597 [00:00<13:18,  7.01it/s]

batch: 0/5597 loss: 0.0027956380508840084


  4%|▎         | 202/5597 [00:27<12:06,  7.43it/s]

batch: 200/5597 loss: 0.0054092505015432835


  7%|▋         | 402/5597 [00:54<11:40,  7.42it/s]

batch: 400/5597 loss: 0.0032815535087138414


 11%|█         | 602/5597 [01:21<11:12,  7.42it/s]

batch: 600/5597 loss: 0.11841163784265518


 14%|█▍        | 802/5597 [01:48<10:54,  7.32it/s]

batch: 800/5597 loss: 0.004421435762196779


 18%|█▊        | 1002/5597 [02:14<10:26,  7.34it/s]

batch: 1000/5597 loss: 0.3419729173183441


 21%|██▏       | 1202/5597 [02:41<09:52,  7.41it/s]

batch: 1200/5597 loss: 0.029678191989660263


 25%|██▌       | 1402/5597 [03:08<09:25,  7.42it/s]

batch: 1400/5597 loss: 0.003717989893630147


 29%|██▊       | 1602/5597 [03:35<08:57,  7.44it/s]

batch: 1600/5597 loss: 0.04394116625189781


 32%|███▏      | 1802/5597 [04:02<08:29,  7.44it/s]

batch: 1800/5597 loss: 0.07196684181690216


 36%|███▌      | 2002/5597 [04:29<08:03,  7.43it/s]

batch: 2000/5597 loss: 0.006018024869263172


 39%|███▉      | 2202/5597 [04:56<07:37,  7.42it/s]

batch: 2200/5597 loss: 0.01637290045619011


 43%|████▎     | 2402/5597 [05:23<07:09,  7.44it/s]

batch: 2400/5597 loss: 0.02709820121526718


 46%|████▋     | 2602/5597 [05:50<06:43,  7.42it/s]

batch: 2600/5597 loss: 0.5464973449707031


 50%|█████     | 2802/5597 [06:17<06:16,  7.43it/s]

batch: 2800/5597 loss: 0.03224257379770279


 54%|█████▎    | 3002/5597 [06:44<05:49,  7.42it/s]

batch: 3000/5597 loss: 0.004559419117867947


 57%|█████▋    | 3202/5597 [07:11<05:22,  7.43it/s]

batch: 3200/5597 loss: 0.27574265003204346


 61%|██████    | 3402/5597 [07:37<04:55,  7.42it/s]

batch: 3400/5597 loss: 0.019435517489910126


 64%|██████▍   | 3602/5597 [08:04<04:28,  7.42it/s]

batch: 3600/5597 loss: 0.017246749252080917


 68%|██████▊   | 3802/5597 [08:31<04:01,  7.42it/s]

batch: 3800/5597 loss: 0.015549279749393463


 72%|███████▏  | 4002/5597 [08:58<03:35,  7.41it/s]

batch: 4000/5597 loss: 0.06339015066623688


 75%|███████▌  | 4202/5597 [09:25<03:08,  7.42it/s]

batch: 4200/5597 loss: 0.0014892916660755873


 79%|███████▊  | 4402/5597 [09:52<02:41,  7.42it/s]

batch: 4400/5597 loss: 0.0008502982091158628


 82%|████████▏ | 4602/5597 [10:19<02:14,  7.42it/s]

batch: 4600/5597 loss: 0.6702120900154114


 86%|████████▌ | 4802/5597 [10:46<01:47,  7.42it/s]

batch: 4800/5597 loss: 0.0022065737284719944


 89%|████████▉ | 5002/5597 [11:13<01:20,  7.42it/s]

batch: 5000/5597 loss: 0.7002713084220886


 93%|█████████▎| 5202/5597 [11:40<00:53,  7.43it/s]

batch: 5200/5597 loss: 0.015437686815857887


 97%|█████████▋| 5402/5597 [12:07<00:26,  7.42it/s]

batch: 5400/5597 loss: 0.004249596036970615


100%|██████████| 5597/5597 [12:33<00:00,  7.43it/s]


Val 5/6
--Val 5/6


  0%|          | 3/1410 [00:00<01:02, 22.47it/s]

batch: 0/1410 loss: 0.042583294212818146


 14%|█▍        | 204/1410 [00:08<00:52, 22.78it/s]

batch: 200/1410 loss: 0.0014824550598859787


 29%|██▊       | 405/1410 [00:17<00:44, 22.75it/s]

batch: 400/1410 loss: 0.7681834697723389


 43%|████▎     | 603/1410 [00:26<00:35, 22.77it/s]

batch: 600/1410 loss: 0.2529301345348358


 57%|█████▋    | 804/1410 [00:35<00:26, 22.78it/s]

batch: 800/1410 loss: 1.7780925035476685


 71%|███████▏  | 1005/1410 [00:44<00:17, 22.77it/s]

batch: 1000/1410 loss: 0.9703247547149658


 85%|████████▌ | 1203/1410 [00:52<00:09, 22.76it/s]

batch: 1200/1410 loss: 2.6980364322662354


100%|█████████▉| 1404/1410 [01:01<00:00, 22.77it/s]

batch: 1400/1410 loss: 2.6678555011749268


100%|██████████| 1410/1410 [01:01<00:00, 22.77it/s]

Write row 5
Early stopper, counter 4/10, best:0.5513570183278755 -> now:0.9161958823309102





## 7. 추론

### 7.1 테스트 Dataset & Dataloader 설정

In [18]:
# Load data
test_dataset = QADataset(data_dir=os.path.join(DATA_DIR, 'test.json'), tokenizer = tokenizer, max_seq_len = 512, mode = 'test')

question_ids = test_dataset.question_ids

test_dataloader = DataLoader(dataset=test_dataset,
                            batch_size=BATCH_SIZE,
                            num_workers=NUM_WORKERS, 
                            shuffle=False,
                            pin_memory=PIN_MEMORY,
                            drop_last=DROP_LAST)

### 7.2 모델 로드

In [19]:
# Load model

model = electra(pretrained="monologg/koelectra-base-v3-discriminator").to(device)

checkpoint = torch.load(os.path.join(RECORDER_DIR, 'model.pt'))

model.load_state_dict(checkpoint['model'])

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForQuestionAnswering: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForQuestionAnswering were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: ['qa_outputs.bias', 

<All keys matched successfully>

### 7.3 추론 진행

In [20]:
model.eval()     # 모델을 eval mode로 전환. train mode와 달리 eval mode에서는 dropout, batchnorm이 적용되지 않는다

pred_df = pd.read_csv(os.path.join(DATA_DIR, 'sample_submission.csv'))

for batch_index, batch in enumerate(tqdm(test_dataloader, leave=True)):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

    # Inference
    outputs = model(input_ids, attention_mask=attention_mask)

    start_score = outputs.start_logits
    end_score = outputs.end_logits

    start_idx = torch.argmax(start_score, dim=1).cpu().tolist()
    end_idx = torch.argmax(end_score, dim=1).cpu().tolist()

    y_pred = []
    for i in range(len(input_ids)):
        if start_idx[i] > end_idx[i]:
            output = ''

        ans_txt = tokenizer.decode(input_ids[i][start_idx[i]:end_idx[i]]).replace('#','')

        if ans_txt == '[CLS]':
            ans_txt == ''

        y_pred.append(ans_txt)


    q_end_idx = BATCH_SIZE*batch_index + len(y_pred)
    for q_id, pred in zip(question_ids[BATCH_SIZE*batch_index:q_end_idx], y_pred):
        pred_df.loc[pred_df['question_id'] == q_id,'answer_text'] = pred

100%|██████████| 407/407 [00:14<00:00, 27.90it/s]


### 7.4 결과 저장

In [21]:
# Set predict serial
kst = timezone(timedelta(hours=9))
predict_timestamp = datetime.now(tz=kst).strftime("%Y%m%d_%H%M%S")
predict_serial = predict_timestamp
predict_serial

PREDICT_DIR = os.path.join(PROJECT_DIR, 'results', 'predict', predict_serial)
os.makedirs(PREDICT_DIR, exist_ok=True)

pred_df.to_csv(os.path.join(PREDICT_DIR, f'prediction_lr{LEARNING_RATE}_bs{BATCH_SIZE}_epoch{EPOCHS}_split0.2_koelectra-base.csv'), index=False)