# **Train Seq2Seq + Attention model for ANY->ENG translation task**

- This notebook contains everything required to train a basic seq2seq RNN model with any type of luong attention.
- It provides flexibilty to select encoder (num of layers, uni or bi - directional) and decoder (num of layers, uni or bi - directional) architectures, attention mechanisms ("dot", "concat", "general").
- It uses single to BPE based tokenizer for both languages(common vocabulary).
- It can also be trained for bidirectional translation.
- It provides data preprocessing utils like unicode normalizaztion, deduplication, length filtering, length-ratio filtering.

## Import basic stuff

In [None]:
import os
import re
import random
import pathlib
import requests
import zipfile
import unicodedata
import numpy as np
import tokenizers
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import tqdm

## Data utils

### Configuration

In [None]:
temp_dir = ""
data_dir = ""

### Data preparation

#### Data loading

In [None]:
def load_data(path):
  if not os.path.exists(path):
    raise Exception("File not found")
  lines = []
  with open(path, "r") as f:
    lines = f.read().splitlines()
  return lines

In [None]:
def download_data(url, path):
  result = requests.get(url, allow_redirects=True)
  with open(path, 'wb') as f:
    f.write(result.content)
  return True

#### Data pre-processing

In [None]:
def null_filter(pair):
  if len(pair[0]) == 0 or len(pair[1]) == 0:
    return False
  return True

In [None]:
def normalize(pair):
    return unicodedata.normalize('NFKC', pair[0]), unicodedata.normalize('NFKC', pair[1])

In [None]:
def deduplication(pairs):
  new_pairs = []
  for pair in pairs:
    if pair not in new_pairs:
      new_pairs.append(pair)
  return new_pairs

In [None]:
def length_filter(pair, min_length, max_length):
  src, tgt = pair[0], pair[1]
  return min_length <= len(src) <= max_length and min_length <= len(tgt) <= max_length

In [None]:
def length_ratio_filter(pair, min_ratio, max_ratio):
  src, tgt = pair[0], pair[1]
  if len(src) == 0 or len(tgt) == 0:
    return False
  return min_ratio <= len(src) / len(tgt) <= max_ratio and min_ratio <= len(tgt) / len(src) <= max_ratio

### Dataset formation

In [None]:
class TranslationDataset:
  def __init__(self, pairs, tokenizer, start_token="[start]", end_token="[end]", pad_token = "[pad]"):
    self.pairs = pairs
    self.tokenizer = tokenizer
    self.start_token = start_token
    self.end_token = end_token
    self.pad_token = pad_token
    self.padding_value = self.tokenizer.token_to_id(self.pad_token)

  def __len__(self):
    return len(self.pairs)

  def __getitem__(self, idx):
    src, tgt = self.pairs[idx][0], self.pairs[idx][1]
    src_ids = self.tokenizer.encode(src).ids
    tgt_ids = self.tokenizer.encode(f"{self.start_token}{tgt}{self.end_token}").ids
    return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

  def collate_fn(self, batch):
    src_batch, tgt_batch = zip(*batch)
    src_padded = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=self.padding_value)
    tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=self.padding_value)
    src_lengths = (src_padded != self.padding_value).sum(dim=1)
    return src_padded, tgt_padded, src_lengths

  def get_dataloader(self, batch_size, shuffle=True):
    return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

#### Test

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# -----------------------------
# Create dummy tokenizer
# -----------------------------
sentences = ["hello world", "how are you", "fine", "[pad]", "[start]", "[end]"]
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(special_tokens=["[pad]", "[start]", "[end]"])
tokenizer.train_from_iterator(sentences, trainer)
tokenizer.token_to_id = tokenizer.model.token_to_id  # make it accessible like in your code

# -----------------------------
# Dummy translation pairs
# -----------------------------
pairs = [
    ("hello world", "how are you"),
    ("how are you", "fine"),
    ("hello", "world"),
    ("", ""),  # edge case: empty string
]

# Create the dataset
dataset = TranslationDataset(pairs, tokenizer)

# -----------------------------
# Test 1: Length
# -----------------------------
print("Test 1: Dataset length")
assert len(dataset) == len(pairs)
print("✅ Passed")

# -----------------------------
# Test 2: __getitem__ output shapes
# -----------------------------
print("Test 2: __getitem__ returns tensors of correct shape")
src, tgt = dataset[0]
assert isinstance(src, torch.Tensor) and isinstance(tgt, torch.Tensor)
assert src.ndim == 1 and tgt.ndim == 1
print("✅ Passed")

# -----------------------------
# Test 3: Collate Function
# -----------------------------
print("Test 3: Collate function pads sequences correctly")
batch = [dataset[i] for i in range(3)]
src_padded, tgt_padded, src_lengths = dataset.collate_fn(batch)
assert src_padded.ndim == 2 and tgt_padded.ndim == 2
assert src_padded.shape[0] == 3  # batch size
assert src_lengths.shape == (3,)
print("✅ Passed")

# -----------------------------
# Test 4: Dataloader iterability
# -----------------------------
print("Test 4: Dataloader can be iterated")
loader = dataset.get_dataloader(batch_size=2)
for src_padded, tgt_padded, src_lengths in loader:
    assert src_padded.ndim == 2
    assert tgt_padded.ndim == 2
    assert src_lengths.ndim == 1
print("✅ Passed")

# -----------------------------
# Test 5: Edge Case — empty strings
# -----------------------------
print("Test 5: Edge case — empty source/target handled")
src, tgt = dataset[3]  # ("", "")
assert len(src) == 0 and len(tgt) > 0  # should at least have start/end tokens
print("✅ Passed")


Test 1: Dataset length
✅ Passed
Test 2: __getitem__ returns tensors of correct shape
✅ Passed
Test 3: Collate function pads sequences correctly
✅ Passed
Test 4: Dataloader can be iterated
✅ Passed
Test 5: Edge case — empty source/target handled
✅ Passed


### Tokenization

