In [None]:
import requests
import torch
import torch.nn.functional as F
import torchtext

import tarfile
def iwslt15(train_test):
  url = "https://github.com/stefan-it/nmt-en-vi/raw/master/data/"
  r = requests.get(url + train_test + "-en-vi.tgz")
  filename = train_test + "-en-vi.gz"
  with open(filename, 'wb') as f:
    f.write(r.content)
    tarfile.open(filename, 'r:gz').extractall("iwslt15")
iwslt15("train")
iwslt15("test-2013")

f = open("iwslt15/train.en")
train_en = [line.split() for line in f]
f.close()
f = open("iwslt15/train.vi")
train_vi = [line.split() for line in f]
f.close()
f = open("iwslt15/tst2013.en")
test_en = [line.split() for line in f]
f.close()
f = open("iwslt15/tst2013.vi")
test_vi = [line.split() for line in f]
f.close()

for i in range(10):
  print(train_en[i])
  print(train_vi[i])
print("# of line", len(train_en), len(train_vi), len(test_en), len(test_vi))

MODELNAME = "iwslt15-en-vi-lstm_dropout.model"
EPOCH = 25
BATCHSIZE = 64
LR = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def make_vocab(train_data, min_freq):
  vocab = {}
  for tokenlist in train_data:
    for token in tokenlist:
      if token not in vocab:
        vocab[token] = 0
      vocab[token] += 1
  vocablist = [('<unk>', 0), ('<pad>', 0), ('<cls>', 0), ('<eos>', 0)]
  vocabidx = {}
  for token, freq in vocab.items():
    if freq >= min_freq:
      idx = len(vocablist)
      vocablist.append((token, freq))
      vocabidx[token] = idx
  vocabidx['<unk>'] = 0
  vocabidx['<pad>'] = 1
  vocabidx['<cls>'] = 2
  vocabidx['<eos>'] = 3
  return vocablist, vocabidx

vocablist_en, vocabidx_en = make_vocab(train_en, 3)
vocablist_vi, vocabidx_vi = make_vocab(train_vi, 3)

print("vocab size en:", len(vocablist_en))
print("vocab size vi:", len(vocablist_vi))

def preprocess(data, vocabidx):
  rr = []
  for tokenlist in data:
    tkl = ['<cls>']
    for token in tokenlist:
      tkl.append(token if token in vocabidx else '<unk>')
    tkl.append('<eos>')
    rr.append(tkl)
  return rr

train_en_prep = preprocess(train_en, vocabidx_en)
train_vi_prep = preprocess(train_vi, vocabidx_vi)
test_en_prep = preprocess(test_en, vocabidx_en)

for i in range(5):
  print(train_en_prep[i])
  print(train_vi_prep[i])
  print(test_en_prep[i])

train_data = list(zip(train_en_prep, train_vi_prep))
train_data.sort(key = lambda x: (len(x[0]), len(x[1])))
test_data = list(zip(test_en_prep, test_en, test_vi))

for i in range(5):
  print(train_data[i])
for i in range(5):
  print(test_data[i])

def make_batch(data, batchsize):
  bb = []
  ben = []
  bvi = []
  for en, vi in data:
    ben.append(en)
    bvi.append(vi)
    if len(ben) >= batchsize:
      bb.append((ben, bvi))
      ben = []
      bvi = []
  if len(ben) > 0:
    bb.append((ben, bvi))
  return bb

train_data = make_batch(train_data, BATCHSIZE)

for i in range(5):
  print(train_data[i])

def padding_batch(b):
  maxlen = max([len(x) for x in b])
  for tkl in b:
    for i in range(maxlen - len(tkl)):
      tkl.append('<pad>')

def padding(bb):
  for ben, bvi in bb:
    padding_batch(ben)
    padding_batch(bvi)

padding(train_data)

for i in range(3):
  print(train_data[i])

train_data = [([[vocabidx_en[token] for token in tokenlist] for tokenlist in ben],
                [[vocabidx_vi[token] for token in tokenlist] for tokenlist in bvi])
              for ben, bvi in train_data]
test_data = [([vocabidx_en[token] for token in enprep], en, vi)
              for enprep, en, vi in test_data]
for i in range(3):
  print(train_data[i])
for i in range(3):
  print(test_data[i])

