In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd
from sklearn.model_selection import train_test_split
from kobert_tokenizer import KoBERTTokenizer
from transformers import BertModel
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup


# GPU 가 있는 경우 지정
device = torch.device("cuda:1")

# GPU 가 없는 경우 지정
# device = torch.device("cpu")

In [2]:
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
bertmodel = BertModel.from_pretrained('skt/kobert-base-v1', return_dict=False)
vocab = nlp.vocab.BERTVocab.from_sentencepiece(tokenizer.vocab_file, padding_token='[PAD]')
tok = tokenizer.tokenize

# Setting parameters
max_len = 64
batch_size = 256
warmup_ratio = 0.1
num_epochs = 30
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'XLNetTokenizer'. 
The class this function is called from is 'KoBERTTokenizer'.


In [3]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len,
                 pad, pair):
   
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len,vocab=vocab, pad=pad, pair=pair)
        
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
         
    def __len__(self):
        return (len(self.labels))


In [4]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=58,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [5]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [6]:
def predict(sentence):
    dataset = [[sentence, '0']]
    test = BERTDataset(dataset, 0, 1, tok, vocab, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size=batch_size, num_workers=2)
    model.eval()
    answer = 0
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        for logits in out:
            logits = logits.detach().cpu().numpy()
            answer = np.argmax(logits)
    return answer

In [7]:
train_set = pd.read_csv('./custom_chatbotdataset(Training).csv')
validation_set = pd.read_csv('./custom_chatbotdataset(Validation).csv')

In [8]:
train_set.head()

Unnamed: 0,label,Q,A
0,9,일은 왜 해도 해도 끝이 없을까? 화가 난다.,많이 힘드시겠어요. 주위에 의논할 상대가 있나요?
1,9,이번 달에 또 급여가 깎였어! 물가는 오르는데 월급만 자꾸 깎이니까 너무 화가 나.,급여가 줄어 속상하시겠어요. 월급이 줄어든 것을 어떻게 보완하실 건가요?
2,9,회사에 신입이 들어왔는데 말투가 거슬려. 그런 애를 매일 봐야 한다고 생각하니까 스...,회사 동료 때문에 스트레스를 많이 받는 것 같아요. 문제 해결을 위해 어떤 노력을 ...
3,9,직장에서 막내라는 이유로 나에게만 온갖 심부름을 시켜. 일도 많은 데 정말 분하고 ...,관련 없는 심부름을 모두 하게 되어서 노여우시군요. 어떤 것이 상황을 나아질 수 있...
4,9,얼마 전 입사한 신입사원이 나를 무시하는 것 같아서 너무 화가 나.,무시하는 것 같은 태도에 화가 나셨군요. 상대방의 어떤 행동이 그런 감정을 유발하는...


In [9]:
train_set_data = [[i, str(j)] for i, j in zip(train_set['Q'], train_set['label'])]
validation_set_data = [[i, str(j)] for i, j in zip(validation_set['Q'], validation_set['label'])]

train_set_data = BERTDataset(train_set_data, 0, 1, tok, vocab, max_len, True, False)
validation_set_data = BERTDataset(validation_set_data, 0, 1, tok, vocab, max_len, True, False)
train_dataloader = torch.utils.data.DataLoader(train_set_data, batch_size=batch_size, num_workers=2)
validation_dataloader = torch.utils.data.DataLoader(validation_set_data, batch_size=batch_size, num_workers=2)

