# 예제: CNN으로 성씨 분류하기

## 데이터 로드

In [1]:
import pandas as pd

df = pd.read_csv("../../data/surnames_with_splits.csv")
df


Unnamed: 0,nationality,nationality_index,split,surname
0,Arabic,15,train,Totah
1,Arabic,15,train,Abboud
2,Arabic,15,train,Fakhoury
3,Arabic,15,train,Srour
4,Arabic,15,train,Sayegh
...,...,...,...,...
10975,Vietnamese,11,test,Dinh
10976,Vietnamese,11,test,Phung
10977,Vietnamese,11,test,Quang
10978,Vietnamese,11,test,Vu


In [2]:
df['nationality'].value_counts()

nationality
English       2972
Russian       2373
Arabic        1603
Japanese       775
Italian        600
German         576
Czech          414
Spanish        258
Dutch          236
French         229
Chinese        220
Irish          183
Greek          156
Polish         120
Korean          77
Scottish        75
Vietnamese      58
Portuguese      55
Name: count, dtype: int64

In [3]:
df['split'].value_counts()

split
train    7680
test     1660
val      1640
Name: count, dtype: int64

### 데이터 split(train/valid/test)

In [4]:
# 데이터를 다시 train/valid/test로 나눠줌

# train 데이터 
train_df = df[df.split=='train']
train_size = len(train_df)

# valid 데이터 
val_df = df[df.split=='val']
val_size = len(val_df)

# test 데이터 
test_df = df[df.split=='test']
test_size = len(test_df)

In [5]:
lookup_dict = {'train': (train_df, train_size), 
              'val': (val_df, val_size), 
              'test': (test_df, test_size)}


## 2. Vocabulary

In [6]:
class Vocabulary:
    def __init__(self, token_to_idx=None):

        if token_to_idx is None:
            token_to_idx = {}
        self.token_to_idx = token_to_idx

        self.idx_to_token = {idx: token 
                              for token, idx in self.token_to_idx.items()}

    def add_token(self, token):
        
#       만약 해당 토큰이 있으면 토큰 idx만 return
        if token in self.token_to_idx:
            index = self.token_to_idx[token]
            
#       만약 해당 토큰이 없으면 새로운 토큰 만들어줌
        else:
            index = len(self.token_to_idx)
            self.token_to_idx[token] = index
            self.idx_to_token[index] = token
        return index

In [7]:
class SequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, unk_token="<UNK>",
                 mask_token="<MASK>", begin_seq_token="<BEGIN>",
                 end_seq_token="<END>"):

        super(SequenceVocabulary, self).__init__(token_to_idx)

        self._mask_token = mask_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.mask_index = self.add_token(self._mask_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)
        
        

In [8]:
# Vocab 생성
char_vocab = SequenceVocabulary()
nationality_vocab = Vocabulary()

for index, row in df.iterrows():
    for char in row.surname:
        char_vocab.add_token(char)
    nationality_vocab.add_token(row.nationality)

### Char Vocabulary

In [9]:
print(dict(list(char_vocab.token_to_idx.items())[:10]))

{'<MASK>': 0, '<UNK>': 1, '<BEGIN>': 2, '<END>': 3, 'T': 4, 'o': 5, 't': 6, 'a': 7, 'h': 8, 'A': 9}


In [10]:
print(dict(list(char_vocab.idx_to_token.items())[:10]))

{0: '<MASK>', 1: '<UNK>', 2: '<BEGIN>', 3: '<END>', 4: 'T', 5: 'o', 6: 't', 7: 'a', 8: 'h', 9: 'A'}


### 국적 Vocabulary

In [11]:
print(dict(list(nationality_vocab.token_to_idx.items())[:10]))

{'Arabic': 0, 'Chinese': 1, 'Czech': 2, 'Dutch': 3, 'English': 4, 'French': 5, 'German': 6, 'Greek': 7, 'Irish': 8, 'Italian': 9}


In [12]:
print(dict(list(nationality_vocab.idx_to_token.items())[:10]))

{0: 'Arabic', 1: 'Chinese', 2: 'Czech', 3: 'Dutch', 4: 'English', 5: 'French', 6: 'German', 7: 'Greek', 8: 'Irish', 9: 'Italian'}


## 3. Vectorizer

In [13]:
# 주어진 토큰에 대응하는 인덱스 반환

def lookup_token(vocabulary_class,token):
    return vocabulary_class.token_to_idx[token]
    

