## Import Package


In [None]:
from google.colab import drive
drive.mount('./MyDrive')

In [None]:
import tensorflow
from tensorflow.python.client import device_lib

device_lib.list_local_devices()

In [None]:
!pip install transformers
!pip install selenium
!apt-get update
!apt install chromium-chromedriver
!pip install pororo

In [None]:
import os
import random
import tarfile
import json
from tqdm import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch.cuda.amp as amp
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

import transformers
from transformers import TrainingArguments, Trainer, AdamW, get_scheduler
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer, EarlyStoppingCallback, TrainerCallback, TrainerControl

import pororo
from pororo import Pororo

import selenium
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

## Data Preprocessing

### Train Data Distribution

In [None]:
train = pd.read_csv('train_data.csv')

train_pre = train['premise'].str.cat(sep='')
train_hyp = train['hypothesis'].str.cat(sep='')

train_union = set.union(set(train_pre), set(train_hyp))
print(''.join(sorted(train_union)))

In [None]:
train_all = train['premise'] + train['hypothesis']

### KLUE Official Dev Data

In [None]:
!wget https://aistages-prod-server-public.s3.amazonaws.com/app/Competitions/000068/data/klue-nli-v1.1.tar.gz

In [None]:
tar = tarfile.open('klue-nli-v1.1.tar.gz')
tar.extractall()
tar.close()

with open('klue-nli-v1.1/klue-nli-v1.1_dev.json') as f:
  df = pd.DataFrame(json.load(f))

KLUE = pd.DataFrame(df[['premise', 'hypothesis', 'gold_label']])
KLUE.columns = ['premise', 'hypothesis', 'label']

KLUE.head()

In [None]:
# KLUE 데이터셋 단어 집합 확인

KLUE_pre = KLUE['premise'].str.cat(sep='')
KLUE_hyp = KLUE['hypothesis'].str.cat(sep='')
KLUE_union = set.union(set(KLUE_pre), set(KLUE_hyp))
print(''.join(sorted(KLUE_union)))

### KorNLI

In [None]:
!gdown https://drive.google.com/uc?id=1UJKeJneCtKt_bSH_CXqv_gcMxH7ThfMf

In [None]:
tar = tarfile.open('/content/drive/MyDrive/dacon/KorNLI.tar')
tar.extractall('KorNLI')
tar.close()

with open('KorNLI/xnli.dev.ko.tsv') as f:
  nli_devfile = f.read().splitlines()

with open('KorNLI/xnli.test.ko.tsv') as g:
  nli_testfile = g.read().splitlines()
     
korNLI_dev_list = [i.split('\t') for i in nli_devfile]
korNLI_test_list = [i.split('\t') for i in nli_testfile]

korNLI_dev = pd.DataFrame(korNLI_dev_list[1:], columns=korNLI_dev_list[0])
korNLI_test = pd.DataFrame(korNLI_test_list[1:], columns=korNLI_test_list[0])

korNLI = pd.concat([korNLI_dev, korNLI_test], ignore_index=True)

In [None]:
korNLI_dev.shape

In [None]:
korNLI_test.shape

In [None]:
korNLI.columns = ['premise', 'hypothesis', 'label']

korNLI

In [None]:
korNLI_pre = korNLI['premise'].str.cat(sep='')
korNLI_hyp = korNLI['hypothesis'].str.cat(sep='')

korNLI_union = set.union(set(korNLI_pre), set(korNLI_hyp))
print(''.join(sorted(korNLI_union)))

In [None]:
# 따옴표 변경
apostrophes = ['‘', '’']
for apostrophe in apostrophes:
  korNLI['hypothesis'] = korNLI['hypothesis'].str.replace(apostrophe, '\'')
  korNLI['premise'] = korNLI['premise'].str.replace(apostrophe, '\'')

apostrophes = ['“', '”']
for apostrophe in apostrophes:
  korNLI['hypothesis'] = korNLI['hypothesis'].str.replace(apostrophe, '\"')
  korNLI['premise'] = korNLI['premise'].str.replace(apostrophe, '\"')

