# KoBERT
https://velog.io/@dev-junku/KoBERT-%EB%AA%A8%EB%8D%B8%EC%97%90-%EB%8C%80%ED%95%B4
https://hoit1302.tistory.com/159

In [None]:
!pip install mxnet # 코랩 환경이기 때문에 앞에 !를 붙여야 한다.
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3.0.2
!pip install torch
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mxnet
  Downloading mxnet-1.9.1-py3-none-manylinux2014_x86_64.whl (49.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.1/49.1 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
Collecting graphviz<0.9.0,>=0.8.1
  Downloading graphviz-0.8.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: graphviz, mxnet
  Attempting uninstall: graphviz
    Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.9.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gluonnlp
  Downloading gluonnlp-0.10.0.tar.gz (344 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m344.5/344.5 KB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hd

In [None]:
# torch
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

#kobert
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [None]:
#GPU 사용
device = torch.device("cuda:0")

#BERT 모델, Vocabulary 불러오기 필수
bertmodel, vocab = get_pytorch_kobert_model()

/content/.cache/kobert_v1.zip[██████████████████████████████████████████████████]
/content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece[██████████████████████████████████████████████████]


In [None]:
# KoBERT에 입력될 데이터셋 정리
class BERTDataset(Dataset):
    def __init__(self, df, text_col, label_col, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([row[text_col]]) for _, row in df.iterrows()]
        self.labels = [np.int32(row[label_col]) for _, row in df.iterrows()]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [None]:
# 모델 정의
class BERTClassifier(nn.Module): ## 클래스를 상속
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=6,   ##클래스 수 조정##
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate

        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)

        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd

data_root = '/content/drive/MyDrive/WithKL/KoBERT_Sentiment'
train_file = '/'.join([data_root, '감성대화말뭉치(최종데이터)_Training.xlsx'])
valid_file = '/'.join([data_root, '감성대화말뭉치(최종데이터)_Validation.xlsx'])
naturalTraining_data = pd.read_excel(train_file)
naturalValidation_data = pd.read_excel(valid_file)

In [None]:
naturalTraining_data.head()
# naturalTraining_data.sample(n=10)

Unnamed: 0.1,Unnamed: 0,연령,성별,상황키워드,신체질환,감정_대분류,감정_소분류,사람문장1,시스템문장1,사람문장2,시스템문장2,사람문장3,시스템문장3
0,1,청년,여성,"진로,취업,직장",해당없음,분노,노여워하는,일은 왜 해도 해도 끝이 없을까? 화가 난다.,많이 힘드시겠어요. 주위에 의논할 상대가 있나요?,그냥 내가 해결하는 게 나아. 남들한테 부담 주고 싶지도 않고.,혼자 해결하기로 했군요. 혼자서 해결하기 힘들면 주위에 의논할 사람을 찾아보세요.,,
1,2,청년,여성,"진로,취업,직장",해당없음,분노,노여워하는,이번 달에 또 급여가 깎였어! 물가는 오르는데 월급만 자꾸 깎이니까 너무 화가 나.,급여가 줄어 속상하시겠어요. 월급이 줄어든 것을 어떻게 보완하실 건가요?,최대한 지출을 억제해야겠어. 월급이 줄어들었으니 고정지출을 줄일 수밖에 없을 것 같아.,월급이 줄어든 만큼 소비를 줄일 계획이군요.,,
2,3,청년,여성,"진로,취업,직장",해당없음,분노,노여워하는,회사에 신입이 들어왔는데 말투가 거슬려. 그런 애를 매일 봐야 한다고 생각하니까 스...,회사 동료 때문에 스트레스를 많이 받는 것 같아요. 문제 해결을 위해 어떤 노력을 ...,잘 안 맞는 사람이랑 억지로 잘 지내는 것보단 조금은 거리를 두고 예의를 갖춰서 대...,스트레스받지 않기 위해선 인간관계에 있어 약간의 거리를 두는 게 좋겠군요.,,
3,4,청년,여성,"진로,취업,직장",해당없음,분노,노여워하는,직장에서 막내라는 이유로 나에게만 온갖 심부름을 시켜. 일도 많은 데 정말 분하고 ...,관련 없는 심부름을 모두 하게 되어서 노여우시군요. 어떤 것이 상황을 나아질 수 있...,직장 사람들과 솔직하게 이야기해보고 싶어. 일하는 데에 방해된다고.,직장 사람들과 이야기를 해 보겠다고 결심하셨군요.,,
4,5,청년,여성,"진로,취업,직장",해당없음,분노,노여워하는,얼마 전 입사한 신입사원이 나를 무시하는 것 같아서 너무 화가 나.,무시하는 것 같은 태도에 화가 나셨군요. 상대방의 어떤 행동이 그런 감정을 유발하는...,상사인 나에게 먼저 인사하지 않아서 매일 내가 먼저 인사한다고!,항상 먼저 인사하게 되어 화가 나셨군요. 어떻게 하면 신입사원에게 화났음을 표현할 ...,,


In [None]:
sentiments = ['기쁨', '불안', '당황', '슬픔', '분노', '상처']
def prepare_data(df):
  def combine(rows):
    result = []
    for row in rows:
      if not pd.isna(row):
          result.append(str(row))
    return ' '.join(result)

  df = naturalTraining_data
  cols = ['사람문장1', '시스템문장1', '사람문장2', '시스템문장2', '사람문장3', '시스템문장3']
  df['text'] = df[cols].apply(lambda row: combine(row), axis=1)
  df['label'] = df['감정_대분류'].apply(lambda cell: sentiments.index(cell))
  data_cols = ['text','label']
  df = df[data_cols]

  return df

In [None]:
def combine(rows):
    result = []
    for row in rows:
      if not pd.isna(row):
          result.append(str(row))
    return ' '.join(result)

df = naturalTraining_data
cols = ['사람문장1', '시스템문장1', '사람문장2', '시스템문장2', '사람문장3', '시스템문장3']
df['text'] = df[cols].apply(lambda row: combine(row), axis=1)
df['label'] = df['감정_대분류'].apply(lambda cell: sentiments.index(cell))
data_cols = ['text','label']
df = df[data_cols]
df.head()

Unnamed: 0,text,label
0,일은 왜 해도 해도 끝이 없을까? 화가 난다. 많이 힘드시겠어요. 주위에 의논할 상...,4
1,이번 달에 또 급여가 깎였어! 물가는 오르는데 월급만 자꾸 깎이니까 너무 화가 나....,4
2,회사에 신입이 들어왔는데 말투가 거슬려. 그런 애를 매일 봐야 한다고 생각하니까 스...,4
3,직장에서 막내라는 이유로 나에게만 온갖 심부름을 시켜. 일도 많은 데 정말 분하고 ...,4
4,얼마 전 입사한 신입사원이 나를 무시하는 것 같아서 너무 화가 나. 무시하는 것 같...,4


In [None]:
dataset_train = prepare_data(naturalTraining_data)
dataset_test = prepare_data(naturalValidation_data)

In [None]:
# Setting parameters 필수
max_len = 128    #default 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 15
max_grad_norm = 1
log_interval = 100
learning_rate =  5e-5

In [None]:
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

data_train = BERTDataset(dataset_train , 'text', 'label', tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 'text', 'label', tok, max_len, True, False)

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece




In [None]:
from transformers import get_cosine_schedule_with_warmup
from transformers import AdamW

#BERT 모델 불러오기
model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)
#optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

