## 엘만 RNN 구현해보기

In [1]:
import torch.nn as nn

class ElmanRNN(nn.Module):
    def __init__(self, input_size, hidden_size, batch_first=False):
        
        super(ElmanRNN,self).__init__()
        
#       RNN 셀을 정의
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        
        self.batch_first = batch_first
        self.hidden_size = hidden_size
        
#       hidden state를 초기화
#       batch_size * hidden_size 차원을 가진 torch tensor로 초기화 
    def _initial_hidden(self, batch_size):
        return torch.zeros((batch_size, self.hidden_size)) # 초기 initial_hidden
    
    def forward(self, x_in, initial_hidden=None):
        
#       batch_first 역할 : 신경망에 입력되는 텐서의 첫번째 차원값이 batch_size가 된다. 
        if self.batch_first: 
            batch_size, seq_size, feat_size = x_in.size()
            x_in = x_in.permute(1,0,2) # 순서 바꿔주기 seq_size, batch_size, feat_size
        else:
            seq_size, batch_size, feat_size = x_in.size()
            
#       hidden state를 저장할 리스트       
        hiddens=[]
        
        if initial_hidden is None:
            initial_hidden = self._initial_hidden(batch_size)
#             initial_hidden = initial_hidden.to(x_in.device)
        hidden_t = initial_hidden
        
#       각 sequence 대해 순회
        for t in range(seq_size):
#       현재의 input과 hidden state를 RNN 셀에 전달하여 다음 숨겨진 상태를 계산
            hidden_t = self.rnn_cell(x_in[t], hidden_t)
            hiddens.append(hidden_t)
            
        hiddens = torch.stack(hiddens) # 리스트로부터 텐서를 생성
        
        if self.batch_first: # batch_size, seq_size, feat_size
            hiddens = hiddens.permute(1,0,2) # 순서 바꿔주기 seq_size, batch_size, feat_size
            
        return hiddens # output으로 hidden state list를 다시 RNN으로 보내줌
            

# 예제: RNN으로 성씨 국적 분류하기

## 1. 데이터 로드

In [2]:
import pandas as pd

df = pd.read_csv("../../data/surnames_with_splits.csv")
df


Unnamed: 0,nationality,nationality_index,split,surname
0,Arabic,15,train,Totah
1,Arabic,15,train,Abboud
2,Arabic,15,train,Fakhoury
3,Arabic,15,train,Srour
4,Arabic,15,train,Sayegh
...,...,...,...,...
10975,Vietnamese,11,test,Dinh
10976,Vietnamese,11,test,Phung
10977,Vietnamese,11,test,Quang
10978,Vietnamese,11,test,Vu


In [3]:
df['split'].value_counts()

split
train    7680
test     1660
val      1640
Name: count, dtype: int64

### 데이터 split(train/valid/test)

In [4]:
# 데이터를 다시 train/valid/test로 나눠줌

# train 데이터 
train_df = df[df.split=='train']
train_size = len(train_df)

# valid 데이터 
val_df = df[df.split=='val']
val_size = len(val_df)

# test 데이터 
test_df = df[df.split=='test']
test_size = len(test_df)

## 2. Vocabulary

In [5]:
# 1. _token_to_idx 생성
# 2. _idx_to_token 생성

class Vocabulary(object):
    
    def __init__(self, token_to_idx=None):
#         token_to_idx (dict): 기존 토큰-인덱스 매핑 딕셔너리
        
        if token_to_idx is None:
            token_to_idx = {}
        self._token_to_idx = token_to_idx

        self._idx_to_token = {idx: token 
                              for token, idx in self._token_to_idx.items()}
        
    def add_token(self, token):

        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._token_to_idx)
            self._token_to_idx[token] = index
            self._idx_to_token[index] = token
            
        return index

In [6]:
class SequenceVocabulary(Vocabulary):        
    def __init__(self, token_to_idx=None,
                 unk_token="<UNK>",
                  mask_token="<MASK>", 
                  begin_seq_token="<BEGIN>",
                  end_seq_token="<END>"):
        
        super(SequenceVocabulary, self).__init__(token_to_idx) #부모 클래스의 생성자를 호출
        
#       여러 토큰 추가
        self._mask_token = mask_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.mask_index = self.add_token(self._mask_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)

#         def add_token(self, token): -> 부모꺼 사용

In [7]:
char_vocab = SequenceVocabulary() 
nationality_vocab = Vocabulary() # label은 위와 달리 기본 Vocabulary 사용 (중요)


# vocabulary에 추가
for index, row in df.iterrows():
    for char in row.surname:
        char_vocab.add_token(char)
    nationality_vocab.add_token(row.nationality)


### char vocab

In [54]:
print(len(char_vocab._token_to_idx))

88


In [8]:
print(dict(list(char_vocab._token_to_idx.items())[:10]))

{'<MASK>': 0, '<UNK>': 1, '<BEGIN>': 2, '<END>': 3, 'T': 4, 'o': 5, 't': 6, 'a': 7, 'h': 8, 'A': 9}


In [9]:
print(dict(list(char_vocab._idx_to_token.items())[:10]))

{0: '<MASK>', 1: '<UNK>', 2: '<BEGIN>', 3: '<END>', 4: 'T', 5: 'o', 6: 't', 7: 'a', 8: 'h', 9: 'A'}