In [None]:
korNLI[korNLI['premise'].str.contains('=')].iloc[0]['premise']

In [None]:
# 7) 잘라버리기 (726, 727, 728)
korNLI.loc[korNLI['premise'].str.contains('='), 'premise'] = korNLI[korNLI['premise'].str.contains('=')]['premise'].str[3:].copy()
korNLI[korNLI['premise'].str.contains('=')].iloc[2]

In [None]:
korNLI[korNLI['premise'].str.contains('…')]

In [None]:
# ...  는 .로 대체
korNLI['premise'] = korNLI['premise'].str.replace("…", ".")

In [None]:
korNLI[korNLI['premise'].str.contains('_')]

In [None]:
korNLI.loc[korNLI['premise'].str.contains('_'), 'premise'] = korNLI[korNLI['premise'].str.contains('_')]['premise'].str[:4] + korNLI[korNLI['premise'].str.contains('_')]['premise'].str[8:]
korNLI[korNLI['premise'].str.contains('_')]

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("McCohe__", "McCoy는 The")
korNLI.iloc[1923]

In [None]:
korNLI[korNLI['premise'].str.contains('·')]

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("·", "")

In [None]:
korNLI[korNLI['premise'].str.contains('『')]

In [None]:
korNLI[korNLI['premise'].str.contains('《')]

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("[《》『』]", "'")
korNLI['hypothesis'] = korNLI['hypothesis'].str.replace("[《》『』]", "'")

In [None]:
korNLI[korNLI['premise'].str.contains('`')]

In [None]:
korNLI.iloc[7155]['premise']

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("`", "")

In [None]:
korNLI[korNLI['premise'].str.contains('\$')]

In [None]:
korNLI[korNLI['hypothesis'].str.contains('\$')]

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("00 생활보호를", "생활보호를")
korNLI['premise'] = korNLI['premise'].str.replace("\$3을", "3달러를")
korNLI['premise'] = korNLI['premise'].str.replace("\$-", "돈")
korNLI['hypothesis'] = korNLI['hypothesis'].str.replace("\$500를", "500달러를")
korNLI['hypothesis'] = korNLI['hypothesis'].str.replace("\$500이", "500달러가")

In [None]:
korNLI.iloc[6214]

In [None]:
# 변경된 데이터셋 단어 집합 확인
korNLI_dropped_pre = korNLI['premise'].str.cat(sep='')
korNLI_dropped_hyp = korNLI['hypothesis'].str.cat(sep='')
korNLI_dropped_union = set.union(set(korNLI_dropped_pre), set(korNLI_dropped_hyp))
print(''.join(sorted(korNLI_dropped_union)))

In [None]:
korNLI[korNLI['premise'].str.contains('\[')]

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace("\[", "")
korNLI['premise'] = korNLI['premise'].str.replace("\]", "")

In [None]:
korNLI['premise'] = korNLI['premise'].str.replace(";", "")

In [None]:
korNLI[korNLI['premise'].str.contains('/')]

In [None]:
korNLI.drop(korNLI[korNLI['premise'].str.contains('/')].index, inplace=True)
korNLI.drop(korNLI[korNLI['hypothesis'].str.contains('/')].index, inplace=True)
korNLI = korNLI.reset_index(drop=True)

In [None]:
korNLI[korNLI['premise'].str.contains('-')]

In [None]:
# 변경된 데이터셋 단어 집합 확인
korNLI_dropped_pre = korNLI['premise'].str.cat(sep='')
korNLI_dropped_hyp = korNLI['hypothesis'].str.cat(sep='')
korNLI_dropped_union = set.union(set(korNLI_dropped_pre), set(korNLI_dropped_hyp))
print(''.join(sorted(korNLI_dropped_union)))

In [None]:
korNLI['index'] = korNLI.index
korNLI = korNLI[['index', 'premise', 'hypothesis', 'label']]

