In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

### Reading train dataset & tokenizing

In [None]:
!pip3 install konlpy

In [None]:
from konlpy.tag import Hannanum 
morph = Hannanum()

In [None]:
with open('kor.txt', 'r') as f:
    lines = f.readlines()
    
targ, inp = zip(*[line.split('\t') for line in lines])

In [None]:
inp[-1]

In [None]:
print(morph.morphs(inp[-1]))

In [None]:
targ[-1].lower()

In [None]:
import nltk
print(nltk.word_tokenize(targ[-1].lower()))

In [None]:
x_tokens = [ morph.morphs(x) for x in inp ]
y_tokens = [ nltk.word_tokenize(x.lower()) for x in targ ]

In [None]:
print(len(x_tokens), len(y_tokens))

### Encoding text to numeric sequences

In [None]:
def text_encoding(lines):
    vocab, index = {}, 3  # start indexing from 3
    vocab['<pad>'] = 0  # add a padding token
    vocab['<bos>'] = 1  # begin of sentence
    vocab['<eos>'] = 2  # end of sentence
    preprocessed_tokens = []

    maxlen = -1
    for sentence in lines:
        for token in sentence:
            if token not in vocab:
                vocab[token] = index
                index += 1
        
        if maxlen < len(sentence):
            maxlen = len(sentence)
    
    arr = np.zeros((len(lines), maxlen+2), dtype='int32')
    
    for i, sentence in enumerate(lines):
        for j, token in enumerate(sentence):
            arr[i, j+1] = vocab[token]
        arr[i, 0] = vocab['<bos>']
        arr[i, len(sentence)+1] = vocab['<eos>']
    
    return arr, vocab

In [None]:
x_train, x_vocab = text_encoding(x_tokens)
y_train, y_vocab = text_encoding(y_tokens)

In [None]:
inverse_x_vocab = {index: token for token, index in x_vocab.items()}
inverse_y_vocab = {index: token for token, index in y_vocab.items()}

In [None]:
def text_decoding(line, invvocab):
    return [ invvocab[x] for x in line]

In [None]:
print(text_decoding(x_train[-1], inverse_x_vocab))

In [None]:
print(text_decoding(y_train[-1], inverse_y_vocab))

In [None]:
print(len(x_vocab), len(y_vocab))

In [None]:
x_train.shape

### Save preprocessed dataset

In [None]:
np.savez('kor-eng', x_train=x_train, y_train=y_train)

In [None]:
import json
with open("kor-eng-krvocab.json", "w") as f:
    json.dump(x_vocab, f)
with open("kor-eng-envocab.json", "w") as f:
    json.dump(y_vocab, f)

### Load preprocessed dataset

In [None]:
import json

npzfile = np.load('kor-eng.npz')
x_train = npzfile['x_train']
y_train = npzfile['y_train']

with open("kor-eng-krvocab.json", "rb") as f:
    x_vocab = json.load(f)
with open("kor-eng-envocab.json", "rb") as f:
    y_vocab = json.load(f)
    
inverse_x_vocab = {index: token for token, index in x_vocab.items()}
inverse_y_vocab = {index: token for token, index in y_vocab.items()}

### Building models

In [None]:
BUFFER_SIZE = len(x_train)
BATCH_SIZE = 16
embedding_dim = 1024
latent_dim = 1024
x_vocab_size = len(x_vocab)
y_vocab_size = len(y_vocab)

In [None]:
# PyTorch Dataset
class TranslationDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = torch.LongTensor(x_data)
        self.y_data = torch.LongTensor(y_data)
    
    def __len__(self):
        return len(self.x_data)
    
    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]

dataset = TranslationDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
for example_input_batch, example_target_batch in dataloader:
    print(example_input_batch[:5])
    print()
    print(example_target_batch[:5])
    break

In [None]:
# Compact batch (remove trailing zeros)
def compact_batch(batch):
    max_len = (batch != 0).sum(dim=1).max().item()
    return batch[:, :max_len]