In [14]:
# 주어진 인덱스에 대응하는 토큰 반환

def lookup_index(vocabulary_class, index):
        if index not in vocabulary_class.idx_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return vocabulary_class.idx_to_token[index]
    

In [15]:
vocab_length = len(char_vocab.token_to_idx)
print("토큰의 수:", vocab_length)

토큰의 수: 88


### 텍스트(surname)에 대한 원 핫 인코딩

In [16]:
import numpy as np 

# vector_length (int): 인덱스 벡터의 길이를 맞추기 위한 매개변수
def vectorize(surname, vector_length=-1):

    indices = [char_vocab.begin_seq_index]
    indices.extend(lookup_token(char_vocab,token) 
                   for token in surname)
    indices.append(char_vocab.end_seq_index)

    if vector_length < 0:
        vector_length = len(indices)
    
    from_vector = np.empty(vector_length, dtype=np.int64)         
    from_indices = indices[:-1]
    from_vector[:len(from_indices)] = from_indices
    from_vector[len(from_indices):] = char_vocab.mask_index

    to_vector = np.empty(vector_length, dtype=np.int64)
    to_indices = indices[1:]
    to_vector[:len(to_indices)] = to_indices
    to_vector[len(to_indices):] = char_vocab.mask_index
        
    return from_vector, to_vector

print("예시")
example = vectorize("Choi", 10)
print(example)
print(len(example[0]))

예시
(array([ 2, 20,  8,  5, 23,  0,  0,  0,  0,  0]), array([20,  8,  5, 23,  3,  0,  0,  0,  0,  0]))
10


### SurnameDataset class

In [17]:
import torch
from torch.utils.data import Dataset

class SurnameDataset(Dataset):
    def __init__(self, surname_df):
        self.surname_df = surname_df
        self.max_seq_length = max(map(len, df.surname)) + 2
        
     

    def __len__(self):
        return len(self.surname_df)

    def __getitem__(self, index):
        row = self.surname_df.iloc[index]
        
        from_vector, to_vector = vectorize(row.surname, self.max_seq_length)
        
        nationality_index = lookup_token(nationality_vocab,row.nationality)

        return {'x_data': from_vector, 
                'y_target': to_vector, 
                'class_index': nationality_index}
    

### 데이터셋 class

In [18]:
# 데이터셋을 인스턴스화 해주어야 로더에 넣어줄 수 있다. 

train_dataset = SurnameDataset(train_df)
train_dataset

valid_dataset = SurnameDataset(val_df)
valid_dataset

test_dataset = SurnameDataset(test_df)
test_dataset


<__main__.SurnameDataset at 0x7f8ae0902f10>

In [19]:
print(train_dataset.max_seq_length)
print(valid_dataset.max_seq_length)
print(test_dataset.max_seq_length)

19
19
19


In [20]:
# 데이터 로더 설정
from torch.utils.data import DataLoader

# drop_last=True -> 배치 사이즈보다 over하면 drop

Traindataloader = DataLoader(dataset=train_dataset, batch_size=512,
                            shuffle=True, drop_last=True)

Validdataloader = DataLoader(dataset=valid_dataset, batch_size=512,
                            shuffle=True, drop_last=True)

Testdataloader = DataLoader(dataset=test_dataset, batch_size=512,
                            shuffle=True, drop_last=True)


In [21]:
print(len(train_dataset),len(Traindataloader))

7680 15


In [22]:
for batch_index, batch_dict in enumerate(Traindataloader):
#     print(batch_index)
    print(batch_dict)
    
    break
    

