# **DACON-GAS**

## **Default Setting**

In [1]:
import torch
import torch_optimizer
import transformers

import collections
import copy
import datetime
import itertools
import json
import os
import pprint
import random
import re

import numpy as np
import pandas as pd

from IPython.display import clear_output
from matplotlib import pyplot as plt
from pathlib import Path
from tqdm import tqdm

print("[VERSION]")
print(f"torch: {torch.__version__}")
print(f"torch_optimizer: {torch_optimizer.__version__}")
print(f"transformers: {transformers.__version__}")

[VERSION]
torch: 1.10.0
torch_optimizer: 0.3.0
transformers: 4.11.3


In [2]:
class HParams():
    
    def __init__(self):
        ## Path.
        self.data = Path("data") ## 문서요약 텍스트
        
        self.tr_data = self.data / Path("tr")
        self.tr_law_data = self.tr_data / Path("tr_law_data.json")
        self.tr_journal_data = self.tr_data / Path("tr_journal_data.json")
        self.tr_article_data = self.tr_data / Path("tr_article_data.json")

        self.vl_data = self.data / Path("vl")
        self.vl_law_data = self.vl_data / Path("vl_law_data.json")
        self.vl_journal_data = self.vl_data / Path("vl_journal_data.json")
        self.vl_article_data = self.vl_data / Path("vl_article_data.json")
        
        self.ts_data = self.data / Path("ts")
        self.ts_all_data = self.ts_data / Path("test.jsonl")
        self.ts_sample_submission = self.ts_data / Path("sample_submission.csv")
        
        self.ckpt_dir = Path("ckpt")

        ## Seed.
        self.seed = 42
        
        ## Dataloader
        self.tokenizer_name = "beomi/KcELECTRA-base"
        ## Avoid using `tokenizers` before the fork if possible.
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.tokenizer_name)
        self.pt_path = self.data
        self.max_cls_len = 50 ## only 50 sentences are allowed.
        
        self.per_replica_batch_size = 4 ## multiple of 8
        self.global_batch_size = self.per_replica_batch_size * torch.cuda.device_count() ## 32
        self.num_workers = 4
        
        ## Modeling.
        self.n_gpus = torch.cuda.device_count()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        ## Transformer Decoder.
        self.d_model = 768
        self.nhead = 12
        self.dropout = 0.1
        self.batch_first = True
        
        self.num_layers = 2
        
        ## Optimizer.
        self.lr = 2e-5
        self.weight_decay = 1e-4
        
        ## Scheduler.
        self.pct_start = 0.1
        self.max_lr = self.lr
        self.epochs = 10
        
        ## Checkpoint manager.
        self.monitor = "vl_loss"
        self.max_to_keep = 3
        
        ## Set environments.
        self._set_os_environments()
        self._seed_everything(self.seed)
        
    
    def _set_os_environments(self):
        os.environ["TOKENIZERS_PARALLELISM"] = "true"
        
    def _seed_everything(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True  # type: ignore
        torch.backends.cudnn.benchmark = True  # type: ignore
        
        
args = HParams()
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(vars(args))

{   'data': PosixPath('data'),
    'tr_data': PosixPath('data/tr'),
    'tr_law_data': PosixPath('data/tr/tr_law_data.json'),
    'tr_journal_data': PosixPath('data/tr/tr_journal_data.json'),
    'tr_article_data': PosixPath('data/tr/tr_article_data.json'),
    'vl_data': PosixPath('data/vl'),
    'vl_law_data': PosixPath('data/vl/vl_law_data.json'),
    'vl_journal_data': PosixPath('data/vl/vl_journal_data.json'),
    'vl_article_data': PosixPath('data/vl/vl_article_data.json'),
    'ts_data': PosixPath('data/ts'),
    'ts_all_data': PosixPath('data/ts/test.jsonl'),
    'ts_sample_submission': PosixPath('data/ts/sample_submission.csv'),
    'ckpt_dir': PosixPath('ckpt'),
    'seed': 42,
    'tokenizer_name': 'beomi/KcELECTRA-base',
    'tokenizer': PreTrainedTokenizerFast(name_or_path='beomi/KcELECTRA-base', vocab_size=50135, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'm

In [3]:
!nvidia-smi; free -h

Sat Nov  6 13:48:41 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 470.74       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:0A:00.0  On |                  N/A |
|  0%   41C    P8    25W / 220W |    328MiB /  7979MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## **Prepare Dataset**

### **Naming**

제공받은 데이터세트의 이름을 조금 변경했습니다.

In [5]:
# !tree -alh

### **Data Format**

참조: https://aihub.or.kr/aidata/8054

In [7]:
def print_tr_sample(sample: dict) -> None:
    tmp = copy.deepcopy(sample)
    tmp["documents"] = [tmp["documents"][0], tmp["documents"][-1]]
    pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(tmp)

## Print sample.
with open(args.tr_law_data, "r") as f:
    sample = json.loads(f.read())
    
print_tr_sample(sample)

{   'name': '법률문서 프로젝트',
    'delivery_date': '2020-12-23 17:23:13',
    'documents': [   {   'id': '100004',
                         'category': '일반행정',
                         'size': 'small',
                         'char_count': 377,
                         'publish_date': '19841226',
                         'title': '부당노동행위구제재심판정취소',
                         'text': [   [   {   'index': 0,
                                             'sentence': '원고가 소속회사의 노동조합에서 분규가 '
                                                         '발생하자 노조활동을 구실로 정상적인 '
                                                         '근무를 해태하고,',
                                             'highlight_indices': ''},
                                         {   'index': 1,
                                             'sentence': '노조조합장이 사임한 경우,',
                                             'highlight_indices': ''},
                                         {   'index': 2,
                                   

In [5]:
def print_ts_sample(sample: list) -> None:
    tmp = copy.deepcopy(sample)
    tmp = [tmp[0], tmp[-1]]
    pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(tmp)

## Print sample.
with open(args.ts_all_data, "r") as f:
    sample = [json.loads(line) for line in f]
    
print_ts_sample(sample)

[   {   'id': '368851881',
        'article_original': [   '한국은행이 지난달 기준금리를 추가 인하한 영향으로 대출금리가 일제 하락했다.',
                                '특히 기업대출금리는 1996년 관련 통계 집계 이후 가장 낮은 수준까지 하락했다.',
                                "한국은행이 28일 발표한 '2019년 10월 중 금융기관 가중평균금리'에 따르면 "
                                '신규취급액 기준 지난달 예금은행의 대출평균금리는 연 3.20%로 전월 대비 '
                                '0.11%포인트 하락했다.',
                                '대출금리를 기업과 가계로 나눠보면 기업대출금리 하락폭이 가팔랐다.',
                                '10월 기업대출금리는 연 3.28%로 전월 대비 0.14%포인트 떨어졌다.',
                                '연 3.28%는 1996년 1월 관련 통계 작성 이후 최저치다.',
                                '대기업의 경우 기준이 되는 단기지표 금리의 하락 영향으로 10월 대출금리가 연 '
                                '3.13%로 전월 대비 0.17%포인트 빠졌다.',
                                '중소기업은 일부 은행의 저금리 대출 취급 등의 영향으로 연 3.39%로 '
                                '0.11%포인트 하락했다.',
                                '한은 관계자는 "지난달 기준금리가 인하되면서 단기 지표에 영향을 많이 받는 '
                                '기업대출 금리도 역대

### **DataLoader**

In [4]:
class TrainDataset(torch.utils.data.Dataset):
    
    def __init__(
        self, 
        mode: str,
        data_path: Path, 
        pt_path: Path = args.pt_path,
        sample_submission_path: Path = args.ts_sample_submission,
        max_cls_len: int = args.max_cls_len, ## 50
        vocab_size: int = 512,
        inp_pad_id: int = 0,
        cls_pad_id: int = -1,
        seg_pad_id: int = 0,
        num_ext_answers: int = 3,
        tokenizer = args.tokenizer,
        debug: bool = False, ## referenced by Flask
    ):
        if not (mode in ["tr", "vl", "ts"]):
            raise AssertionError(f"Mode must be the one of 'tr', 'vl', or 'ts': not {mode}")
        self.mode = mode
        
        self.data_path = data_path
        self.pt_path = pt_path
        self.sample_submission_path = sample_submission_path
        self.max_cls_len = max_cls_len
        self.vocab_size = vocab_size ## maximum embedding length
        self.inp_pad_id = inp_pad_id
        self.cls_pad_id = cls_pad_id
        self.seg_pad_id = seg_pad_id
        self.tokenizer = tokenizer        
        self.num_ext_answers = num_ext_answers
        
        self.debug = debug
        
        self.dummy_tar_ext = np.array(range(self.num_ext_answers))
        self.dummy_tar_abs = np.array([self.tokenizer.cls_token_id, self.tokenizer.sep_token_id]) ## [2, 3]
        
        ## Error types.
        self.errors = {
            "none_in_ext_answer": 0,
            "leak_of_ext_answer": 0, ## len(tar_ext) < 3
            "duplicated_ext_answer": 0, ## e.g. [0, 0, 3]
            "too_many_sentences": 0,
            "out_of_answer_index": 0,
        }
        
        self.df = self._data_loader() ## dataframe format
        
        
    def _get_documents(self, data_path: Path) -> list:
        print(f"Loading json data in {data_path}:")
        
        documents = []

        ## If training and validation phase...
        if self.mode in ["tr", "vl"]:
            for file_path in sorted(list(data_path.glob("*.json"))):
                print(f"  - {file_path}...")
                with open(file_path, "r", encoding="utf-8") as f:
                    document = json.loads(f.read()) ## list type -> not append, but extent
                documents.extend(document["documents"])
                
        ## If test phase...
        elif self.mode in ["ts"]:
            for file_path in sorted(list(data_path.glob("*.jsonl"))):
                print(f"  - {file_path}...")
                with open(file_path, "r", encoding="utf-8") as f:
                    documents.extend([json.loads(line) for line in f])

        return documents
                      
        
    def _data_loader(self) -> dict:
        ## Check if the dataset already constructed or not.
        df_name = Path(self.pt_path, f"{self.mode}_df{'_debug' if self.debug else ''}.pt")
        
        if df_name.is_file():
            print(f"Preprocessed dataframe is already exist: loading {df_name}...")
            # df = pd.DataFrame(torch.load(df_name))
            df = torch.load(df_name)
        
        else:
            print(f"Preprocessed dataframe is not exist: constructing {df_name}...")
            documents = self._get_documents(self.data_path)
            df = self._construct_dataframe(documents)
            # torch.save(df.to_dict(), df_name)
            torch.save(df, df_name)
        
        ## Print the informations.
        print(f"File {df_name} loaded:")
        print(f"  - Shape: {df.shape}")
        print(f"  - Columns: {list(df.columns)}")
        print(f"  - Errors:")
        print(*[f"    * {key}: {value}" for key, value in self.errors.items()], sep="\n")
        print(end="\n"*1)
        
        return df
        

    def _construct_dataframe(self, documents: list) -> dict:
        ## Empty dictionary.
        data = {"inp": [], "tar_ext": [], "tar_abs": []}
        # data = {"inp": [], "tar_ext": []}
        
        ## Training or validation phase...
        if self.mode in ["tr", "vl"]:
            for document in documents:
                flag = False
                
                ## Maybe some errors can be occured in answers.
                if None in document["extractive"]:
                    self.errors["none_in_ext_answer"] += 1
                    flag = True
                    
                if len(document["extractive"]) < self.num_ext_answers:
                    self.errors["leak_of_ext_answer"] += 1
                    flag = True
                    
                if len(np.unique(document["extractive"])) < self.num_ext_answers:
                    self.errors["duplicated_ext_answer"] += 1
                    flag = True

                ## Elements.
                inp = [self._clean_text(sentence["sentence"]) for sentence in itertools.chain(*document["text"])]
                tar_ext = document["extractive"]
                tar_abs = document["abstractive"]
                
                ## Check the number of sentences.
                if len(inp) > self.max_cls_len:
                    self.errors["too_many_sentences"] += 1
                    ## It's not a serious error, so we need to drop the documents.
                    inp = inp[:self.max_cls_len]
                    # flag = True
                
                ## Check if extractive answer is out-of-data or not.
                ## Thus the element 'tar_ext' is index type, 
                ## we need to use ">= (great or equal)", not "> (great)".
                if max(tar_ext) >= self.max_cls_len:
                    self.errors["out_of_answer_index"] += 1
                    flag = True
                
                ## Insert if flag == False.
                if not flag:
                    data["inp"].append(inp)
                    data["tar_ext"].append(tar_ext)
                    data["tar_abs"].append(tar_abs)

                ## In development stage, we limit the maximum document size to 10 for fast experiments.
                if self.debug and len(data["inp"]) >= 10:
                    break
        
        ## Test phase...
        elif self.mode in ["ts"]:
            ## First, we need to sort the indexes as in 'sample_submission.csv'.
            inp_ids = np.array([int(document["id"]) for document in documents])
            tar_ids = np.array(pd.read_csv(self.sample_submission_path, index_col=False)["id"])
            
            reallocated_idx = np.concatenate([np.where(inp_ids == i)[0] for i in tar_ids])
            documents = np.array(documents)[reallocated_idx]
            
            for document in documents:                
                ## Elements.
                inp = [self._clean_text(sentence) for sentence in document["article_original"]]
                tar_ext = self.dummy_tar_ext ## dummy
                tar_abs = self.dummy_tar_abs ## dummy
                
                ## Check the number of sentences.
                if len(inp) > self.max_cls_len:
                    self.errors["too_many_sentences"] += 1
                    ## We cannot drop any items in test dataset.
                    inp = inp[:self.max_cls_len]
                    # continue
                
                ## Insert.
                data["inp"].append(inp)
                data["tar_ext"].append(tar_ext)
                data["tar_abs"].append(tar_abs)
                
                ## In development stage, we limit the maximum document size to 10 for fast experiments.
                if self.debug and len(data["inp"]) >= 10:
                    break

        ## Convert to dataframe.
        df = pd.DataFrame(data)

        ## Encoding.
        df["inp"] = self._tokenize(df["inp"])
        
        ## Generate 'cls' s.t. means the index of tokens.
        df["cls"] = df["inp"].map(lambda x: np.concatenate([np.where(x == self.tokenizer.cls_token_id)[0], [len(x)]]))
        
        ## Generate 'seg' s.t. means segmentation embeddings which represented as [0, 0, ..., 0, 1, ..., 1, 0, 0, ...].
        df["seg"] = df["cls"].map(lambda x: list(itertools.starmap(lambda x, y: [x] * y, zip(np.arange(len(np.diff(x))) % 2, np.diff(x)))))
        df["seg"] = df["seg"].map(lambda x: np.array(list(itertools.chain.from_iterable(x))))
        
        ## Drop the last token in cls.
        df["cls"] = df["cls"].map(lambda x: x[:-1])
        
        ## Padding.
        self.max_inp_len = max(df["inp"].map(lambda x: len(x)))
        self.max_cls_len = max(df["cls"].map(lambda x: len(x))) if self.mode in ["ts"] else self.max_cls_len
        
        df["inp"] = self._pad(df["inp"], self.inp_pad_id, self.max_inp_len) ## 0
        df["cls"] = self._pad(df["cls"], self.cls_pad_id, self.max_cls_len) ## -1
        df["seg"] = self._pad(df["seg"], self.seg_pad_id, self.max_inp_len) ## 0
        df["msk"] = df["inp"].map(lambda x: ~(x == self.inp_pad_id))
        df["msk_cls"] = df["cls"].map(lambda x: ~(x == self.cls_pad_id))
        
        ## One hot label.
        if self.mode in ["tr", "vl"]:
            ## Extractive.
            df["tar_ext"] = df["tar_ext"].map(lambda x: self._one_hot_encoding(x))
            ## Abstractive.
            df["tar_abs"] = self._tokenize(df["tar_abs"])
            df["tar_abs"] = self._pad(df["inp"], self.inp_pad_id, self.max_inp_len) ## 0
            df["tar_abs_msk"] = df["tar_abs"].map(lambda x: ~(x == self.inp_pad_id))
            
        elif self.mode in ["ts"]:
            pass
        
        ## Reallocate the columns' name.
        df = df[["inp", "cls", "seg", "msk", "msk_cls", "tar_ext", "tar_abs"]]
        # df = df[["inp", "cls", "seg", "msk", "msk_cls", "tar_ext"]]
        
        return df
    
    
    def _clean_text(self, text: str) -> str:
        ## Ref. https://blog.naver.com/PostView.nhn?blogId=wideeyed&logNo=221347960543
        
        ## Remove email.
        pattern = "([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" 
        text = re.sub(pattern=pattern, repl="", string=text)
        
        ## Remove URL.
        pattern = "(http|ftp|https)://(?:[-\w.]|(?:%[\da-fA-F]{2}))+"
        text = re.sub(pattern=pattern, repl="", string=text)
        
        ## Stand-alone korean 자음/모음.
        pattern = "([ㄱ-ㅎㅏ-ㅣ]+)"
        text = re.sub(pattern=pattern, repl="", string=text)
        
        ## HTML tags.
        pattern = "<[^>]*>"
        text = re.sub(pattern=pattern, repl="", string=text)
        
        ## Specail words.
        # pattern = "[^\w\s]"
        # text = re.sub(pattern=pattern, repl="", string=text)
        
        ## Strip.
        text = text.strip()
        
        ## Remove double space, line feed, carrage returns.
        text = " ".join(text.split())
        
        return text
    
    
    def _tokenize(self, data: pd.Series, truncation: bool = True, add_special_tokens: bool = True) -> pd.Series:
        ## Tokenize input and abstractive target.
        ## Eithers must be in type of list format.
        ##  e.g. [sent_1, sent_2, ...] -> for using 'itertools.chain.from_iterable'.
        return data.map(lambda x: np.array(list(itertools.chain.from_iterable([self.tokenizer.encode(
            x[i], max_length=int(self.vocab_size / len(x)), truncation=truncation, add_special_tokens=add_special_tokens,
        ) for i in range(len(x))]))))
        

    def _pad(self, data: pd.Series, pad_id: int, max_len: int) -> pd.Series:
        ## When len(x) == max_len, the concatenate function will be try to
        ## concat with empty list [], and it will be forced to casting from
        ## long (int64) to double (float32).
        return data.map(lambda x: x if len(x) == max_len else np.concatenate([x, np.array([pad_id] * (max_len - len(x)))]))
    
    
    def _one_hot_encoding(self, tar: list) -> list:
        return np.sum(np.eye(self.max_cls_len)[np.array(tar), :], axis=0)
    
    
    def __len__(self) -> int:
        return self.df.shape[0]

    
    def __getitem__(self, idx) -> dict:
        return {key: torch.from_numpy(value) for key, value in self.df.loc[idx, :].to_dict().items()}

In [8]:
%%time
tr_ds = TrainDataset("tr", data_path=args.tr_data, debug=False)
vl_ds = TrainDataset("vl", data_path=args.vl_data, debug=False)
ts_ds = TrainDataset("ts", data_path=args.ts_data, debug=False)

tr_dataloader = torch.utils.data.DataLoader(tr_ds, batch_size=args.global_batch_size, num_workers=args.num_workers, shuffle=True)
vl_dataloader = torch.utils.data.DataLoader(vl_ds, batch_size=args.global_batch_size, num_workers=args.num_workers)
ts_dataloader = torch.utils.data.DataLoader(ts_ds, batch_size=args.global_batch_size, num_workers=args.num_workers)

Preprocessed dataframe is not exist: constructing data/tr_df.pt...
Loading json data in data/tr:
  - data/tr/tr_article_data.json...
  - data/tr/tr_journal_data.json...
  - data/tr/tr_law_data.json...
File data/tr_df.pt loaded:
  - Shape: (361038, 7)
  - Columns: ['inp', 'cls', 'seg', 'msk', 'msk_cls', 'tar_ext', 'tar_abs']
  - Errors:
    - none_in_ext_answer: 8
    - leak_of_ext_answer: 0
    - duplicated_ext_answer: 71
    - too_many_sentences: 239
    - out_of_answer_index: 76

Preprocessed dataframe is not exist: constructing data/vl_df.pt...
Loading json data in data/vl:
  - data/vl/vl_article_data.json...
  - data/vl/vl_journal_data.json...
  - data/vl/vl_law_data.json...
File data/vl_df.pt loaded:
  - Shape: (40130, 7)
  - Columns: ['inp', 'cls', 'seg', 'msk', 'msk_cls', 'tar_ext', 'tar_abs']
  - Errors:
    - none_in_ext_answer: 1
    - leak_of_ext_answer: 0
    - duplicated_ext_answer: 2
    - too_many_sentences: 5
    - out_of_answer_index: 1

Preprocessed dataframe is not e

In [7]:
# tr_ds[0]

## **Modeling**

In [9]:
class KoBERTSumExt(torch.nn.Module):
    
    def __init__(
        self,
        d_model: int = args.d_model,
        nhead: int = args.nhead,
        dropout: float = args.dropout,
        num_layers: int = args.num_layers,
        batch_first: bool = args.batch_first,
        name: str = None
    ):
        super(KoBERTSumExt, self).__init__()
        ## Encoder.
        self.encoder = transformers.BertModel.from_pretrained(args.tokenizer_name)
        
        ## Decoder.
        self.d_model = d_model
        self.nhead = nhead
        self.dropout = dropout
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.name = name
        
        self.decoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=self.d_model, 
                nhead=self.nhead, 
                dropout=self.dropout,
                batch_first=self.batch_first,
            ),
            num_layers=self.num_layers,
        )
        
        ## Fully connected.
        self.fc = torch.nn.Linear(self.d_model, 1)

        
    @torch.cuda.amp.autocast()
    def forward(self, inp, cls, seg, msk, msk_cls) -> torch.tensor:
        ## Pretrained language model encoder.
        top_vec = self.encoder(
            input_ids=inp.long(), 
            attention_mask=msk.float(), ## bool -> float
            token_type_ids=seg.long(),
        ).last_hidden_state
        
        sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), cls.long()]
        sents_vec = sents_vec * msk_cls[..., None].float() ## (batch, 50, d_model)
        
        ## Transformer decoder.
        sent_scores = self.decoder(sents_vec)
        
        ## FC layer.
        sent_scores = self.fc(sent_scores).squeeze(-1)
        
        return sent_scores

In [10]:
def get_model():
    ## Model naming.
    KST = datetime.timezone(datetime.timedelta(hours=9))
    model_name = datetime.datetime.now(tz=KST).strftime("%Y%m%d-%H%M%S")

    ## Generate model.
    model = KoBERTSumExt(name=model_name)
    clear_output(wait=True)

    ## Multi-gpu setting.
    if args.n_gpus > 1:
        print(f"{args.n_gpus} gpus available.")
        model = torch.nn.DataParallel(model)

    ## RAM to VRAM.
    _ = model.to(args.device)

    print(f"Model {model.name} generated.")
    
    return model

## **Hyperparameters for Compile**

In [11]:
def get_optimizer():
    optimizer = torch_optimizer.RAdam(
        model.parameters(), 
        lr=args.lr,
        weight_decay=args.weight_decay,
    )
    return optimizer


def get_scheduler(optimizer, steps_per_epoch: int):
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer, 
        pct_start=args.pct_start, 
        max_lr=args.max_lr,
        epochs=args.epochs, 
        steps_per_epoch=steps_per_epoch,
    )
    return scheduler


def get_loss_fn():
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
    return loss_fn


def get_metric_fn():
    def hitrate(y_true, y_pred, num_ans: int = 3):
        h = [len(list(set(ans).intersection(y_true[i]))) / num_ans for i, ans in enumerate(y_pred)]
        score = np.mean(np.array(h))
        return score
    return hitrate

## **Checkpoint Manager**

In [12]:
class CheckpointManager():
    
    def __init__(
        self,
        model_name: str,
        ckpt_dir: Path = args.ckpt_dir,
        monitor: str = args.monitor,
        max_to_keep: int = args.max_to_keep,
        mode: str = "less_good",
    ):
        self.model_name = model_name
        self.ckpt_dir = ckpt_dir
        self.monitor = monitor
        self.max_to_keep = max_to_keep
        self.mode = mode
        if not (mode in ["less_good", "great_good"]):
            raise AssertionError(f"Monitering mode must be 'less_good' or 'great_good': not {mode}")

        self.latest_monitering_value = np.inf if self.mode == "less_good" else -np.inf
            
        ## Directory s.t. ckpt will be stored.
        self.save_dir = self.ckpt_dir / Path(model_name)
        self.save_dir.mkdir(parents=True, exist_ok=False)
        
        
        print(f"Checkpoint manager is now ready.")
        print(f"  - Save path: {self.save_dir}")
        
    
    def latest(self):
        ckpt_list = sorted(list(self.ckpt_dir.glob("*.pt")))
        
        ## If checkpoints are exist,
        if len(ckpt_list) != 0:
            latest_ckpt_name = ckpt_list[-1]
            latest_ckpt = torch.load(latest_ckpt_name)
            print(f"Latest checkpoint {latest_ckpt_name} loaded.")
        
        ## If no valid checkpoints are exist,
        else:
            latest_ckpt = None
            
        return latest_ckpt
    
    
    def _make_clean(self) -> None:
        ckpt_list = sorted(list(self.ckpt_dir.glob("*.pt")))
        
        ## # of data must be more then 'max_to_keep'.
        if len(ckpt_list) <= self.max_to_keep:
            print(f"Noting to clean: {len(ckpt)} checkpoints exist.")
            return
    
        tar_ckpt_list = ckpt_list[:-self.max_to_keep]
        free_size = sum([c.stat().st_size for c in tar_ckpt_list])
        _ = [c.unlink() for c in tar_ckpt_list]
        
        print(f"Checkpoint folder {self.ckpt_dir} is now clean, {free_size / (2 ** 20):.2f}MB free.")
    
    
    def save(self, epoch: int, model_state_dict: dict, optimizer_state_dict: dict, performance: dict) -> None:
        ## Monitoring value must be located in performance dictionary.
        monitering_value = performance.get(self.monitor)
        
        ## Naming.
        fname = self.save_dir / Path(f"cp-{epoch:03d}-{monitoring_value:.6f}.pt")
        
        ## Save it.
        torch.save({
            "epoch": epoch,
            "model_state_dict": model_state_dict,
            "optimizer_state_dict": optimizer_state_dict,
            self.monitor: monitoring_value,
        }, fname)
        
        ## Make clean directory.
        self._make_clean()
        
        ## Keep memory.
        # del model_state_dict, optimizer_state_dict
        
        ## Update.
        self.latest_monitering_value = monitering_value
        
        
    def is_need_to_save(self, monitering_value: float) -> bool:
        if self.mode == "less_good":
            answer = True if monitering_value < self.latest_monitering_value else False
        elif self.mode == "great_good":
            answer = True if monitering_value > self.latest_monitering_value else False
            
        return answer

## **Trainer**

In [13]:
class Trainer():
    
    def __init__(
        self,
        model,
        steps_per_epoch: int,
        epochs: int = args.epochs,
    ):
        self.model = model
        self.steps_per_epoch = steps_per_epoch
        self.epochs = epochs
        
        self.device = args.device
        self.optimizer = get_optimizer()
        self.scheduler = get_scheduler(self.optimizer, self.steps_per_epoch)
        self.loss_fn = get_loss_fn()
        self.metric_fn = get_metric_fn()
        self.ckpt_manager = CheckpointManager(model.name)
        
        self.num_ext_answers = 3
        self.scaler = torch.cuda.amp.GradScaler()
        
        
    def fit(self, tr_dataloader, vl_dataloader):
        
        for epoch_index in range(self.epochs):
            tr_result = self.train_epoch(tr_dataloader, epoch_index)
            ##########
            continue
            ##########
            vl_result = self.validate_epoch(vl_dataloader, epoch_index)
            
            ## Record performance.
            performance = {}
            performance.update(tr_result)
            performance.update(vl_result)
            
            ## Save.
            monitering_value = performance.get(self.ckpt_manager.monitor)
            if monitering_value == None:
                raise ValueError(f"Performances must have the element '{self.ckpt_manager.moniter}': {performance.keys()}")
            
            if self.ckpt_manager.is_need_to_save(monitering_value):
                self.ckpt_manager.save(
                    epoch=epoch_index,
                    model_state_dict=self.model.state_dict(),
                    optimizer_state_dict=self.optimizer.state_dict(),
                    performance=performance,
                )
            
            
    def predict(self):
        pass
                
    
    def train_epoch(self, dataloader, epoch_index: int) -> dict:
        self.model.train()
        self.tr_total_loss = 0
        
        tot_y_true = []
        tot_y_pred = []
        
        tqdm_dataloader = tqdm(dataloader, desc=f"[Epoch {epoch_index + 1:02}]")
        
        ## Batch iteration.
        for batch_index, data in enumerate(tqdm_dataloader):
            ## Initialize optimizer.
            self.optimizer.zero_grad()
            
            ## Unpack.
            inp = data["inp"].to(self.device)
            cls = data["cls"].to(self.device)
            seg = data["seg"].to(self.device)
            msk = data["msk"].to(self.device)
            msk_cls = data["msk_cls"].to(self.device)
            
            tar_ext = data["tar_ext"].to(self.device)
            # tar_abs = data["tar_abs"].to(self.device)
            
            ## Forward.
            with torch.cuda.amp.autocast():
                sent_score = self.model(inp, cls, seg, msk, msk_cls)
                loss = (self.loss_fn(sent_score, tar_ext.float()) * msk_cls.float()).sum() / sent_score.size(0)
                                
            ## Backward.
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            
            scale = self.scaler.get_scale()
            self.scaler.update()
            
            ## No update scheduler when errors occured in mixed precision policy.
            if scale == self.scaler.get_scale():
                self.scheduler.step()

            ## Keep loss.
            self.tr_total_loss += loss
                
            ## Inference.
            cur_y_true = torch.where(tar_ext == 1)[1].reshape(-1, self.num_ext_answers).tolist()
            cur_y_pred = torch.topk(torch.sigmoid(sent_score) * msk_cls.float(), self.num_ext_answers, axis=1).indices.tolist()
            
            tot_y_true.extend(cur_y_true)
            tot_y_pred.extend(cur_y_pred)
            
            ## Show the results.
            tqdm_dataloader.set_postfix(self._get_assets(batch_index, self.tr_total_loss, tot_y_true, tot_y_pred))
            
        ## Total loss and accuracy in one epoch.
        return self._get_assets(batch_index, tr_total_loss, tot_y_true, tot_y_pred, as_str=False)
    
    
    def _get_assets(self, batch_index: int, total_loss: float, y_true: list, y_pred: list, as_str: bool = True) -> dict:
        ## We define to show loss and accuracy.
        assets = {
            "loss": total_loss / (batch_index + 1),
            "acc": self.metric_fn(y_true=y_true, y_pred=y_pred),
        }
        ## When using tqdm dataloader, we need to convert float to string.
        if as_str:
            assets = {key: f"{value:.6f}" for key, value in assets.items()}
        
        return assets

## **Train**

In [14]:
model = get_model()
trainer = Trainer(model, steps_per_epoch=len(tr_dataloader))

Model 20211106-135835 generated.
Checkpoint manager is now ready.
  - Save path: ckpt/20211106-135835


In [15]:
trainer.fit(tr_dataloader, vl_dataloader)

[Epoch 01]:   7%|█              | 6226/90260 [15:53<3:34:30,  6.53it/s, loss=6.856729, acc=0.384851]


KeyboardInterrupt: 