## Relation Extraction Model

In [154]:
### torch == 1.9.1
### transformers == 4.10.3
### pytorch_lightning == 1.2.8
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import random
import easydict
import re
from tqdm import tqdm
from collections import Counter

from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertModel, AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import classification_report
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

In [14]:
args = easydict.EasyDict({
    'seed': 42
})

In [339]:
args.batch_size = 12
args.hidden_size = 768
args.n_class = 97
args.num_workers = 4
args.epochs = 5
args.train = True
args.bert_model = 'snunlp/KR-Medium'
args.max_token_len = 512
args.train_data = '../data/toy_data_split/train.csv' 
args.val_data = '../data/toy_data_split/val.csv'
args.test_data = '../data/toy_data_split/test.csv'
args.relation_list = '../data/relation/relation_list.txt'
args.save_dir = '../ckpt/'
args.log_file = '../log/toy_result.txt'
args.mode = "ALLCC"

In [16]:
args

{'seed': 42,
 'batch_size': 12,
 'n_class': 97,
 'num_workers': 4,
 'epochs': 5,
 'train': True,
 'bert_model': 'snunlp/KR-Medium',
 'max_token_len': 512,
 'train_data': '../data/toy_data_split/train.csv',
 'val_data': '../data/toy_data_split/val.csv',
 'test_data': '../data/toy_data_split/test.csv',
 'save_dir': '../ckpt/toy.pt',
 'log_file': '../log/toy_result.txt'}

In [17]:
train_df = pd.read_csv(args.train_data)

In [19]:
train_df.head()

Unnamed: 0,sentence,subj_name,subj_start_pos,subj_end_pos,subj_type,obj_name,obj_start_pos,obj_end_pos,obj_type,relation_id,relation_label,relation_type,relation_description,label_onehot
0,모토로라 레이저 M는 모토로라 모빌리티에서 제조/판매하는 안드로이드 스마트폰이다.,모토로라 레이저 M,0,10,ARTIFACT,모토로라 모빌리티,12,21,ORGANIZATION,P176,제조사,WikibaseItem,이 물건을 제조/제작한 주요 회사,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,"웨인 페더먼은 미국의 배우, 텔레비전 배우, 각본가, 성우, 영화배우, 영화 프로듀...",웨인 페더먼,0,6,PERSON,배우,12,14,OCCUPATION,P106,직업,WikibaseItem,항목 주제인 인물의 직업,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"공부하는 곳인 명륜당이 앞에, 사당인 대성전이 뒤에 있는 전학후묘의 형태로 향교의 ...",뒤,26,27,TERM,앞,13,14,TERM,P461,반대 개념,WikibaseItem,이 항목과 반대 관계의 항목,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."
3,"《2000 투데이》의 방송을 위해 영국의 BBC는 당시 총 5,000여명의 인력이 ...",텔레비전 센터,79,86,ARTIFACT,BBC,23,26,ORGANIZATION,P127,소유자,WikibaseItem,항목 주제를 소유하고 있는 사람/단체,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,캐피틀 레코드가 미국에 비틀즈 음반을 판매하는 것을 거부하고 있던 것이 거슬렸던 브...,미국,9,11,COUNTRY,미국,67,69,COUNTRY,P17,다음 나라의 것임,WikibaseItem,항목 주제는 다음 나라(국가)의 것을 다루고 있음,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


## 각 문장을 BERT를 통과할 수 있는 형태로 변형