{'x_data': tensor([[ 2, 45,  7,  ...,  0,  0,  0],
        [ 2, 29,  5,  ...,  0,  0,  0],
        [ 2, 30, 11,  ...,  0,  0,  0],
        ...,
        [ 2,  4, 43,  ...,  0,  0,  0],
        [ 2, 40, 25,  ...,  0,  0,  0],
        [ 2, 37,  7,  ...,  0,  0,  0]]), 'y_target': tensor([[45,  7, 25,  ...,  0,  0,  0],
        [29,  5, 15,  ...,  0,  0,  0],
        [30, 11, 23,  ...,  0,  0,  0],
        ...,
        [ 4, 43,  8,  ...,  0,  0,  0],
        [40, 25, 43,  ...,  0,  0,  0],
        [37,  7, 21,  ...,  0,  0,  0]]), 'class_index': tensor([14, 14,  4, 14,  4,  4,  4,  6, 14,  4,  4,  6,  0, 14,  6,  4,  0,  4,
        14,  4, 14, 14,  0, 14,  0,  3,  4,  7, 14,  5,  0,  4,  4,  0,  5,  4,
         4, 10,  6,  4,  6,  0, 14,  4, 14, 14,  4, 14, 14, 15,  4,  0,  4,  0,
        10, 14,  4,  4,  4, 14, 14,  0,  9,  9,  4, 14, 10,  0, 14, 16,  4, 16,
         4,  4, 14,  4,  1,  4, 14,  4,  7,  0,  4,  0, 10, 14,  0,  4,  4,  4,
        14,  1,  4, 14,  0, 14, 14,  4,  0,  9,  4, 

## 모델 정의

In [23]:
# 사전 정의 함수
# 주어진 배치의 각 데이터 포인트에 대해 시퀀스의 마지막 벡터를 추출
# => y_out에 있는 각 데이터 포인트에서 마지막 벡터 추출
def column_gather(y_out, x_lengths):
    
#     x_lengths = x_lengths.long().detach().cpu().numpy() - 1
    x_lengths = x_lengths-1
    out = []
    for batch_index, length in enumerate(x_lengths):
        out.append(y_out[batch_index, length])

    return torch.stack(out)

### column_gather 출력예시

In [24]:
import torch

# 가상의 입력 데이터
y_out = torch.tensor([
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],  # 첫 번째 시퀀스: 길이 3
    [[10, 11, 12], [13, 14, 15], [16, 17, 18]],  # 두 번째 시퀀스: 길이 3 (길이를 맞춤)
])

x_lengths = torch.tensor([3, 3])  # 각 시퀀스의 길이 (동일하게 맞춤)
# column_gather 함수 호출
result = column_gather(y_out, x_lengths)
print(result)

tensor([[ 7,  8,  9],
        [16, 17, 18]])


### 조건이 없는 성씨 생성 모델 

In [25]:
import torch.nn.functional as F 
import torch.nn as nn

class SurnameGenerationModel(nn.Module):
    def __init__(self, char_embedding_size, char_vocab_size, rnn_hidden_size, 
                 batch_first=True, padding_idx=0, dropout_p=0.5):
        
        super(SurnameGenerationModel, self).__init__()
        
        self.emb = nn.Embedding(num_embeddings = char_vocab_size,
                               embedding_dim = char_embedding_size,
                               padding_idx = padding_idx)
        
        self.rnn = nn.GRU(input_size=char_embedding_size, 
                          hidden_size=rnn_hidden_size,
                          batch_first=batch_first)
        
        self.fc = nn.Linear(in_features=rnn_hidden_size, 
                            out_features=char_vocab_size)
        
        self.dropout_p = dropout_p
        
    def forward(self, x_in, apply_softmax=False):
        x_embedded = self.emb(x_in)

        y_out, _ = self.rnn(x_embedded)

        batch_size, seq_size, feat_size = y_out.shape
        y_out = y_out.contiguous().view(batch_size * seq_size, feat_size)

        y_out = self.fc(F.dropout(y_out, p=self.dropout_p))
                         
        if apply_softmax:
            y_out = F.softmax(y_out, dim=1)
            
        new_feat_size = y_out.shape[-1]
        y_out = y_out.view(batch_size, seq_size, new_feat_size)
            
        return y_out


In [26]:
len(char_vocab.token_to_idx)

88

In [27]:
len(nationality_vocab.token_to_idx)

18

In [28]:
char_embedding_size = 32
rnn_hidden_size = 32

model = SurnameGenerationModel(char_embedding_size=char_embedding_size,
                               char_vocab_size=len(char_vocab.token_to_idx),
                               rnn_hidden_size=rnn_hidden_size,
                               padding_idx=char_vocab.mask_index)
model

SurnameGenerationModel(
  (emb): Embedding(88, 32, padding_idx=0)
  (rnn): GRU(32, 32, batch_first=True)
  (fc): Linear(in_features=32, out_features=88, bias=True)
)

### 옵티마이저, loss function

In [29]:
lr = 0.001
num_epochs = 100

In [30]:
# 옵티마이저
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr = lr)
optimizer


Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [31]:
df['nationality'].value_counts()

nationality
English       2972
Russian       2373
Arabic        1603
Japanese       775
Italian        600
German         576
Czech          414
Spanish        258
Dutch          236
French         229
Chinese        220
Irish          183
Greek          156
Polish         120
Korean          77
Scottish        75
Vietnamese      58
Portuguese      55
Name: count, dtype: int64

In [32]:
# numSample_list = df['nationality'].value_counts().tolist()
# numSample_list
# # weights 계산
# weights = [1 - (x / sum(numSample_list)) for x in numSample_list]

# # weights를 torch.FloatTensor로 변환
# weights = torch.FloatTensor(weights)
# weights

In [33]:
def normalize_sizes(y_pred, y_true):
    """텐서 크기 정규화
    
    매개변수:
        y_pred (torch.Tensor): 모델의 출력
            3차원 텐서이면 행렬로 변환합니다.
        y_true (torch.Tensor): 타깃 예측
            행렬이면 벡터로 변환합니다.
    """
    if len(y_pred.size()) == 3:
        y_pred = y_pred.contiguous().view(-1, y_pred.size(2))
    if len(y_true.size()) == 2:
        y_true = y_true.contiguous().view(-1)
    return y_pred, y_true

In [34]:
# loss function

def sequence_loss(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)
    return F.cross_entropy(y_pred, y_true, ignore_index=mask_index)

### Train

In [35]:
def compute_accuracy(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)

    _, y_pred_indices = y_pred.max(dim=1)
    
    correct_indices = torch.eq(y_pred_indices, y_true).float()
    valid_indices = torch.ne(y_true, mask_index).float()
    
    n_correct = (correct_indices * valid_indices).sum().item()
    n_valid = valid_indices.sum().item()

    return n_correct / n_valid * 100

In [36]:
# Train state 초기화 
def make_train_state():
    return {
        'stop_early':False,
        'early_stopping_step':0,
        'early_stopping_best_val':1e8,
        'early_stopping_criteria' : 10,
        'epoch_index' : 0,
        'train_loss': [], 
        'train_acc' :[], 
        'val_loss' : [],
        'val_acc' : [], 
        'test_loss' : [],
        'test_acc' : [],
         
#       모델 저장파일
        'model_filename' : 'model.pth'
    } 


# Train update 
def update_train_state(model, train_state):
    
#   학습시작하면 초기에 모델 저장하기 
    
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(),train_state['model_filename'])
        