example_input_batch = compact_batch(example_input_batch)
example_target_batch = compact_batch(example_target_batch)

In [None]:
example_input_batch.shape

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru1 = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.gru2 = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.gru3 = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        
    def forward(self, x):
        # x: (batch, seq_len)
        # Create mask for padding
        lengths = (x != 0).sum(dim=1).cpu()
        
        embedded = self.embedding(x)  # (batch, seq_len, embedding_dim)
        
        # Pack sequence for efficient computation
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )
        
        packed_out, s1 = self.gru1(packed)
        packed_out, s2 = self.gru2(packed_out)
        packed_out, s3 = self.gru3(packed_out)
        
        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        
        # s1, s2, s3: (1, batch, hidden_dim) -> (batch, hidden_dim)
        return output, s1.squeeze(0), s2.squeeze(0), s3.squeeze(0)

In [None]:
encoder = Encoder(x_vocab_size, embedding_dim, latent_dim).to(device)

In [None]:
example_input_batch.shape

In [None]:
last_output, last_state1, last_state2, last_state3 = encoder(example_input_batch.to(device))
print(last_output.shape, last_state1.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru1 = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.gru2 = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.gru3 = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, s1, s2, s3):
        # x: (batch, seq_len)
        # s1, s2, s3: (batch, hidden_dim)
        
        # Get lengths for packing
        lengths = (x != 0).sum(dim=1).cpu()
        lengths = lengths.clamp(min=1)  # Ensure minimum length of 1
        
        embedded = self.embedding(x)  # (batch, seq_len, embedding_dim)
        
        # Pack sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False
        )
        
        # Add batch dimension for hidden states: (batch, hidden) -> (1, batch, hidden)
        packed_out, out_s1 = self.gru1(packed, s1.unsqueeze(0))
        packed_out, out_s2 = self.gru2(packed_out, s2.unsqueeze(0))
        packed_out, out_s3 = self.gru3(packed_out, s3.unsqueeze(0))
        
        # Unpack sequence
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        
        logits = self.fc(output)  # (batch, seq_len, vocab_size)
        
        return logits, out_s1.squeeze(0), out_s2.squeeze(0), out_s3.squeeze(0)

In [None]:
decoder = Decoder(y_vocab_size, embedding_dim, latent_dim).to(device)
logits, s1, s2, s3 = decoder(example_target_batch.to(device), last_state1, last_state2, last_state3)

In [None]:
print(logits.shape, s1.shape, s2.shape, s3.shape)

### Loss function

In [None]:
def batch_loss(y_true, y_pred):
    # y_true: (batch, seq_len)
    # y_pred: (batch, seq_len, vocab_size)
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    
    # Reshape for CrossEntropyLoss: (batch * seq_len, vocab_size)
    batch_size, seq_len, vocab_size = y_pred.shape
    y_pred_flat = y_pred.reshape(-1, vocab_size)
    y_true_flat = y_true.reshape(-1)
    
    loss = loss_fn(y_pred_flat, y_true_flat)
    loss = loss.reshape(batch_size, seq_len)
    
    # Mask padding tokens
    mask = (y_true != 0).float()
    loss = loss * mask
    
    return loss.sum() / mask.sum()

In [None]:
batch_loss(example_target_batch[:, 1:].to(device), logits[:, :-1, :])

### Training

In [None]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

In [None]:
def predict(x_batch, y_batch, training=True):
    encoder.train(training)
    decoder.train(training)
    
    _, s1, s2, s3 = encoder(x_batch)
    logits, _, _, _ = decoder(y_batch, s1, s2, s3)
    return logits