In [None]:
korNLI

In [None]:
korNLI.to_csv('korNLI_final.csv', index=False)

### Back Translation

In [None]:
train=pd.read_csv('train_data.csv')

In [None]:
def chrome_setting():
  chrome_options = webdriver.ChromeOptions()
  chrome_options.add_argument('--headless')
  chrome_options.add_argument('--no-sandbox')
  chrome_options.add_argument('--disable-dev-shm-usage')
  driver = webdriver.Chrome('chromedriver', options=chrome_options)
  return driver

In [None]:
driver=chrome_setting()

In [None]:
def kor_to_trans(text_data, trans_lang,start_index,final_index):

  target_present = EC.presence_of_element_located((By.XPATH, '//*[@id="txtTarget"]'))

  for i in tqdm(range(start_index,final_index)): 
    
    if (i!=0)&(i%99==0):
      time.sleep(2)
      print('{}th : '.format(i), backtrans)
      np.save(data_path+'kor_to_eng_train_{}_{}.npy'.format(start_index,final_index),trans_list)
    
    try:
      driver.get('https://papago.naver.com/?sk=ko&tk='+trans_lang+'&st='+text_data[i])
      time.sleep(1.5)
      element=WebDriverWait(driver, 10).until(target_present)
      time.sleep(0.1)
      backtrans = element.text 

      if (backtrans=='')|(backtrans==' '):
        element=WebDriverWait(driver, 10).until(target_present)
        backtrans = element.text 
        trans_list.append(backtrans)
      else:
        trans_list.append(backtrans)
    
    except:
      trans_list.append('')

In [None]:
trans_list=[]
kor_to_trans(train['premise'], 'en',0,10000)
np.save(data_path+'kor_to_eng_train_0_10000.npy',trans_list)

In [None]:
eng_data0=np.load(data_path+'kor_to_eng_train_0_10000.npy')
eng_data1=np.load(data_path+'kor_to_eng_train_10000_20000.npy')
eng_data2=np.load(data_path+'kor_to_eng_train_20000_all.npy')

eng_data=[*eng_data0, *eng_data1, *eng_data2]
eng_data=pd.DataFrame(eng_data,columns=['eng_premise'])

In [None]:
back_train=pd.concat([train,eng_data],axis=1)
back_train

In [None]:
# Reset selenium chrome driver
driver=chrome_setting()

In [None]:
def trans_to_kor(transed_list, transed_lang,start_index,final_index): 
  
  target_present = EC.presence_of_element_located((By.XPATH, '//*[@id="txtTarget"]'))

  for i in tqdm(range(start_index,final_index)): 
    
    if (i!=0)&(i%99==0):
      time.sleep(1.5)
      print('{}th : '.format(i), backtrans)
      np.save(data_path+'eng_to_kor_train_{}_{}.npy'.format(start_index,final_index),trans_list)
    
    try:
      driver.get('https://papago.naver.com/?sk=en&tk='+transed_lang+'&st='+transed_list[i])
      time.sleep(2)
      element=WebDriverWait(driver, 10).until(target_present)
      time.sleep(0.2)
      backtrans = element.text 

      if (backtrans=='')|(backtrans==' '):
        element=WebDriverWait(driver, 10).until(target_present)
        backtrans = element.text 
        trans_list.append(backtrans)
      else:
        trans_list.append(backtrans)
    
    except:
      trans_list.append('')

In [None]:
trans_list=[]
trans_to_kor(back_train['eng_premise'], 'ko',0,10000)
np.save(data_path+'eng_to_kor_train_0_10000.npy',trans_list)

trans_list=[]
trans_to_kor(back_train['eng_premise'], 'ko',10000,20000)
np.save(data_path+'eng_to_kor_train_10000_20000.npy',trans_list)

trans_list=[]
trans_to_kor(back_train['eng_premise'], 'ko',20000,len(back_train))
np.save(data_path+'eng_to_kor_train_20000_all.npy',trans_list)