In [None]:
import torch.nn as nn
class LSTM(nn.Module):
    def __init__(self, vocablist_x, vocabidx_x, vocablist_y, vocabidx_y):
        super(LSTM, self).__init__()

        self.enc_emb = nn.Embedding(len(vocablist_x), 256, padding_idx=vocabidx_x['<pad>'])
        self.dropout = nn.Dropout(0.5)
        self.enc_lstm = nn.LSTM(256, 516, 2, dropout=0.5)

        self.dec_emb = nn.Embedding(len(vocablist_y), 256, padding_idx=vocabidx_y['<pad>'])
        self.dec_lstm = nn.LSTM(256, 516, 2, dropout=0.5)
        self.attn_linear = nn.Linear(2 * 516, 516)
        self.attn_bias = nn.Parameter(torch.zeros(516))
        self.output_linear = nn.Linear(516, len(vocablist_y))
        self.output_bias = nn.Parameter(torch.zeros(len(vocablist_y)))

    def forward(self, x):
        src, tgt = x[0], x[1]
        # Encoder
        enc_embedded = self.dropout(self.enc_emb(src))
        encoder_outputs, (hidden, cell) = self.enc_lstm(enc_embedded)

        max_len = encoder_outputs.shape[0]
        batch_size = encoder_outputs.shape[1]

        tgt_len = tgt.shape[0]
        loss = torch.tensor(0., dtype=torch.float32).to(src.device)
        for j in range(tgt_len - 1):
            # Decoder
            dec_input = tgt[j].unsqueeze(0)
            dec_embedded = self.dropout(self.dec_emb(dec_input))
            dec_output, (hidden, cell) = self.dec_lstm(dec_embedded, (hidden, cell))
            hidden_last = hidden[-1].unsqueeze(2)
            attn_energies = torch.bmm(encoder_outputs.permute(1, 0, 2), hidden_last).squeeze(2)
            attn_weights = F.softmax(attn_energies, dim=1).unsqueeze(1)
            context = torch.bmm(attn_weights, encoder_outputs.permute(1, 0, 2)).squeeze(1)
            combined = torch.cat((context, hidden[-1]), dim=1)
            h_combined = torch.tanh(self.attn_linear(combined) + self.attn_bias)
            output = self.output_linear(h_combined) + self.output_bias
            loss += F.cross_entropy(output, tgt[j + 1])
        return loss

    def evaluate(self, src, vocablist_y, vocabidx_y):
        # Encodder
        enc_embedded = self.dropout(self.enc_emb(src))
        encoder_outputs, (hidden, cell) = self.enc_lstm(enc_embedded)

        y = torch.tensor([vocabidx_y['<cls>']]).to(src.device)
        predictions = []
        for _ in range(50):
            dec_input = y.unsqueeze(0)
            # Decoder
            dec_embedded = self.dropout(self.dec_emb(dec_input))
            dec_output, (hidden, cell) = self.dec_lstm(dec_embedded, (hidden, cell))
            hidden_last = hidden[-1].unsqueeze(2)
            attn_energies = torch.bmm(encoder_outputs.permute(1, 0, 2), hidden_last).squeeze(2)
            attn_weights = F.softmax(attn_energies, dim=1).unsqueeze(1)
            context = torch.bmm(attn_weights, encoder_outputs.permute(1, 0, 2)).squeeze(1)
            combined = torch.cat((context, hidden[-1]), dim=1)
            h_combined = torch.tanh(self.attn_linear(combined) + self.attn_bias)
            output = self.output_linear(h_combined) + self.output_bias
            pred_id = output.squeeze().argmax().item()

            if pred_id == vocabidx_y['<eos>']:
                break
            pred_word = vocablist_y[pred_id][0]
            predictions.append(pred_word)
            y[0] = pred_id
        return predictions

In [None]:
def train():
  model = LSTM(vocablist_en, vocabidx_en, vocablist_vi, vocabidx_vi).to(DEVICE)
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
  for epoch in range(EPOCH):
    loss = 0
    step = 0
    for ben, bvi in train_data:
      ben = torch.tensor(ben, dtype=torch.int64).transpose(0,1).to(DEVICE)
      bvi = torch.tensor(bvi, dtype=torch.int64).transpose(0,1).to(DEVICE)
      optimizer.zero_grad()
      batchloss = model((ben, bvi))
      batchloss.backward()
      optimizer.step()
      loss = loss + batchloss.item()
      if step % 100 == 0:
        print("step:", step, "batch loss:", batchloss.item())
      step += 1
    print("epoch", epoch, ": loss", loss)
  torch.save(model.state_dict(), MODELNAME)

train()

In [None]:
from torchtext.data.metrics import bleu_score

#BLUE = 0.19

In [None]:
def test():
  total = 0
  correct = 0
  model = LSTM(vocablist_en, vocabidx_en, vocablist_vi, vocabidx_vi).to(DEVICE)
  model.load_state_dict(torch.load(MODELNAME))
  model.eval()
  ref = []
  pred = []
  for enprep, en, vi in test_data:
    input = torch.tensor([enprep], dtype=torch.int64).transpose(0, 1).to(DEVICE)
    # flat_inputs = torch.flatten(input)
    p=model.evaluate(input, vocablist_vi, vocabidx_vi)
    print("INPUT", en)
    print("REF", vi)
    print("MT", p)
    ref.append([vi])
    pred.append(p)

  bleu = bleu_score(pred, ref)
  print("total:", len(test_data))
  print("bleu:", bleu)

test()