# Seq2seq NMT with RNN



[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)

**NOTE:**

-  use clean bpe data
-  use a piece of triaing data during coding or low in credits

You have to implement:

- Encoder
- Attention (Bahdanau)
- training loop
- extra: BLEU model selection

Goal:

- Loss in training, validation and test





In [1]:
%%capture
!pip install torch==2.3.0
!pip install torchtext==0.18

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import random
import time

In [3]:
#if you dont have bpe data use sacremoese tokenizer
#!pip install sacremoses

#clone repo to access bpe files from week2_files and to put requirements.txt
!git clone https://github.com/fubotz/BMT_2025S
%cd BMT_2025S/week6_files

Cloning into 'BMT_2025S'...
remote: Enumerating objects: 376, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 376 (delta 21), reused 8 (delta 8), pack-reused 342 (from 4)[K
Receiving objects: 100% (376/376), 90.16 MiB | 12.00 MiB/s, done.
Resolving deltas: 100% (155/155), done.
/content/BMT_2025S/week6_files


In [4]:
#which libraries are we using!!?
!pip freeze > requirements.txt

In [5]:
!cat requirements.txt

absl-py==1.4.0
accelerate==1.6.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.15
aiosignal==1.3.2
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.6
ale-py==0.11.0
altair==5.5.0
annotated-types==0.7.0
anyio==4.9.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.7.2
arviz==0.21.0
astropy==7.0.1
astropy-iers-data==0.2025.4.28.0.37.27
astunparse==1.6.3
atpublic==5.1
attrs==25.3.0
audioread==3.0.1
autograd==1.7.0
babel==2.17.0
backcall==0.2.0
backports.tarfile==1.2.0
beautifulsoup4==4.13.4
betterproto==2.0.0b6
bigframes==2.1.0
bigquery-magics==0.9.0
bleach==6.2.0
blinker==1.9.0
blis==1.3.0
blosc2==3.3.1
bokeh==3.7.2
Bottleneck==1.4.2
bqplot==0.12.44
branca==0.8.1
build==1.2.2.post1
CacheControl==0.14.2
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.4.26
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.1
chex==0.1.89
clarabel==0.10.0
click==8.1.8
cloudpathlib==0.21.0
cloudpickle==3.1.1
cmake==3.31.6
cmdstanpy==1.2.5
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
commun

In [6]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [7]:
import torchtext
dir(torchtext)

['_CACHE_DIR',
 '_TEXT_BUCKET',
 '_TORCHTEXT_DEPRECATION_MSG',
 '_WARN',
 '__all__',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '__version__',
 '_extension',
 '_get_torch_home',
 '_internal',
 '_torchtext',
 'git_version',
 'os',
 'version']

In [8]:
data_path = "../week2_files/Basic-MT_week2_files"

files = {
    "train_en": "train.en-de.bpe.en",
    "train_de": "train.en-de.bpe.de",
    "dev_en":   "dev.en-de.bpe.en",
    "dev_de":   "dev.en-de.bpe.de",
    "test_en":  "test.en-de.bpe.en",
    "test_de":  "test.en-de.bpe.de",
}

data = {}

for key, filename in files.items():
    with open(f"{data_path}/{filename}", encoding="utf-8") as f:
        data[key] = f.read().splitlines()

In [9]:
print(data["train_en"][:3])     #first 3 lines of English training data
print(data["dev_de"][:5])       #first 5 lines of German dev data

['A recent analysis by Ap@@ al@@ de@@ tt@@ i et al. (201@@ 1) suggest@@ s that G@@ ong@@ x@@ ian@@ os@@ aur@@ us was more bas@@ al than V@@ ul@@ can@@ od@@ on, T@@ az@@ ou@@ d@@ as@@ aur@@ us and Is@@ an@@ os@@ aur@@ us, but more der@@ i@@ ved than the early s@@ aur@@ op@@ o@@ ds An@@ t@@ et@@ on@@ it@@ r@@ us, L@@ ess@@ em@@ s@@ aur@@ us, B@@ lik@@ an@@ as@@ aur@@ us, Cam@@ el@@ o@@ ti@@ a and Mel@@ an@@ or@@ os@@ aur@@ us.', 'R@@ ei@@ ch@@ h@@ art also carried out execu@@ tions in C@@ olog@@ ne, Fran@@ k@@ fur@@ t-@@ Pre@@ ung@@ es@@ hei@@ m, Berlin@@ -@@ Pl@@ ö@@ tz@@ en@@ se@@ e, B@@ ran@@ den@@ burg@@ -@@ G@@ ör@@ den and B@@ res@@ lau@@ , where central execu@@ tion sit@@ es had also been construc@@ ted.', 'U@@ ph@@ old the right of all, without disc@@ ri@@ min@@ ation, to a natural and social en@@ vir@@ on@@ ment suppor@@ tive of human di@@ gn@@ ity, bo@@ di@@ ly health and spir@@ itu@@ al w@@ ell@@ -@@ b@@ ein@@ g, with special attention to the rights of in@@ di@@ gen@@ ous pe@@

In [10]:
import torchtext; torchtext.disable_torchtext_deprecation_warning()
import torch
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
#from torchtext.utils import download_from_url, extract_archive
import io

#0=en 1=de
#soure and target data
#NOTE: USE clean bpe data!

train_filepaths = [f"{data_path}/train.en-de.bpe.en", f"{data_path}/train.en-de.bpe.de"]
val_filepaths   = [f"{data_path}/dev.en-de.bpe.en",   f"{data_path}/dev.en-de.bpe.de"]
test_filepaths  = [f"{data_path}/test.en-de.bpe.en",  f"{data_path}/test.en-de.bpe.de"]



#NB: lines in input files already consists of bpe tokens!!!
#de_tokenizer = get_tokenizer('moses', language='de')
de_tokenizer = None
#en_tokenizer = get_tokenizer('moses', language='en')
en_tokenizer = None


def build_vocab(filepath, tokenizer=None):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f:
    for string_ in f:
      #counter.update(tokenizer(string_))
      counter.update(string_.split())
  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

#Vocab
en_vocab = build_vocab(train_filepaths[0], en_tokenizer)
de_vocab = build_vocab(train_filepaths[1], de_tokenizer)

print(dir(en_vocab))
en_vocab.set_default_index(en_vocab['<unk>'])
de_vocab.set_default_index(de_vocab['<unk>'])


def data_process(filepaths):
  raw_en_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_de_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_en, raw_de) in zip(raw_en_iter, raw_de_iter):
    en_tensor_ = torch.tensor([en_vocab[token] for token in raw_en.split()], #en_tokenizer(raw_en)
                            dtype=torch.long)
    de_tensor_ = torch.tensor([de_vocab[token] for token in raw_de.split()], #de_tokenizer(raw_de)
                            dtype=torch.long)
    data.append((en_tensor_, de_tensor_))
  return data

#pre-process
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

['T_destination', '__annotations__', '__call__', '__class__', '__contains__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__jit_unused_properties__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__prepare_scriptable__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '

In [19]:
print(len(train_data))
print(len(val_data))
print(len(test_data))

10000
466
467


In [16]:
#NOTE: if you are low on credits or testing only use a piece of the data e.g. 1K segments (for trying; 20k for final training)
train_data = train_data[:20000]

In [18]:
print(len(train_data))
print(len(val_data))
print(len(test_data))

10000
466
467


Define the device.

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


Create the iterators.

In [15]:
BATCH_SIZE = 8
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader


def generate_batch(data_batch):
    en_batch, de_batch = [], []
    for (en_item, de_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX, batch_first=True)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX, batch_first=True)
    return en_batch, de_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

## Building the Seq2Seq Model

### Encoder




In [19]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True, batch_first=True)       #RNN
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)       #linear layer for projecting 'hidden' to enc_hid_dim
        self.dropout = nn.Dropout(dropout)


    def forward(self, src):
        embedded = self.dropout(self.embedding(src))        #[B, S, E]
        outputs, hidden = self.rnn(embedded)                #outputs: [B, S, enc_hid_dim*2]; hidden: [2, B, enc_hid_dim]

        h1 = hidden[-2, :, :]       #forward
        h2 = hidden[-1, :, :]       #backward
        h_cat = torch.cat((h1, h2), dim=1)                  # [B, enc_hid_dim*2]
        hidden = torch.tanh(self.fc(h_cat))                 # [B, dec_hid_dim]

        return outputs, hidden