#   모델 성능이 향상되면 모델 저장(valid loss가 더 낮아지면)
    elif train_state['epoch_index'] >=1 :
        loss_t = train_state['val_loss'][-1]
#        loss가 나빠지면 early stop step 업데이트
        if loss_t >= train_state['early_stopping_best_val']:
            train_state['early_stopping_step']+=1
            
#        loss가 좋아지면   
        else:
#            early stop step 0으로 다시 초기화        
            train_state['early_stopping_step']=0
    
#           최저 loss이면 모델 저장 
            if loss_t < train_state['early_stopping_best_val']:
                train_state['early_stopping_best_val'] = loss_t
                torch.save(model.state_dict(),train_state['model_filename'])

#       기준점 넘으면 early stop 
        if train_state['early_stopping_step'] >= train_state['early_stopping_criteria']:
            train_state['stop_early'] = True
        
        return train_state


In [37]:
# 모델 진행 상황 함수 초기화
train_state = make_train_state()
train_state

{'stop_early': False,
 'early_stopping_step': 0,
 'early_stopping_best_val': 100000000.0,
 'early_stopping_criteria': 10,
 'epoch_index': 0,
 'train_loss': [],
 'train_acc': [],
 'val_loss': [],
 'val_acc': [],
 'test_loss': [],
 'test_acc': [],
 'model_filename': 'model.pth'}

In [38]:
import tqdm

mask_index = char_vocab.mask_index

# 에포크만큼
for epoch in tqdm.tqdm(range(num_epochs)):

#     print('epoch',epoch)
#     print(train_state['epoch_index']) 
    train_state['epoch_index'] +=1 

    running_loss = 0.0
    running_acc = 0.0


#     모델을 학습 모드로 설정 -> 드롭아웃 및 배치 정규화와 같은 학습 중에만 적용되는 기법들이 활성화
#     모델을 평가 모드로 전환하려면 model.eval()을 사용
    model.train()
# 배치 만큼
    for batch_idx, batch_data in enumerate(Traindataloader):
        

#       1. 옵티마이저 그레디언트 0으로 초기화
        optimizer.zero_grad()