In [None]:
def get_tokenizer(path=None, downlaod= False, save_dir=None):

  if path:
    if os.path.exists(path):
        tokenizer = tokenizers.Tokenizer.from_file(path)
    elif downlaod:
        tokenizer = tokenizers.Tokenizer.from_pretrained(path)
        if save_dir:
          tokenizer.save(save_dir)
        else:
          tokenizer.save(path)
    else:
        raise Exception("File not found")
  else:
    raise Exception("No path provided")

  if not tokenizer:
    raise Exception("Tokenizer not found")

  return tokenizer

In [None]:
def train_tokenizer(path, text_strs, vocab_size = 16000):
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[unk]"))
  tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)
  tokenizer.decoder = tokenizers.decoders.ByteLevel()
  trainer = tokenizers.trainers.BpeTrainer(
      vocab_size=vocab_size,
      special_tokens=["[start]", "[end]", "[pad]"],
      show_progress=True
    )
  tokenizer.train_from_iterator(text_strs, trainer=trainer)
  tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[pad]"), pad_token="[pad]")
  tokenizer.save(path, pretty=True)
  return tokenizer

## Networks

### Encoder

In [None]:
class Encoder(nn.Module):
  def __init__(self,
      input_size, embedding_size, hidden_size,
      num_layers=1,
      dropout=0.01,
      bidirectional=False,
      arch = "gru",
      batch_first=True
    ):
    super(Encoder, self).__init__()
    self.input_size = input_size
    self.embedding_size = embedding_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.dropout_rate = dropout
    self.directions = 2 if bidirectional else 1
    self.arch = arch
    self.batch_first = batch_first

    self.embedding = nn.Embedding(input_size, embedding_size)
    self.rnn = getattr(nn, arch.upper())(
        embedding_size,
        hidden_size,
        num_layers,
        dropout=dropout,
        bidirectional=bidirectional,
        batch_first=batch_first
    )
    self.fc = nn.Linear(self.hidden_size*self.directions, hidden_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, src_padded, src_lengths):
    embedded = self.dropout(self.embedding(src_padded))
    embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_lengths, batch_first=True, enforce_sorted=False)
    # hidden: (h_n, c_n) if LSTM else (n_layer**num_directions, batch_size, hidden_dim)
    outputs, hidden = self.rnn(embedded)
    # output: (batch_size, max_length, hidden_dim*directions)
    outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
    # output: (batch_size, max_length, hidden_dim)
    outputs = self.fc(outputs)
    outputs = outputs.permute(1, 0, 2)  # (max_length, batch_size, hidden_dim)
    return outputs, hidden

###### Test

In [None]:
# Test
# === Define dummy input ===
batch_size = 3
max_seq_len = 5
vocab_size = 10       # dummy vocab size
pad_id = 0            # padding index
embedding_size = 10
hidden_size = 16

# Random lengths for each sample in the batch
src_lengths = torch.tensor([5, 3, 4])  # not sorted, which is fine due to enforce_sorted=False

# Create dummy sequences (as list of tensors of different lengths)
src_seqs = [
    torch.tensor([1, 2, 3, 4, 5]),      # len 5
    torch.tensor([6, 7, 8]),           # len 3
    torch.tensor([9, 3, 4, 5])         # len 4
]

# Pad the batch
src_padded = torch.nn.utils.rnn.pad_sequence(src_seqs, batch_first=True, padding_value=pad_id)
print("src_padded:\n", src_padded)
print("src_lengths:", src_lengths)

# === Initialize encoder ===
encoder = Encoder(
    input_size=vocab_size,
    embedding_size=embedding_size,
    hidden_size=hidden_size,
    num_layers=3,
    dropout=0.1,
    bidirectional=True,
    arch="lstm",
    batch_first=True
)

# === Forward pass ===
outputs, hidden = encoder(src_padded, src_lengths)

# === Output shapes ===
print("\nEncoder outputs shape:", outputs.shape)  # (batch_size, max_seq_len, hidden_size)
if isinstance(hidden, tuple):  # LSTM
    print("Encoder hidden shape (h, c):", hidden[0].shape, hidden[1].shape)
else:  # GRU
    print("Encoder hidden shape:", hidden.shape)

src_padded:
 tensor([[1, 2, 3, 4, 5],
        [6, 7, 8, 0, 0],
        [9, 3, 4, 5, 0]])
src_lengths: tensor([5, 3, 4])

Encoder outputs shape: torch.Size([3, 5, 16])
Encoder hidden shape (h, c): torch.Size([6, 3, 16]) torch.Size([6, 3, 16])


### Decoder

#### Attention