#NB:
#B = Batch size
#S = Source sequence length (number of tokens in the input sentence)
#T = Target sequence length (number of tokens in the output sentence)

# Attention

## Luong Attention




In [20]:
class LuongAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)


    def forward(self, hidden, encoder_outputs):     #keys, query
        """
        Compute Luong-style (concat) attention weights.

        Args:
            hidden (Tensor): [B, dec_hid_dim] --> current decoder hidden state
            encoder_outputs (Tensor): [B, S, enc_hid_dim*2] --> encoder outputs

        Returns:
            attention_weights (Tensor): [B, S] --> normalized attention scores
        """

        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        #repeat decoder hidden state to match encoder output length
        #hidden: [B, 1, dec_hid_dim] --> [B, S, dec_hid_dim]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        #concatenate decoder hidden state with encoder outputs along the last dimension
        #shape: [B, S, (enc_hid_dim*2 + dec_hid_dim)]
        scores = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2)))


        #compute raw attention scores and remove last dimension
        #scores: [B, S, dec_hid_dim] --> attention: [B, S]
        attention = self.v(scores).squeeze(2)

        #normalize scores over source sequence length
        return F.softmax(attention, dim=1)

In [21]:
class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(BahdanauAttention, self).__init__()

        self.Wa = nn.Linear(dec_hid_dim, dec_hid_dim, bias=False)
        self.Ua = nn.Linear(enc_hid_dim * 2, dec_hid_dim, bias=False)
        self.Va = nn.Linear(dec_hid_dim, 1, bias=False)


    def forward(self, hidden, encoder_outputs):     #keys, query
        """
        Compute Bahdanau-style (additive) attention weights.

        Args:
            hidden (Tensor): [B, dec_hid_dim] --> current decoder hidden state
            encoder_outputs (Tensor): [B, S, enc_hid_dim*2] --> encoder outputs

        Returns:
            weights (Tensor): [B, S] --> normalized attention scores
        """

        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        #repeat decoder hidden state across src_len
        #hidden: [B, 1, dec_hid_dim] --> [B, S, dec_hid_dim]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        #apply additive attention: Va(tanh(Wa(hidden) + Ua(encoder_outputs)))
        scores = self.Va(torch.tanh(self.Wa(hidden) + self.Ua(encoder_outputs)))        #[B, S, 1]
        scores = scores.squeeze(2)      #[B, S]

        #softmax over source sequence length
        weights = F.softmax(scores, dim=1)      #[B, S]

        return weights