### nationality vocab

In [55]:
print(len(nationality_vocab._token_to_idx))

18


In [10]:
print(dict(list(nationality_vocab._token_to_idx.items())[:10]))

{'Arabic': 0, 'Chinese': 1, 'Czech': 2, 'Dutch': 3, 'English': 4, 'French': 5, 'German': 6, 'Greek': 7, 'Irish': 8, 'Italian': 9}


In [11]:
print(dict(list(nationality_vocab._idx_to_token.items())[:10]))

{0: 'Arabic', 1: 'Chinese', 2: 'Czech', 3: 'Dutch', 4: 'English', 5: 'French', 6: 'German', 7: 'Greek', 8: 'Irish', 9: 'Italian'}


## 3. vectorize

In [12]:
# 주어진 토큰에 대응하는 인덱스 반환

def lookup_token(voca, token):
        return voca._token_to_idx[token]

In [13]:
# 주어진 인덱스에 대응하는 토큰 반환
def lookup_index(voca, index):
        if index not in voca._idx_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return voca._idx_to_token[index]

### 텍스트(surname)에 대한 원 핫 인코딩

In [14]:
import numpy as np
def vectorize(surname, vector_length=-1):
        """
        매개변수:
            title (str): 문자열
            vector_length (int): 인덱스 벡터의 길이를 맞추기 위한 매개변수
        """
#       begin 토큰 추가
        indices = [char_vocab.begin_seq_index]
    
#       char에 대한 인덱스 추가
        indices.extend(lookup_token(char_vocab,token) 
                       for token in surname)
#       end 토큰 추가
        indices.append(char_vocab.end_seq_index)

        if vector_length < 0:
            vector_length = len(indices)

        out_vector = np.zeros(vector_length, dtype=np.int64)         
        out_vector[:len(indices)] = indices
#       문장 길이가 작아서 padding 진행하면 해당 부분 mask 처리
        out_vector[len(indices):] = char_vocab.mask_index
        
        return out_vector, len(indices)

print("2로 시작, 3으로 끝나는 토큰 생성")
print(vectorize("choigoun",vector_length=20))

2로 시작, 3으로 끝나는 토큰 생성
(array([ 2, 43,  8,  5, 23, 19,  5, 11, 25,  3,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0]), 10)


## 4. Dataset class

In [15]:
import torch
from torch.utils.data import Dataset

class SurnameDataset(Dataset):
    def __init__(self, surname_df):
        
        self.surname_df = surname_df
#       max 시퀀스 길이 구해주기 (padding)
        self.max_seq_length = max(map(len, surname_df.surname))+2 #시작 및 끝 토큰 고려

    def __len__(self):
        return len(self.surname_df)

    def __getitem__(self, index):
        
        row = self.surname_df.iloc[index]
        
        surname_vector, vec_length = vectorize(row.surname, self.max_seq_length)
        nationality_index = lookup_token(nationality_vocab,row.nationality)

        return {'x_data': surname_vector,
                'y_target': nationality_index,
                'x_length': vec_length} # vec_length 추가로 return
    

In [16]:
# 데이터셋을 인스턴스화 해주어야 로더에 넣어줄 수 있다. 

train_dataset = SurnameDataset(train_df)
train_dataset

valid_dataset = SurnameDataset(val_df)
valid_dataset

test_dataset = SurnameDataset(test_df)
test_dataset


<__main__.SurnameDataset at 0x7fcdf29b6040>

In [56]:
train_dataset.max_seq_length

19

In [17]:
# 데이터 로더 설정
from torch.utils.data import DataLoader

# drop_last=True -> 배치 사이즈보다 over하면 drop

Traindataloader = DataLoader(dataset=train_dataset, batch_size=512,
                            shuffle=True, drop_last=True)

Validdataloader = DataLoader(dataset=valid_dataset, batch_size=512,
                            shuffle=False, drop_last=True)

Testdataloader = DataLoader(dataset=test_dataset, batch_size=512,
                            shuffle=False, drop_last=True)


In [18]:
print(len(train_dataset),len(Traindataloader))

7680 15


In [53]:
for batch_index, batch_dict in enumerate(Traindataloader):
#     print(batch_index)
    print(len(batch_dict))     # x_in.size()
    print(batch_dict)


    break
    