#       2. 모델에 데이터 넣어서 출력받기
        y_pred = model(x_in=batch_data['x_data'])
#       3. loss 계산하기
        loss =  sequence_loss(y_pred, batch_dict['y_target'], mask_index)

    
#       4. gradient 계산하기
        loss.backward()

#       5. 옵티마이저 가중치 업데이트
        optimizer.step()

#       Accuracy 계산
        # 이동 손실과 이동 정확도를 계산
        running_loss += (loss.item() - running_loss) / (batch_idx + 1)
        acc_t = compute_accuracy(y_pred, batch_dict['y_target'], mask_index)
        running_acc += (acc_t - running_acc) / (batch_idx + 1)


    train_state['train_loss'].append(running_loss)
    train_state['train_acc'].append(running_acc)


#   valid에 대한 계산

    running_loss = 0.0
    running_acc = 0.0

    model.eval() # 모델 파라미터를 수정하지 못 하게 비활성화

    for batch_idx, batch_data in enumerate(Validdataloader):

#       1. 모델의 출력값(y_pred)계산
        y_pred = model(x_in=batch_data['x_data'])

#       2. loss 계산
        loss_t = sequence_loss(y_pred, batch_dict['y_target'], mask_index)
        running_loss += (loss_t.item() - running_loss) / (batch_idx + 1)

#       3. Accuracy 계산
        acc_t = compute_accuracy(y_pred,batch_data['y_target'],mask_index)
        running_acc += (acc_t - running_acc) / (batch_idx + 1)
    
    print("val_loss",running_loss)
    print("val_acc",running_acc)

    train_state['val_loss'].append(running_loss)
    train_state['val_acc'].append(running_acc)
    

#   전체 loss, acc 저장
    train_state = update_train_state(model=model,
                                     train_state=train_state)
#   early stop해라고 했으면 학습 멈추기    
    if train_state['stop_early']:
        break



  1%|▍                                          | 1/100 [00:03<05:02,  3.05s/it]

val_loss 4.371586958567302
val_acc 4.5338280686133805


  2%|▊                                          | 2/100 [00:05<04:26,  2.72s/it]

val_loss 4.1335768699646
val_acc 9.210562744860738


  3%|█▎                                         | 3/100 [00:07<03:45,  2.33s/it]

val_loss 3.7114296754201255
val_acc 10.713930249351195


  4%|█▋                                         | 4/100 [00:09<03:16,  2.05s/it]

val_loss 3.3486293156941733
val_acc 12.832106259635193


  5%|██▏                                        | 5/100 [00:10<02:59,  1.89s/it]

val_loss 3.190239906311035
val_acc 15.025515695310691


  6%|██▌                                        | 6/100 [00:12<02:59,  1.91s/it]

val_loss 3.123265345891317
val_acc 16.53255833183104


  7%|███                                        | 7/100 [00:14<02:50,  1.83s/it]

val_loss 3.0531802972157798
val_acc 17.52826504710769


  8%|███▍                                       | 8/100 [00:15<02:44,  1.79s/it]

val_loss 3.022765318552653
val_acc 17.91532824044064


  9%|███▊                                       | 9/100 [00:17<02:41,  1.77s/it]

val_loss 2.9727914333343506
val_acc 18.183255007401975


 10%|████▏                                     | 10/100 [00:19<02:37,  1.75s/it]

val_loss 2.958892504374186
val_acc 18.390078172670172


 11%|████▌                                     | 11/100 [00:21<02:32,  1.71s/it]

val_loss 2.9337037404378257
val_acc 19.00673104415413


 12%|█████                                     | 12/100 [00:22<02:29,  1.70s/it]

val_loss 2.916237990061442
val_acc 18.78534056276807


 13%|█████▍                                    | 13/100 [00:24<02:25,  1.67s/it]

val_loss 2.896769126256307
val_acc 19.045606446731853


 14%|█████▉                                    | 14/100 [00:25<02:22,  1.65s/it]

val_loss 2.88787833849589
val_acc 18.918670029134436


 15%|██████▎                                   | 15/100 [00:27<02:19,  1.64s/it]

val_loss 2.883155902226766
val_acc 19.12320328532728


 16%|██████▋                                   | 16/100 [00:29<02:16,  1.62s/it]

val_loss 2.8685765266418457
val_acc 19.288503417685632


 17%|███████▏                                  | 17/100 [00:30<02:14,  1.62s/it]