#정확도 측정을 위한 함수 정의
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

train_dataloader

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))

    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.8456945419311523 train acc 0.140625
epoch 1 batch id 101 loss 1.8095557689666748 train acc 0.18193069306930693
epoch 1 batch id 201 loss 1.6618930101394653 train acc 0.21260883084577115
epoch 1 batch id 301 loss 1.1619409322738647 train acc 0.28976328903654486
epoch 1 batch id 401 loss 0.639406144618988 train acc 0.3825202618453865
epoch 1 batch id 501 loss 0.6619148254394531 train acc 0.4333832335329341
epoch 1 batch id 601 loss 1.623644471168518 train acc 0.47290973377703827
epoch 1 batch id 701 loss 0.6222613453865051 train acc 0.48613587731811697
epoch 1 batch id 801 loss 0.6523513197898865 train acc 0.5180828651685393
epoch 1 train acc 0.5204284184580572


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 1 test acc 0.43708329966057863


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 1.395207405090332 train acc 0.546875
epoch 2 batch id 101 loss 0.8298099040985107 train acc 0.6633663366336634
epoch 2 batch id 201 loss 0.926529586315155 train acc 0.6856343283582089
epoch 2 batch id 301 loss 0.6699690222740173 train acc 0.6907703488372093
epoch 2 batch id 401 loss 0.6752663254737854 train acc 0.7030081047381546
epoch 2 batch id 501 loss 0.5488507151603699 train acc 0.6966067864271457
epoch 2 batch id 601 loss 1.4255428314208984 train acc 0.6985232945091514
epoch 2 batch id 701 loss 0.5304467678070068 train acc 0.6899295649072753
epoch 2 batch id 801 loss 0.5430184006690979 train acc 0.7009597378277154
epoch 2 train acc 0.702115659177846


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 2 test acc 0.47365359490329184


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 1.1592212915420532 train acc 0.625
epoch 3 batch id 101 loss 0.6886431574821472 train acc 0.7049814356435643
epoch 3 batch id 201 loss 0.9028903245925903 train acc 0.7242692786069652
epoch 3 batch id 301 loss 0.6776923537254333 train acc 0.7246677740863787
epoch 3 batch id 401 loss 0.5187483429908752 train acc 0.7342191396508728
epoch 3 batch id 501 loss 0.4338489770889282 train acc 0.7301334830339321
epoch 3 batch id 601 loss 1.3433884382247925 train acc 0.733101081530782
epoch 3 batch id 701 loss 0.369276225566864 train acc 0.7273092011412268
epoch 3 batch id 801 loss 0.32397013902664185 train acc 0.7381398252184769
epoch 3 train acc 0.7392011475674801


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 3 test acc 0.5589652025752923


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.7858498692512512 train acc 0.78125
epoch 4 batch id 101 loss 0.6881260871887207 train acc 0.7479888613861386
epoch 4 batch id 201 loss 0.8482486009597778 train acc 0.7608830845771144
epoch 4 batch id 301 loss 0.6266888380050659 train acc 0.7600705980066446
epoch 4 batch id 401 loss 0.5033056139945984 train acc 0.7667549875311721
epoch 4 batch id 501 loss 0.33409133553504944 train acc 0.7630364271457086
epoch 4 batch id 601 loss 1.0675990581512451 train acc 0.7677568635607321
epoch 4 batch id 701 loss 0.2108861804008484 train acc 0.764978601997147
epoch 4 batch id 801 loss 0.3023231625556946 train acc 0.7746761860174781
epoch 4 train acc 0.7757941719196163


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 4 test acc 0.5961938001724045


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.6099874377250671 train acc 0.859375
epoch 5 batch id 101 loss 0.501189112663269 train acc 0.7820235148514851
epoch 5 batch id 201 loss 0.8094882369041443 train acc 0.7960976368159204
epoch 5 batch id 301 loss 0.5177580118179321 train acc 0.7925664451827242
epoch 5 batch id 401 loss 0.4032920002937317 train acc 0.7979660224438903
epoch 5 batch id 501 loss 0.18018309772014618 train acc 0.7938498003992016
epoch 5 batch id 601 loss 0.8989397287368774 train acc 0.7998388103161398
epoch 5 batch id 701 loss 0.24806228280067444 train acc 0.7980117689015692
epoch 5 batch id 801 loss 0.2418084442615509 train acc 0.8074282147315855
epoch 5 train acc 0.8084230712246108


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 5 test acc 0.5658420209040461


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.5956190228462219 train acc 0.8125
epoch 6 batch id 101 loss 0.44437175989151 train acc 0.8098700495049505
epoch 6 batch id 201 loss 0.6701738238334656 train acc 0.8264925373134329
epoch 6 batch id 301 loss 0.3919488191604614 train acc 0.8226225083056479
epoch 6 batch id 401 loss 0.33046093583106995 train acc 0.828514650872818
epoch 6 batch id 501 loss 0.201883003115654 train acc 0.8247255489021956
epoch 6 batch id 601 loss 0.4759903848171234 train acc 0.8315307820299501
epoch 6 batch id 701 loss 0.09353163838386536 train acc 0.8296852710413695
epoch 6 batch id 801 loss 0.24004210531711578 train acc 0.8380735018726592
epoch 6 train acc 0.838886812402349


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 6 test acc 0.5910124050428317


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.5584732890129089 train acc 0.859375
epoch 7 batch id 101 loss 0.34292301535606384 train acc 0.8349319306930693
epoch 7 batch id 201 loss 0.638351559638977 train acc 0.853467039800995
epoch 7 batch id 301 loss 0.31684187054634094 train acc 0.8501349667774086
epoch 7 batch id 401 loss 0.2133106291294098 train acc 0.8548940149625935
epoch 7 batch id 501 loss 0.09166759997606277 train acc 0.8506736526946108
epoch 7 batch id 601 loss 0.6182287335395813 train acc 0.8564891846921797
epoch 7 batch id 701 loss 0.0881117507815361 train acc 0.8553851640513552
epoch 7 batch id 801 loss 0.11037115752696991 train acc 0.8627887016229713
epoch 7 train acc 0.8635226415602607


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 7 test acc 0.61795229244114


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.3354845345020294 train acc 0.921875
epoch 8 batch id 101 loss 0.27917760610580444 train acc 0.8598391089108911
epoch 8 batch id 201 loss 0.40682896971702576 train acc 0.878886815920398
epoch 8 batch id 301 loss 0.2656410336494446 train acc 0.8757786544850499
epoch 8 batch id 401 loss 0.14409850537776947 train acc 0.8806889027431422
epoch 8 batch id 501 loss 0.054516877979040146 train acc 0.8763098802395209
epoch 8 batch id 601 loss 0.6991859674453735 train acc 0.8808236272878536
epoch 8 batch id 701 loss 0.045873988419771194 train acc 0.8792796005706134
epoch 8 batch id 801 loss 0.1165243536233902 train acc 0.8862554619225967
epoch 8 train acc 0.8869193133451861


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 8 test acc 0.5864505885997522


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.4124372601509094 train acc 0.90625
epoch 9 batch id 101 loss 0.20922230184078217 train acc 0.880569306930693
epoch 9 batch id 201 loss 0.39500492811203003 train acc 0.8981654228855721
epoch 9 batch id 301 loss 0.1832129955291748 train acc 0.8978405315614618
epoch 9 batch id 401 loss 0.03650318458676338 train acc 0.9014572942643392
epoch 9 batch id 501 loss 0.06667708605527878 train acc 0.8978293413173652
epoch 9 batch id 601 loss 0.7918826937675476 train acc 0.902428244592346
epoch 9 batch id 701 loss 0.029030200093984604 train acc 0.9004324179743224
epoch 9 batch id 801 loss 0.06674880534410477 train acc 0.9059964107365793
epoch 9 train acc 0.9065252747696783


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 9 test acc 0.6300652241258553


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.3711910545825958 train acc 0.890625
epoch 10 batch id 101 loss 0.14323711395263672 train acc 0.8969678217821783
epoch 10 batch id 201 loss 0.29800766706466675 train acc 0.9142568407960199
epoch 10 batch id 301 loss 0.21611522138118744 train acc 0.913984634551495
epoch 10 batch id 401 loss 0.019157223403453827 train acc 0.9177447007481296
epoch 10 batch id 501 loss 0.027629844844341278 train acc 0.9144211576846307
epoch 10 batch id 601 loss 0.8337579965591431 train acc 0.918261231281198
epoch 10 batch id 701 loss 0.12286260724067688 train acc 0.915789942938659
epoch 10 batch id 801 loss 0.05219171196222305 train acc 0.9204705056179775
epoch 10 train acc 0.920872393728786


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 10 test acc 0.6412058550185874


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.33665716648101807 train acc 0.890625
epoch 11 batch id 101 loss 0.09082601219415665 train acc 0.906559405940594
epoch 11 batch id 201 loss 0.18219730257987976 train acc 0.9303482587064676
epoch 11 batch id 301 loss 0.34453284740448 train acc 0.9286233388704319
epoch 11 batch id 401 loss 0.015436988323926926 train acc 0.9313045511221946
epoch 11 batch id 501 loss 0.014268414117395878 train acc 0.9282996506986028
epoch 11 batch id 601 loss 0.7552205920219421 train acc 0.9314683860232945
epoch 11 batch id 701 loss 0.03655414655804634 train acc 0.9282944008559201
epoch 11 batch id 801 loss 0.08077150583267212 train acc 0.9329744069912609
epoch 11 train acc 0.9333759293680297


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 11 test acc 0.6592317224287485


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.20460639894008636 train acc 0.96875
epoch 12 batch id 101 loss 0.10663991421461105 train acc 0.9057858910891089
epoch 12 batch id 201 loss 0.09907136857509613 train acc 0.9303482587064676
epoch 12 batch id 301 loss 0.14696231484413147 train acc 0.9308035714285714
epoch 12 batch id 401 loss 0.01643856056034565 train acc 0.9350451995012469
epoch 12 batch id 501 loss 0.050071634352207184 train acc 0.9325099800399201
epoch 12 batch id 601 loss 0.5569334626197815 train acc 0.9365120632279534
epoch 12 batch id 701 loss 0.05341941490769386 train acc 0.9321950784593438
epoch 12 batch id 801 loss 0.012497592717409134 train acc 0.9367197253433208
epoch 12 train acc 0.9370740396530359


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 12 test acc 0.6997947645600991


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.41863080859184265 train acc 0.90625
epoch 13 batch id 101 loss 0.026833899319171906 train acc 0.9031559405940595
epoch 13 batch id 201 loss 0.09934071451425552 train acc 0.933613184079602
epoch 13 batch id 301 loss 0.11721456795930862 train acc 0.9343334717607974
epoch 13 batch id 401 loss 0.020258165895938873 train acc 0.9399548004987531
epoch 13 batch id 501 loss 0.011825514025986195 train acc 0.9369074351297405
epoch 13 batch id 601 loss 0.8250396251678467 train acc 0.9412177620632279
epoch 13 batch id 701 loss 0.013876190409064293 train acc 0.9342234308131241
epoch 13 batch id 801 loss 0.02950126864016056 train acc 0.9388654806491885
epoch 13 train acc 0.9392425650557621


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 13 test acc 0.7256815365551424


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.227501779794693 train acc 0.953125
epoch 14 batch id 101 loss 0.09707602858543396 train acc 0.8777846534653465
epoch 14 batch id 201 loss 0.092288076877594 train acc 0.9221859452736318
epoch 14 batch id 301 loss 0.131122887134552 train acc 0.928156146179402
epoch 14 batch id 401 loss 0.006249370984733105 train acc 0.9356686408977556
epoch 14 batch id 501 loss 0.006502480246126652 train acc 0.9337574850299402
epoch 14 batch id 601 loss 0.857865571975708 train acc 0.9378899750415973
epoch 14 batch id 701 loss 0.038526542484760284 train acc 0.9228557417974322
epoch 14 batch id 801 loss 0.11980308592319489 train acc 0.9281171972534332
epoch 14 train acc 0.9285548327137546


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 14 test acc 0.7670384138785625


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.1942870318889618 train acc 0.953125
epoch 15 batch id 101 loss 0.21266278624534607 train acc 0.7713490099009901
epoch 15 batch id 201 loss 0.08269178122282028 train acc 0.8582089552238806
epoch 15 batch id 301 loss 0.06617610156536102 train acc 0.8869912790697675
epoch 15 batch id 401 loss 0.0066981809213757515 train acc 0.9059382793017456
epoch 15 batch id 501 loss 0.0478762686252594 train acc 0.9099301397205589
epoch 15 batch id 601 loss 0.8552554845809937 train acc 0.9188851913477537
epoch 15 batch id 701 loss 1.3271716833114624 train acc 0.8741084165477889
epoch 15 batch id 801 loss 1.303822636604309 train acc 0.8395950374531835
epoch 15 train acc 0.8378017078821185


  0%|          | 0/807 [00:00<?, ?it/s]