In [None]:
def train_step(x_batch, y_batch):
    # Compact batch tensors
    x_batch = compact_batch(x_batch).to(device)
    y_batch = compact_batch(y_batch).to(device)
    
    optimizer.zero_grad()
    
    # Encoder & decoder
    logits = predict(x_batch, y_batch, training=True)
    
    # Loss: compare y_batch[:, 1:] with logits[:, :-1, :]
    loss = batch_loss(y_batch[:, 1:], logits[:, :-1, :])
    
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [None]:
for epoch in range(50):
    start = time.time()
    
    loss_sum = 0
    for x_batch, y_batch in dataloader:
        loss = train_step(x_batch, y_batch)
        loss_sum += loss
    
    print('Time for epoch {} is {:.2f} sec: training loss = {:.6f}'.format(
        epoch + 1, time.time() - start, loss))

In [None]:
# Save model weights
torch.save(encoder.state_dict(), 'nmt-wo-attention.encoder.pt')
torch.save(decoder.state_dict(), 'nmt-wo-attention.decoder.pt')

### Test translation

In [None]:
def translate(src, max_steps=100):
    encoder.eval()
    decoder.eval()
    
    # Tokenization
    src_tokens = np.array([x_vocab['<bos>']] + [x_vocab[x] for x in morph.morphs(src)] + [x_vocab['<eos>']])
    
    print([inverse_x_vocab[x] for x in src_tokens])
    
    # Add the batch axis
    x_test = torch.LongTensor(src_tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # Compute encoder and get hidden states
        _, s1, s2, s3 = encoder(x_test)
        
        # y_test: add the batch axis
        y_test = torch.LongTensor([[y_vocab['<bos>']]]).to(device)
        output_seq = []
        
        for _ in range(max_steps):
            logits, s1, s2, s3 = decoder(y_test, s1, s2, s3)
            
            # Greedily use the token with the highest logit
            y_test = logits.argmax(dim=2)
            pred = y_test.squeeze(0).item()
            
            # If prediction is eos, output sequence is complete
            if pred == y_vocab['<eos>']:
                break
            output_seq.append(pred)
    
    return ' '.join([inverse_y_vocab[x] for x in output_seq])

In [None]:
translate('잘 안된다.')

In [None]:
def beam_translate(src, max_steps=100, k=16):
    encoder.eval()
    decoder.eval()
    
    # Tokenization
    src_tokens = np.array([x_vocab['<bos>']] + [x_vocab[x] for x in morph.morphs(src)] + [x_vocab['<eos>']])
    print(morph.morphs(src))
    
    # Add the batch axis
    x_test = torch.LongTensor(src_tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # Compute encoder and get hidden states
        _, s1, s2, s3 = encoder(x_test)
        
        # Init candidates: (score, last_token, s1, s2, s3, output_seq, eos)
        last_token = torch.LongTensor([[y_vocab['<bos>']]]).to(device)
        candidates = [(0., last_token, s1, s2, s3, [y_vocab['<bos>']], False)]
        
        for _ in range(max_steps):
            new_candidates = []
            
            for score, token, c_s1, c_s2, c_s3, output_seq, eos in candidates:
                # If the candidate already ends
                if eos:
                    new_candidates.append((score, token, c_s1, c_s2, c_s3, output_seq, eos))
                    continue
                
                # Compute the prob. of following tokens
                logits, new_s1, new_s2, new_s3 = decoder(token, c_s1, c_s2, c_s3)
                # shape of logits: (1, 1, vocab_size)
                probs = torch.log_softmax(logits, dim=2)
                
                # Use the token with the top-k logits
                values, indices = torch.topk(probs.squeeze(), k=k)
                
                for prob, idx in zip(values, indices):
                    idx_val = idx.item()
                    # If prediction is eos, output sequence is complete
                    is_eos = (idx_val == y_vocab['<eos>'])
                    
                    new_token = torch.LongTensor([[idx_val]]).to(device)
                    new_candidates.append(
                        (score + prob.item(), new_token, new_s1, new_s2, new_s3,
                         output_seq + [idx_val], is_eos)
                    )
            
            candidates = sorted(new_candidates, key=lambda t: -t[0])[:k]
    
    return [(candidates[i][0], ' '.join([inverse_y_vocab[x] for x in candidates[i][5]])) for i in range(k)]

In [None]:
beam_translate('잘 안된다.')