In [99]:
def entity_markers_added(sent: str, subj_range: list, obj_range: list) -> str:
    """ 문장과 관계를 구하고자 하는 두 개체의 인덱스 범위가 주어졌을 때 entity marker token을 추가하여 반환하는 함수.
    
    Example:
        sent = '모토로라 레이저 M는 모토로라 모빌리티에서 제조/판매하는 안드로이드 스마트폰이다.'
        subj_range = [0, 10]   # sent[subj_range[0]: subj_range[1]] => '모토로라 레이저 M'
        obj_range = [12, 21]   # sent[obj_range[0]: obj_range[1]] => '모토로라 모빌리티'
        
    Return:
        '[E1] 모토로라 레이저 M [/E1] 는  [E2] 모토로라 모빌리티 [/E2] 에서 제조/판매하는 안드로이드 스마트폰이다.'
    """
    result_sent = ''
    
    for i, char in enumerate(sent):
        if i == subj_range[0]:
            result_sent += ' [E1] '
        elif i == subj_range[1]:
            result_sent += ' [/E1] '
        if i == obj_range[0]:
            result_sent += ' [E2] '
        elif i == obj_range[1]:
            result_sent += ' [/E2] '
        result_sent += sent[i]
    if subj_range[1] == len(sent):
        result_sent += ' [/E1]'
    elif obj_range[1] == len(sent):
        result_sent += ' [/E2]'
    
    return result_sent.strip()
    

In [135]:
tokenizer = BertTokenizer.from_pretrained(args.bert_model)