In [10]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss().to(device)
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(validation_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 4.144363880157471 train acc 0.01953125
epoch 1 batch id 201 loss 4.0161824226379395 train acc 0.020522388059701493
epoch 1 batch id 401 loss 3.9211533069610596 train acc 0.027713918329177058
epoch 1 train acc 0.03182464973730298


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(validation_dataloader)):


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 1 test acc 0.0448393485915493


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 4.11107873916626 train acc 0.05078125
epoch 2 batch id 201 loss 3.437403917312622 train acc 0.11421408582089553
epoch 2 batch id 401 loss 3.3541839122772217 train acc 0.11778171758104738
epoch 2 train acc 0.10631744488513445


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 2 test acc 0.09628080985915492


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 3.9006946086883545 train acc 0.1328125
epoch 3 batch id 201 loss 2.7697913646698 train acc 0.24012748756218905
epoch 3 batch id 401 loss 2.801152229309082 train acc 0.2171329488778055
epoch 3 train acc 0.183953509709488


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 3 test acc 0.109375


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 3.718740701675415 train acc 0.203125
epoch 4 batch id 201 loss 2.3249669075012207 train acc 0.30050917288557216
epoch 4 batch id 401 loss 2.547844886779785 train acc 0.26494311097256856
epoch 4 train acc 0.22301963145152984


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 4 test acc 0.12973151408450703


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 3.445897102355957 train acc 0.20703125
epoch 5 batch id 201 loss 2.1477620601654053 train acc 0.33257540422885573
epoch 5 batch id 401 loss 2.376643180847168 train acc 0.2935142612219451
epoch 5 train acc 0.24947484740393533


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 5 test acc 0.13721390845070422


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 3.3407254219055176 train acc 0.234375
epoch 6 batch id 201 loss 2.007582187652588 train acc 0.35962764303482586
epoch 6 batch id 401 loss 2.266035318374634 train acc 0.32320565773067333
epoch 6 train acc 0.27780652814978885


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 6 test acc 0.14660357981220656


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 3.108306407928467 train acc 0.265625
epoch 7 batch id 201 loss 1.9784644842147827 train acc 0.38969216417910446
epoch 7 batch id 401 loss 2.0845694541931152 train acc 0.3533159289276808
epoch 7 train acc 0.3080593418409395


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 7 test acc 0.15056484741784038


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 2.881065845489502 train acc 0.31640625
epoch 8 batch id 201 loss 1.768722653388977 train acc 0.42014536691542287
epoch 8 batch id 401 loss 1.9255218505859375 train acc 0.3856861751870324
epoch 8 train acc 0.3403093450602658


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 8 test acc 0.1688306924882629


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 2.6725993156433105 train acc 0.328125
epoch 9 batch id 201 loss 1.6568074226379395 train acc 0.4536302860696517
epoch 9 batch id 401 loss 1.720716953277588 train acc 0.4198098503740648
epoch 9 train acc 0.37458752382301436


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 9 test acc 0.1778719190140845


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 2.553565740585327 train acc 0.34375
epoch 10 batch id 201 loss 1.4604172706604004 train acc 0.4884561567164179
epoch 10 batch id 401 loss 1.6244605779647827 train acc 0.4540017144638404
epoch 10 train acc 0.4096725462295251


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 10 test acc 0.19069102112676056


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 2.3480591773986816 train acc 0.40625
epoch 11 batch id 201 loss 1.3661165237426758 train acc 0.519414645522388
epoch 11 batch id 401 loss 1.475515604019165 train acc 0.48615765274314215
epoch 11 train acc 0.44369720304934585


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 11 test acc 0.20327171361502347


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 2.1202218532562256 train acc 0.44921875
epoch 12 batch id 201 loss 1.254910945892334 train acc 0.5493625621890548
epoch 12 batch id 401 loss 1.3473814725875854 train acc 0.5174563591022444
epoch 12 train acc 0.4761995209642526


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 12 test acc 0.20947036384976528


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 2.0054781436920166 train acc 0.4609375
epoch 13 batch id 201 loss 1.120219111442566 train acc 0.5814870957711443
epoch 13 batch id 401 loss 1.208771824836731 train acc 0.548131624064838
epoch 13 train acc 0.5074805069794994


  0%|          | 0/71 [00:00<?, ?it/s]

epoch 13 test acc 0.22901995305164322


  0%|          | 0/571 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 1.7841709852218628 train acc 0.49609375
epoch 14 batch id 201 loss 0.9618239402770996 train acc 0.6092195273631841
epoch 14 batch id 401 loss 1.1783982515335083 train acc 0.577842503117207


In [None]:
torch.save({
    'model' : model.state_dict(),
    'optimizer': optimizer.state_dict()
}, './QtEmodel.pth')