In [None]:
import numpy as np
import pandas as pd

import random
import pickle

import os
import glob

import time
import datetime
import json

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup

import gluonnlp as nlp
from gluonnlp.data import SentencepieceTokenizer
from mxnet.gluon.data import SimpleDataset

from kobert.pytorch_kobert import get_pytorch_kobert_model
from kobert.utils import get_tokenizer

from tqdm import tqdm_notebook


from kobert_classifier import BERTDataset, BERTDataset_Ops
from kobert_classifier import BERTClassifier, BERTClassifier_HLHL #, BERTClassifier_Softmax, BERTClassifier_HL

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", 50)

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### RANDOM SEED 고정
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [None]:
root_dir = '..'

# input 데이터 폴더
input_dir = os.path.join(root_dir, 'data')
# output 폴더
output_dir = os.path.join(root_dir, 'data', 'output')


In [None]:
bertmodel, vocab = get_pytorch_kobert_model()
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

### 데이터 읽기

In [None]:
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 20
learning_rate = 5e-5

In [None]:
data_train = BERTDataset(train_dataset, 0, 1, 2, tok, max_len, True, False)
data_test = BERTDataset(train_dataset, 0, 1, 2, tok, max_len, True, False)

In [None]:
train_sampler = RandomSampler(data_train)
train_dataloader = DataLoader(data_train, sampler=train_sampler, batch_size=batch_size)

test_sampler = RandomSampler(data_test)
test_dataloader = DataLoader(data_test, sampler=test_sampler, batch_size=batch_size)

### 모델

In [None]:
# model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
# model = BERTClassifier_Softmax(bertmodel, dr_rate=0.5).to(device)
model = BERTClassifier_HL(bertmodel, dr_rate=0.5).to(device)
# model = BERTClassifier_HLHL(bertmodel, dr_rate=0.5).to(device)
# model = BERTClassifier_HLHL_Softmax(bertmodel, dr_rate=0.5).to(device)

In [None]:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_params = [
    {'params' : [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay' : 0.01},
    {'params' : [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
     'weight_decay' : 0.0},
]

In [None]:
optimizer = AdamW(optimizer_params, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
t_total = len(train_dataloader) * num_epochs
warmup_step = init(t_total * warmup_ratio)

In [None]:
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = warmup_step,
                                           num_training_steps = t_total)

In [None]:
def calc_accuracy(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_Acc

### 모델 학습

In [None]:
for e in range(num_epochs):
    train_acc = 0.0
    
    for batch_id, (token_ids, valid_length, segment_ids, label, idx) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length = valid_length
        label = label.long().to(device)
        
        out = model(token_ids, valid_length, segment_ids)
        
        loss = loss_fn(out, label)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()
        scheduler.step()
        
        train_acc += calc_accuracy(out, label)
        
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc/(batch_id +1)))