In [1]:
import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
import wandb
from utils import *
from model import *
from dataset import *
from tqdm import tqdm
from torch.utils.data import DataLoader

# Vocab 구축

In [2]:
#%pip install sentencepiece

In [3]:
import sentencepiece as spm

model_file = "./data/kowiki.model" 

# SentencePiece 모델 로드
sp = spm.SentencePieceProcessor()
sp.load(model_file)


True

# 학습 데이터 불러오기

In [4]:
import pandas as pd
dataset_name = 'ChatbotData_KorQuAD'
csv_path = './data/ChatbotData_KorQuAD.csv'
df = pd.read_csv(csv_path)
df.head(10)

Unnamed: 0,Q,A
0,1941년 이우가 배속된 소속은?,조선군사령부
1,1787년경 프랑스에서 거두고 있던 소득세는 무엇일까?,벵티엠
2,늑대 보호 조치와 효과적 법 집행으로 적당한 늑대 개체수를 유지하고 있는 국가는?,이스라엘
3,독도가 한국영토로 표기된 사례는 세계지도 3380건 중에 몇 건인가?,49건
4,중력과 관계된 낙차에 의해 움직여지는 수차를 무엇이라 부르는가?,중력수차
5,나랑 놀아줘,같이 놀아요.
6,친구들한테 인기 얻으려면,성격이 좋으면 인기가 있을 거예요.
7,시민단체는 정부가 기초연금에 대해 어떤 원리를 경직되게 적용한다고 보았는가?,보충성의 원리
8,한국독립당이 창당될 때 안창호 외 창당 발기인은 몇 명이었는가?,28명
9,노무현이 5공 청문회에서 명패를 던졌던 사람은?,전두환


In [5]:
#시퀀스 max 길이 찾기
seq_max_len = 0
for line in list(df['Q'].values):
    leng = len(sp.encode_as_ids(line))
    if seq_max_len < leng:
        seq_max_len = leng
print("Q seq_max_len:", seq_max_len)


for line in list(df['A'].values):
    leng = len(sp.encode_as_ids(line))
    if seq_max_len < leng:
        seq_max_len = leng
print("A seq_max_len:", seq_max_len)

#학습데이터 Vocab 적용해보기
ids_stack = []
for line in list(df['Q'][:5].values):
    pieces = sp.encode_as_pieces(line)
    ids = sp.encode_as_ids(line)
    ids += (seq_max_len-len(ids))*[0]
    ids_stack.append(ids)
    print("Original Text:", line)
    print("Tokens:", pieces)
    print("IDs:", ids)
    print()