In [None]:
back0=np.load(data_path+'eng_to_kor_train_0_10000.npy')
back1=np.load(data_path+'eng_to_kor_train_10000_20000.npy')
back2=np.load(data_path+'eng_to_kor_train_20000_all.npy')
back_train_fin=[*back0,*back1,*back2]
back_train_fin=pd.DataFrame(back_train_fin,columns=['back_premise'])

In [None]:
back_train_fin=pd.concat([train,back_train_fin],axis=1)
back_train_fin

In [None]:
for i in range(0,len(back_train_fin)):
  back_train_fin.at[i,'back_premise']=back_train_fin['back_premise'][i].replace('U.S.','미국')
  back_train_fin.at[i,'back_premise']=re.sub(r'([.?!/\\]+)(?![0-9])', r'\1 ', back_train_fin['back_premise'][i]).replace('... ','...').replace('  ',' ').replace('.. ','..').strip()

In [None]:
# If the proportion of Hangul in the total sentence length is 0.6 or higher, it is considered an abnormal translation and an attempt to re-translate it.
# -> 전체 문장길이에서 한글가 차지하는 비중이 0.6이상이면 이상번역으로 간주하고 재번역 시도

# Attempt to re-translate a translated sentence if the translated sentence has a ratio of less than 0.5 to the length of an existing sentence
# -> 번역된 문장이 기존 문장의 길이에 대한 비율이 0.5이하이면 재번역 시도
retrans_ind=[]

for i in tqdm(range(0,len(back_train_fin))):
  kor_ratio=len(re.sub('[a-zA-Z]','',back_train_fin['back_premise'][i]))/(len(back_train_fin['back_premise'][i])+1)
  if kor_ratio<0.6:
    retrans_ind.append(i)
  if len(back_train_fin['back_premise'][i])/len(back_train_fin['premise'][i])<=0.5:
    retrans_ind.append(i)

retrans_ind=list(set(retrans_ind))

In [None]:
back_trans = pd.read_csv('back_trans.csv', index_col=0)

In [None]:
back_trans

In [None]:
# 변경된 데이터셋 단어 집합 확인
back_trans_pre = back_trans['premise'].str.cat(sep='')
back_trans_hyp = back_trans['hypothesis'].str.cat(sep='')
back_trans_all = set.union(set(back_trans_pre), set(back_trans_hyp))
print(''.join(sorted(back_trans_all)))

In [None]:
back_trans[back_trans['premise'].str.contains('故')]

In [None]:
back_trans[back_trans['hypothesis'].str.contains('故')]

In [None]:
back_trans['premise'] = back_trans['premise'].str.replace("故", '고')
back_trans['premise'] = back_trans['premise'].str.replace("京義線", ' ')
back_trans['premise'] = back_trans['premise'].str.replace("九州", ' ')
back_trans['premise'] = back_trans['premise'].str.replace("㎡", '제곱미터')
back_trans['premise'] = back_trans['premise'].str.replace("㎞", '킬로미터')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('ㄷ', ' ')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('옞ㅇ', '예정')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('%', '퍼센트')
back_trans['premise'] = back_trans['premise'].str.replace("%", '퍼센트')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('[「」]', ' ')
back_trans['premise'] = back_trans['premise'].str.replace("[「」]", ' ')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('[ᅵ…]', ' ')
back_trans['premise'] = back_trans['premise'].str.replace("[ᅵ…]", ' ')
back_trans['hypothesis'] = back_trans['hypothesis'].str.replace('[<>]', ' ')
back_trans['premise'] = back_trans['premise'].str.replace("[<>]", ' ')

In [None]:
# 변경된 데이터셋 단어 집합 확인
back_trans_pre = back_trans['premise'].str.cat(sep='')
back_trans_hyp = back_trans['hypothesis'].str.cat(sep='')
back_trans_all = set.union(set(back_trans_pre), set(back_trans_hyp))
print(''.join(sorted(back_trans_all)))