### Decoder



In [23]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

        #attention
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)


    def forward(self, input, hidden, encoder_outputs):
        """
        Args:
            input: [B] --> current target token IDs
            hidden: [B, dec_hid_dim] --> previous decoder hidden state
            encoder_outputs: [B, S, enc_hid_dim*2] --> all encoder outputs


        Returns:
            prediction: [B, output_dim] --> logits for next token
            hidden: [B, dec_hid_dim] --> new decoder hidden state
        """

        input = input.unsqueeze(0)      #[1, B] add sequence length dimension

        embedded = self.dropout(self.embedding(input))      #[1, B, emb_dim]

        a = self.attention(hidden, encoder_outputs)     #[B, S] attention weights

        a = a.unsqueeze(1)      #[B, 1, S] for batch matrix multiplication

        #encoder_outputs: [B, S, enc_hid_dim*2]
        weighted = torch.bmm(a, encoder_outputs)        #[B, 1, enc_hid_dim*2] context vector

        weighted = weighted.permute(1, 0, 2)        #[1, B, enc_hid_dim*2] match RNN input

        rnn_input = torch.cat((embedded, weighted), dim=2)      #[1, B, emb_dim + enc_hid_dim*2]

        rnn_input = rnn_input.permute(1, 0, 2)      #[B, 1, emb_dim + enc_hid_dim*2]
        hidden = hidden.unsqueeze(0)                #[1, B, dec_hid_dim] match GRU expected input

        output, hidden = self.rnn(rnn_input, hidden)
        #output: [B, 1, dec_hid_dim]
        #hidden: [1, B, dec_hid_dim]

        embedded = embedded.squeeze(0)      #[B, emb_dim]
        output = output.squeeze(1)          #[B, dec_hid_dim]
        weighted = weighted.squeeze(0)      #[B, enc_hid_dim*2]

        #concatenate GRU output, context vector, and embedding; then predict next token
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))        #[B, output_dim]

        return prediction, hidden.squeeze(0)        #return predicted logits and new hidden state