val_loss 2.8631296157836914
val_acc 19.075272284281862


 18%|███████▌                                  | 18/100 [00:32<02:12,  1.61s/it]

val_loss 2.858396132787069
val_acc 19.258894623114184


 19%|███████▉                                  | 19/100 [00:33<02:09,  1.60s/it]

val_loss 2.850728432337443
val_acc 19.484739166814368


 20%|████████▍                                 | 20/100 [00:35<02:07,  1.60s/it]

val_loss 2.8450376192728677
val_acc 19.723560441311168


 21%|████████▊                                 | 21/100 [00:37<02:06,  1.60s/it]

val_loss 2.8420214653015137
val_acc 19.40906483625827


 22%|█████████▏                                | 22/100 [00:38<02:04,  1.59s/it]

val_loss 2.8391596476236978
val_acc 19.549606188027294


 23%|█████████▋                                | 23/100 [00:40<02:02,  1.59s/it]

val_loss 2.829118251800537
val_acc 19.78057852117794


 24%|██████████                                | 24/100 [00:41<02:01,  1.60s/it]

val_loss 2.8381129105885825
val_acc 19.586309917298454


 25%|██████████▌                               | 25/100 [00:43<02:00,  1.61s/it]

val_loss 2.829327424367269
val_acc 19.498473526073433


 26%|██████████▉                               | 26/100 [00:45<01:58,  1.60s/it]

val_loss 2.8263725439707437
val_acc 19.46207101463547


 27%|███████████▎                              | 27/100 [00:46<01:58,  1.62s/it]

val_loss 2.8221234480539956
val_acc 19.837305495624538


 28%|███████████▊                              | 28/100 [00:48<01:57,  1.63s/it]

val_loss 2.8253773053487143
val_acc 19.586282959464338


 29%|████████████▏                             | 29/100 [00:50<01:55,  1.63s/it]

val_loss 2.8185150623321533
val_acc 19.435228836164


 30%|████████████▌                             | 30/100 [00:51<01:53,  1.62s/it]

val_loss 2.8203254540761313
val_acc 19.28231656223044


 31%|█████████████                             | 31/100 [00:53<01:51,  1.61s/it]

val_loss 2.8084877332051597
val_acc 19.802302813929156


 32%|█████████████▍                            | 32/100 [00:54<01:48,  1.60s/it]

val_loss 2.8178109327952066
val_acc 19.552524493293376


 33%|█████████████▊                            | 33/100 [00:56<01:47,  1.60s/it]

val_loss 2.808476765950521
val_acc 19.796634684773046


 34%|██████████████▎                           | 34/100 [00:57<01:45,  1.60s/it]

val_loss 2.811448017756144
val_acc 19.847259554178752


 35%|██████████████▋                           | 35/100 [00:59<01:43,  1.60s/it]

val_loss 2.810793161392212
val_acc 19.48083755941952


 36%|███████████████                           | 36/100 [01:01<01:41,  1.59s/it]

val_loss 2.800457795461019
val_acc 19.727502206095267


 37%|███████████████▌                          | 37/100 [01:02<01:41,  1.61s/it]

val_loss 2.8075652917226157
val_acc 19.90372202406522


 38%|███████████████▉                          | 38/100 [01:04<01:42,  1.65s/it]

val_loss 2.8020026683807373
val_acc 19.841288439906577


 39%|████████████████▍                         | 39/100 [01:06<01:41,  1.66s/it]

val_loss 2.8036988576253257
val_acc 19.64059573531469


 40%|████████████████▊                         | 40/100 [01:07<01:38,  1.65s/it]

val_loss 2.8030402660369873
val_acc 20.03275340976394


 41%|█████████████████▏                        | 41/100 [01:09<01:36,  1.63s/it]

val_loss 2.8009393215179443
val_acc 19.79637791396655


 42%|█████████████████▋                        | 42/100 [01:11<01:34,  1.62s/it]

val_loss 2.7992873986562095
val_acc 19.80497768081722


 43%|██████████████████                        | 43/100 [01:12<01:33,  1.63s/it]

val_loss 2.7975541750590005
val_acc 19.75468209231582


 44%|██████████████████▍                       | 44/100 [01:14<01:30,  1.62s/it]

val_loss 2.7965813477834067
val_acc 19.9672979850907


 45%|██████████████████▉                       | 45/100 [01:15<01:29,  1.63s/it]