In [None]:
back_trans[back_trans['premise'].str.contains('\*')]

In [None]:
back_trans['premise'] = back_trans['premise'].str.replace("\*", '원이 ')

In [None]:
back_trans.iloc[1444]

In [None]:
# 변경된 데이터셋 단어 집합 확인
back_trans_pre = back_trans['premise'].str.cat(sep='')
back_trans_hyp = back_trans['hypothesis'].str.cat(sep='')
back_trans_all = set.union(set(back_trans_pre), set(back_trans_hyp))
print(''.join(sorted(back_trans_all)))

### Concat

In [None]:
output = pd.concat([train, KLUE, korNLI, back_trans])

In [None]:
output = output.reset_index(drop=True)
output['index'] = output.index
output

In [None]:
# 최종 훈련 데이터셋 단어 집합 확인
output_pre = output['premise'].str.cat(sep='')
output_hyp = output['hypothesis'].str.cat(sep='')
output_union = set.union(set(output_pre), set(output_hyp))
print(''.join(sorted(output_union)))

In [None]:
output = output.reset_index(drop=True)
output.drop('index', inplace=True, axis=1)

In [None]:
output.to_csv('train+KLUE+korNLI+back_trans(processed).csv', index=False)

## Load Data

In [None]:
PATH = '/content/MyDrive/MyDrive/Colab Notebooks/dacon'

train = pd.read_csv(os.path.join(PATH, 'data/train+KLUE+korNLI+back_trans(processed).csv'), encoding='utf-8')
test = pd.read_csv(os.path.join(PATH, 'data/test_data.csv'), encoding='utf-8')

train.head(5)

In [None]:
pre = train['premise'].str.cat(sep='')
hyp = train['hypothesis'].str.cat(sep='')

union = set.union(set(pre), set(hyp))
print(''.join(sorted(union)))

In [None]:
print(train[train['premise'].str.contains('ㄷ')])
print(train[train['premise'].str.contains('ㅇ')])

In [None]:
print(train[train['hypothesis'].str.contains('ㄷ')])
print(train[train['hypothesis'].str.contains('ㅇ')])

In [None]:
train['hypothesis'].iloc[20415] = train['hypothesis'].iloc[20415].replace("ㄷ.", "")
train['hypothesis'].iloc[12955] = train['hypothesis'].iloc[12955].replace("옞ㅇ", "예정")

## Modeling

### Fixing Seed/GPU

In [None]:
def seed_everything(seed:int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = True 

seed_everything(42)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

### Tokenizing/Splitting

In [None]:
MODEL_NAME = 'klue/roberta-large'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
class BERTDataset(torch.utils.data.Dataset):
    def __init__(self, pair_dataset, label):
        self.pair_dataset = pair_dataset
        self.label = label

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.pair_dataset.items()}
        item['label'] = torch.tensor(self.label[idx])
        
        return item

    def __len__(self):
        return len(self.label)

In [None]:
def label_to_num(label):
    label_dict = {"entailment": 0, "contradiction": 1, "neutral": 2, "answer": 3}
    num_label = []

    for v in label:
        num_label.append(label_dict[v])
    
    return num_label

In [None]:
skf = StratifiedKFold(n_splits=5, shuffle=True)
folds = []
for train_idx, eval_idx in skf.split(train, train['label']):
    train_data = train.iloc[train_idx]
    eval_data = train.iloc[eval_idx]

    folds.append([train_data, eval_data])

tokens = []
for i in range(len(folds)):
    tokenized_train = tokenizer(
    list(folds[i][0]['premise']),
    list(folds[i][0]['hypothesis']),
    return_tensors="pt",
    max_length=256,
    padding=True,
    truncation=True,
    add_special_tokens=True
    )

    tokenized_eval = tokenizer(
    list(folds[i][1]['premise']),
    list(folds[i][1]['hypothesis']),
    return_tensors="pt",
    max_length=256,
    padding=True,
    truncation=True,
    add_special_tokens=True
    )

    tokens.append((tokenized_train, tokenized_eval))