### Seq2Seq




In [24]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device


    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        """
        Args:
            src: [B, S] --> source token indices
            trg: [B, T] --> target token indices
            teacher_forcing_ratio: float --> probability of using ground truth as the next input

        Returns:
            outputs: [B, T, output_dim] --> predicted token logits for each time step
        """

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim

        #tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        #run the encoder on the full source sequence
        #encoder_outputs: [B, src_len, enc_hid_dim * 2] --> all hidden states of the input sequence
        #hidden: [B, dec_hid_dim] --> final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)

        #first input to the decoder is the <sos> tokens
        input = trg[:,0]

        # unroll RNN
        for t in range(1, trg_len):
            # Decode one token
            # input: [B]
            # hidden: [B, dec_hid_dim]
            # encoder_outputs: [B, src_len, enc_hid_dim*2]
            output, hidden = self.decoder(input, hidden, encoder_outputs)

            #predictions
            outputs[:, t] = output      #[B, output_dim]

            #whether to use teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio

            #greedy search
            top1 = output.argmax(1)     #[B]

            #if teacher forcing, use gold token as next input
            #if not, use predicted token
            input = trg[:, t] if teacher_force else top1

        return outputs

## Training the Seq2Seq Model



In [25]:
#voc sizes
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(de_vocab)

#embedding and hidden dimensions
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.2

#attention mechanism (OR)
attn = LuongAttention(ENC_HID_DIM, DEC_HID_DIM)

#alternative: use Bahdanau attention instead
#attn = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)

#build encoder and decoder
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [27]:
print(len(en_vocab))        #BPE size 16k approx --> no?
print(len(de_vocab))

5871
6423


In [28]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(5871, 256)
    (rnn): GRU(256, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(6423, 256)
    (rnn): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=1792, out_features=6423, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
    (attn): Linear(in_features=1536, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=False)
  )
)

In [29]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 21,884,439 trainable parameters


We create an optimizer.

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

We initialize the loss function.

In [31]:
TRG_PAD_IDX = de_vocab['<pad>']     #index of <pad> token in target voc

#https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
#CrossEntropyLoss expects raw logits (not softmaxed) and ignores padding tokens in loss computation
#it will compute loss only over non-pad positions
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

In [32]:
print(TRG_PAD_IDX)

1


In [34]:
def train(model, iterator, optimizer, criterion, clip):
    #set model to training mode
    model.train()

    epoch_loss = 0

    for (src, trg) in tqdm(iterator):
        #move source and target tensors to device
        src, trg = src.to(model.device), trg.to(model.device)

        #zero the gradients from the previous batch
        optimizer.zero_grad()

        #forward pass through the model
        output = model(src, trg)

        #output: [B, trg_len, output_dim]
        #trg:    [B, trg_len]

        output = output.permute(1, 0, 2)        #[trg_len, B, output_dim]
        trg = trg.permute(1, 0)                 #[trg_len, B]

        output_dim = output.shape[-1]

        #skip the first token (<sos>) in both prediction and target
        output = output[1:].reshape(-1, output_dim)     #[(trg_len-1)*B, output_dim]
        trg = trg[1:].reshape(-1)                       #[(trg_len-1)*B]

        #compute the loss (ignoring padding index)
        loss = criterion(output, trg)

        #backpropagate the gradients
        loss.backward()

        #clip gradients to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        #update model parameters
        optimizer.step()

        epoch_loss += loss.item()

    #return average loss per batch
    return epoch_loss / len(iterator)

