In [1]:
import json, re
import pandas as pd
import torch
from tqdm import tqdm
import datasets
import logging
from kobart_transformers import get_kobart_tokenizer
from kobart_transformers import get_kobart_for_conditional_generation
import os
from rouge_score import rouge_scorer
from transformers import AutoTokenizer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
MODEL_DIR = '/USER/Kaggle/dacon/final_model'
ROOT_DIR = '/USER/Kaggle/dacon/final_model'

In [3]:
train_path = './data/train/train_original.json' # 법률
train_path2 = './data/train/train_original2.json'
train_path3 = './data/train/train_original3.json'
valid_path = './data/valid/valid_original.json' 
valid_path2 = './data/valid/valid_original2.json'

In [4]:
train_data = json.load(open(train_path, 'r')) # 법률
# train_data2 = json.load(open(train_path2, 'r'))
train_data3 = json.load(open(train_path3, 'r'))
valid_data = json.load(open(valid_path2,'r')) # 법률
valid_data3 = json.load(open(valid_path,'r'))

In [5]:
def load_data(datas):
    id, original, ext, abs =[], [], [], []
    e_pattern = '([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
    pattern = '기자'
    for data in tqdm(datas):
        id.append(data['id'])
        abs.append(data['abstractive'][0])
        ext_tmp = ''
        for idx in data['extractive']:
            for articles in data['text']:
                for article in articles:
                    if idx == article['index']:
                        ext_tmp+= article['sentence'] + ' '
        ext.append(ext_tmp[:-1])
        text = ''
        for articles in data['text']:
            for article in articles:
                if re.search(pattern=e_pattern, string= article['sentence']) != None: continue
                if re.search(pattern=pattern, string= article['sentence']) != None: continue
                text += article['sentence'] + ' '
        
        original.append(text[:-1]) 
    sum_data = { 'id' : id, 'abs' : abs, 'ext' : ext , 'original' : original}
    df = pd.DataFrame(sum_data)
    return df

In [6]:
def data_split(data, s_type):
    df = data.loc[:,['id','original', s_type]]
    df.rename(columns = {s_type : 'summary'}, inplace = True)
    return df

In [7]:
df_t_1 = load_data(train_data3['documents'][50000:70000]) # 문서
df_t_2 = load_data(train_data['documents'][1000:8000]) # 법률
df_v_1 = load_data(valid_data3['documents'][500:1000]) # 문서
df_v_2 = load_data(valid_data['documents'][300:600]) # 법률

  0%|          | 0/20000 [00:00<?, ?it/s]

  0%|          | 0/7000 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

In [8]:
df_t = pd.concat([df_t_1,df_t_2], ignore_index=True)
df_v = pd.concat([df_v_1,df_v_2], ignore_index=True)

In [9]:
train_abs = data_split(df_t,'abs')
valid_abs = data_split(df_v,'abs')
trian_ext = data_split(df_t,'ext')
valid_ext = data_split(df_v,'ext')

In [10]:
import numpy as np
from torch.utils.data import Dataset, DataLoader, IterableDataset

