In [None]:
pair = ['I want to go home', '나는 집에 가고 싶다.']
src_sentence = pair[0]
trg_sentence = pair[1]

In [None]:
class Lang():
  def __init__(self):
    self.word2ix = {'<pad>': 0,
                    '<sos>': 1, 
                    '<eos>': 2,
                    }
    self.ix2word = {0: '<pad>',
                    1: '<sos>',
                    2: '<eos>'}

    self.n_words = 3
    self.word2count = {}
  def fromSentence(self, sentence):
    words = [word for word in sentence.split() ]
    for word in words:
      if word in self.word2ix:
        self.word2count[word] += 1
      else:
        self.word2ix[word] = self.n_words
        self.ix2word[self.n_words] = word
        self.n_words += 1
        self.word2count[word] = 1

In [None]:
src = Lang()
src.fromSentence(src_sentence)
src_size = src.n_words

In [None]:
import torch
import torch.nn as nn
class Encoder(nn.Module):
  def __init__(self,
               seq_length,
               emb_dim,
               hid_dim,
               drop_p):
    super().__init__()

    self.embedding = nn.Embedding(seq_length, emb_dim)
    self.rnn = nn.LSTM(emb_dim, hid_dim)
    self.dropout = nn.Dropout(drop_p)
    
  def forward(self, src):
    embedded = self.dropout(self.embedding(src))
    outputs, (hidden, cell) = self.rnn(embedded)

    return outputs, (hidden, cell)

In [None]:
src_words = src_sentence.split()
n_words = len(src_words)
MAX_SEQ_LENGTH = 10
if n_words > MAX_SEQ_LENGTH - 2: # for <sos>, <eos> token
  src_words = src_words[:MAX_SEQ_LENGTH-2]

batch_size = 1
src_batch = torch.zeros((MAX_SEQ_LENGTH, batch_size), dtype = torch.int64)
for i, word in enumerate(src_words):
  src_batch[i+1]= src.word2ix[word]

src_batch[0] = src.word2ix['<sos>']
src_batch[i+2] = src.word2ix['<eos>']

In [None]:
trg = Lang()
trg.fromSentence(trg_sentence)
n_token = trg.n_words

In [None]:
class Decoder(nn.Module):
  def __init__(self, seq_size, n_token, emb_dim, hid_dim, drop_p = 0.75):
    super().__init__()
    self.seq_size = seq_size
    self.n_token = n_token
    self.embedding = nn.Embedding(seq_size, emb_dim)
    self.rnn = nn.LSTM(emb_dim, hid_dim)
    self.dropout = nn.Dropout(drop_p)
    self.fc = nn.Linear(hid_dim, n_token)

  def forward(self, trg, hidden, cell):
    trg = trg.unsqueeze(0) # [batch_size] -> [1, batch_size]
    embedded = self.dropout(self.embedding(trg))
    output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
    logits = self.fc(output)

    return logits, (hidden, cell)


In [None]:
trg_words = trg_sentence.split()
n_words = len(trg_words)
MAX_SEQ_LENGTH = 10
if n_words > MAX_SEQ_LENGTH - 2: # for <sos>, <eos> token
  trg_words = trg_words[:MAX_SEQ_LENGTH-2]

batch_size = 1
trg_batch = torch.zeros((MAX_SEQ_LENGTH, batch_size), dtype = torch.int64)
for i, word in enumerate(trg_words):
  trg_batch[i+1]= trg.word2ix[word]
trg_batch[0] = trg.word2ix['<sos>']
trg_batch[i+2] = trg.word2ix['<eos>']

In [None]:
class Seq2Seq(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.seq_size = decoder.seq_size
    self.n_token = decoder.n_token

  def forward(self, src, trg):
    _, (hidden, cell) = encoder(src)
    input = trg[0, :]
    batch_size = 1
    outputs = torch.zeros((self.seq_size, batch_size, self.n_token))
    for i in range(1, self.seq_size):
      logit, (hidden, cell) = self.decoder(input, hidden, cell)
      outputs[i] = logit.squeeze(0)
      
    return outputs

In [None]:
encoder = Encoder(MAX_SEQ_LENGTH, emb_dim = 20, hid_dim = 40, drop_p = 0.75)
decoder = Decoder(MAX_SEQ_LENGTH, n_token, emb_dim = 20, hid_dim = 40)
model = Seq2Seq(encoder, decoder)
outputs = model(src_batch, trg_batch)

In [None]:
batch = [(src_batch, trg_batch)]
loss_fn = nn.CrossEntropyLoss(ignore_index = 0)
learning_rate =0.001
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [None]:
def train(model, batch, loss_fn, optimizer):
  model.train()
  running_loss = 0.0
  for i, (src, trg) in enumerate(batch):
    outputs = model(src, trg) # output: [seq_length, batch_size,  n_token] # trg: [seq_length, batch_size]
    out_dim = outputs.size(-1)
    outputs = outputs[1:].view(-1, out_dim)
    trg = trg[1:].view(-1)
    optimizer.zero_grad()    
    loss = loss_fn(outputs, trg)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()

  return running_loss, model

In [None]:
def evaluate(model, batch, loss_fn):
  model.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, (src, trg) in enumerate(batch):
      outputs = model(src, trg)
      out_dim = outputs.size(-1)
      outputs = outputs[1:].view(-1, out_dim)
      trg = trg[1:].view(-1)
      loss = loss_fn(outputs, trg)
      running_loss += loss.item()

  return running_loss, model

In [None]:
def run(model, batch, loss_fn, optimizer, num_epochs = 5, print_every = 100):
  best_model = None
  min_loss = float("inf")
  for epoch in range(num_epochs):
    train_loss, model = train(model, batch, loss_fn, optimizer)
    val_loss, model = evaluate(model, batch, loss_fn)
    if (epoch+1) % print_every == 0 or epoch == 0:
      print(f'Epoch| {epoch+1}/{num_epochs}')
      print(f'train loss: {train_loss}')
      print(f'val loss: {val_loss}')  
    if min_loss > val_loss:
      min_loss = val_loss
      best_model = model

  return best_model

In [None]:
best_model = run(model, batch, loss_fn, optimizer, num_epochs = 500)

Epoch| 1/500
train loss: 0.0016592040192335844
val loss: 0.0015384580474346876
Epoch| 100/500
train loss: 0.001490377588197589
val loss: 0.0013969524297863245
Epoch| 200/500
train loss: 0.0014575652312487364
val loss: 0.0012718020007014275
Epoch| 300/500
train loss: 0.0012493958929553628
val loss: 0.0011633248068392277
Epoch| 400/500
train loss: 0.001154080149717629
val loss: 0.0010671226773411036
Epoch| 500/500
train loss: 0.0010473668808117509
val loss: 0.0009749621385708451


In [None]:
inference =[]
with torch.no_grad():
  best_model.eval()
  outputs = best_model(src_batch, torch.zeros((MAX_SEQ_LENGTH, batch_size), dtype = torch.long))
  preds = outputs.argmax(-1)
  preds = preds.squeeze(0)
  for pred in preds:
    inference.append(trg.ix2word[pred.item()])

In [None]:
inference

['<pad>', '나는', '집에', '가고', '싶다.', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>']