val_loss 2.7958484490712485
val_acc 20.008969343824802


 46%|███████████████████▎                      | 46/100 [01:17<01:28,  1.64s/it]

val_loss 2.794985214869181
val_acc 19.791685710816104


 47%|███████████████████▋                      | 47/100 [01:19<01:26,  1.63s/it]

val_loss 2.795844078063965
val_acc 19.677495081226514


 48%|████████████████████▏                     | 48/100 [01:20<01:23,  1.62s/it]

val_loss 2.7936272621154785
val_acc 20.047904543744316


 49%|████████████████████▌                     | 49/100 [01:22<01:22,  1.61s/it]

val_loss 2.7915585041046143
val_acc 20.0420528822765


 50%|█████████████████████                     | 50/100 [01:24<01:20,  1.62s/it]

val_loss 2.7933963934580484
val_acc 19.774537412839898


 51%|█████████████████████▍                    | 51/100 [01:25<01:18,  1.61s/it]

val_loss 2.7910653750101724
val_acc 19.712296929447767


 52%|█████████████████████▊                    | 52/100 [01:27<01:17,  1.61s/it]

val_loss 2.7933882077534995
val_acc 19.865518715060375


 53%|██████████████████████▎                   | 53/100 [01:28<01:16,  1.63s/it]

val_loss 2.793864091237386
val_acc 19.784379602800183


 54%|██████████████████████▋                   | 54/100 [01:30<01:14,  1.63s/it]

val_loss 2.795107285181681
val_acc 19.868655573282982


 55%|███████████████████████                   | 55/100 [01:32<01:13,  1.63s/it]

val_loss 2.7876636187235513
val_acc 19.914197171538206


 56%|███████████████████████▌                  | 56/100 [01:33<01:11,  1.62s/it]

val_loss 2.789115826288859
val_acc 19.869675478662447


 57%|███████████████████████▉                  | 57/100 [01:35<01:09,  1.62s/it]

val_loss 2.78832213083903
val_acc 19.765876517988747


 58%|████████████████████████▎                 | 58/100 [01:37<01:08,  1.62s/it]

val_loss 2.792834758758545
val_acc 19.685286130812774


 59%|████████████████████████▊                 | 59/100 [01:38<01:06,  1.62s/it]

val_loss 2.7841528256734214
val_acc 19.841875951340796


 60%|█████████████████████████▏                | 60/100 [01:40<01:04,  1.62s/it]

val_loss 2.785991827646891
val_acc 19.91310280684384


 61%|█████████████████████████▌                | 61/100 [01:41<01:03,  1.62s/it]

val_loss 2.786848545074463
val_acc 19.878133944572323


 62%|██████████████████████████                | 62/100 [01:43<01:01,  1.61s/it]

val_loss 2.7887869675954184
val_acc 19.83070484389385


 63%|██████████████████████████▍               | 63/100 [01:45<00:59,  1.60s/it]

val_loss 2.7879863580067954
val_acc 19.847825263441717


 64%|██████████████████████████▉               | 64/100 [01:46<00:57,  1.60s/it]

val_loss 2.7836272716522217
val_acc 19.703591179757673


 65%|███████████████████████████▎              | 65/100 [01:48<00:56,  1.61s/it]

val_loss 2.784733215967814
val_acc 20.01200646433569


 66%|███████████████████████████▋              | 66/100 [01:49<00:54,  1.61s/it]

val_loss 2.7780936559041343
val_acc 19.800835206892632


 67%|████████████████████████████▏             | 67/100 [01:51<00:53,  1.61s/it]

val_loss 2.783039093017578
val_acc 19.775765744635375


 68%|████████████████████████████▌             | 68/100 [01:53<00:52,  1.64s/it]

val_loss 2.787869135538737
val_acc 19.90638948511476


 69%|████████████████████████████▉             | 69/100 [01:54<00:50,  1.63s/it]

val_loss 2.7849205334981284
val_acc 19.65123604539941


 70%|█████████████████████████████▍            | 70/100 [01:56<00:48,  1.62s/it]

val_loss 2.778562545776367
val_acc 19.879846159342893


 71%|█████████████████████████████▊            | 71/100 [01:57<00:46,  1.61s/it]

val_loss 2.7813854217529297
val_acc 19.805967269014


 72%|██████████████████████████████▏           | 72/100 [01:59<00:45,  1.63s/it]