class KoBARTSummaryDataset(Dataset):
    def __init__(self, df,  pad_index = None, ignore_index=-100):
        super().__init__()
        self.tok = AutoTokenizer.from_pretrained('hyunwoongko/kobart', use_fast=True)
        self.max_len = 1024
        self.df = df
        self.len = len(self.df)
        if pad_index is None:
            self.pad_index = self.tok.pad_token_id
        else:
            self.pad_index = pad_index
        self.ignore_index = ignore_index

    def add_padding_data(self, inputs):
        if len(inputs) < self.max_len:
            pad = np.array([self.pad_index] *(self.max_len - len(inputs)))
            inputs = np.concatenate([inputs, pad])
        else:
            inputs = inputs[:self.max_len]

        return inputs

    def add_ignored_data(self, inputs):
        if len(inputs) < self.max_len:
            pad = np.array([self.ignore_index] *(self.max_len - len(inputs)))
            inputs = np.concatenate([inputs, pad])
        else:
            inputs = inputs[:self.max_len]

        return inputs
    
    def __getitem__(self, idx):
        instance = self.df.iloc[idx]
        input_ids = self.tok.encode('<s> '+instance['original'])
        input_ids = self.add_padding_data(input_ids)

        label_ids = self.tok.encode(instance['summary'])
        label_ids.append(self.tok.eos_token_id)
        dec_input_ids = [self.tok.bos_token_id]
        dec_input_ids += label_ids[:-1]
        dec_input_ids = self.add_padding_data(dec_input_ids)
        label_ids = self.add_ignored_data(label_ids)

        return {'input_ids': np.array(input_ids, dtype=np.int_),
                'decoder_input_ids': np.array(dec_input_ids, dtype=np.int_),
                'labels': np.array(label_ids, dtype=np.int_)}
    
    def __len__(self):
        return self.len

In [11]:
# hyper-parameters
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 4

In [12]:
ext_train_dataset = KoBARTSummaryDataset(trian_ext)
ext_valid_dataset = KoBARTSummaryDataset(valid_ext)
abs_train_dataset = KoBARTSummaryDataset(train_abs)
abs_valid_dataset = KoBARTSummaryDataset(valid_abs)

ext_train_loader = DataLoader(ext_train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True,num_workers=0)
ext_valid_loader = DataLoader(ext_valid_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=True,num_workers=0)
abs_train_loader = DataLoader(abs_train_dataset, batch_size=TRAIN_BATCH_SIZE ,shuffle=True,num_workers=0)
abs_valid_loader = DataLoader(abs_valid_dataset, batch_size=EVAL_BATCH_SIZE, shuffle=True,num_workers=0)

In [13]:
def Hitrate(y_true, y_pred):
    m = datasets.load_metric('rouge')
    rouge = m.compute(predictions= y_pred,references=y_true)
    score = (rouge['rouge1'].mid.fmeasure+rouge['rouge2'].mid.fmeasure + rouge['rougeL'].mid.fmeasure) /3 
    return score