In [39]:
def evaluate(model, iterator, criterion):
    #set model to evaluation mode
    model.eval()

    epoch_loss = 0

    with torch.no_grad():
        for (src, trg) in iterator:
            #move data to device again
            src, trg = src.to(model.device), trg.to(model.device)

            #forward pass with no teacher forcing
            output = model(src, trg, teacher_forcing_ratio=0)       #NB: why turn off here?

            #trg: [B, trg_len]
            #output: [B, trg_len, output_dim]

            output = output.permute(1, 0, 2)        #[trg_len, B, output_dim]
            trg = trg.permute(1, 0)                 #[trg_len, B]

            output_dim = output.shape[-1]

            output = output[1:].reshape(-1, output_dim)     #[(trg_len-1)*B, output_dim]
            trg = trg[1:].reshape(-1)                       #[(trg_len-1)*B]

            #compute loss without backpropagation
            loss = criterion(output, trg)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [40]:
from torchtext.data.metrics import bleu_score

def compute_bleu(model, iterator, vocab, device):
    model.eval()
    trgs = []
    preds = []

    with torch.no_grad():
        for src, trg in iterator:
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg, teacher_forcing_ratio=0)       #greedy decoding
            output = output.argmax(2)       #[B, trg_len]

            for i in range(trg.shape[0]):
                #skip <sos>, stop at <eos>
                trg_tokens = [vocab.lookup_token(tok.item()) for tok in trg[i, 1:] if tok.item() != vocab['<pad>']]
                pred_tokens = [vocab.lookup_token(tok.item()) for tok in output[i, 1:] if tok.item() != vocab['<pad>']]

                #ensure at least one token is present (BLEU requires non-empty candidates and refs)
                if len(pred_tokens) > 0 and len(trg_tokens) > 0:
                    preds.append(pred_tokens)
                    trgs.append([trg_tokens])       #BLEU expects list of references

    return bleu_score(preds, trgs) * 100        #BLEU in percentage

In [41]:
N_EPOCHS = 5
CLIP = 1

best_valid_loss = float('inf')
best_bleu = 0.0

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    bleu = compute_bleu(model, valid_iter, de_vocab, device)


    if valid_loss < best_valid_loss or bleu > best_bleu:
        best_valid_loss = valid_loss
        best_bleu = bleu
        torch.save(model.state_dict(), 'model.pt')


    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}\tTrain PPL: {np.exp(train_loss):7.3f}')
    print(f'\t Validation Loss: {valid_loss:.3f}\tValidation PPL: {np.exp(valid_loss):7.3f}')
    print(f'\tValidation BLEU: {bleu:.2f}')

100%|██████████| 125/125 [00:31<00:00,  3.98it/s]


Epoch: 01
	Train Loss: 7.701	Train PPL: 2209.549
	 Validation Loss: 7.457	Validation PPL: 1731.264
	Validation BLEU: 0.00


100%|██████████| 125/125 [00:31<00:00,  4.00it/s]


Epoch: 02
	Train Loss: 7.218	Train PPL: 1363.552
	 Validation Loss: 7.438	Validation PPL: 1699.539
	Validation BLEU: 0.00


100%|██████████| 125/125 [00:31<00:00,  3.97it/s]


Epoch: 03
	Train Loss: 7.064	Train PPL: 1169.537
	 Validation Loss: 7.545	Validation PPL: 1891.274
	Validation BLEU: 0.00


100%|██████████| 125/125 [00:31<00:00,  3.94it/s]


Epoch: 04
	Train Loss: 6.865	Train PPL: 958.539
	 Validation Loss: 7.504	Validation PPL: 1814.769
	Validation BLEU: 0.00


100%|██████████| 125/125 [00:31<00:00,  3.97it/s]


Epoch: 05
	Train Loss: 6.640	Train PPL: 764.934
	 Validation Loss: 7.551	Validation PPL: 1902.677
	Validation BLEU: 0.00


In [42]:
#NOTE: load model from file
model.load_state_dict(torch.load('model.pt', map_location=device))

test_loss = evaluate(model, test_iter, criterion)
bleu = compute_bleu(model, test_iter, de_vocab, device)

print(f'\tTest Loss: {test_loss:.3f}\tTest PPL: {np.exp(test_loss):7.3f}')
print(f'\tTest BLEU: {bleu:.2f}')

	Test Loss: 7.427	Test PPL: 1681.099
	Test BLEU: 0.00


In [44]:
#NB: bad results:
#solution:
    #increase train_data = train_data[:20000]

In [43]:
#clean mem
del model
del train_iter
del valid_iter
del test_iter
torch.cuda.empty_cache()