3
{'x_data': tensor([[ 2,  9, 36,  ...,  0,  0,  0],
        [ 2, 44,  7,  ...,  0,  0,  0],
        [ 2,  4, 27,  ...,  0,  0,  0],
        ...,
        [ 2, 29,  5,  ...,  0,  0,  0],
        [ 2, 30, 23,  ...,  0,  0,  0],
        [ 2, 24,  5,  ...,  0,  0,  0]]), 'y_target': tensor([ 0, 14, 14,  0,  4,  4, 12,  4,  7,  0,  0, 12,  0, 14, 14,  9,  4, 14,
        12,  4,  6,  4,  4, 14,  6, 14,  4,  5, 16,  4,  4, 14, 14,  4,  0, 14,
         4,  0, 14,  5,  9,  0, 10, 14,  9,  0,  0, 14,  9,  4, 14, 16,  6,  4,
         6, 14,  5,  4,  4,  1, 14,  0,  4,  4,  9,  0,  3,  7,  4,  4, 10,  9,
        14, 10,  0,  0, 14,  6,  9,  4,  0,  0, 10, 14,  4,  5, 10,  4, 14,  4,
         2,  0, 14, 10,  4,  9,  4,  9, 10, 14, 14, 10, 14, 16,  0,  8,  7,  1,
         5,  4,  4,  9, 16,  4,  8,  0,  0,  9,  9,  4, 14,  0, 14, 14,  7,  5,
        14, 10,  0,  7,  4,  4,  4,  0,  0,  9, 10,  8, 10, 14,  0,  9, 14, 14,
         0,  4, 14, 14, 14,  4,  5,  0,  4,  0, 14, 14, 14,  5,  4,  4, 12,  9,


## 5. 모델 정의 및 옵티마이저, loss func 설정

In [20]:
# 사전 정의 함수
# 주어진 배치의 각 데이터 포인트에 대해 시퀀스의 마지막 벡터를 추출
# => y_out에 있는 각 데이터 포인트에서 마지막 벡터 추출
def column_gather(y_out, x_lengths):
    
#     x_lengths = x_lengths.long().detach().cpu().numpy() - 1
    x_lengths = x_lengths-1
    out = []
    for batch_index, length in enumerate(x_lengths):
        out.append(y_out[batch_index, length])

    return torch.stack(out)

### column_gather 출력예시

In [21]:
import torch

# 가상의 입력 데이터
y_out = torch.tensor([
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],  # 첫 번째 시퀀스: 길이 3
    [[10, 11, 12], [13, 14, 15], [16, 17, 18]],  # 두 번째 시퀀스: 길이 3 (길이를 맞춤)
])

x_lengths = torch.tensor([3, 3])  # 각 시퀀스의 길이 (동일하게 맞춤)
# column_gather 함수 호출
result = column_gather(y_out, x_lengths)
print(result)

tensor([[ 7,  8,  9],
        [16, 17, 18]])


In [22]:
import torch.nn as nn

class ElmanRNN(nn.Module):
    def __init__(self, input_size, hidden_size, batch_first=False):
        
        super(ElmanRNN,self).__init__()
        
#       RNN 셀을 정의
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        
        self.batch_first = batch_first
        self.hidden_size = hidden_size
        
#       hidden state를 초기화
#       batch_size * hidden_size 차원을 가진 torch tensor로 초기화 
    def _initial_hidden(self, batch_size):
        return torch.zeros((batch_size, self.hidden_size)) # 초기 initial_hidden
    
    def forward(self, x_in, initial_hidden=None):
        
#       batch_first 역할 : 신경망에 입력되는 텐서의 첫번째 차원값이 batch_size가 된다. 
        if self.batch_first: 
            batch_size, seq_size, feat_size = x_in.size()
            x_in = x_in.permute(1,0,2) # 순서 바꿔주기 seq_size, batch_size, feat_size
        else:
            seq_size, batch_size, feat_size = x_in.size()
            
#       hidden state를 저장할 리스트       
        hiddens=[]
        
        if initial_hidden is None:
            initial_hidden = self._initial_hidden(batch_size)
#             initial_hidden =. initial_hidden.to(x_in.device)
        hidden_t = initial_hidden
        
#       각 sequence 대해 순회
        for t in range(seq_size):
#       현재의 input과 hidden state를 RNN 셀에 전달하여 다음 숨겨진 상태를 계산
            hidden_t = self.rnn_cell(x_in[t], hidden_t)
            hiddens.append(hidden_t)
            
        hiddens = torch.stack(hiddens) # 리스트로부터 텐서를 생성
        
        if self.batch_first: # batch_size, seq_size, feat_size
            hiddens = hiddens.permute(1,0,2) # 순서 바꿔주기 seq_size, batch_size, feat_size
            
        return hiddens 
            

In [57]:
# RNN으로 특성 추출 -> MLP로 분류 모델(Transfer Learning)
import torch.nn.functional as F 
class SurnameClassifier(nn.Module):
    def __init__(self, embedding_size, num_embeddings, num_classes,
                rnn_hidden_size, batch_first=True, padding_idx=0):
        
        super(SurnameClassifier, self).__init__()
        
        self.emb = nn.Embedding(num_embeddings = num_embeddings,
                               embedding_dim = embedding_size,
                               padding_idx = padding_idx)
        self.rnn = ElmanRNN(input_size = embedding_size, # 100
                           hidden_size = rnn_hidden_size, # 64 * 
                           batch_first = batch_first)
        self.fc1 = nn.Linear(in_features = rnn_hidden_size,
                            out_features = rnn_hidden_size)
        self.fc2 = nn.Linear(in_features = rnn_hidden_size,
                            out_features = num_classes)
        
    def forward(self, x_in, x_lengths=None, apply_softmax=False):
#         x_lengths, 각 시퀀스의 길이 (시퀀스의 마지막 벡터를 찾는데 사용)
        
        x_embedded = self.emb(x_in) # 19(max length) * 512(input)
#         x_embedded -> 512 * 100
        y_out = self.rnn(x_embedded) 
#         y_out -> batch_size * seq_size * hidden_size
#                  
        if x_lengths is not None:
            y_out = column_gather(y_out, x_lengths)
        else:
            y_out = y_out[:, -1, :] # 각 데이터 포인트의 마지막 시간 스텝에 해당하는 벡터를 추출

#       y_out -> batch_size * hidden_size # seq_size에서 마지막 요소만 들고올 것이다. -> y_out[:, -1, :]
        y_out = F.relu(self.fc1(F.dropout(y_out, 0.5)))
#       batch_size * hidden_size(64)
        y_out = self.fc2(F.dropout(y_out, 0.5))
#       batch_size * hidden_size(18)
        if apply_softmax:
            y_out = F.softmax(y_out, dim=1)
            # batch_size * output_dim(18), 활성화 함수는 차원을 건들이지 않는다 !⭐️
        return y_out # tensor.shape는 (batch, output_dim) 
    

In [24]:
char_embedding_size = 100
rnn_hidden_size = 64

classifier = SurnameClassifier(embedding_size=char_embedding_size, 
                               num_embeddings=len(char_vocab._token_to_idx),
                               num_classes=len(nationality_vocab._token_to_idx),
                               rnn_hidden_size=rnn_hidden_size,
                               padding_idx=char_vocab.mask_index)

classifier

SurnameClassifier(
  (emb): Embedding(88, 100, padding_idx=0)
  (rnn): ElmanRNN(
    (rnn_cell): RNNCell(100, 64)
  )
  (fc1): Linear(in_features=64, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=18, bias=True)
)

### 옵티마이저, loss function

In [25]:
lr = 0.001
num_epochs = 100

In [26]:
# 옵티마이저
import torch.optim as optim

optimizer = optim.Adam(classifier.parameters(), lr = lr)
optimizer


Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [27]:
loss_func = nn.CrossEntropyLoss()
loss_func

CrossEntropyLoss()

## Train

In [28]:
def compute_accuracy(y_pred, y_target):
#      예측값과 타겟값을 비교하여 일치하는 개수를 계산
    _, y_pred_indices = y_pred.max(dim=1)
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100
   

In [29]:
# Train state 초기화 
def make_train_state():
    return {
        'stop_early':False,
        'early_stopping_step':0,
        'early_stopping_best_val':1e8,
        'early_stopping_criteria' : 10,
        'epoch_index' : 0,
        'train_loss': [], 
        'train_acc' :[], 
        'val_loss' : [],
        'val_acc' : [], 
        'test_loss' : [],
        'test_acc' : [],
         
#       모델 저장파일
        'model_filename' : 'model.pth'
    } 


# Train update 
def update_train_state(model, train_state):
    
#   학습시작하면 초기에 모델 저장하기 
    
    if train_state['epoch_index'] == 0:
        torch.save(model.state_dict(),train_state['model_filename'])
        
#   모델 성능이 향상되면 모델 저장(valid loss가 더 낮아지면)
    elif train_state['epoch_index'] >=1 :
        loss_t = train_state['val_loss'][-1]
#        loss가 나빠지면 early stop step 업데이트
        if loss_t >= train_state['early_stopping_best_val']:
            train_state['early_stopping_step']+=1
            
#        loss가 좋아지면   
        else:
#            early stop step 0으로 다시 초기화        
            train_state['early_stopping_step']=0
    
#           최저 loss이면 모델 저장 
            if loss_t < train_state['early_stopping_best_val']:
                train_state['early_stopping_best_val'] = loss_t
                torch.save(model.state_dict(),train_state['model_filename'])

#       기준점 넘으면 early stop 
        if train_state['early_stopping_step'] >= train_state['early_stopping_criteria']:
            train_state['stop_early'] = True
        
        return train_state


In [30]:
# 모델 진행 상황 함수 초기화
train_state = make_train_state()
train_state

{'stop_early': False,
 'early_stopping_step': 0,
 'early_stopping_best_val': 100000000.0,
 'early_stopping_criteria': 10,
 'epoch_index': 0,
 'train_loss': [],
 'train_acc': [],
 'val_loss': [],
 'val_acc': [],
 'test_loss': [],
 'test_acc': [],
 'model_filename': 'model.pth'}

In [31]:
import tqdm

# 에포크만큼
for epoch in tqdm.tqdm(range(num_epochs)):

#     print('epoch',epoch)
#     print(train_state['epoch_index']) 
    train_state['epoch_index'] +=1 

    running_loss = 0.0
    running_acc = 0.0


#     모델을 학습 모드로 설정 -> 드롭아웃 및 배치 정규화와 같은 학습 중에만 적용되는 기법들이 활성화
#     모델을 평가 모드로 전환하려면 classifier.eval()을 사용
    classifier.train()
# 배치 만큼
    for batch_idx, batch_data in enumerate(Traindataloader):

        

#       1. 옵티마이저 그레디언트 0으로 초기화
        optimizer.zero_grad()
#       2. 모델에 데이터 넣어서 출력받기
        y_pred = classifier(x_in=batch_data['x_data'])
#       3. loss 계산하기
        loss =  loss_func(y_pred, batch_data['y_target'])
    
#       tensor(0.3190) -> 0.3190, item()으로 스칼라 값만 추출
        loss_t = loss.item()

#       배치에서의 평균 loss 구하기
        running_loss += (loss_t - running_loss) / (batch_idx + 1)

#       4. gradient 계산하기
        loss.backward()

#       5. 옵티마이저 가중치 업데이트
        optimizer.step()

#       Accuracy 계산
        acc_t = compute_accuracy(y_pred, batch_data['y_target'])
        running_acc += (acc_t - running_acc) / (batch_idx + 1)



    train_state['train_loss'].append(running_loss)
    train_state['train_acc'].append(running_acc)


#   valid에 대한 계산

    running_loss = 0.0
    running_acc = 0.0

    classifier.eval() # 모델 파라미터를 수정하지 못 하게 비활성화

    for batch_idx, batch_data in enumerate(Validdataloader):

#       1. 모델의 출력값(y_pred)계산
        y_pred = classifier(x_in=batch_data['x_data'])

#       2. loss 계산
        loss = loss_func(y_pred,batch_data['y_target'])
        loss_t = loss.item()
        running_loss += (loss_t - running_loss) / (batch_idx + 1)

#       3. Accuracy 계산
        acc_t = compute_accuracy(y_pred,batch_data['y_target'])
        running_acc += (acc_t - running_acc) / (batch_idx + 1)
    
    print("val_loss",running_loss)
    print("val_acc",running_acc)

    train_state['val_loss'].append(running_loss)
    train_state['val_acc'].append(running_acc)
    

#   전체 loss, acc 저장
    train_state = update_train_state(model=classifier,
                                     train_state=train_state)
#   early stop해라고 했으면 학습 멈추기    
    if train_state['stop_early']:
        break



  1%|▍                                          | 1/100 [00:02<03:26,  2.08s/it]

val_loss 2.467093070348104
val_acc 28.645833333333336


  2%|▊                                          | 2/100 [00:02<02:13,  1.36s/it]

val_loss 2.320603291193644
val_acc 21.2890625


  3%|█▎                                         | 3/100 [00:03<01:52,  1.16s/it]

val_loss 2.244600852330526
val_acc 29.036458333333336


  4%|█▋                                         | 4/100 [00:04<01:42,  1.07s/it]

val_loss 2.2516547044118247
val_acc 24.8046875


  5%|██▏                                        | 5/100 [00:05<01:35,  1.01s/it]

val_loss 2.2337876160939536
val_acc 26.432291666666668


  6%|██▌                                        | 6/100 [00:06<01:32,  1.02it/s]

val_loss 2.220968723297119
val_acc 27.408854166666664


  7%|███                                        | 7/100 [00:07<01:31,  1.01it/s]

val_loss 2.1808332602183023
val_acc 28.3203125


  8%|███▍                                       | 8/100 [00:08<01:34,  1.03s/it]

val_loss 2.1571085453033447
val_acc 27.34375


  9%|███▊                                       | 9/100 [00:09<01:29,  1.01it/s]

val_loss 2.12507164478302
val_acc 32.161458333333336


 10%|████▏                                     | 10/100 [00:10<01:27,  1.03it/s]

val_loss 2.0280267794926963
val_acc 32.03125


 11%|████▌                                     | 11/100 [00:11<01:24,  1.05it/s]

val_loss 1.9843840599060059
val_acc 33.203125


 12%|█████                                     | 12/100 [00:12<01:23,  1.06it/s]

val_loss 1.9524056514104207
val_acc 32.096354166666664


 13%|█████▍                                    | 13/100 [00:13<01:20,  1.08it/s]

val_loss 1.92887544631958
val_acc 33.268229166666664


 14%|█████▉                                    | 14/100 [00:14<01:19,  1.08it/s]

val_loss 1.8980967998504639
val_acc 32.161458333333336


 15%|██████▎                                   | 15/100 [00:15<01:17,  1.10it/s]

val_loss 1.925618569056193
val_acc 34.765625


 16%|██████▋                                   | 16/100 [00:16<01:16,  1.09it/s]

val_loss 1.868672768274943
val_acc 36.328125


 17%|███████▏                                  | 17/100 [00:16<01:16,  1.09it/s]

val_loss 1.8321406443913777
val_acc 40.364583333333336


 18%|███████▌                                  | 18/100 [00:17<01:14,  1.10it/s]

val_loss 1.8007342020670574
val_acc 45.833333333333336


 19%|███████▉                                  | 19/100 [00:18<01:14,  1.09it/s]

val_loss 1.7428789933522542
val_acc 48.177083333333336


 20%|████████▍                                 | 20/100 [00:19<01:12,  1.10it/s]

val_loss 1.6933106184005737
val_acc 50.325520833333336


 21%|████████▊                                 | 21/100 [00:20<01:11,  1.10it/s]

val_loss 1.7144356568654378
val_acc 51.041666666666664


 22%|█████████▏                                | 22/100 [00:21<01:10,  1.10it/s]

val_loss 1.6887750228246052
val_acc 50.065104166666664


 23%|█████████▋                                | 23/100 [00:22<01:09,  1.11it/s]

val_loss 1.6398818095525105
val_acc 51.041666666666664


 24%|██████████                                | 24/100 [00:23<01:08,  1.11it/s]

val_loss 1.603166937828064
val_acc 54.1015625


 25%|██████████▌                               | 25/100 [00:24<01:09,  1.07it/s]

val_loss 1.5928110281626384
val_acc 53.385416666666664


 26%|██████████▉                               | 26/100 [00:25<01:10,  1.05it/s]

val_loss 1.6243962446848552
val_acc 55.078125


 27%|███████████▎                              | 27/100 [00:26<01:10,  1.04it/s]

val_loss 1.5719356934229534
val_acc 55.338541666666664


 28%|███████████▊                              | 28/100 [00:27<01:10,  1.01it/s]

val_loss 1.570383628209432
val_acc 54.6875


 29%|████████████▏                             | 29/100 [00:28<01:10,  1.01it/s]

val_loss 1.5534924666086833
val_acc 56.119791666666664


 30%|████████████▌                             | 30/100 [00:29<01:09,  1.00it/s]

val_loss 1.5692803462346394
val_acc 54.752604166666664


 31%|█████████████                             | 31/100 [00:30<01:09,  1.01s/it]

val_loss 1.5437227884928386
val_acc 55.46875


 32%|█████████████▍                            | 32/100 [00:31<01:08,  1.01s/it]

val_loss 1.5349803765614827
val_acc 56.770833333333336


 33%|█████████████▊                            | 33/100 [00:32<01:07,  1.00s/it]

val_loss 1.4940434694290161
val_acc 57.291666666666664


 34%|██████████████▎                           | 34/100 [00:33<01:06,  1.00s/it]

val_loss 1.4717923005421956
val_acc 58.203125


 35%|██████████████▋                           | 35/100 [00:34<01:04,  1.01it/s]

val_loss 1.4800435304641724
val_acc 58.072916666666664


 36%|███████████████                           | 36/100 [00:35<01:03,  1.01it/s]

val_loss 1.475186785062154
val_acc 57.6171875


 37%|███████████████▌                          | 37/100 [00:36<01:03,  1.01s/it]

val_loss 1.4835165739059448
val_acc 57.096354166666664


 38%|███████████████▉                          | 38/100 [00:37<01:02,  1.01s/it]

val_loss 1.4277714093526204
val_acc 59.505208333333336


 39%|████████████████▍                         | 39/100 [00:38<01:02,  1.03s/it]

val_loss 1.4065616925557454
val_acc 60.221354166666664


 40%|████████████████▊                         | 40/100 [00:39<01:02,  1.04s/it]

val_loss 1.441463828086853
val_acc 58.268229166666664


 41%|█████████████████▏                        | 41/100 [00:40<01:01,  1.04s/it]

val_loss 1.4810031652450562
val_acc 57.8125


 42%|█████████████████▋                        | 42/100 [00:41<01:00,  1.04s/it]

val_loss 1.4045114119847615
val_acc 59.5703125


 43%|██████████████████                        | 43/100 [00:42<01:00,  1.06s/it]

val_loss 1.3920566240946453
val_acc 60.416666666666664


 44%|██████████████████▍                       | 44/100 [00:43<00:57,  1.04s/it]

val_loss 1.3886419137318928
val_acc 61.067708333333336


 45%|██████████████████▉                       | 45/100 [00:44<00:57,  1.05s/it]

val_loss 1.35954749584198
val_acc 61.1328125


 46%|███████████████████▎                      | 46/100 [00:45<00:55,  1.02s/it]

val_loss 1.3834171295166016
val_acc 60.9375


 47%|███████████████████▋                      | 47/100 [00:46<00:54,  1.03s/it]

val_loss 1.3761032025019329
val_acc 59.765625


 48%|████████████████████▏                     | 48/100 [00:47<00:53,  1.03s/it]

val_loss 1.3662521839141846
val_acc 61.002604166666664


 49%|████████████████████▌                     | 49/100 [00:48<00:51,  1.00s/it]

val_loss 1.4216807683308919
val_acc 59.700520833333336


 50%|█████████████████████                     | 50/100 [00:49<00:50,  1.01s/it]

val_loss 1.3389455874760945
val_acc 62.760416666666664


 51%|█████████████████████▍                    | 51/100 [00:50<00:50,  1.03s/it]

val_loss 1.3390223979949951
val_acc 62.825520833333336


 52%|█████████████████████▊                    | 52/100 [00:51<00:50,  1.05s/it]

val_loss 1.3468927542368572
val_acc 62.3046875


 53%|██████████████████████▎                   | 53/100 [00:52<00:48,  1.02s/it]

val_loss 1.3579210837682087
val_acc 62.239583333333336


 54%|██████████████████████▋                   | 54/100 [00:53<00:46,  1.02s/it]

val_loss 1.3900660673777263
val_acc 59.9609375


 55%|███████████████████████                   | 55/100 [00:54<00:46,  1.04s/it]

val_loss 1.3357168833414714
val_acc 61.9140625


 56%|███████████████████████▌                  | 56/100 [00:55<00:44,  1.02s/it]

val_loss 1.355055848757426
val_acc 61.1328125


 57%|███████████████████████▉                  | 57/100 [00:56<00:43,  1.01s/it]

val_loss 1.4399052858352661
val_acc 59.635416666666664


 58%|████████████████████████▎                 | 58/100 [00:57<00:42,  1.01s/it]

val_loss 1.32172429561615
val_acc 63.020833333333336


 59%|████████████████████████▊                 | 59/100 [00:58<00:41,  1.00s/it]

val_loss 1.3093653519948323
val_acc 63.0859375


 60%|█████████████████████████▏                | 60/100 [01:00<00:41,  1.04s/it]

val_loss 1.295585036277771
val_acc 64.12760416666667


 61%|█████████████████████████▌                | 61/100 [01:01<00:39,  1.02s/it]

val_loss 1.3259456555048625
val_acc 63.0859375


 62%|██████████████████████████                | 62/100 [01:02<00:39,  1.04s/it]

val_loss 1.2992665370305378
val_acc 64.0625


 63%|██████████████████████████▍               | 63/100 [01:03<00:38,  1.04s/it]

val_loss 1.3260613679885864
val_acc 62.630208333333336


 64%|██████████████████████████▉               | 64/100 [01:04<00:37,  1.04s/it]

val_loss 1.3193140029907227
val_acc 63.606770833333336


 65%|███████████████████████████▎              | 65/100 [01:05<00:36,  1.03s/it]

val_loss 1.2746777137120564
val_acc 64.19270833333333


 66%|███████████████████████████▋              | 66/100 [01:06<00:34,  1.02s/it]

val_loss 1.2791197299957275
val_acc 64.38802083333333


 67%|████████████████████████████▏             | 67/100 [01:07<00:34,  1.06s/it]

val_loss 1.2659575541814168
val_acc 65.234375


 68%|████████████████████████████▌             | 68/100 [01:08<00:33,  1.03s/it]

val_loss 1.3139220078786213
val_acc 63.671875


 69%|████████████████████████████▉             | 69/100 [01:09<00:31,  1.03s/it]

val_loss 1.2709432045618694
val_acc 65.36458333333333


 70%|█████████████████████████████▍            | 70/100 [01:10<00:31,  1.04s/it]

val_loss 1.2843952576319377
val_acc 64.84375


 71%|█████████████████████████████▊            | 71/100 [01:11<00:29,  1.03s/it]

val_loss 1.2756272157033284
val_acc 63.932291666666664


 72%|██████████████████████████████▏           | 72/100 [01:12<00:28,  1.02s/it]

val_loss 1.251684546470642
val_acc 65.36458333333333


 73%|██████████████████████████████▋           | 73/100 [01:13<00:27,  1.01s/it]

val_loss 1.245513916015625
val_acc 66.14583333333333


 74%|███████████████████████████████           | 74/100 [01:14<00:25,  1.02it/s]

val_loss 1.2909352382024128
val_acc 64.32291666666667


 75%|███████████████████████████████▌          | 75/100 [01:15<00:25,  1.00s/it]

val_loss 1.2475098768870037
val_acc 64.97395833333333


 76%|███████████████████████████████▉          | 76/100 [01:16<00:24,  1.03s/it]

val_loss 1.248478849728902
val_acc 65.8203125


 77%|████████████████████████████████▎         | 77/100 [01:17<00:23,  1.02s/it]

val_loss 1.2425848245620728
val_acc 66.86197916666667


 78%|████████████████████████████████▊         | 78/100 [01:18<00:22,  1.03s/it]

val_loss 1.2763192256291707
val_acc 66.40625


 79%|█████████████████████████████████▏        | 79/100 [01:19<00:21,  1.01s/it]

val_loss 1.2496333916982014
val_acc 64.38802083333333


 80%|█████████████████████████████████▌        | 80/100 [01:20<00:20,  1.00s/it]

val_loss 1.259058952331543
val_acc 65.49479166666667


 81%|██████████████████████████████████        | 81/100 [01:21<00:19,  1.00s/it]

val_loss 1.224747935930888
val_acc 66.92708333333333


 82%|██████████████████████████████████▍       | 82/100 [01:22<00:17,  1.00it/s]

val_loss 1.2868573069572449
val_acc 63.932291666666664


 83%|██████████████████████████████████▊       | 83/100 [01:23<00:17,  1.01s/it]

val_loss 1.2776494820912678
val_acc 66.34114583333333


 84%|███████████████████████████████████▎      | 84/100 [01:24<00:16,  1.02s/it]

val_loss 1.275307059288025
val_acc 65.75520833333333


 85%|███████████████████████████████████▋      | 85/100 [01:25<00:15,  1.02s/it]

val_loss 1.2635083198547363
val_acc 65.8203125


 86%|████████████████████████████████████      | 86/100 [01:26<00:14,  1.02s/it]

val_loss 1.2199093103408813
val_acc 66.34114583333333


 87%|████████████████████████████████████▌     | 87/100 [01:27<00:13,  1.00s/it]

val_loss 1.199521541595459
val_acc 67.578125


 88%|████████████████████████████████████▉     | 88/100 [01:28<00:11,  1.00it/s]

val_loss 1.224696159362793
val_acc 66.40625


 89%|█████████████████████████████████████▍    | 89/100 [01:29<00:11,  1.00s/it]

val_loss 1.1850648323694866
val_acc 67.83854166666667


 90%|█████████████████████████████████████▊    | 90/100 [01:30<00:10,  1.01s/it]

val_loss 1.2093619505564372
val_acc 67.3828125


 91%|██████████████████████████████████████▏   | 91/100 [01:31<00:09,  1.00s/it]

val_loss 1.2785049279530842
val_acc 65.234375


 92%|██████████████████████████████████████▋   | 92/100 [01:32<00:08,  1.01s/it]

val_loss 1.242878754933675
val_acc 66.796875


 93%|███████████████████████████████████████   | 93/100 [01:33<00:06,  1.00it/s]

val_loss 1.203527847925822
val_acc 66.86197916666667


 94%|███████████████████████████████████████▍  | 94/100 [01:34<00:06,  1.00s/it]

val_loss 1.1859810749689739
val_acc 68.81510416666667


 95%|███████████████████████████████████████▉  | 95/100 [01:35<00:04,  1.01it/s]

val_loss 1.247092644373576
val_acc 66.66666666666667


 96%|████████████████████████████████████████▎ | 96/100 [01:36<00:04,  1.03s/it]

val_loss 1.1984477639198303
val_acc 67.25260416666667


 97%|████████████████████████████████████████▋ | 97/100 [01:37<00:03,  1.03s/it]

val_loss 1.2083359162012737
val_acc 66.47135416666667


 98%|█████████████████████████████████████████▏| 98/100 [01:38<00:02,  1.06s/it]

val_loss 1.1952077349026997
val_acc 67.31770833333333


 98%|█████████████████████████████████████████▏| 98/100 [01:40<00:02,  1.02s/it]

val_loss 1.2099319696426392
val_acc 67.70833333333333





### Test 진행

In [34]:
# 가장 좋은 모델을 사용해 테스트 세트의 손실과 정확도를 계산합니다

classifier.load_state_dict(torch.load(train_state['model_filename']))

running_loss = 0.0
running_acc = 0.0

# 가중치 업데이트 하지 못 하게
classifier.eval()

for batch_idx, batch_data in enumerate(Testdataloader):
    
    y_pred = classifier(x_in=batch_data['x_data'])
    loss = loss_func(y_pred,batch_data['y_target'])
    loss_t = loss.item()
    running_loss += (loss_t - running_loss) / (batch_idx + 1)
    
    acc_t = compute_accuracy(y_pred, batch_data['y_target'])
    running_acc += (acc_t - running_acc) / (batch_idx + 1)

train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc

In [35]:
print("테스트 손실: {:.3f}".format(train_state['test_loss']))
print("테스트 정확도: {:.2f}".format(train_state['test_acc']))

테스트 손실: 1.428
테스트 정확도: 61.00


In [36]:
train_state

{'stop_early': True,
 'early_stopping_step': 10,
 'early_stopping_best_val': 1.1850648323694866,
 'early_stopping_criteria': 10,
 'epoch_index': 99,
 'train_loss': [2.7951792240142823,
  2.407435099283854,
  2.334129889806112,
  2.301321776707967,
  2.2955955664316816,
  2.278027407328288,
  2.244572480519613,
  2.2203129768371586,
  2.186063543955485,
  2.1220308462778728,
  2.03430503209432,
  2.0013904651006063,
  1.9797150770823158,
  1.954682397842407,
  1.937115732828776,
  1.9110198815663657,
  1.885982131958008,
  1.8544119834899901,
  1.7949352423350016,
  1.732702994346619,
  1.6998157103856404,
  1.6906292597452797,
  1.667766523361206,
  1.6324763139088947,
  1.5971126556396482,
  1.6109055519104003,
  1.5660906314849854,
  1.5333666642506918,
  1.5267679532368978,
  1.5142848412195842,
  1.495674673716227,
  1.50391906897227,
  1.4682870388031006,
  1.4551185131072997,
  1.4667138894399008,
  1.4279104868570964,
  1.4173740228017173,
  1.413165275255839,
  1.38957155545552

### 추론

In [44]:
def predict_nationality(surname, classifier):
    vectorized_surname, vec_length = vectorize(surname)
    vectorized_surname = torch.tensor(vectorized_surname).unsqueeze(dim=0)
    vec_length = torch.tensor([vec_length], dtype=torch.int64)
    
    result = classifier(vectorized_surname, vec_length, apply_softmax=True)
    probability_values, indices = result.max(dim=1)
    
    index = indices.item()
    prob_value = probability_values.item()

    predicted_nationality = lookup_index(nationality_vocab,index)

    return {'nationality': predicted_nationality, 'probability': prob_value, 'surname': surname}



In [45]:
classifier
for surname in ['McMahan', 'Nakamoto', 'Wan', 'Cho']:
    print(predict_nationality(surname, classifier))

{'nationality': 'Arabic', 'probability': 0.6438420414924622, 'surname': 'McMahan'}
{'nationality': 'Arabic', 'probability': 0.8882786631584167, 'surname': 'Nakamoto'}
{'nationality': 'Arabic', 'probability': 0.913679838180542, 'surname': 'Wan'}
{'nationality': 'Arabic', 'probability': 0.9940536618232727, 'surname': 'Cho'}


In [47]:
surname = input("Enter a surname: ")
print(predict_nationality(surname, classifier))

Enter a surname: danaka
{'nationality': 'Japanese', 'probability': 0.4318922758102417, 'surname': 'danaka'}