In [14]:
class Trainer():
    """ Trainer
        epoch에 대한 학습 및 검증 절차 정의
    
    Attributes:
        model (`model`)
        device (str)
        loss_fn (Callable)
        metric_fn (Callable)
        optimizer (`optimizer`)
        scheduler (`scheduler`)
    """

    def __init__(self, model,device,metric_fn, optimizer=None, scheduler=None, logger=None):
        """ 초기화
        """
        self.model = model
        self.device = device
        # self.loss_fn = loss_fn
        self.tok =  AutoTokenizer.from_pretrained('hyunwoongko/kobart', use_fast=True)
        self.metric_fn = metric_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.logger = logger

    def train_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 학습 절차

        Args:
            dataloader (`dataloader`)
            epoch_index (int)
        """
        self.model.train()
        self.train_total_loss = 0
        total_loss = 0
        # pred_lst = []
        # target_lst = []
        for batch_index, data in enumerate(tqdm(dataloader)):
            attention_mask = data['input_ids'].ne(self.tok.pad_token_id).float().to(device)
            decoder_attention_mask = data['decoder_input_ids'].ne(self.tok.pad_token_id).float().to(device)
            outputs = self.model(input_ids=data['input_ids'].to(device),
                                  attention_mask=attention_mask,
                                  decoder_input_ids=data['decoder_input_ids'].to(device),
                                  decoder_attention_mask=decoder_attention_mask,
                                  labels=data['labels'].to(device), return_dict=True)
            self.optimizer.zero_grad()
            self.train_total_loss += outputs.loss
            total_loss += outputs.loss
            outputs.loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            if (batch_index+1) % 200 == 0:
                print('[Epoch {}/{}] Iteration {} -> Train Loss: {:.4f}'.format
                      (epoch_index , epoch_index, (batch_index+1) , total_loss / 200))

                total_loss = 0
                
        self.train_mean_loss = self.train_total_loss / len(dataloader)
    
        msg = f'Epoch {epoch_index}, Train, loss: {self.train_mean_loss}'
        print(msg)

    def validate_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 검증 절차

        Args:
            dataloader (`dataloader`)
            epoch_index (int)
        """
        self.model.eval()
        self.val_score_all = 0
        pred_lst = []
        target_lst = []

        with torch.no_grad():
            for batch_index, data in enumerate(tqdm(dataloader)):
                attention_mask = data['input_ids'].ne(self.tok.pad_token_id).float().to(device)
                decoder_attention_mask = data['decoder_input_ids'].ne(self.tok.pad_token_id).float().to(device)
                outputs = self.model.generate(input_ids=data['input_ids'].to(device),
                                  attention_mask = attention_mask,
                                     num_beams=5,
                                     no_repeat_ngram_size=4,
                                     decoder_start_token_id=self.tok.bos_token_id,
                                     temperature=1.0, top_k=0, top_p=0.92,
                                     length_penalty=1.0, min_length=1,
                                     max_length=100,
                                     early_stopping=False,
                                     num_return_sequences=1,
                                     do_sample= True).to(device)
                # outputs = self.model(input_ids=data['input_ids'].to(device),
                #                   attention_mask=attention_mask,
                #                   decoder_input_ids=data['decoder_input_ids'].to(device),
                #                   decoder_attention_mask=decoder_attention_mask,
                #                   labels=data['labels'].to(device), return_dict=True)
                ref = self.tok.batch_decode(
                        data['decoder_input_ids'],
                        skip_special_tokens=True
                )
                pred = self.tok.batch_decode(
                        outputs,
                        skip_special_tokens=True
                )
                
                # loss = self.loss_fn(sent_score, target)
                # self.val_total_loss += outputs.loss
            # self.val_mean_loss = self.val_total_loss / len(dataloader)
                self.val_score = self.metric_fn(y_true=ref, y_pred=pred)
                self.val_score_all += self.val_score
            self.val_mean_score = self.val_score_all/len(dataloader)
            msg = f'Epoch {epoch_index}, Validation, score : {self.val_mean_score}'
            print(msg)

In [15]:
# parameters
ext_epochs = 1
abs_epochs = 7
LEARNING_RATE = 3e-5
WEIGHT_DECAY = 0.00001
NUM_WORKERS = 1
EARLY_STOPPING_PATIENCE = 20

In [16]:
# model = get_kobart_for_conditional_generation().to(device)
MODEL_DIR = os.path.join(ROOT_DIR, 'final_abs.pt')
model = get_kobart_for_conditional_generation().to(device)
model.load_state_dict(torch.load(MODEL_DIR)['model_state_dict'])
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, div_factor=1e3, max_lr=3e-5, epochs=abs_epochs, steps_per_epoch=len(abs_train_loader))
metric_fn = Hitrate

In [17]:
trainer = Trainer(model ,device, metric_fn,optimizer,scheduler)

In [18]:
# TRAIN
import time

start = time.time()
criterion = 0

for epoch_index in tqdm(range(abs_epochs)):
    
    trainer.train_epoch(abs_train_loader, epoch_index=epoch_index)
    trainer.validate_epoch(abs_valid_loader, epoch_index=epoch_index)
   
    # early_stopping check
#     early_stopper.check_early_stopping(loss=trainer.val_mean_score)

#     if early_stopper.stop:
#         print('Early stopped')
#         break

    if trainer.val_mean_score > criterion:
        criterion = trainer.val_mean_score     
        torch.save({
            'epoch': epoch_index,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': trainer.val_mean_score,
            }, os.path.join(ROOT_DIR, 'final_abs.pt'))
        print('final_abs.pt saved ', epoch_index, trainer.val_mean_score)
        
        
print("train finished, best.pt saved.")

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/13500 [00:00<?, ?it/s]

KeyboardInterrupt: 