In [136]:
special_tokens_dict = {'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
num_added_toks
# model.resize_token_embeddings(len(tokenizer))

4

In [137]:
tokenizer

PreTrainedTokenizer(name_or_path='snunlp/KR-Medium', vocab_size=20000, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']})

In [138]:
len(tokenizer)

20004

In [111]:
text_info = tokenizer.encode_plus(
            converted_sent,
            add_special_tokens=True,
            max_length = args.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )

In [112]:
input_ids = text_info['input_ids'].flatten()
input_ids

tensor([    2, 20000,  2871,  5457,  5016,  5095, 14144,  5316,    49, 20001,
         2367, 20002,  2871,  5457,  5016,  5095,  2871,  5821, 17420, 20003,
         9235, 11532,    19,  9779,  8453,  3516, 13733, 12026, 12823,  8459,
           18,     3,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [113]:
mask = text_info['attention_mask'].flatten()
mask

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [118]:
tokenizer.decode(tokenizer(converted_sent)['input_ids'])

'[CLS] [E1] 모토로라 레이저 M [/E1] 는 [E2] 모토로라 모빌리티 [/E2] 에서 제조 / 판매하는 안드로이드 스마트폰이다. [SEP]'

## Dataset 만들기

In [149]:
class KREDataset(Dataset):
    """ Dataloader for Korean Relation Extraction Dataset.
    """
    def __init__(self, data: pd.DataFrame, args):
        super().__init__()
        
        self.args = args
        self.data = data
        
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        # entity markers tokens
        special_tokens_dict = {'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']}
        num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)   # num_added_toks: 4
        # model.resize_token_embeddings(len(tokenizer))
        
        self.max_token_len = args.max_token_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx: int):
        data_row = self.data.iloc[idx]
        
        # input 문장
        sentence = data_row.sentence
        
        # subj range, obj range
        subj_range = [data_row['subj_start_pos'], data_row['subj_end_pos']]
        obj_range = [data_row['obj_start_pos'], data_row['obj_end_pos']]
        
        # input 문장 변형 - entity markers 추가: entity_markers_added 함수 이용
        converted_sent = entity_markers_added(sentence, subj_range, obj_range)
        
        labels = torch.FloatTensor(eval(data_row.label_onehot))
        
        encoding = self.tokenizer.encode_plus(
            converted_sent,
            add_special_tokens=True,
            max_length = self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        input_ids = encoding['input_ids'].flatten()
        mask = encoding['attention_mask'].flatten()
        
        return dict(sentence=converted_sent,
                   input_ids=input_ids,
                   attention_mask=mask,
                   labels=labels)

In [150]:
kredataset = KREDataset(train_df, args)

In [152]:
kredataset.tokenizer

PreTrainedTokenizer(name_or_path='snunlp/KR-Medium', vocab_size=20000, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]', 'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']})

## Model

In [325]:
class KREModel(pl.LightningModule):
    """ Model for Multi-label classification for Korean Relation Extraction Dataset.
    """
    def __init__(self, args, n_training_steps=None, n_warmup_steps=None):
        super().__init__()
        
        self.args = args
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps
        
        self.bert = BertModel.from_pretrained(args.bert_model, return_dict=True)
        
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        # entity markers tokens
        special_tokens_dict = {'additional_special_tokens': ['[E1]', '[/E1]', '[E2]', '[/E2]']}
        num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)   # num_added_toks: 4
        
        self.bert.resize_token_embeddings(len(self.tokenizer))
        
        if self.args.mode == "ALLCC":
            self.scale = 4
        elif self.args.mode == "ENTMARK":
            self.scale = 2
            
        self.classifier = nn.Linear(self.bert.config.hidden_size * self.scale, args.n_class)
        
        self.criterion = nn.BCELoss()
        
    def forward(self, input_ids, attention_mask, labels=None):
        batch_size = input_ids.size()[0]
        
        bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = bert_outputs.last_hidden_state
        
        # 모든 entity marker의 hidden states를 concat
        if self.args.mode == "ALLCC":
            h_start_pos_tensor = (input_ids == 20000).nonzero()
            h_end_pos_tensor = (input_ids == 20001).nonzero()
            t_start_pos_tensor = (input_ids == 20002).nonzero()
            t_end_pos_tensor = (input_ids == 20003).nonzero()
            
            h_start_list = h_start_pos_tensor.tolist()
            h_end_list = h_end_pos_tensor.tolist()
            t_start_list = t_start_pos_tensor.tolist()
            t_end_list = t_end_pos_tensor.tolist()
            
            special_token_idx = []
            
            # special_token_idx example: [[1, 9, 11, 19], [3, 5, 8, 12], ..]
            for h_start, h_end, t_start, t_end in zip(h_start_list, h_end_list, t_start_list, t_end_list):
                special_token_idx.append([h_start[1], h_end[1], t_start[1], t_end[1]])
            
            # concat_state shape: [batch size, hidden size * 4]
            for i, idx_list in enumerate(special_token_idx):
                if i == 0:
                    concat_state = last_hidden_state[i, idx_list].flatten().unsqueeze(0)
                else:
                    concat_state = torch.cat([concat_state, last_hidden_state[i, idx_list].flatten().unsqueeze(0)], dim=0)
            
        elif self.args.mode == "ENTMARK":
            h_start_pos_tensor = (input_ids == 20000).nonzero()
#             h_end_pos_tensor = (input_ids == 20001).nonzero()
            t_start_pos_tensor = (input_ids == 20002).nonzero()
#             t_end_pos_tensor = (input_ids == 20003).nonzero()
            
            h_start_list = h_start_pos_tensor.tolist()
#             h_end_list = h_end_pos_tensor.tolist()
            t_start_list = t_start_pos_tensor.tolist()
#             t_end_list = t_end_pos_tensor.tolist()
            
            special_token_idx = []
        
            # special_token_idx example: [[1, 11], [3, 8], ..]
            for h_start, t_start in zip(h_start_list, t_start_list):
                special_token_idx.append([h_start[1], t_start[1]])
            
            # concat_state shape: [batch size, hidden size * 2]
            for i, idx_list in enumerate(special_token_idx):
                if i == 0:
                    concat_state = last_hidden_state[i, idx_list].flatten().unsqueeze(0)
                else:
                    concat_state = torch.cat([concat_state, last_hidden_state[i, idx_list].flatten().unsqueeze(0)], dim=0)
        
        output = self.classifier(concat_state)
        output = torch.sigmoid(output)
        
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output
    
    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs, "labels": labels}
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def training_epoch_end(self, outputs):
        labels = []
        predictions = []
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)
        
        for i in range(self.args.n_class):
            class_roc_auc = auroc(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{str(i)}_roc_auc/Train", class_roc_auc, self.current_epoch)
            
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=2e-5)
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=self.n_warmup_steps,
                                                    num_training_steps=self.n_training_steps)
     
        return dict(optimizer=optimizer, lr_scheduler=dict(scheduler=scheduler, interval='step'))

NameError: name 'pl' is not defined

## Train Module

