In [None]:
import os, sys
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm

from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup

from sklearn.metrics import f1_score
from torch.utils.tensorboard import SummaryWriter
from utils import CheckpointManager, SummaryManager

# gpu로 cuda를 사용할 것이다. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 7개의 감정을 분류할 분류기를 사용
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 7,
                 dr_rate = None,
                 params = None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size, num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids=token_ids, token_type_ids=torch.zeros_like(segment_ids), attention_mask=attention_mask.float().to(token_ids.device))

        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler

        return self.classifier(out)

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len, pad, pair):

        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
        self.sentences = [transform([i[sent_idx]]) for i in dataset] # 감정 문장
        self.labels = [np.int32(i[label_idx]) for i in dataset] # 해당 문장이 어떤 감정인지 - label

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

def calc_accuracy(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

def calc_F1(X, Y):
    max_vals, max_indices = torch.max(X, 1)
    train_f1 = f1_score(Y.cpu().numpy(), max_indices.cpu().numpy(), average = 'micro')
    return train_f1

# 파라미터를 선언
max_len = 64
batch_size = 32
warmup_ratio = 0.1
num_epochs = 20
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

# kobert model을 받아온다.
bertmodel, vocab = get_pytorch_kobert_model()

#  gluonnlp에서 TSVDataset을 이용하여 dataset을 학습에 이용할 수 있게 바꿔준다.
# all_korea_final.txt 는 단발성 대화 , 연속적 대화 데이터를 합친 것을 뜻한다.
dataset = nlp.data.TSVDataset("all_korea_final.txt", field_indices=[1, 2], num_discard_samples = 1,encoding='cp949')

# train 과 test 를 분리
train_dataset = []
test_dataset = []

# 학습에 사용할 train data , test data 분리
# 8:2 비율로 train, test data를 나눈다.
for j, d in enumerate(dataset):
    i = np.random.randint(2)
    if i == 0:
        train_dataset.append(d)
    elif i == 1 and len(test_dataset) <= 18000:
        test_dataset.append(d)
    else:
        train_dataset.append(d)

# data의 각각의 레이블 개수를 세준다.
label_cnt = [int(s[1]) for s in train_dataset[:]]
label_counts = [label_cnt.count(i) for i in range(7)]

tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower = False)

# Bert 모델에 사용할 수 있게 위에서 정의한 BERTDataset에 넣는다.
data_train = BERTDataset(train_dataset, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(test_dataset, 0, 1, tok, max_len, True, False)

# dataloader를 파이토치를 사용한다.
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size = batch_size, num_workers = 0, shuffle = True)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size = batch_size, num_workers = 0, shuffle = True)

# 모델을 만든다.
model = BERTClassifier(bertmodel, dr_rate = 0.5).to(device)

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

# 옵티마이저로 Adam 을 사용한다.
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

# 손실함수로 crossentropyloss를 사용한다.
num_weights = [sum(label_counts) / n for n in label_counts]
loss_fn = nn.CrossEntropyLoss(weight = torch.FloatTensor(weights).cuda())

train_num = len(train_dataloader) * num_epochs
warmup_step = int(train_num * warmup_ratio)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=train_num)

# 모델을 저장할 위치
model_dir = "./saves"
tb_writer = SummaryWriter('{}/runs'.format(model_dir))
checkpoint_manager = CheckpointManager(model_dir)
summary_manager = SummaryManager(model_dir)

best_dev_f1 = -sys.maxsize 

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    train_f1 = 0.0
    test_f1 = 0.0
    _loss = 0.0

    model.train()

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        train_f1 += calc_F1(out, label)

        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {} train F1 {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1) , train_f1 / (batch_id+1)))

    print("epoch {} train acc {} train F1 {}".format(e+1, train_acc / (batch_id+1), train_f1 / (batch_id+1)))
    tr_summary = {'acc': train_acc / (batch_id + 1), 'f1': train_f1 / (batch_id+1)}

    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
        test_f1 += calc_F1(out, label)
    print("epoch {} test acc {} test_f1 {}".format(e+1, test_acc / (batch_id+1), test_f1 / (batch_id+1)))
    eval_summary = {'acc': test_acc / (batch_id+1), 'f1': test_f1 / (batch_id+1)}

    # save model
    output_dir = "./output/checkpoints/epoch-{}".format(e + 1)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    print("model checkpoint: ", output_dir)

    state = {'global_step': e + 1,
             'model_state_dict': model.state_dict(),
             'opt_state_dict': optimizer.state_dict()}

    summary = {'train': tr_summary, 'eval': eval_summary}
    summary_manager.update(summary)
    print("summary: ", summary)
    summary_manager.save('summary.json')

    #save
    is_best = eval_summary['f1'] >= best_dev_f1

    if is_best:
        best_dev_f1 = eval_summary['f1']
        checkpoint_manager.save_checkpoint(state, 'best-epoch-{}-f1-{:.3f}.bin'.format(e + 1, best_dev_f1))
        print("model checkpoint has been saved: best-epoch-{}-f1-{:.3f}.bin".format(e + 1, best_dev_f1))

    else:
        torch.save(state, os.path.join(output_dir, 'model-epoch-{}-f1-{:.3f}.bin'.format(e + 1, eval_summary["f1"])))
        print("model checkpoint has been saved: best-epoch-{}-f1-{:.3f}.bin".format(e + 1, eval_summary['f1']))

tb_writer.close()