epoch 15 test acc 0.8859512216475406


In [None]:
## 학습 모델 저장
PATH = data_root # google 드라이브 연동 해야함. 관련코드는 뺐음
torch.save(model, PATH + 'KoBERT_담화.pt')  # 전체 모델 저장
torch.save(model.state_dict(), PATH + 'model_state_dict.pt')  # 모델 객체의 state_dict 저장
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()
}, PATH + 'all.tar')  # 여러 가지 값 저장, 학습 중 진행 상황 저장을 위해 epoch, loss 값 등 일반 scalar값 저장 가능

In [None]:
#토큰화
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

def new_softmax(a) :
    c = np.max(a) # 최댓값
    exp_a = np.exp(a-c) # 각각의 원소에 최댓값을 뺀 값에 exp를 취한다. (이를 통해 overflow 방지)
    sum_exp_a = np.sum(exp_a)
    y = (exp_a / sum_exp_a) * 100
    return np.round(y, 3)


# 예측 모델 설정
def predict(predict_sentence):

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(predict_sentence, 'text', 'label', tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)

    model.eval()
    preds = []
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)

        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()
            min_v = min(logits)
            total = 0
            probability = []
            logits = np.round(new_softmax(logits), 3).tolist()
            for logit in logits:
                print(logit)
                probability.append(np.round(logit, 3))

            if np.argmax(logits) == 0:  emotion = '기쁨'
            elif np.argmax(logits) == 1: emotion = '불안'
            elif np.argmax(logits) == 2: emotion = '당황'
            elif np.argmax(logits) == 3: emotion = '슬픔'
            elif np.argmax(logits) == 4: emotion = '분노'
            elif np.argmax(logits) == 5: emotion = '상처'

            probability.append(emotion)
            print(probability)
            preds.append(emotion)
    return probability, preds

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [None]:
import numpy as np
test_file = '/content/drive/MyDrive/WithKL/KoBERT_Sentiment/LDA_MBN_AH_ALL.csv'
df = pd.read_csv(test_file)
df['text'] = df['korean']
df['label'] = np.zeros(len(df))
cols = ['text','label']
df = df[cols]
_, preds = predict(df)

In [None]:
df['preds'] = preds
cols = ['text','preds']
df = df[cols]
outfile = '/content/drive/MyDrive/WithKL/KoBERT_Sentiment/preds87.csv'
#df.to_csv(outfile, mode="w", encoding='euc-kr')
df.to_csv(outfile, mode="w", encoding='utf-8')