label_folds = []
for i in range(len(folds)):
    train_label = label_to_num(folds[i][0]['label'].values)
    eval_label = label_to_num(folds[i][1]['label'].values)

    label_folds.append([train_label, eval_label])

dataset_folds = []
for i in range(len(tokens)):
    train_dataset = BERTDataset(tokens[i][0], label_folds[i][0])
    eval_dataset = BERTDataset(tokens[i][1], label_folds[i][1])

    dataset_folds.append([train_dataset, eval_dataset])

In [None]:
print(dataset_folds[0][0].__len__())
print(dataset_folds[0][0].__getitem__(0))
print(tokenizer.decode(dataset_folds[0][0].__getitem__(0)['input_ids']))

### Training

In [None]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        with torch.cuda.amp.autocast(enabled=True):
            if not return_outputs:
                normal_emb = model.roberta.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])
                normal_encoder = model.roberta.encoder(normal_emb, encoder_attention_mask=inputs['attention_mask'])
                normal_sequence = normal_encoder[0]
                logits = model.classifier(normal_sequence)
                ce_loss = torch.nn.CrossEntropyLoss()
                loss = ce_loss(logits.view(-1, self.model.config.num_labels), labels.view(-1))

                if self.args.r3f_lambda != 0:
                    noise_sampler = torch.distributions.normal.Normal(loc=0.0, scale=1e-5)
                    noise = noise_sampler.sample(sample_shape=normal_emb.shape)
                    noise_emb = noise.to(normal_emb.get_device()) + normal_emb.detach().clone()
                    noise_encoder = model.roberta.encoder(noise_emb, encoder_attention_mask=inputs['attention_mask'])
                    noise_sequence = noise_encoder[0]
                    noise_logits = model.classifier(noise_sequence)
                    loss += self.args.r3f_lambda * self.symm_kl_loss(noise_logits.view(-1, self.model.config.num_labels), logits.view(-1, self.model.config.num_labels))

                return loss

            else:
                outputs = model(**inputs)
                logits = outputs.get("logits")
                ce_loss = torch.nn.CrossEntropyLoss()
                loss = ce_loss(logits.view(-1, self.model.config.num_labels), labels.view(-1))

                return (loss, outputs)
            
    def symm_kl_loss(self, noised_logits, input_logits):
      return (
          F.kl_div(
              F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
              F.softmax(input_logits, dim=-1, dtype=torch.float32),
              reduction="sum",
          )
          + F.kl_div(
              F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
              F.softmax(noised_logits, dim=-1, dtype=torch.float32),
              reduction="sum",
          )
      ) / noised_logits.size(0)

In [None]:
class CustomCallback(EarlyStoppingCallback):
    def check_metric_value(self, args, state, control, metric_value):
        operator = np.greater if args.greater_is_better else np.less
        if metric_value <= 0.8:
            control.should_training_stop = True

        elif state.best_metric is None or (
            operator(metric_value, state.best_metric)
            and abs(metric_value - state.best_metric) > self.early_stopping_threshold
        ):
            self.early_stopping_patience_counter = 0

        else:
            self.early_stopping_patience_counter += 1

In [None]:
class CustomTrainingArguments(TrainingArguments):
    def __init__(self, *args, **kwargs):
        if 'r3f_lambda' in kwargs:
            self.r3f_lambda = kwargs.pop('r3f_lambda', None)

        else:
            self.r3f_lambda = 0

        super(CustomTrainingArguments, self).__init__(*args, **kwargs)  

In [None]:
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  acc = accuracy_score(labels, preds) 

  return {
      'accuracy': acc,
  }

In [None]:
config = AutoConfig.from_pretrained(MODEL_NAME)
config.num_labels = 3
CHECK_PATH = os.path.join(PATH, 'checkpoints')

num_epochs = 5
r3f_lambda_search = [0, 0.5, 1, 2]
optimizer_lr = 1e-5
batch_size = 30 # [28, 30, 32]

