In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


### Reading train dataset & tokenizing

In [2]:
!pip3 install konlpy

Collecting konlpy
  Downloading konlpy-0.6.0-py2.py3-none-any.whl.metadata (1.9 kB)
Collecting JPype1>=0.7.0 (from konlpy)
  Downloading jpype1-1.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.0 kB)
Downloading konlpy-0.6.0-py2.py3-none-any.whl (19.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.4/19.4 MB[0m [31m41.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jpype1-1.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (495 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m495.9/495.9 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: JPype1, konlpy
Successfully installed JPype1-1.6.0 konlpy-0.6.0


In [3]:
from konlpy.tag import Hannanum
morph = Hannanum()

In [4]:
import json
import urllib.request
import os

# Download data files if not exist
data_urls = {
    'kor.txt': 'https://github.com/nongaussian/class-2026-lginnotek-llm/raw/refs/heads/main/nmt/kor.txt'
}

for filename, url in data_urls.items():
    if not os.path.exists(filename):
        print(f'Downloading {filename}...')
        urllib.request.urlretrieve(url, filename)
        print(f'{filename} downloaded.')
    else:
        print(f'{filename} already exists.')

Downloading kor.txt...
kor.txt downloaded.


In [5]:
with open('kor.txt', 'r') as f:
    lines = f.readlines()

targ, inp = zip(*[line.split('\t') for line in lines])

In [6]:
inp[-1]

'의심의 여지 없이 세상에는 어떤 남자이든 정확히 딱 알맞는 여자와 결혼하거나 그 반대의 상황이 존재하지. 그런데 인간이 수백 명의 사람만 알고 지내는 사이가 될 기회를 갖는다고 생각해 보면, 또 그 수백 명 중 열여 명 쯤 이하만 잘 알 수 있고, 그리고 나서 그 열여 명 중에 한두 명만 친구가 될 수 있다면, 그리고 또 만일 우리가 이 세상에 살고 있는 수백만 명의 사람들만 기억하고 있다면, 딱 맞는 남자는 지구가 생겨난 이래로 딱 맞는 여자를 단 한번도 만난 적이 없을 수도 있을 거라는 사실을 쉽게 눈치챌 수 있을 거야.\n'

In [7]:
print(morph.morphs(inp[-1]))

['의심', '의', '여', '이', '지', '없이', '세상', '에는', '어떤', '남자', '이', '든', '정확히', '딱', '알맞', '는', '여자', '와', '결혼', '하', '거나', '그', '반대', '의', '상황', '이', '존재', '하', '지', '.', '그런데', '인간', '이', '수백', '명', '의', '사람', '만', '알', '고', '지내', '는', '사이', '가', '되', 'ㄹ', '기회', '를', '갖', '는다', '고', '생각', '하', '어', '보', '면', ',', '또', '그', '수백', '명', '중', '열', '여', '명', '쯤', '이하', '만', '잘', '알', 'ㄹ', '수', '있', '고', ',', '그리고', '나', '서', '그', '열', '여', '명', '중', '에', '한두', '명', '만', '친구', '가', '되', 'ㄹ', '수', '있', '다면', ',', '그리고', '또', '만', '이', 'ㄹ', '우리', '가', '이', '세상', '에', '살', '고', '있', '는', '수백만', '명', '의', '사람들', '만', '기억', '하고', '있', '다면', ',', '딱', '맞', '는', '남자', '는', '지구', '가', '생기', '어', '나', 'ㄴ', '이래', '로', '딱', '맞', '는', '여자', '를', '달', 'ㄴ', '한번', '도', '만나', 'ㄴ', '적', '이', '없', '을', '수', '도', '있', '을', '것', '이', '라는', '사실', '을', '쉽', '게', '눈치채', 'ㄹ', '수', '있', '을', '것', '이', '야', '.']


In [8]:
targ[-1].lower()

'doubtless there exists in this world precisely the right woman for any given man to marry and vice versa; but when you consider that a human being has the opportunity of being acquainted with only a few hundred people, and out of the few hundred that there are but a dozen or less whom he knows intimately, and out of the dozen, one or two friends at most, it will easily be seen, when we remember the number of millions who inhabit this world, that probably, since the earth was created, the right man has never yet met the right woman.'

In [9]:
!pip install nltk



In [10]:
import nltk
nltk.download('punkt_tab')
print(nltk.word_tokenize(targ[-1].lower()))

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


['doubtless', 'there', 'exists', 'in', 'this', 'world', 'precisely', 'the', 'right', 'woman', 'for', 'any', 'given', 'man', 'to', 'marry', 'and', 'vice', 'versa', ';', 'but', 'when', 'you', 'consider', 'that', 'a', 'human', 'being', 'has', 'the', 'opportunity', 'of', 'being', 'acquainted', 'with', 'only', 'a', 'few', 'hundred', 'people', ',', 'and', 'out', 'of', 'the', 'few', 'hundred', 'that', 'there', 'are', 'but', 'a', 'dozen', 'or', 'less', 'whom', 'he', 'knows', 'intimately', ',', 'and', 'out', 'of', 'the', 'dozen', ',', 'one', 'or', 'two', 'friends', 'at', 'most', ',', 'it', 'will', 'easily', 'be', 'seen', ',', 'when', 'we', 'remember', 'the', 'number', 'of', 'millions', 'who', 'inhabit', 'this', 'world', ',', 'that', 'probably', ',', 'since', 'the', 'earth', 'was', 'created', ',', 'the', 'right', 'man', 'has', 'never', 'yet', 'met', 'the', 'right', 'woman', '.']


In [11]:
x_tokens = [ morph.morphs(x) for x in inp ]
y_tokens = [ nltk.word_tokenize(x.lower()) for x in targ ]

In [12]:
print(len(x_tokens), len(y_tokens))

3729 3729


### Encoding text to numeric sequences

In [13]:
def text_encoding(lines):
    vocab, index = {}, 3  # start indexing from 3
    vocab['<pad>'] = 0  # add a padding token
    vocab['<bos>'] = 1  # begin of sentence
    vocab['<eos>'] = 2  # end of sentence

    maxlen = -1
    for sentence in lines:
        for token in sentence:
            if token not in vocab:
                vocab[token] = index
                index += 1

        if maxlen < len(sentence):
            maxlen = len(sentence)

    arr = np.zeros((len(lines), maxlen+2), dtype='int32')

    for i, sentence in enumerate(lines):
        for j, token in enumerate(sentence):
            arr[i, j+1] = vocab[token]
        arr[i, 0] = vocab['<bos>']
        arr[i, len(sentence)+1] = vocab['<eos>']

    return arr, vocab

In [14]:
x_train, x_vocab = text_encoding(x_tokens)
y_train, y_vocab = text_encoding(y_tokens)

In [15]:
inverse_x_vocab = {index: token for token, index in x_vocab.items()}
inverse_y_vocab = {index: token for token, index in y_vocab.items()}

In [16]:
def text_decoding(line, invvocab):
    return [ invvocab[x] for x in line]

In [17]:
print(text_decoding(x_train[-1], inverse_x_vocab))

['<bos>', '의심', '의', '여', '이', '지', '없이', '세상', '에는', '어떤', '남자', '이', '든', '정확히', '딱', '알맞', '는', '여자', '와', '결혼', '하', '거나', '그', '반대', '의', '상황', '이', '존재', '하', '지', '.', '그런데', '인간', '이', '수백', '명', '의', '사람', '만', '알', '고', '지내', '는', '사이', '가', '되', 'ㄹ', '기회', '를', '갖', '는다', '고', '생각', '하', '어', '보', '면', ',', '또', '그', '수백', '명', '중', '열', '여', '명', '쯤', '이하', '만', '잘', '알', 'ㄹ', '수', '있', '고', ',', '그리고', '나', '서', '그', '열', '여', '명', '중', '에', '한두', '명', '만', '친구', '가', '되', 'ㄹ', '수', '있', '다면', ',', '그리고', '또', '만', '이', 'ㄹ', '우리', '가', '이', '세상', '에', '살', '고', '있', '는', '수백만', '명', '의', '사람들', '만', '기억', '하고', '있', '다면', ',', '딱', '맞', '는', '남자', '는', '지구', '가', '생기', '어', '나', 'ㄴ', '이래', '로', '딱', '맞', '는', '여자', '를', '달', 'ㄴ', '한번', '도', '만나', 'ㄴ', '적', '이', '없', '을', '수', '도', '있', '을', '것', '이', '라는', '사실', '을', '쉽', '게', '눈치채', 'ㄹ', '수', '있', '을', '것', '이', '야', '.', '<eos>']


In [18]:
print(text_decoding(y_train[-1], inverse_y_vocab))

['<bos>', 'doubtless', 'there', 'exists', 'in', 'this', 'world', 'precisely', 'the', 'right', 'woman', 'for', 'any', 'given', 'man', 'to', 'marry', 'and', 'vice', 'versa', ';', 'but', 'when', 'you', 'consider', 'that', 'a', 'human', 'being', 'has', 'the', 'opportunity', 'of', 'being', 'acquainted', 'with', 'only', 'a', 'few', 'hundred', 'people', ',', 'and', 'out', 'of', 'the', 'few', 'hundred', 'that', 'there', 'are', 'but', 'a', 'dozen', 'or', 'less', 'whom', 'he', 'knows', 'intimately', ',', 'and', 'out', 'of', 'the', 'dozen', ',', 'one', 'or', 'two', 'friends', 'at', 'most', ',', 'it', 'will', 'easily', 'be', 'seen', ',', 'when', 'we', 'remember', 'the', 'number', 'of', 'millions', 'who', 'inhabit', 'this', 'world', ',', 'that', 'probably', ',', 'since', 'the', 'earth', 'was', 'created', ',', 'the', 'right', 'man', 'has', 'never', 'yet', 'met', 'the', 'right', 'woman', '.', '<eos>']


In [19]:
print(len(x_vocab), len(y_vocab))

2896 2527


In [20]:
x_train.shape

(3729, 169)

### Save preprocessed dataset

In [21]:
np.savez('kor-eng', x_train=x_train, y_train=y_train)

In [22]:
import json
with open("kor-eng-krvocab.json", "w") as f:
    json.dump(x_vocab, f)
with open("kor-eng-envocab.json", "w") as f:
    json.dump(y_vocab, f)

### Load preprocessed dataset

In [23]:
import json

npzfile = np.load('kor-eng.npz')
x_train = npzfile['x_train']
y_train = npzfile['y_train']

with open("kor-eng-krvocab.json", "rb") as f:
    x_vocab = json.load(f)
with open("kor-eng-envocab.json", "rb") as f:
    y_vocab = json.load(f)

inverse_x_vocab = {index: token for token, index in x_vocab.items()}
inverse_y_vocab = {index: token for token, index in y_vocab.items()}

### Building models

In [24]:
BUFFER_SIZE = len(x_train)
BATCH_SIZE = 16
embedding_dim = 1024
latent_dim = 1024
x_vocab_size = len(x_vocab)
y_vocab_size = len(y_vocab)

In [25]:
# PyTorch Dataset
class TranslationDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = torch.LongTensor(x_data)
        self.y_data = torch.LongTensor(y_data)

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]

dataset = TranslationDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [26]:
for example_input_batch, example_target_batch in dataloader:
    print(example_input_batch[:5])
    print()
    print(example_target_batch[:5])
    break

tensor([[   1,  470, 2260,  110,  860,   52, 1045, 1100,  144,   11,    2,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,  

In [27]:
# Compact batch (remove trailing zeros)
def compact_batch(batch):
    max_len = (batch != 0).sum(dim=1).max().item()
    return batch[:, :max_len]

example_input_batch = compact_batch(example_input_batch)
example_target_batch = compact_batch(example_target_batch)

In [28]:
print(example_input_batch[:5])

tensor([[   1,  470, 2260,  110,  860,   52, 1045, 1100,  144,   11,    2,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0],
        [   1, 1110, 1111,  564,   48,  201,  323,   48, 1264, 2394,   48,   40,
            5,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0],
        [   1,   60,   48,  108,   27,    5,    2,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0],
        [   1,   45,  359,  132,  593,   46,  779,   52,  104,    8,  112,   46,
          325,  274,    5,    2,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0],
        [   1,   58,   59,  110,  157,   27,    5,    2,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    

In [29]:
example_input_batch.shape

torch.Size([16, 29])

In [30]:
max_len = (example_input_batch != 0).sum(dim=1).cpu()
example_input_batch_packed = nn.utils.rnn.pack_padded_sequence(
    example_input_batch, max_len, batch_first=True, enforce_sorted=False
)

In [31]:
example_input_batch_packed

PackedSequence(data=tensor([   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,  783, 2850,   45, 1110,   25,  271,   58,   60,
         470, 1688,  162,  643,   58,   60,   58,  466,   72, 2851,  359, 1111,
          46, 2119,   59,  910, 2260,  136,   21,  976,   59,   48,   46,   48,
          42, 2852,  132,  564,  755, 1467,  110, 1044,  110,   45,  274,   46,
         110,  108,  486,  137,   38, 2853,  593,   48,   61, 2292,  138,   46,
         860,   46,  134,  126,  157,   27,   94,    5,   48,  218,   46,  201,
        1814,  960,  139, 2582,   52, 1689,   21,   48,   27,    5,    5,    2,
        1873, 2854,  779,  323,  478,  359,  806,  576, 1045,   48,  633,   40,
           5,    2,    2,   21,  422,   52,   48, 1616,  946,  118, 1154, 1100,
          99,  144,   11,    2,   52, 1536,  104, 1264,  122,   48,   52,  265,
         144,   97,    5,    2,  797, 2855,    8, 2394,  790,  135,   56,  135,
          11,    5, 

In [32]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=3):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)

    def forward(self, x):
        # x: (batch, seq_len)
        # Create mask for padding
        lengths = (x != 0).sum(dim=1).cpu()

        embedded = self.embedding(x)  # (batch, seq_len, embedding_dim)

        # Pack sequence for efficient computation
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )

        packed_out, hidden = self.gru(packed)

        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        # hidden: (num_layers, batch, hidden_dim)
        # Return each layer's hidden state separately for compatibility
        s1 = hidden[0]  # (batch, hidden_dim)
        s2 = hidden[1]  # (batch, hidden_dim)
        s3 = hidden[2]  # (batch, hidden_dim)

        return output, s1, s2, s3

In [33]:
encoder = Encoder(x_vocab_size, embedding_dim, latent_dim).to(device)
print(encoder)

Encoder(
  (embedding): Embedding(2896, 1024, padding_idx=0)
  (gru): GRU(1024, 1024, num_layers=3, batch_first=True)
)


In [34]:
last_output, last_state1, last_state2, last_state3 = \
    encoder(example_input_batch.to(device))
print(last_output.shape, last_state1.shape)

torch.Size([16, 29, 1024]) torch.Size([16, 1024])


In [35]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=3):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, s1, s2, s3):
        # x: (batch, seq_len)
        # s1, s2, s3: (batch, hidden_dim)

        # Get lengths for packing
        lengths = (x != 0).sum(dim=1).cpu()
        lengths = lengths.clamp(min=1)  # Ensure minimum length of 1

        embedded = self.embedding(x)  # (batch, seq_len, embedding_dim)

        # Pack sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )

        # Stack hidden states from all layers: (num_layers, batch, hidden)
        hidden = torch.stack([s1, s2, s3], dim=0)

        packed_out, out_hidden = self.gru(packed, hidden)

        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        logits = self.fc(output)  # (batch, seq_len, vocab_size)

        # Return each layer's hidden state separately for compatibility
        out_s1 = out_hidden[0]  # (batch, hidden_dim)
        out_s2 = out_hidden[1]  # (batch, hidden_dim)
        out_s3 = out_hidden[2]  # (batch, hidden_dim)

        return logits, out_s1, out_s2, out_s3

In [36]:
decoder = Decoder(y_vocab_size, embedding_dim, latent_dim).to(device)
logits, s1, s2, s3 = decoder(example_target_batch.to(device), \
                             last_state1, last_state2, last_state3)
print(decoder)

Decoder(
  (embedding): Embedding(2527, 1024, padding_idx=0)
  (gru): GRU(1024, 1024, num_layers=3, batch_first=True)
  (fc): Linear(in_features=1024, out_features=2527, bias=True)
)


In [37]:
print(logits.shape, s1.shape, s2.shape, s3.shape)

torch.Size([16, 18, 2527]) torch.Size([16, 1024]) torch.Size([16, 1024]) torch.Size([16, 1024])


### Loss function

In [38]:
def batch_loss(y_true, y_pred):
    # y_true: (batch, seq_len)
    # y_pred: (batch, seq_len, vocab_size)
    loss_fn = nn.CrossEntropyLoss(reduction='none')

    # Reshape for CrossEntropyLoss: (batch * seq_len, vocab_size)
    batch_size, seq_len, vocab_size = y_pred.shape
    y_pred_flat = y_pred.reshape(-1, vocab_size)
    y_true_flat = y_true.reshape(-1)

    loss = loss_fn(y_pred_flat, y_true_flat)
    loss = loss.reshape(batch_size, seq_len)

    # Mask padding tokens
    mask = (y_true != 0).float()
    loss = loss * mask

    return loss.sum() / mask.sum()

In [39]:
batch_loss(example_target_batch[:, 1:].to(device), logits[:, :-1, :])

tensor(7.8361, grad_fn=<DivBackward0>)

### Training

In [40]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

In [41]:
def predict(x_batch, y_batch, training=True):
    encoder.train(training)
    decoder.train(training)

    _, s1, s2, s3 = encoder(x_batch)
    logits, _, _, _ = decoder(y_batch, s1, s2, s3)
    return logits

In [42]:
def train_step(x_batch, y_batch):
    # Compact batch tensors
    x_batch = compact_batch(x_batch).to(device)
    y_batch = compact_batch(y_batch).to(device)

    optimizer.zero_grad()

    # Encoder & decoder
    logits = predict(x_batch, y_batch, training=True)

    # Loss: compare y_batch[:, 1:] with logits[:, :-1, :]
    loss = batch_loss(y_batch[:, 1:], logits[:, :-1, :])

    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
for epoch in range(50):
    start = time.time()

    loss_sum = 0
    for x_batch, y_batch in dataloader:
        loss = train_step(x_batch, y_batch)
        loss_sum += loss

    print('Time for epoch {} is {:.2f} sec: training loss = {:.6f}'.format(
        epoch + 1, time.time() - start, loss))

Time for epoch 1 is 476.55 sec: training loss = 4.098871
Time for epoch 2 is 468.05 sec: training loss = 7.214556
Time for epoch 3 is 464.21 sec: training loss = 3.637597
Time for epoch 4 is 466.61 sec: training loss = 2.035397
Time for epoch 5 is 462.48 sec: training loss = 1.346120
Time for epoch 6 is 458.19 sec: training loss = 3.879420


In [None]:
# Save model weights
torch.save(encoder.state_dict(), 'nmt-wo-attention.encoder.pt')
torch.save(decoder.state_dict(), 'nmt-wo-attention.decoder.pt')

### Test translation

In [None]:
def translate(src, max_steps=100):
    encoder.eval()
    decoder.eval()

    # Tokenization
    src_tokens = np.array([x_vocab['<bos>']] + [x_vocab[x] for x in morph.morphs(src)] + [x_vocab['<eos>']])

    print([inverse_x_vocab[x] for x in src_tokens])

    # Add the batch axis
    x_test = torch.LongTensor(src_tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        # Compute encoder and get hidden states
        _, s1, s2, s3 = encoder(x_test)

        # y_test: add the batch axis
        y_test = torch.LongTensor([[y_vocab['<bos>']]]).to(device)
        output_seq = []

        for _ in range(max_steps):
            logits, s1, s2, s3 = decoder(y_test, s1, s2, s3)

            # Greedily use the token with the highest logit
            y_test = logits.argmax(dim=2)
            pred = y_test.squeeze(0).item()

            # If prediction is eos, output sequence is complete
            if pred == y_vocab['<eos>']:
                break
            output_seq.append(pred)

    return ' '.join([inverse_y_vocab[x] for x in output_seq])

In [None]:
translate('잘 안된다.')

In [None]:
def beam_translate(src, max_steps=100, k=16):
    encoder.eval()
    decoder.eval()

    # Tokenization
    src_tokens = np.array([x_vocab['<bos>']] + [x_vocab[x] for x in morph.morphs(src)] + [x_vocab['<eos>']])
    print(morph.morphs(src))

    # Add the batch axis
    x_test = torch.LongTensor(src_tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        # Compute encoder and get hidden states
        _, s1, s2, s3 = encoder(x_test)

        # Init candidates: (score, last_token, s1, s2, s3, output_seq, eos)
        last_token = torch.LongTensor([[y_vocab['<bos>']]]).to(device)
        candidates = [(0., last_token, s1, s2, s3, [y_vocab['<bos>']], False)]

        for _ in range(max_steps):
            new_candidates = []

            for score, token, c_s1, c_s2, c_s3, output_seq, eos in candidates:
                # If the candidate already ends
                if eos:
                    new_candidates.append((score, token, c_s1, c_s2, c_s3, output_seq, eos))
                    continue

                # Compute the prob. of following tokens
                logits, new_s1, new_s2, new_s3 = decoder(token, c_s1, c_s2, c_s3)
                # shape of logits: (1, 1, vocab_size)
                probs = torch.log_softmax(logits, dim=2)

                # Use the token with the top-k logits
                values, indices = torch.topk(probs.squeeze(), k=k)

                for prob, idx in zip(values, indices):
                    idx_val = idx.item()
                    # If prediction is eos, output sequence is complete
                    is_eos = (idx_val == y_vocab['<eos>'])

                    new_token = torch.LongTensor([[idx_val]]).to(device)
                    new_candidates.append(
                        (score + prob.item(), new_token, new_s1, new_s2, new_s3,
                         output_seq + [idx_val], is_eos)
                    )

            candidates = sorted(new_candidates, key=lambda t: -t[0])[:k]

    return [(candidates[i][0], ' '.join([inverse_y_vocab[x] for x in candidates[i][5]])) for i in range(k)]

In [None]:
beam_translate('잘 안된다.')