In [None]:
class LuongAttention(nn.Module):
    def __init__(self, method, hidden_size):
        super(LuongAttention, self).__init__()
        self.method = method
        self.hidden_size = hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        hidden = hidden.repeat(encoder_output.size(0), 1, 1)
        energy = self.attn(torch.cat([hidden, encoder_output], 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        attn_energies = attn_energies.t()
        #  attn_weights: shape: (batch_size, 1, seq_len)
        attn_weights = nn.functional.softmax(attn_energies, dim=1).unsqueeze(1)
        # context: shape: (batch_size, 1, hidden_size)
        context = torch.bmm(attn_weights, encoder_outputs.transpose(0, 1))
        return context, attn_weights

##### Test

In [None]:
def test_luong_attention(method='dot'):
    batch_size = 2
    seq_len = 4
    hidden_size = 8

    # Create dummy inputs
    # encoder_outputs: (seq_len, batch_size, hidden_size)
    encoder_outputs = torch.randn(seq_len, batch_size, hidden_size)

    # hidden: (1, batch_size, hidden_size)
    # Simulating the decoder hidden state at a single time step
    hidden = torch.randn(1, batch_size, hidden_size)

    # Initialize attention
    attn = LuongAttention(method=method, hidden_size=hidden_size)

    # Forward pass
    context, attn_weights = attn(hidden, encoder_outputs)

    # Print output shapes
    print(f"\nTesting Luong Attention ({method.upper()})")
    print("Context shape:      ", context.shape)       # (batch_size, 1, hidden_size)
    print("Attention weights:  ", attn_weights.shape)   # (batch_size, 1, seq_len)

    # Basic shape assertions
    assert context.shape == (batch_size, 1, hidden_size), "Context shape incorrect"
    assert attn_weights.shape == (batch_size, 1, seq_len), "Attention weights shape incorrect"

# Run tests for all three attention types
for method in ['dot', 'general', 'concat']:
    test_luong_attention(method)


Testing Luong Attention (DOT)
Context shape:       torch.Size([2, 1, 8])
Attention weights:   torch.Size([2, 1, 4])

Testing Luong Attention (GENERAL)
Context shape:       torch.Size([2, 1, 8])
Attention weights:   torch.Size([2, 1, 4])

Testing Luong Attention (CONCAT)
Context shape:       torch.Size([2, 1, 8])
Attention weights:   torch.Size([2, 1, 4])


#### Decoder

In [None]:
class Decoder(nn.Module):
  def __init__(self,
      vocab_size, embedding_size, hidden_size,
      num_layers=1,
      dropout=0.01,
      bidirectional=False,
      batch_first=True,
      arch = "gru",
      attn = "dot",
    ):
    super(Decoder, self).__init__()
    self.input_size = vocab_size
    self.embedding_size = embedding_size
    self.hidden_size = hidden_size
    self.output_size = vocab_size
    self.num_layers = num_layers
    self.dropout_rate = dropout
    self.directions = 2 if bidirectional else 1
    self.batch_first = batch_first
    self.arch = arch
    self.attn_method = attn

    self.embedding = nn.Embedding(self.input_size, embedding_size)
    self.attention = LuongAttention(attn, hidden_size)
    self.rnn = getattr(nn, arch.upper())(
        self.embedding_size+self.hidden_size,
        self.hidden_size,
        num_layers,
        dropout=dropout,
        bidirectional=bidirectional,
        batch_first=batch_first
    )
    self.hid_attn = nn.Linear(self.num_layers * self.directions, 1)
    self.fc = nn.Sequential(
            nn.Linear(self.hidden_size*self.directions, self.embedding_size), nn.LeakyReLU(),
            nn.Linear(self.embedding_size, self.output_size),
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, tgt_padded, hidden, enc_out):

    if hidden is None:
       h_0 = torch.zeros(
          self.num_layers*self.directions,
          tgt_padded.size(0),
          self.hidden_size,
          device=tgt_padded.device
       )
       if self.arch =="lstm":
          c_0 = torch.zeros_like(h_0)
          hidden = (h_0, c_0)
       else:
          hidden = h_0

    if self.arch == "lstm" and isinstance(hidden,tuple):
      hidden_ = hidden[0]
    else:
      hidden_ = hidden

    # [L*D, B, H]
    hidden_ = hidden_.permute(1, 2, 0)           # [B, H, L*D]
    attn_inp = self.hid_attn(hidden_)            # [B, H, 1]
    attn_inp = attn_inp.permute(2, 0, 1)         # [1, B, H]

    # print(f"Eo: {enc_out.shape}, Attin: {attn_inp.shape}")
    context, attn_weights = self.attention(attn_inp, enc_out)
    embedded = self.dropout(self.embedding(tgt_padded))
    context = context.repeat(1, embedded.size(1), 1)
    rnn_input = torch.cat([embedded, context], dim=2)
    # hidden: (h_n, c_n) if LSTM else (n_layer**num_directions, batch_size, hidden_dim)
    outputs, hidden = self.rnn(rnn_input, hidden)
    output = self.fc(outputs)
    return output, hidden

##### Test

In [None]:
#---------------------------------------
# 1. Load Tokenizer
#---------------------------------------

tokenizer = get_tokenizer(path="/content/en_tokenizer.json")

# ---------------------------------------
# 2. Create dummy translation pairs
# ---------------------------------------
pairs = [
    ("hello world", "how are you"),
    ("i am fine", "this is good"),
    ("nice to meet you", "hello world")
]

# ---------------------------------------
# 3. Use your TranslationDataset
# ---------------------------------------
dataset = TranslationDataset(pairs, tokenizer)
dataloader = dataset.get_dataloader(batch_size=2)

# ---------------------------------------
# 4. Get one batch from the dataloader
# ---------------------------------------
src_padded, tgt_padded, src_lengths = next(iter(dataloader))
print("Source padded:", src_padded)
print("Target padded:", tgt_padded)
print("Source lengths:", src_lengths)

# ---------------------------------------
# 5. Create a dummy encoder output
# ---------------------------------------
batch_size, seq_len = src_padded.size()
hidden_size = 32
embedding_size = 16
output_size = len(tokenizer.get_vocab())

enc_out = torch.randn(seq_len, batch_size, hidden_size)

# ---------------------------------------
# 6. Run your Decoder
# ---------------------------------------
decoder = Decoder(
    input_size=output_size,
    embedding_size=embedding_size,
    hidden_size=hidden_size,
    output_size=output_size,
    num_layers=1,
    dropout=0.1,
    bidirectional=False,
    batch_first=True,
    arch="lstm",
    attn="dot"
)

output, hidden = decoder(tgt_padded, None, enc_out)

print("Decoder output shape:", output.shape)
print("Decoder hidden shape:", hidden[0].shape if isinstance(hidden, tuple) else hidden.shape)

Source padded: tensor([[3548,  987,    2,    2],
        [ 854,   80,  605,   81]])
Target padded: tensor([[   0,  219,  157,   81,    1],
        [   0, 3548,  987,    1,    2]])
Source lengths: tensor([2, 4])
Decoder output shape: torch.Size([2, 5, 8000])
Decoder hidden shape: torch.Size([1, 2, 32])


### Seq2Seq

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 encoder, decoder,
                 vocab_size, hidden_sz, enc_embed_sz = 512, dec_embed_sz = 512,
                 enc_num_layers=1, dec_num_layers=1, enc_bidir = True, dec_bidir = False,
                 enc_dropout=0.0, dec_dropout=0.0,
                 arch = "gru",  attn="dot", device="cpu"
        ):
        super().__init__()
        self.batch_first=True
        self.vocab_size = vocab_size
        self.hidden_sz = hidden_sz
        self.arch = arch
        self.attn = attn
        self.enc_dirs = 2 if enc_bidir else 1
        self.dec_dirs = 2 if dec_bidir else 1
        self.enc_num_layers = enc_num_layers
        self.dec_num_layers = dec_num_layers
        # self.enc_embed_sz = enc_embed_sz
        # self.dec_embed_sz = dec_embed_sz
        self.device = device
        self.encoder = encoder(
            vocab_size,
            enc_embed_sz,
            hidden_sz,
            num_layers=enc_num_layers,
            bidirectional=enc_bidir,
            batch_first = self.batch_first,
            arch = arch,
            dropout=enc_dropout,
            )

        self.decoder = decoder(
            vocab_size,
            dec_embed_sz,
            hidden_sz,
            num_layers=dec_num_layers,
            bidirectional=dec_bidir,
            batch_first = self.batch_first,
            arch = arch,
            dropout=dec_dropout,
            attn = attn
        )
        # [batch_sz, hidden_sz,enc_num_layers*enc_bidir] -> [batch_sz, hidden_sz,dec_num_layers*dec_dirs]
        self.enc_dec = nn.Linear(enc_num_layers*self.enc_dirs, dec_num_layers*self.dec_dirs)
        self.encoder.to(device)
        self.decoder.to(device)
        self.to(device)

    def forward(self, input_seq, input_len, target_seq = None, tfr=0.5):
        """Given the partial target sequence, predict the next token"""

        # input seq = [batch_size, seq_len]
        # target seq = [batch_size, seq_len]
        batch_size, target_len = target_seq.shape
        # device = target_seq.device
        # storing output logits
        outputs = []
        # encoder forward pass
        _enc_out, enc_hidden = self.encoder(input_seq, input_len)
        if self.arch == "lstm" and isinstance(enc_hidden, tuple):
            enc_hid_0 = enc_hidden[0] # [enc_num_layers*enc_dirs, batch_sz, hidden_sz]
            enc_hid_1 = enc_hidden[1] # [enc_num_layers*enc_dirs, batch_sz, hidden_sz]
            enc_hid_0, enc_hid_1 = enc_hid_0.permute(1,2,0), enc_hid_1.permute(1,2,0)  # [batch_sz, hidden_sz,enc_num_layers*enc_bidir]
            enc_hid_0, enc_hid_1 = self.enc_dec(enc_hid_0) , self.enc_dec(enc_hid_1)
            hidden = (enc_hid_0.permute(2,0,1), enc_hid_1.permute(2,0,1)) # [dec_num_layers*dec_dirs, batch_sz, hidden_sz]
        else:
            enc_hidden = enc_hidden.permute(1,2,0) # [batch_sz, hidden_sz,enc_num_layers*enc_bidir]
            hidden = self.enc_dec(enc_hidden)
            hidden = hidden.permute(2, 0, 1) # [dec_num_layers*dec_dirs, batch_sz, hidden_sz]

        # decoder forward pass
        dec_in = target_seq[:, :1].to(self.device)
        # decoder forward pass
        for t in range(target_len-1):
            # last target token and hidden states -> next token
            pred, hidden = self.decoder(dec_in, hidden, _enc_out)
            # store the prediction
            pred = pred[:, -1:, :] # [batch, 1, vocab]
            outputs.append(pred)

            use_teacher = torch.rand(1).item() < tfr
            if use_teacher:
                # use the target token as the next input
                dec_in = target_seq[:, t+1:t+2]
            else:
                # use the predicted token as the next input
                dec_in = pred.argmax(dim=2)

        outputs = torch.cat(outputs, dim=1)
        return outputs

    def generate(self, input_seq, tokenizer, max_len=50):
        """
        Generate decoded sequences from input_seq.

        Args:
            input_seq (Tensor): [batch_size, seq_len] input tokens
            tokenizer: must have tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.decode()
            max_len (int): max generation length

        Returns:
            List[str]: list of decoded strings (one per input)
        """
        self.eval()
        pad_idx = tokenizer.token_to_id("[pad]")
        sos_idx = tokenizer.token_to_id("[start]")
        eos_idx = tokenizer.token_to_id("[end]")

        input_ids = tokenizer.encode(input_seq).ids
        input_tensor = torch.tensor(input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
        input_len = (input_tensor != pad_idx).sum(dim=1)

        with torch.no_grad():
            enc_out, enc_hidden = self.encoder(input_tensor, input_len)

            if self.arch == "lstm" and isinstance(enc_hidden, tuple):
                h_0, c_0 = enc_hidden
                h_0, c_0 = h_0.permute(1, 2, 0), c_0.permute(1, 2, 0)
                h_0, c_0 = self.enc_dec(h_0), self.enc_dec(c_0)
                hidden = (h_0.permute(2, 0, 1), c_0.permute(2, 0, 1))
            else:
                enc_hidden = enc_hidden.permute(1, 2, 0)
                hidden = self.enc_dec(enc_hidden)
                hidden = hidden.permute(2, 0, 1)

            input_token = torch.full((1, 1), sos_idx, dtype=torch.long, device=self.device)
            outputs = [input_token]

            for _ in range(max_len - 1):
                logits, hidden = self.decoder(input_token, hidden, enc_out)
                next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
                outputs.append(next_token)
                input_token = next_token

                if next_token == eos_idx:
                    break

            # Concatenate predictions: [batch_size, seq_len]
            generated = torch.cat(outputs, dim=1)

            # Convert each sequence of token IDs to string
            tokens = generated.tolist()
            # Optionally truncate at eos token
            if eos_idx in tokens:
                tokens = tokens[:tokens.index(eos_idx)]
            decoded = ''
            clean_tokens = [token for token in tokens if token not in [pad_idx, sos_idx, eos_idx]]
            for token in clean_tokens:
                t = tokenizer.decode(token)
                # print(t,end='')
                decoded += t

            return decoded


#### Test

In [None]:
def test_seq2seq_forward_pass():
    # Mock data
    pairs = [
        ("hello world", "goodbye world"),
        ("I am GPT", "hello world"),
        ("i am fine", "this is good"),
        ("nice to meet you", "hello world"),
    ]

    tokenizer = get_tokenizer(path="/content/en_tokenizer.json")
    dataset = TranslationDataset(pairs, tokenizer)
    dataloader = dataset.get_dataloader(batch_size=2)

    # Model
    vocab_size = len(tokenizer.get_vocab())
    hidden_sz = 32
    encoder = Encoder
    decoder = Decoder
    model = Seq2Seq(
        encoder=encoder,
        decoder=decoder,
        vocab_size=vocab_size,
        hidden_sz=hidden_sz,
        enc_embed_sz=32,
        dec_embed_sz=32,
        arch="lstm",
        attn="dot",
        enc_bidir=True,
        dec_bidir=True,
        enc_num_layers=5,
        dec_num_layers=2,
        enc_dropout=0.1,
        dec_dropout=0.1,
    )

    # Run forward pass on a batch
    for src_batch, tgt_batch, src_lengths in dataloader:
        print("Input shape:", src_batch.shape)
        print("Target shape:", tgt_batch.shape)
        output = model(src_batch, src_lengths, tgt_batch, tfr=0.5)
        print("Output shape:", output.shape)

        assert output.shape[0] == tgt_batch.shape[0]       # batch size
        assert output.shape[1] == tgt_batch.shape[1] - 1   # because predicting next token
        assert output.shape[2] == vocab_size               # logits over vocab
        print("✅ Test passed.")
        # break  # Only run one batch

test_seq2seq_forward_pass()

Input shape: torch.Size([2, 3])
Target shape: torch.Size([2, 5])
Output shape: torch.Size([2, 4, 8000])
✅ Test passed.
Input shape: torch.Size([2, 4])
Target shape: torch.Size([2, 4])
Output shape: torch.Size([2, 3, 8000])
✅ Test passed.


## Train

In [None]:
def train(model, trainloader, optimizer, criterion, device, epochs=1, evaloader=None):
    train_loss = 0
    for e in range(epochs):
        epoch_loss = 0
        model.train()
        for src_tagged, tgt_tagged, src_len in tqdm.tqdm(trainloader, desc="Training"):
            src_padded = src_tagged.to(device)
            tgt_padded = tgt_tagged.to(device)
            optimizer.zero_grad()
            output = model(src_padded, src_len, tgt_padded)
            # compute the loss: compare 3D logits to 2D targets
            loss = criterion(output.view(-1, output.shape[-1]), tgt_padded[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        if e < epochs-1:
          print(f"Epoch {e+1}/{epochs} completed with epoch loss of {epoch_loss/len(trainloader):.4f}")
        train_loss += epoch_loss/len(trainloader)
        if (e+1)%5 != 0 or not evaloader:
            continue
        eval_loss = 0
        model.eval()
        for src_tagged, tgt_tagged, src_len in tqdm.tqdm(evaloader, desc="Validation"):
            src_padded = src_tagged.to(device)
            tgt_padded = tgt_tagged.to(device)
            with torch.no_grad():
              output = model(src_padded, src_len, tgt_padded)
              # compute the loss: compare 3D logits to 2D targets
              loss = criterion(output.view(-1, output.shape[-1]), tgt_padded[:, 1:].reshape(-1))
              eval_loss += loss.item()
        print(f"Evaluation completed with epoch loss of {eval_loss/len(evaloader):.4f}")

    print(f"\nTraining completed with training loss of {train_loss/epochs:.4f}")

#### Testing on dummy data for testing

##### Dummy data

In [None]:
eng_pairs = [
    ("How are you?", "How have you been?"),
    ("I'm going to the store.", "I'm heading to the shop."),
    ("She likes reading books.", "She enjoys reading."),
    ("He is very tired.", "He's extremely exhausted."),
    ("Can you help me?", "Would you give me a hand?"),
    ("It's raining outside.", "There is rain falling outside."),
    ("I love chocolate.", "Chocolate is my favorite."),
    ("What time is it?", "Do you know the time?"),
    ("I'm learning Python.", "I'm studying Python."),
    ("That movie was great.", "I really liked that film."),
    ("Let's take a break.", "Let's pause for a moment."),
    ("The food was delicious.", "The meal tasted amazing."),
    ("He runs fast.", "He's a quick runner."),
    ("She's my best friend.", "She's my closest friend."),
    ("Do you speak English?", "Can you talk in English?"),
    ("I need some water.", "I'm thirsty."),
    ("This book is interesting.", "This is a fascinating read."),
    ("Turn off the light.", "Switch the light off."),
    ("Please be quiet.", "Can you lower your voice?"),
    ("I lost my keys.", "I can't find my keys."),
    ("Are you okay?", "Is everything alright?"),
    ("I have no idea.", "I don't know."),
    ("It's too expensive.", "That costs a lot."),
    ("Where are you from?", "What is your hometown?"),
    ("She sings beautifully.", "She has a lovely voice."),
    ("He's my brother.", "He's my sibling."),
    ("We are late.", "We’re behind schedule."),
    ("I'm going to bed.", "I'm off to sleep."),
    ("That's a good idea.", "That sounds great."),
    ("The weather is nice.", "It’s a beautiful day."),
    ("I'm sorry.", "I apologize."),
    ("I don't understand.", "I’m confused."),
    ("Could you repeat that?", "Can you say that again?"),
    ("That was funny.", "I found it hilarious."),
    ("Do you want coffee?", "Would you like some coffee?"),
    ("I forgot my password.", "I can't remember my password."),
    ("It's getting dark.", "The sun is going down."),
    ("Close the door.", "Shut the door."),
    ("She is very smart.", "She’s really intelligent."),
    ("I feel sick.", "I'm not feeling well."),
    ("We should go now.", "It's time to leave."),
    ("This is my house.", "I live here."),
    ("He lives nearby.", "He stays close."),
    ("What's your name?", "May I know your name?"),
    ("Don't worry.", "It’s all right."),
    ("Can I ask you something?", "May I ask you a question?"),
    ("I can't believe it!", "That's unbelievable!"),
    ("They look the same.", "They appear identical."),
    ("I'm really busy.", "I have a lot to do."),
    ("He arrived late.", "He came after the scheduled time."),
    ("She is very kind.", "She’s really nice."),
    ("It’s not my fault.", "I’m not to blame."),
    ("Please sit down.", "Have a seat."),
    ("We’re having lunch.", "We’re eating now."),
    ("Let me see.", "Let me take a look."),
    ("I hate waiting.", "I dislike delays."),
    ("I'm on my way.", "I’m coming."),
    ("Don't be late.", "Be on time."),
    ("He made a mistake.", "He messed up."),
    ("She passed the exam.", "She succeeded in the test."),
    ("It’s broken.", "It doesn’t work."),
    ("I missed the bus.", "I didn’t catch the bus."),
    ("They are married.", "They’re husband and wife."),
    ("I’m feeling cold.", "It’s chilly."),
    ("She was surprised.", "She didn’t expect that."),
    ("I can't hear you.", "You're too quiet."),
    ("Let’s go shopping.", "Let’s buy some things."),
    ("He's a good driver.", "He drives well."),
    ("That’s not fair.", "That’s unjust."),
    ("I agree with you.", "You’re right."),
    ("Let’s start.", "Let’s begin."),
    ("I need a break.", "I want to rest."),
    ("This is confusing.", "I don’t get it."),
    ("She looks tired.", "She seems exhausted."),
    ("He’s always late.", "He never arrives on time."),
    ("I made it myself.", "I did it on my own."),
    ("Can you drive?", "Do you know how to drive?"),
    ("It’s your turn.", "Now it’s up to you."),
    ("We had a great time.", "We really enjoyed ourselves."),
    ("Please call me.", "Give me a call."),
    ("That’s enough.", "Stop now."),
    ("I’m full.", "I can’t eat more."),
    ("Don’t forget.", "Remember that."),
    ("I was born here.", "This is my birthplace."),
    ("It’s near here.", "It’s close by."),
    ("You did well.", "Good job."),
    ("He’s rich.", "He has a lot of money."),
    ("This is fun.", "I’m enjoying this."),
    ("He’s very tall.", "He’s really big."),
    ("We’re done.", "We’ve finished."),
    ("Be careful.", "Watch out."),
    ("She’s shy.", "She’s introverted."),
    ("Try again.", "Give it another shot."),
    ("It’s not working.", "Something’s wrong."),
    ("I’ll be back soon.", "See you in a bit."),
    ("She won the prize.", "She got the award."),
    ("It’s your choice.", "You decide."),
    ("Keep going.", "Continue."),
    ("I’m listening.", "Go ahead, I’m paying attention.")
]

##### Trainig Loop

In [None]:
tokenizer = get_tokenizer(path="/content/en_tokenizer.json")
dataset = TranslationDataset(eng_pairs, tokenizer)
dataloader = dataset.get_dataloader(batch_size=2)
encoder = Encoder
decoder = Decoder
model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    vocab_size=len(tokenizer.get_vocab()),
    hidden_sz=32,
    enc_num_layers=2,
    dec_num_layers=2,
    arch="lstm",
    attn="dot"
)

train(model, dataloader, optimizer=torch.optim.Adam(model.parameters()), criterion=nn.CrossEntropyLoss(), device="cpu", epochs=100)

Training: 100%|██████████| 50/50 [00:12<00:00,  3.98it/s]


Epoch 1/100 completed with epoch loss of 6.8778


Training: 100%|██████████| 50/50 [00:12<00:00,  3.98it/s]


Epoch 2/100 completed with epoch loss of 4.1669


Training: 100%|██████████| 50/50 [00:12<00:00,  3.97it/s]


Epoch 3/100 completed with epoch loss of 3.6537


Training: 100%|██████████| 50/50 [00:12<00:00,  4.12it/s]


Epoch 4/100 completed with epoch loss of 3.3617


Training: 100%|██████████| 50/50 [00:12<00:00,  4.15it/s]


Epoch 5/100 completed with epoch loss of 3.0598


Training: 100%|██████████| 50/50 [00:11<00:00,  4.25it/s]


Epoch 6/100 completed with epoch loss of 2.8040


Training: 100%|██████████| 50/50 [00:11<00:00,  4.48it/s]


Epoch 7/100 completed with epoch loss of 2.5444


Training: 100%|██████████| 50/50 [00:11<00:00,  4.54it/s]


Epoch 8/100 completed with epoch loss of 2.2566


Training: 100%|██████████| 50/50 [00:12<00:00,  4.10it/s]


Epoch 9/100 completed with epoch loss of 2.0064


Training: 100%|██████████| 50/50 [00:11<00:00,  4.48it/s]


Epoch 10/100 completed with epoch loss of 1.8166


Training: 100%|██████████| 50/50 [00:11<00:00,  4.18it/s]


Epoch 11/100 completed with epoch loss of 1.6023


Training: 100%|██████████| 50/50 [00:11<00:00,  4.31it/s]


Epoch 12/100 completed with epoch loss of 1.4304


Training: 100%|██████████| 50/50 [00:11<00:00,  4.28it/s]


Epoch 13/100 completed with epoch loss of 1.1593


Training: 100%|██████████| 50/50 [00:10<00:00,  4.69it/s]


Epoch 14/100 completed with epoch loss of 0.9467


Training: 100%|██████████| 50/50 [00:11<00:00,  4.33it/s]


Epoch 15/100 completed with epoch loss of 0.7584


Training: 100%|██████████| 50/50 [00:12<00:00,  4.07it/s]


Epoch 16/100 completed with epoch loss of 0.6206


Training: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Epoch 17/100 completed with epoch loss of 0.5369


Training: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


Epoch 18/100 completed with epoch loss of 0.5097


Training: 100%|██████████| 50/50 [00:11<00:00,  4.31it/s]


Epoch 19/100 completed with epoch loss of 0.4754


Training: 100%|██████████| 50/50 [00:11<00:00,  4.21it/s]


Epoch 20/100 completed with epoch loss of 0.3521


Training: 100%|██████████| 50/50 [00:12<00:00,  3.93it/s]


Epoch 21/100 completed with epoch loss of 0.2532


Training: 100%|██████████| 50/50 [00:12<00:00,  3.93it/s]


Epoch 22/100 completed with epoch loss of 0.1811


Training: 100%|██████████| 50/50 [00:13<00:00,  3.81it/s]


Epoch 23/100 completed with epoch loss of 0.1416


Training: 100%|██████████| 50/50 [00:13<00:00,  3.74it/s]


Epoch 24/100 completed with epoch loss of 0.1263


Training: 100%|██████████| 50/50 [00:13<00:00,  3.80it/s]


Epoch 25/100 completed with epoch loss of 0.0993


Training: 100%|██████████| 50/50 [00:13<00:00,  3.85it/s]


Epoch 26/100 completed with epoch loss of 0.0775


Training: 100%|██████████| 50/50 [00:12<00:00,  4.11it/s]


Epoch 27/100 completed with epoch loss of 0.0719


Training: 100%|██████████| 50/50 [00:13<00:00,  3.82it/s]


Epoch 28/100 completed with epoch loss of 0.0619


Training: 100%|██████████| 50/50 [00:11<00:00,  4.51it/s]


Epoch 29/100 completed with epoch loss of 0.0494


Training: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


Epoch 30/100 completed with epoch loss of 0.0406


Training: 100%|██████████| 50/50 [00:11<00:00,  4.22it/s]


Epoch 31/100 completed with epoch loss of 0.0354


Training: 100%|██████████| 50/50 [00:12<00:00,  3.96it/s]


Epoch 32/100 completed with epoch loss of 0.0315


Training: 100%|██████████| 50/50 [00:11<00:00,  4.37it/s]


Epoch 33/100 completed with epoch loss of 0.0264


Training: 100%|██████████| 50/50 [00:10<00:00,  4.57it/s]


Epoch 34/100 completed with epoch loss of 0.0232


Training: 100%|██████████| 50/50 [00:11<00:00,  4.39it/s]


Epoch 35/100 completed with epoch loss of 0.0204


Training: 100%|██████████| 50/50 [00:11<00:00,  4.41it/s]


Epoch 36/100 completed with epoch loss of 0.0183


Training: 100%|██████████| 50/50 [00:12<00:00,  4.08it/s]


Epoch 37/100 completed with epoch loss of 0.0169


Training: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Epoch 38/100 completed with epoch loss of 0.0150


Training: 100%|██████████| 50/50 [00:10<00:00,  4.71it/s]


Epoch 39/100 completed with epoch loss of 0.0142


Training: 100%|██████████| 50/50 [00:12<00:00,  4.17it/s]


Epoch 40/100 completed with epoch loss of 0.0129


Training: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Epoch 41/100 completed with epoch loss of 0.0120


Training: 100%|██████████| 50/50 [00:12<00:00,  4.08it/s]


Epoch 42/100 completed with epoch loss of 0.0112


Training: 100%|██████████| 50/50 [00:11<00:00,  4.27it/s]


Epoch 43/100 completed with epoch loss of 0.0104


Training: 100%|██████████| 50/50 [00:11<00:00,  4.51it/s]


Epoch 44/100 completed with epoch loss of 0.0096


Training: 100%|██████████| 50/50 [00:12<00:00,  3.97it/s]


Epoch 45/100 completed with epoch loss of 0.0089


Training: 100%|██████████| 50/50 [00:12<00:00,  3.99it/s]


Epoch 46/100 completed with epoch loss of 0.0084


Training: 100%|██████████| 50/50 [00:11<00:00,  4.29it/s]


Epoch 47/100 completed with epoch loss of 0.0077


Training: 100%|██████████| 50/50 [00:11<00:00,  4.23it/s]


Epoch 48/100 completed with epoch loss of 0.0073


Training: 100%|██████████| 50/50 [00:12<00:00,  4.13it/s]


Epoch 49/100 completed with epoch loss of 0.0070


Training: 100%|██████████| 50/50 [00:11<00:00,  4.18it/s]


Epoch 50/100 completed with epoch loss of 0.0064


Training: 100%|██████████| 50/50 [00:12<00:00,  4.06it/s]


Epoch 51/100 completed with epoch loss of 0.0061


Training: 100%|██████████| 50/50 [00:13<00:00,  3.79it/s]


Epoch 52/100 completed with epoch loss of 0.0057


Training: 100%|██████████| 50/50 [00:12<00:00,  3.87it/s]


Epoch 53/100 completed with epoch loss of 0.0054


Training: 100%|██████████| 50/50 [00:11<00:00,  4.25it/s]


Epoch 54/100 completed with epoch loss of 0.0051


Training: 100%|██████████| 50/50 [00:11<00:00,  4.25it/s]


Epoch 55/100 completed with epoch loss of 0.0050


Training: 100%|██████████| 50/50 [00:12<00:00,  3.98it/s]


Epoch 56/100 completed with epoch loss of 0.0046


Training: 100%|██████████| 50/50 [00:13<00:00,  3.73it/s]


Epoch 57/100 completed with epoch loss of 0.0044


Training: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Epoch 58/100 completed with epoch loss of 0.0041


Training: 100%|██████████| 50/50 [00:13<00:00,  3.78it/s]


Epoch 59/100 completed with epoch loss of 0.0039


Training: 100%|██████████| 50/50 [00:12<00:00,  4.01it/s]


Epoch 60/100 completed with epoch loss of 0.0038


Training: 100%|██████████| 50/50 [00:12<00:00,  3.96it/s]


Epoch 61/100 completed with epoch loss of 0.0035


Training: 100%|██████████| 50/50 [00:12<00:00,  4.07it/s]


Epoch 62/100 completed with epoch loss of 0.0034


Training: 100%|██████████| 50/50 [00:11<00:00,  4.36it/s]


Epoch 63/100 completed with epoch loss of 0.0032


Training: 100%|██████████| 50/50 [00:11<00:00,  4.36it/s]


Epoch 64/100 completed with epoch loss of 0.0031


Training: 100%|██████████| 50/50 [00:12<00:00,  4.10it/s]


Epoch 65/100 completed with epoch loss of 0.0029


Training: 100%|██████████| 50/50 [00:12<00:00,  3.96it/s]


Epoch 66/100 completed with epoch loss of 0.0029


Training: 100%|██████████| 50/50 [00:11<00:00,  4.35it/s]


Epoch 67/100 completed with epoch loss of 0.0027


Training: 100%|██████████| 50/50 [00:12<00:00,  3.94it/s]


Epoch 68/100 completed with epoch loss of 0.0026


Training: 100%|██████████| 50/50 [00:12<00:00,  4.15it/s]


Epoch 69/100 completed with epoch loss of 0.0024


Training: 100%|██████████| 50/50 [00:13<00:00,  3.75it/s]


Epoch 70/100 completed with epoch loss of 0.0023


Training: 100%|██████████| 50/50 [00:12<00:00,  4.08it/s]


Epoch 71/100 completed with epoch loss of 0.0022


Training: 100%|██████████| 50/50 [00:11<00:00,  4.24it/s]


Epoch 72/100 completed with epoch loss of 0.0021


Training: 100%|██████████| 50/50 [00:10<00:00,  4.81it/s]


Epoch 73/100 completed with epoch loss of 0.0020


Training: 100%|██████████| 50/50 [00:13<00:00,  3.69it/s]


Epoch 74/100 completed with epoch loss of 0.0019


Training: 100%|██████████| 50/50 [00:13<00:00,  3.72it/s]


Epoch 75/100 completed with epoch loss of 0.0019


Training: 100%|██████████| 50/50 [00:12<00:00,  3.93it/s]


Epoch 76/100 completed with epoch loss of 0.0018


Training: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Epoch 77/100 completed with epoch loss of 0.0017


Training: 100%|██████████| 50/50 [00:12<00:00,  3.85it/s]


Epoch 78/100 completed with epoch loss of 0.0016


Training: 100%|██████████| 50/50 [00:12<00:00,  4.12it/s]


Epoch 79/100 completed with epoch loss of 0.0016


Training: 100%|██████████| 50/50 [00:11<00:00,  4.28it/s]


Epoch 80/100 completed with epoch loss of 0.0015


Training: 100%|██████████| 50/50 [00:11<00:00,  4.41it/s]


Epoch 81/100 completed with epoch loss of 0.0015


Training: 100%|██████████| 50/50 [00:13<00:00,  3.84it/s]


Epoch 82/100 completed with epoch loss of 0.0014


Training: 100%|██████████| 50/50 [00:13<00:00,  3.64it/s]


Epoch 83/100 completed with epoch loss of 0.0013


Training: 100%|██████████| 50/50 [00:11<00:00,  4.49it/s]


Epoch 84/100 completed with epoch loss of 0.0013


Training: 100%|██████████| 50/50 [00:13<00:00,  3.64it/s]


Epoch 85/100 completed with epoch loss of 0.0012


Training: 100%|██████████| 50/50 [00:12<00:00,  3.86it/s]


Epoch 86/100 completed with epoch loss of 0.0012


Training: 100%|██████████| 50/50 [00:12<00:00,  3.96it/s]


Epoch 87/100 completed with epoch loss of 0.0011


Training: 100%|██████████| 50/50 [00:12<00:00,  3.98it/s]


Epoch 88/100 completed with epoch loss of 0.0011


Training: 100%|██████████| 50/50 [00:13<00:00,  3.70it/s]


Epoch 89/100 completed with epoch loss of 0.0011


Training: 100%|██████████| 50/50 [00:11<00:00,  4.45it/s]


Epoch 90/100 completed with epoch loss of 0.0010


Training: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Epoch 91/100 completed with epoch loss of 0.0010


Training: 100%|██████████| 50/50 [00:11<00:00,  4.51it/s]


Epoch 92/100 completed with epoch loss of 0.0010


Training: 100%|██████████| 50/50 [00:11<00:00,  4.43it/s]


Epoch 93/100 completed with epoch loss of 0.0009


Training: 100%|██████████| 50/50 [00:11<00:00,  4.24it/s]


Epoch 94/100 completed with epoch loss of 0.0009


Training: 100%|██████████| 50/50 [00:12<00:00,  3.97it/s]


Epoch 95/100 completed with epoch loss of 0.0008


Training: 100%|██████████| 50/50 [00:11<00:00,  4.27it/s]


Epoch 96/100 completed with epoch loss of 0.0008


Training: 100%|██████████| 50/50 [00:13<00:00,  3.64it/s]


Epoch 97/100 completed with epoch loss of 0.0008


Training: 100%|██████████| 50/50 [00:13<00:00,  3.82it/s]


Epoch 98/100 completed with epoch loss of 0.0008


Training: 100%|██████████| 50/50 [00:11<00:00,  4.31it/s]


Epoch 99/100 completed with epoch loss of 0.0007


Training: 100%|██████████| 50/50 [00:13<00:00,  3.84it/s]


Training completed with training loss of 0.4246





## Predicting

In [None]:
def predict(model, tokenizer, input_seq):
    output = model.generate(input_seq, tokenizer)

for src, tgt in eng_pairs[:4]:
    print(f"Source: {src}")
    print(f"Target: {tgt}")
    print("Predicted: ")
    predict(model, tokenizer, src)
    print()


Source: How are you?
Target: How have you been?
Predicted: 
 ow have you been?
Source: I'm going to the store.
Target: I'm heading to the shop.
Predicted: 
 'm heading to the shop.
Source: She likes reading books.
Target: She enjoys reading.
Predicted: 
 he enjoys reading.
Source: He is very tired.
Target: He's extremely exhausted.
Predicted: 
 e's extremely exhausted.


# extra

In [None]:
torch.save(model, 'model.pth')