In [330]:
class KREDataModule(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        
        self.args = args
        self.batch_size = args.batch_size
        self.max_token_len = args.max_token_len
        
        self.train_df = pd.read_csv(args.train_data)
        self.val_df = pd.read_csv(args.val_data)
        self.test_df = pd.read_csv(args.test_data)
        
    def setup(self, stage=None):
        self.train_dataset = KREDataset(self.train_df, args)
        self.val_dataset = KREDataset(self.val_df, args)
        self.test_dataset = KREDataset(self.test_df, args)     
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.args.num_workers)
                         
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.args.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.args.num_workers)

NameError: name 'pl' is not defined

In [332]:
with open(args.relation_list, 'r') as f:
    relation_list = [x.strip() for x in f.readlines()]

In [None]:
relation_list[:5]

In [334]:
def idx2class(idx_list):
    label_out = []
    
    for idx in idx_list:
        label = relation_list[idx]
        label_out.append(label)
    return label_out if label_out else np.nan

In [336]:
idx2class([0,1,4])

['P17', 'P131', 'P47']

In [None]:
def evaluate(trained_model, test_dataset, log_filepath):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trained_model = trained_model.to(device)
    
    predictions = []
    labels = []
    sentence_list = []
    label_list = []
    
    for i, item in enumerate(tqdm(test_dataset)):
        _, prediction = trained_model(
            item["input_ids"].unsqueeze(dim=0).to(device),
            item["attention_mask"].unsqueeze(dim=0).to(device)
        )
        
        predictions.append(prediction.flatten())
        labels.append(item["labels"].int())
        sentence_list.append(test_dataset[i]['sentence'])
        label_list.append(idx2class(np.where(test_dataset[i]['labels'] == 1)[0]))
    
    predictions = torch.stack(predictions).detach().cpu()
    labels = torch.stack(labels).detach().cpu()
    
    # [0.1, 0.2, .., 0.9]
    threshold_list = np.arange(0, 1, 0.1)[1:]
    
    print("********** Accuracy per threshold **********")
    acc_per_threshold = dict()
    for i in threshold_list:
        acc = accuracy(predictions, labels, threshold=i)
        acc_per_threshold[i] = acc
        print(f"Threshold: {i:.1f}, Accuracy: {acc:.6f}")
   
    max_acc_threshold = sorted(acc_per_threshold.items(), key=lambda x: x[1], reverse=True)[0][0]
    print(f"********** Max Threshold: {max_acc_threshold} **********")
    
    print("********** Classification Report **********")
    y_pred = predictions.numpy()
    y_true = labels.numpy()
    upper, lower = 1, 0
    y_pred = np.where(y_pred > max_acc_threshold, upper, lower)
    
    cls_report = classification_report(
        y_true,
        y_pred,
        target_names=relation_list,
        zero_division=0,
        output_dict=True
    )
    print(classification_report(
        y_true,
        y_pred,
        target_names=relation_list,
        zero_division=0
    ))
    
    # accuracy calculation
    correct = 0
    for i in range(len(predictions)):
        for j in range(len(predictions[i])):
            if predictions[i][j] < max_acc_threshold:
                predictions[i][j] = 0
            else:
                predictions[i][j] = 1
        if predictions[i].tolist() == labels[i].tolist():
            correct += 1

    acc = correct / len(predictions)
    print("***********************************************************************")
    print(f"단순 정확도(Accuracy): {acc*100:.4f}")
    print(f"조항별 정확도를 고려한 전체 정확도(Weighted Avg의 F1 Accuracy): {cls_report['weighted avg']['f1-score']*100:.4f}")
    print("***********************************************************************")
    
    ## 결과 csv 파일 저장
    result_df = pd.DataFrame(columns=['sentence', 'relation_id', 'predicted_relation_id'])
    
    result_df['sentence'] = sentence_list
    result_df['relation_id'] = label_list
    
    preds_list = []
    for i in range(len(y_pred)):
        class_pred = idx2class(np.where(y_pred[i] == 1)[0])
        preds_list.append(class_pred)
        
    result_df['predicted_relation_id'] = preds_list
    
    result_df.to_csv('../log/results.csv', index=False)
    
    with open(log_filepath, 'w') as f:
        # write args
        f.write('********** Args **********\n')
        for k in list(vars(args).keys()):
            f.write(f"{k}: {vars(args)[k]}\n")

        f.write('\n********** Accuracy per threshold **********\n')
        for k, v in acc_per_threshold.items():
            f.write(str(k)[:3] + '\t' + str(v)[7:-1] + '\n')
        f.write('\n')
        f.write('********** Max Threshold: ' + str(max_acc_threshold) + ' **********\n')
       
        f.write('\n********** Classification Report **********\n')
        f.write('라벨이름\tprecision\trecall  \tf1-score\tsupport\n')
        for k, v in cls_report.items():
            f.write(str(k) + '\t')
            f.write(f"{v['precision']:.6f}\t{v['recall']:.6f}\t{v['f1-score']:.6f}\t{v['support']}" + '\n')
        f.write('\n')
        f.write("***********************************************************************\n")
        f.write(f"단순 정확도(Accuracy): {acc*100:.4f}\n")
        f.write(f"조항별 정확도를 고려한 전체 정확도(Weighted Avg의 F1 Accuracy): {cls_report['weighted avg']['f1-score']*100:.4f}\n")
        f.write("***********************************************************************\n")