val_loss 2.78399928410848
val_acc 19.81870972675603


 73%|██████████████████████████████▋           | 73/100 [02:01<00:43,  1.61s/it]

val_loss 2.7800671259562173
val_acc 19.881977450707442


 74%|███████████████████████████████           | 74/100 [02:02<00:41,  1.61s/it]

val_loss 2.78303329149882
val_acc 19.602702172872757


 75%|███████████████████████████████▌          | 75/100 [02:04<00:40,  1.61s/it]

val_loss 2.78485377629598
val_acc 19.864646774174187


 75%|███████████████████████████████▌          | 75/100 [02:06<00:42,  1.68s/it]

val_loss 2.784484624862671
val_acc 19.472807831836885





### Test 진행

In [44]:
# 가장 좋은 모델을 사용해 테스트 세트의 손실과 정확도를 계산합니다

model.load_state_dict(torch.load(train_state['model_filename']))

running_loss = 0.0
running_acc = 0.0

# 가중치 업데이트 하지 못 하게
model.eval()

for batch_idx, batch_data in enumerate(Testdataloader):
    
    y_pred = model(x_in=batch_data['x_data'])
    loss = sequence_loss(y_pred,batch_data['y_target'],mask_index)
    loss_t = loss.item()
    running_loss += (loss_t - running_loss) / (batch_idx + 1)
    
    acc_t = compute_accuracy(y_pred, batch_data['y_target'],mask_index)
    running_acc += (acc_t - running_acc) / (batch_idx + 1)

train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc

In [45]:
print("테스트 손실: {:.3f}".format(train_state['test_loss']))
print("테스트 정확도: {:.2f}".format(train_state['test_acc']))

테스트 손실: 2.841
테스트 정확도: 20.11


### 추론

In [58]:
def decode_samples(sampled_indices):
    """인덱스를 성씨 문자열로 변환합니다
    
    매개변수:
        sampled_indices (torch.Tensor): `sample_from_model` 함수에서 얻은 인덱스
    """
    decoded_surnames = []
    vocab = char_vocab
    
    for sample_index in range(sampled_indices.shape[0]):
        surname = ""
        for time_step in range(sampled_indices.shape[1]):
            sample_item = sampled_indices[sample_index, time_step].item()
            if sample_item == vocab.begin_seq_index:
                continue
            elif sample_item == vocab.end_seq_index:
                break
            else:
                surname += lookup_index(vocab,sample_item)
        decoded_surnames.append(surname)
    return decoded_surnames


In [59]:
def sample_from_model(model, num_samples=1, sample_size=20, 
                      temperature=1.0):
    """모델이 만든 인덱스 시퀀스를 샘플링합니다.
    
    매개변수:
        model (SurnameGenerationModel): 훈련 모델
        num_samples (int): 샘플 개수
        sample_size (int): 샘플의 최대 길이
        temperature (float): 무작위성 정도
            0.0 < temperature < 1.0 이면 최대 값을 선택할 가능성이 높습니다
            temperature > 1.0 이면 균등 분포에 가깝습니다
    반환값:
        indices (torch.Tensor): 인덱스 행렬
        shape = (num_samples, sample_size)
    """
    begin_seq_index = [char_vocab.begin_seq_index 
                       for _ in range(num_samples)]
    begin_seq_index = torch.tensor(begin_seq_index, 
                                   dtype=torch.int64).unsqueeze(dim=1)
    indices = [begin_seq_index]
    h_t = None
    
    for time_step in range(sample_size):
        x_t = indices[time_step]
        x_emb_t = model.emb(x_t)
        rnn_out_t, h_t = model.rnn(x_emb_t, h_t)
        prediction_vector = model.fc(rnn_out_t.squeeze(dim=1))
        probability_vector = F.softmax(prediction_vector / temperature, dim=1)
        indices.append(torch.multinomial(probability_vector, num_samples=1))
    indices = torch.stack(indices).squeeze().permute(1, 0)
    return indices

In [60]:
# 생성할 이름 개수
num_names = 10
model = model.cpu()
# 이름 생성
sampled_surnames = decode_samples(
    sample_from_model(model, num_samples=num_names))
# 결과 출력
print ("-"*15)
for i in range(num_names):
    print (sampled_surnames[i])

---------------
Bimektvv
Aavcvg
Ksygm
Minilg
Lklalue
Aoiceyvrk
Ketlr
Giinmllhrdh
Mnpaejn
Shneol