Q seq_max_len: 78
A seq_max_len: 78
Original Text: 1941년 이우가 배속된 소속은?
Tokens: ['▁194', '1', '년', '▁이', '우', '가', '▁배', '속', '된', '▁소속', '은', '?']
IDs: [429, 3597, 3616, 8, 3679, 3599, 179, 3763, 3703, 765, 3604, 4245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Original Text: 1787년경 프랑스에서 거두고 있던 소득세는 무엇일까?
Tokens: ['▁17', '87', '년', '경', '▁프랑스', '에서', '▁거두', '고', '▁있던', '▁소', '득', '세는', '▁무', '엇', '일', '까', '?']
IDs: [381, 3209, 3616, 3673, 542, 10, 1987, 3600, 804, 68, 4054, 1561, 108, 4491, 3620, 3794, 4245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Original Text: 늑대 보호 조치와 효과적 법 집행으로 적당한 늑대 개체수를 유지하고 있는 국가는?
Tokens: ['▁', '늑', '대', '▁보호', '▁조', '치', '와', '▁효과', '적', '▁법', '▁집', '행', '으로', '▁

In [6]:
train_dataset = ChatBotDataset(csv_path, seq_max_len)
val_dataset = ChatBotDataset(csv_path, seq_max_len, train=False)

print("train_dataset length:", len(train_dataset))
print("val_dataset length:", len(val_dataset)) 

train_dataset length: 65007
val_dataset length: 7223


# Scaled_dot_Attention 계산 확인

In [7]:
n_dim = 3
q = torch.tensor([[[1,2,3],
                   [3,2,1],
                   [4,5,6]],
                  
                  [[3,2,2],
                   [1,1,1],
                   [5,2,4]]], dtype=torch.float32)
k = q.transpose(-1,-2)
token_ids = torch.tensor([[1,1,0], [1,0,0]])


scaled_attention = torch.matmul(q,k) / np.sqrt(n_dim)
masked_attention =  making_padding_mask(scaled_attention, token_ids)
attention_score = torch.softmax(masked_attention, dim=-1)
output = torch.matmul(attention_score, q)
print("정답:", output)


정답: tensor([[[1.1807, 2.0000, 2.8193],
         [2.8193, 2.0000, 1.1807],
         [1.1807, 2.0000, 2.8193]],

        [[3.0000, 2.0000, 2.0000],
         [3.0000, 2.0000, 2.0000],
         [3.0000, 2.0000, 2.0000]]])


In [8]:
#확인할 때 q,k,v의 가중치는 제외하고 확인해볼 것
model = ScaledDotProductAttention(3)
model(q,q,q,token_ids)

tensor([[[-0.1136, -0.9815,  2.9034],
         [ 0.0749, -1.1777,  2.9638],
         [ 0.2185, -1.3270,  3.0097]],

        [[-0.0239, -1.0855,  3.2401],
         [-0.0239, -1.0855,  3.2401],
         [-0.0239, -1.0855,  3.2401]]], grad_fn=<UnsafeViewBackward0>)

# Hyper Parameter

In [9]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

#model hyperparameter
model_name = 'transformer_chatbot'
n_seq = seq_max_len
n_vocab = sp.vocab_size()
# 논문의 절반으로 setting
n_dim = 256
n_head = 4
n_layer = 3

# dataloader hyperparameter
batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

#train, Loss, Optimizer hyperparameter
epochs = 100
lr = 1e-5

# Transformer Model

In [19]:
model = Transformer(n_seq, n_vocab, n_dim, n_head, n_layer, device=device).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [11]:
def train(loss_fn, optimizer, save_path, eval_step=1):
    
    wandb.init(
    project='Transformer_ChatBot',
    name = f'{model_name}_{dataset_name}',
    config={
        "architecture": model_name,
        "dataset": dataset_name,
        "batch_size" : batch_size,
        "lr": lr,
        "epochs": epochs,
        "n_seq": n_seq,
        "n_vocab": n_vocab,
        "n_dim": n_dim,
        "n_head": n_head,
        "n_layer": n_layer,
        "loss_fn": "CrossEntropyLoss",
        "optimizer": "Adam",
        })
    

    model.train()
    best_val_loss = np.inf
    # BOS 토큰 ID를 가져옵니다.
    bos_token_id = sp.bos_id()
    for epoch in tqdm(range(epochs), ascii=True, desc="epoch"):
        print()
        print(f"********** epoch{epoch+1} train start **********")
        train_loss = 0
        for idx, (q, a) in enumerate(train_dataloader):
            q, a = q.to(device), a.to(device)
            # a를 bos_a로 변경 (오른쪽 시프트)
            bos_a = shift_right(a, bos_token_id)
            
            optimizer.zero_grad()
            pred = model(q, bos_a)
            loss = loss_fn(pred.view(-1, pred.size(-1)), a.view(-1))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
            if idx % (len(train_dataloader)//3) == 0:
                print("epoch: {}, loss: {}".format(epoch+1, loss.item()))
        train_metrics = {'train_loss': train_loss/len(train_dataloader)}
        
        print("epoch: {}, train_loss: {}".format(epoch+1, train_loss/len(train_dataloader)))
        print("********** train end **********")
        wandb.log(train_metrics, step=epoch)
        print()
        if epoch % eval_step == 0:
            with torch.no_grad():
                print("********** eval start **********")
                model.eval()
                val_loss = 0
                for idx, (q, a) in enumerate(val_dataloader):
                    q, a = q.to(device), a.to(device)
                    pred = model(q, a)
                    loss = loss_fn(pred.view(-1, pred.size(-1)), a.view(-1))
                    val_loss += loss.item()
                
                if best_val_loss > val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), save_path)
                    print()
                    print("****best model saved****")
                    print()
                val_metrics = {'val_loss': val_loss/len(val_dataloader)}
                print("epoch: {}, val_loss: {}".format(epoch+1, val_loss/len(val_dataloader)))
                
                #문장 만들기
                indx = np.random.randint(0, len(q))#0~len(val_dataset) 사이의 숫자 랜덤으로 뽑기
                max_arg = torch.argmax(pred, dim=-1)
                pred_sentc = sp.DecodeIds(max_arg[indx].tolist())
                label_sentc = sp.DecodeIds(a[indx].tolist())
                q_sentc = sp.DecodeIds(q[indx].tolist())
                print("Q:", q_sentc)
                print("Pred:", pred_sentc)
                print("Label:", label_sentc)
                wandb.log(val_metrics, step=epoch)
                print("********** eval end **********")
    wandb.finish()

In [12]:
save_path = f'./weight/best_{model_name}.pt'
train(loss_fn, optimizer,save_path, eval_step=1)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdwkim8155[0m ([33mboostcamp-oif[0m). Use [1m`wandb login --relogin`[0m to force relogin


epoch:   0%|          | 0/100 [00:00<?, ?it/s]


**********epoch1 train start**********
epoch: 1, loss: 9.135636329650879
epoch: 1, loss: 0.882185697555542
epoch: 1, loss: 0.8259548544883728
epoch: 1, loss: 0.6149036884307861
epoch: 0, train_loss: 1.1399515186707805
**********train end**********

**********eval start**********


epoch:   1%|1         | 1/100 [03:07<5:09:42, 187.70s/it]


****best model saved****

epoch: 1, val_loss: 0.6459396887669521
Q: J. K. 롤링이 책을 집필하기 위해 자주 이용했던 곳은?
Pred: 
Label: 카페
**********eval end**********

**********epoch2 train start**********
epoch: 2, loss: 0.6310096383094788
epoch: 2, loss: 0.5644517540931702
epoch: 2, loss: 0.5706245303153992
epoch: 2, loss: 0.4819130599498749
epoch: 1, train_loss: 0.5767584623783593
**********train end**********

**********eval start**********


epoch:   2%|2         | 2/100 [06:16<5:07:32, 188.29s/it]


****best model saved****

epoch: 2, val_loss: 0.5387054861119364
Q: 별에서온 그대에서 안재현이 나온 역할은?
Pred: 
Label: 천윤재 역할
**********eval end**********

**********epoch3 train start**********
epoch: 3, loss: 0.5028736591339111
epoch: 3, loss: 0.46147820353507996
epoch: 3, loss: 0.5061978101730347
epoch: 3, loss: 0.4949367642402649
epoch: 2, train_loss: 0.5016547492228625
**********train end**********

**********eval start**********


epoch:   3%|3         | 3/100 [09:23<5:03:31, 187.75s/it]


****best model saved****

epoch: 3, val_loss: 0.5184027927111735
Q: IRB는 최소 몇명 이상으로 구성되어야 하는가?
Pred: 
Label: 5명 이상
**********eval end**********

**********epoch4 train start**********
epoch: 4, loss: 0.5239384770393372
epoch: 4, loss: 0.49657753109931946
epoch: 4, loss: 0.3507120609283447
epoch: 4, loss: 0.41790154576301575
epoch: 3, train_loss: 0.46653400666601075
**********train end**********

**********eval start**********


epoch:   4%|4         | 4/100 [12:29<4:59:01, 186.89s/it]

epoch: 4, val_loss: 0.5298156168608539
Q: 별에서온 그대에서 안재현이 나온 역할은?
Pred: 
Label: 천윤재 역할
**********eval end**********

**********epoch5 train start**********
epoch: 5, loss: 0.47072434425354004
epoch: 5, loss: 0.4913058280944824
epoch: 5, loss: 0.4560984969139099
epoch: 5, loss: 0.4070553779602051
epoch: 4, train_loss: 0.43920606761936126
**********train end**********

**********eval start**********


epoch:   5%|5         | 5/100 [15:35<4:55:51, 186.86s/it]

epoch: 5, val_loss: 0.5534577583317208
Q: 민주노동당 당원 가족들을 불법사찰한 사실을 폭로한 날짜는?
Pred: 
Label: 8월 17일
**********eval end**********

**********epoch6 train start**********
epoch: 6, loss: 0.48735663294792175
epoch: 6, loss: 0.41468673944473267
epoch: 6, loss: 0.442660391330719
epoch: 6, loss: 0.4618214964866638
epoch: 5, train_loss: 0.41947527614048147
**********train end**********

**********eval start**********


epoch:   6%|6         | 6/100 [18:41<4:52:13, 186.53s/it]

epoch: 6, val_loss: 0.5777770615784468
Q: 제천 스포츠센터 건축시 외장재를 어떤 재질로 건축하였는가?
Pred: 
Label: 드라이비트
**********eval end**********

**********epoch7 train start**********
epoch: 7, loss: 0.4341537654399872
epoch: 7, loss: 0.45753028988838196
epoch: 7, loss: 0.4517773389816284
epoch: 7, loss: 0.3548433482646942
epoch: 6, train_loss: 0.40478366281925227
**********train end**********

**********eval start**********


epoch:   7%|7         | 7/100 [21:47<4:48:40, 186.25s/it]

epoch: 7, val_loss: 0.5992654430127777
Q: 오스트레일리아까치는 얼마나 높은 음을 낼 수 있나?
Pred: 일
Label: 4 옥타브 이상
**********eval end**********

**********epoch8 train start**********
epoch: 8, loss: 0.34922415018081665
epoch: 8, loss: 0.40957313776016235
epoch: 8, loss: 0.3464091420173645
epoch: 8, loss: 0.4095548093318939
epoch: 7, train_loss: 0.3932300850338354
**********train end**********

**********eval start**********


epoch:   8%|8         | 8/100 [24:53<4:45:21, 186.10s/it]

epoch: 8, val_loss: 0.6191616443406164
Q: 서독에서 출판된 라살로 고발된 슈테판 하임이 받은 선고는?
Pred: 스
Label: 벌금형
**********eval end**********

**********epoch9 train start**********
epoch: 9, loss: 0.3891252279281616
epoch: 9, loss: 0.47059258818626404
epoch: 9, loss: 0.41786864399909973
epoch: 9, loss: 0.3897298276424408
epoch: 8, train_loss: 0.38339548889459585
**********train end**********

**********eval start**********


epoch:   9%|9         | 9/100 [27:58<4:41:58, 185.92s/it]

epoch: 9, val_loss: 0.6390705135016315
Q: 야마우치 가즈토요가 이에야스에게 제공하기로 하여 환심을 산 성의 이름은?
Pred: 
Label: 가케가와 성
**********eval end**********

**********epoch10 train start**********
epoch: 10, loss: 0.3929477632045746
epoch: 10, loss: 0.3707222640514374
epoch: 10, loss: 0.34809935092926025
epoch: 10, loss: 0.3578353822231293
epoch: 9, train_loss: 0.3744176293569287
**********train end**********

**********eval start**********


epoch:  10%|#         | 10/100 [31:04<4:38:38, 185.76s/it]

epoch: 10, val_loss: 0.6626675070914547
Q: 보잉사가 가진 업무 철학은 무엇인가?
Pred: 리니
Label: 장벽 제거 철학
**********eval end**********

**********epoch11 train start**********
epoch: 11, loss: 0.3556901514530182
epoch: 11, loss: 0.3631409704685211
epoch: 11, loss: 0.3479026257991791
epoch: 11, loss: 0.36990073323249817
epoch: 10, train_loss: 0.3658619427716169
**********train end**********

**********eval start**********


epoch:  11%|#1        | 11/100 [34:09<4:35:25, 185.68s/it]

epoch: 11, val_loss: 0.6959032674806308
Q: 유전형질의 적응도가 더욱 증가하는 경우의 선택은?
Pred: 리
Label: 안정성 선택
**********eval end**********

**********epoch12 train start**********
epoch: 12, loss: 0.3858884274959564
epoch: 12, loss: 0.3681367039680481
epoch: 12, loss: 0.4007795751094818
epoch: 12, loss: 0.35060515999794006
epoch: 11, train_loss: 0.3582985148889812
**********train end**********

**********eval start**********


epoch:  12%|#2        | 12/100 [37:15<4:32:21, 185.69s/it]

epoch: 12, val_loss: 0.7192110577515797
Q: 이어폰 사야지
Pred: 나
Label: 잘 골라보세요.
**********eval end**********

**********epoch13 train start**********
epoch: 13, loss: 0.45323240756988525
epoch: 13, loss: 0.3536745309829712
epoch: 13, loss: 0.33765509724617004
epoch: 13, loss: 0.3931470811367035
epoch: 12, train_loss: 0.3515262484227813
**********train end**********

**********eval start**********


epoch:  13%|#3        | 13/100 [40:21<4:29:18, 185.72s/it]

epoch: 13, val_loss: 0.7483261428048126
Q: 오버워치를 출시한 게임 회사의 이름은?
Pred: 리
Label: 블리자드
**********eval end**********

**********epoch14 train start**********
epoch: 14, loss: 0.33954039216041565
epoch: 14, loss: 0.3155936300754547
epoch: 14, loss: 0.34418460726737976
epoch: 14, loss: 0.33991003036499023
epoch: 13, train_loss: 0.34525459550145104
**********train end**********

**********eval start**********


epoch:  14%|#4        | 14/100 [43:26<4:26:08, 185.68s/it]

epoch: 14, val_loss: 0.7743430185107003
Q: 이어폰 사야지
Pred: 나
Label: 잘 골라보세요.
**********eval end**********

**********epoch15 train start**********
epoch: 15, loss: 0.35961607098579407
epoch: 15, loss: 0.35164675116539
epoch: 15, loss: 0.33552640676498413
epoch: 15, loss: 0.28418052196502686
epoch: 14, train_loss: 0.339341731316696
**********train end**********

**********eval start**********


epoch:  15%|#5        | 15/100 [47:01<4:35:33, 194.51s/it]

epoch: 15, val_loss: 0.7851275063194005
Q: 오스트레일리아까치는 얼마나 높은 음을 낼 수 있나?
Pred: 니이
Label: 4 옥타브 이상
**********eval end**********

**********epoch16 train start**********
epoch: 16, loss: 0.3176496922969818
epoch: 16, loss: 0.3642856478691101
epoch: 16, loss: 0.3594782054424286
epoch: 16, loss: 0.32227662205696106
epoch: 15, train_loss: 0.3336987718939781
**********train end**********

**********eval start**********


epoch:  16%|#6        | 16/100 [1:12:08<13:45:27, 589.62s/it]

epoch: 16, val_loss: 0.7996084136245525
Q: 말제르브가 정치에 개입하기 시작한 해는?
Pred: 권
Label: 1771년
**********eval end**********

**********epoch17 train start**********
epoch: 17, loss: 0.26528438925743103


epoch:  16%|#6        | 16/100 [1:12:23<6:20:01, 271.44s/it] 


KeyboardInterrupt: 

In [22]:
#저장된 best 가중치 불러오기
save_path = f'./weight/best_{model_name}.pt'
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)

model.eval()
prompt = "내일 뭐해?"
output = generate_sentc(model, sp, n_seq, prompt, device)
print("입력 문장:", prompt)
print("출력 문장:", output)

입력 문장: 내일 뭐해?
출력 문장: 년