In [337]:
if not os.path.exists('../log'):
    os.mkdir('../log')
    
if not os.path.exists('../ckpt'):
    os.mkdir('../ckpt')
    
pl.seed_everything(args.seed)

data_module = KREDataModule(args)

NameError: name 'pl' is not defined

In [338]:
# train
if args.train:
    steps_per_epoch = len(data_module.train_df) // args.batch_size
    total_training_steps = steps_per_epoch * args.epochs
    warmup_steps = total_training_steps // 5
    
    model = KREModel(args, n_training_steps=total_training_steps, n_warmup_steps=warmup_steps)
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.save_dir,
        filename="best-checkpoint",
        save_top_k=1,
        verbose=True,
        monitor="val_loss",
        mode="min"
    )
    
    logger = TensorBoardLogger("lightning_logs", name="KoreanRE")
    
    early_stopping_callback = EarlyStopping(monitor="val_loss", patience=2)
    
    trainer = pl.Trainer(
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        callbacks=[early_stopping_callback],
        max_epochs=args.epochs,
        gpus=1,
        progress_bar_refresh_rate=30
    )
    
    trainer.fit(model, data_module)
    
    trainer.test()
    
    model.bert.save_pretrained(args.save_dir + '_bert')
    
# evaluate
else:
    test_dataset = KREDataset(data_module.test_df, args)
    
    trained_model = KREModel.load_from_checkpoint(os.path.join(args.save_dir, "best-checkpoint.ckpt"), args=args)
    
    trained_model.eval()
    trained_model.freeze()
    
    evaluate(trained_model, test_dataset, args.log_file)

NameError: name 'data_module' is not defined

In [None]:
args.train = False

# train
if args.train:
    steps_per_epoch = len(data_module.train_df) // args.batch_size
    total_training_steps = steps_per_epoch * args.epochs
    warmup_steps = total_training_steps // 5
    
    model = KREModel(args, n_training_steps=total_training_steps, n_warmup_steps=warmup_steps)
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.save_dir,
        filename="best-checkpoint",
        save_top_k=1,
        verbose=True,
        monitor="val_loss",
        mode="min"
    )
    
    logger = TensorBoardLogger("lightning_logs", name="KoreanRE")
    
    early_stopping_callback = EarlyStopping(monitor="val_loss", patience=2)
    
    trainer = pl.Trainer(
        logger=logger,
        checkpoint_callback=checkpoint_callback,
        callbacks=[early_stopping_callback],
        max_epochs=args.epochs,
        gpus=1,
        progress_bar_refresh_rate=30
    )
    
    trainer.fit(model, data_module)
    
    trainer.test()
    
    model.bert.save_pretrained(args.save_dir + '_bert')
    
# evaluate
else:
    test_dataset = KREDataset(data_module.test_df, args)
    
    trained_model = KREModel.load_from_checkpoint(os.path.join(args.save_dir, "best-checkpoint.ckpt"), args=args)
    
    trained_model.eval()
    trained_model.freeze()
    
    evaluate(trained_model, test_dataset, args.log_file)