for (train_dataset, eval_dataset), i in zip(dataset_folds, range(1, len(dataset_folds) + 1)):
    FOLD_CHECK_PATH = os.path.join(CHECK_PATH, 'fold_'+ str(i))
    print('Fold %d training started' %(i))

    for j in range(len(r3f_lambda_search)):
        PARAMETER_CHECK_PATH = os.path.join(FOLD_CHECK_PATH, 'lambda_' + str(r3f_lambda_search[j]))

        training_ars = CustomTrainingArguments(
            output_dir=PARAMETER_CHECK_PATH,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            save_total_limit=3,
            save_strategy='epoch',
            evaluation_strategy='epoch',
            load_best_model_at_end = True,
            half_precision_backend='amp',
            metric_for_best_model= 'accuracy',
            r3f_lambda = r3f_lambda_search[j]
        )

        model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, config=config)
        optimizer = AdamW(model.parameters(), lr=optimizer_lr)
        num_training_steps = num_epochs * len(train_dataset) / batch_size
        lr_scheduler = get_scheduler("polynomial", optimizer=optimizer, num_warmup_steps=(num_training_steps/batch_size) , num_training_steps=num_training_steps)

        trainer = CustomTrainer(
            model=model,
            args=training_ars,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            compute_metrics=compute_metrics,
            callbacks=[CustomCallback(1, 0)],
            optimizers=(optimizer, lr_scheduler)
        )
        print('lambda = %.4f training' %(r3f_lambda_search[j]))
        trainer.train()
        
        best_result = trainer.evaluate(eval_dataset)
        best_acc = 'best_model(%.4f)' % (best_result['eval_accuracy'])
        BEST_PATH = os.path.join(PARAMETER_CHECK_PATH, best_acc)
        model.save_pretrained(BEST_PATH)

## Submission

### Best Model Path

In [None]:
FOLD_1_BEST =
FOLD_2_BEST =
FOLD_3_BEST = 
FOLD_4_BEST = 
FOLD_5_BEST = 

BEST_MODELS = [FOLD_1_BEST, FOLD_2_BEST, FOLD_3_BEST, FOLD_4_BEST, FOLD_5_BEST]

In [None]:
def num_to_label(label):
    label_dict = {0: "entailment", 1: "contradiction", 2: "neutral"}
    str_label = []

    for i, v in enumerate(label):
        str_label.append([i,label_dict[v]])
    
    return str_label

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenized_test = tokenizer(
    list(test['premise']),
    list(test['hypothesis']),
    return_tensors="pt",
    max_length=256,
    padding=True,
    truncation=True,
    add_special_tokens=True
)
test_label = label_to_num(test['label'].values)
test_dataset = BERTDataset(tokenized_test, test_label)
dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

### Soft Voting

In [None]:
ensemble_logits = 0
for best_model in BEST_MODELS:
    model = AutoModelForSequenceClassification.from_pretrained(best_model)
    model.resize_token_embeddings(tokenizer.vocab_size)
    model.to(device)
    model.eval()
    logits = []
    for i, data in enumerate(tqdm(dataloader)):
        with torch.no_grad():
            outputs = model(
                input_ids=data['input_ids'].to(device),
                attention_mask=data['attention_mask'].to(device),
                token_type_ids=data['token_type_ids'].to(device)
            )
        logit = F.softmax(outputs[0], dim=-1)
        logits.extend(logit.tolist())
  
    ensemble_logits += np.array(logits)

ensemble_logits = ensemble_logits / len(dataset_folds) 
ensemble_pred = np.argmax(ensemble_logits, axis=-1)

In [None]:
answer = num_to_label(ensemble_pred)
print(answer)

df = pd.DataFrame(answer, columns=['index', 'label'])

df.to_csv('/content/MyDrive/MyDrive/Colab Notebooks/dacon/submission_ensemble_final.csv', index=False)

print(df)