<a href="https://colab.research.google.com/github/mohammadp1001/attention_is_all_you_need/blob/main/Encoder_Decoder_Model_with_Attention_ipyn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --quiet datasets

In [None]:
# @title  Import the necessary libraries
import re
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader


In [None]:
# @title Download German-English translation data

# Load the WMT14 dataset for German-English translation
wmt14 = load_dataset("wmt14", "de-en")

# Access the training data
train_data = wmt14['train']
test_data  = wmt14['test']

# Example: Display the first training example
print(train_data[0])

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/280M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/265M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/273M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/474k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/509k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4508785 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode', 'en': 'Resumption of the session'}}


In [None]:
# @title Tokenizer Class
class Tokenizer:
    def __init__(self, special_tokens=["<PAD>", "<UNK>", "<SOS>", "<EOS>"]):
        """
        Initializes the tokenizer with an empty vocabulary and special tokens.
        """
        self.vocab = {}  # Word to index mapping
        self.inv_vocab = {}  # Index to word mapping
        self.special_tokens = special_tokens
        self.word_freqs = Counter()
        self.vocab_size = 0  # Will be set when vocabulary is built
        self.max_length = 0 # The maximum length of sequence

    def _clean_text(self, text):
        """
        Preprocess text: lowercase, remove extra spaces, handle punctuation.
        """
        text = text.lower().strip()  # Lowercase and trim spaces
        text = re.sub(r"([?.!,])", r" \1 ", text)  # Space before punctuation
        text = re.sub(r"[^a-zA-Z?.!,']+", " ", text)  # Remove unwanted characters
        text = re.sub(r"\s+", " ", text).strip()  # Remove extra spaces
        return text

    def tokenize(self, text):
        """
        Tokenizes a sentence into words.
        """
        text = self._clean_text(text)
        splited_text = text.split()
        self.max_length = max(self.max_length, len(splited_text))
        return splited_text

    def build_vocab(self, texts, min_freq=1):
        """
        Builds vocabulary from a list of sentences.
        - `min_freq`: Minimum frequency for a word to be included in vocab.
        """
        # Count word frequencies
        for text in texts:
            tokens = self.tokenize(text)
            self.word_freqs.update(tokens)

        # Initialize vocabulary with special tokens
        self.vocab = {token: i for i, token in enumerate(self.special_tokens)}

        # Add words that meet min_freq condition
        for word, freq in self.word_freqs.items():
            if freq >= min_freq and word not in self.vocab:
                self.vocab[word] = len(self.vocab)

        # Create inverse vocabulary (index-to-word mapping)
        self.inv_vocab = {idx: word for word, idx in self.vocab.items()}

        self.vocab_size = len(self.vocab)
        print(f"Vocabulary size: {self.vocab_size}")

    def encode(self, text, add_special_tokens=True, padding=True):
        """
        Converts a sentence into a sequence of token IDs.
        """
        tokens = self.tokenize(text)
        token_ids = [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens]

        if add_special_tokens:
            token_ids = [self.vocab["<SOS>"]] + token_ids + [self.vocab["<EOS>"]]

        if padding:
            token_ids += [self.vocab["<PAD>"]] * (self.max_length - len(token_ids))

        return token_ids

    def decode(self, token_ids, remove_special_tokens=True):
        """
        Converts a sequence of token IDs back into a sentence.
        """
        words = [self.inv_vocab.get(idx, "<UNK>") for idx in token_ids]

        if remove_special_tokens:
            words = [w for w in words if w not in self.special_tokens]

        return " ".join(words)

    def to_numpy(self, token_ids):
        """
        Converts a sequence of token IDs into a NumPy array.
        """
        return np.array(token_ids)

In [None]:
# include only the first N = 2500 examples for training.

N = 8000
english_sentences = [train_data[index]['translation']['en'] for index in range(N)]
german_sentences = [train_data[index]['translation']['de'] for index in range(N)]

# Initialize tokenizer (English)
en_tokenizer = Tokenizer()

# Build vocabulary from sentences (English)
en_tokenizer.build_vocab(english_sentences, min_freq=1)

# Initialize tokenizer (German)
de_tokenizer = Tokenizer()

# Build vocabulary from sentences (German)
de_tokenizer.build_vocab(german_sentences, min_freq=1)


english_sentences_tokenized = [en_tokenizer.tokenize(sentence) for sentence in english_sentences]
german_sentences_tokenized = [de_tokenizer.tokenize(sentence) for sentence in german_sentences]

english_sentences_encoded = [en_tokenizer.encode(sentence) for sentence in english_sentences]
german_sentences_encoded = [de_tokenizer.encode(sentence) for sentence in german_sentences]


Vocabulary size: 9742
Vocabulary size: 15998


In [None]:
from torch.utils.data import Dataset

class TranslationDataset(Dataset):
    def __init__(self, src_sequences, tgt_sequences):
        self.src = src_sequences
        self.tgt = tgt_sequences

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)

    # Convert lists to tensors
    src_batch = [torch.tensor(x) for x in src_batch]
    tgt_batch = [torch.tensor(x) for x in tgt_batch]

    # Pad sequences to same length in the batch
    src_padded = pad_sequence(src_batch, batch_first=True, padding_value=0)  # 0 = <pad>
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=0)

    return src_padded, tgt_padded



# Create dataset
dataset = TranslationDataset(english_sentences_encoded, german_sentences_encoded)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)


In [None]:
class Encoder(nn.Module):
  """The encoder class."""
  def __init__(self, input_size: int, embedding_dim : int, hidden_size: int, num_layers: int =1):
    super(Encoder, self).__init__()
    self._hidden_size = hidden_size
    self._num_layers = num_layers
    self._embedding =  nn.Embedding(input_size, embedding_dim)
    self._rnn = nn.GRU(embedding_dim , hidden_size, num_layers, batch_first=True)

  def forward(self,x):
    """Forward pass for encoder.

    :parm x: Input tensor of shape (batch_size, seq_len)
    :return output: Output tensor of shape (batch_size, seq_len, hidden_size)
    :return hidden: Hidden state tensor of shape (num_layers, batch_size, hidden_size)
    """
    embedded = self._embedding(x)
    output, hidden = self._rnn(embedded)
    return output, hidden

In [None]:
class Attention(nn.Module):
  def __init__(self,  hidden_size: int, input_size: int= None, proj_values: bool=False) -> None:

    super().__init__()
    self._hidden_size = hidden_size
    self._input_size = hidden_size if input_size is None else input_size # what is input_size here?
    self._proj_values = proj_values
    self.d_k = hidden_size
    self._alphas = None

    # Affine transformation for Q,K and V.
    self._query_linear = nn.Linear(self._input_size, self._hidden_size)
    self._key_linear = nn.Linear(self._input_size, self._hidden_size)
    self._value_linear = nn.Linear(self._input_size, self._hidden_size)

  def init_keys(self,keys):
    self._keys = keys
    self._proj_keys = self._key_linear(self._keys)
    self._values = self._value_linear(self._keys) if self._proj_values else self._keys

  def score_function(self,query):
    proj_query = self._query_linear(query)
    dot_products = torch.bmm(proj_query, self._proj_keys.permute(0,2,1)) # (N,L,H) -> (N,H,L), (N,1,H)*(N,H,L) -> (N,1,L)
    scores = dot_products/np.sqrt(self.d_k)
    # softmax is missing here?!
    return scores

  def forward(self, query, mask=None):

    score = self.score_function(query)
    if mask is not None:
      score.masked_fill_(mask, -1e9)
    alphas = torch.softmax(score, dim=-1)
    self._alphas = alphas.detach()
    context = torch.bmm(alphas, self._values)
    return context

  def get_alphas(self):
    return self._alphas

In [None]:
class DecoderAttn(nn.Module):
    def __init__(self, output_size: int, embedding_dim: int, hidden_size: int, num_layers: int =1) -> None:
      """ Initialize DecoderAttn class

      :param output_size: Size of the output vocabulary
      :param embedding_dim: Dimension of the embedding layer
      :param hidden_size: Dimension of the hidden state
      :param num_layers: Number of layers in the RNN
      """
      super(DecoderAttn, self).__init__()
      self._embedding = nn.Embedding(output_size, embedding_dim)
      self._rnn = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True)
      self._fc = nn.Linear(2*hidden_size, output_size)
      self._attn = Attention(hidden_size)

    def init_hidden(self, encoder_output):
      """Initialize hidden state for decoder.

      :param encoder_output: Output tensor of shape (batch_size, seq_len, hidden_size)
      """
      self._attn.init_keys(encoder_output)

    def forward(self, x, encoder_hidden, mask=None):
      """Forward pass for deconder

      :param x: Input token IDs, shape (batch_size, 1)
      :param encoder_hidden: Decoder hidden state, shape (num_layers, batch_size, hidden_size)
      :param mask: Attention mask, shape (batch_size, 1, encoder_seq_len)
      :return: output: (batch_size, 1, vocab_size),
      :return: hidden: (num_layers, batch_size, hidden_size)
      """
      embedded = self._embedding(x)
      output, hidden = self._rnn(embedded, encoder_hidden)
      query = output
      context = self._attn(query, mask)
      concatenated = torch.cat([context, query], axis=-1)
      output = self._fc(concatenated)
      return output, hidden

    def get_attention_score(self):
      return self._attn.get_alphas()


## NLP example

In [None]:
class Seq2SeqWithAttension(nn.Module):
    def __init__(self, encoder, decoder, trg_vocab_size):
        super(Seq2SeqWithAttension, self).__init__()
        self._encoder = encoder
        self._decoder = decoder
        self._trg_vocab_size = trg_vocab_size

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """
        src: Source sequence (batch_size, src_seq_len, input_size)
        trg: Target sequence (batch_size, trg_seq_len, output_size)
        teacher_forcing_ratio: Probability of using ground truth during training
        """
        batch_size = src.shape[0]
        trg_len = trg.shape[1]

        outputs = torch.zeros(batch_size, trg_len, self._trg_vocab_size).to(src.device)


        # Encode the source sequence
        output_encoder, hidden = self._encoder(src)
        self._decoder.init_hidden(output_encoder)

        # First decoder input (e.g., <SOS> token in NLP)
        input_step = trg[:, 0].unsqueeze(1)
        mask = (src != 0).unsqueeze(1)

        for t in range(1, trg_len):

          output, hidden = self._decoder(input_step, hidden, mask)
          outputs[:, t, :] = output.squeeze(1)

          # Get the predicted token ID (greedy)
          top1 = output.argmax(2)  # shape: (batch_size, 1)

          # Decide whether to use teacher forcing
          teacher_force = torch.rand(1).item() < teacher_forcing_ratio
          input_step = trg[:, t].unsqueeze(1) if teacher_force else top1  # token IDs only!

        return outputs

In [None]:
def train_seq2seq(model, dataloader, optimizer, criterion, device, num_epochs=10, pad_idx=0):
    model = model.to(device)
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        for src_batch, tgt_batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            src_batch = src_batch.to(device)
            tgt_batch = tgt_batch.to(device)

            # Reset gradients
            optimizer.zero_grad()

            # Forward pass
            # Assume model returns output of shape: (batch_size, tgt_len, vocab_size)
            output = model(src_batch, tgt_batch)

            # Shift target to get correct "next word" targets
            output = output[:, 1:, :].contiguous()  # ignore <SOS> prediction
            target = tgt_batch[:, 1:].contiguous()  # ignore <SOS> input

            # Flatten for loss: (batch_size * tgt_len, vocab_size) vs (batch_size * tgt_len)
            output = output.view(-1, output.size(-1))
            target = target.view(-1)

            # Compute loss (ignore pad tokens)
            loss = criterion(output, target)

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader):.4f}")


In [None]:
encoder = Encoder(input_size=en_tokenizer.vocab_size, embedding_dim=256, hidden_size=512, num_layers=1)
decoder = DecoderAttn(output_size=de_tokenizer.vocab_size, embedding_dim=256, hidden_size=512, num_layers=1)
model = Seq2SeqWithAttension(encoder, decoder, trg_vocab_size=de_tokenizer.vocab_size)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)  # 0 is the padding index
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_seq2seq(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    criterion=criterion,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    num_epochs=15,
    pad_idx=0
)

Epoch 1:  21%|██        | 52/250 [58:51<3:43:39, 67.77s/it]

In [None]:
def translate_en_to_de(sentence, model, en_tokenizer, de_tokenizer, device, max_len=50):
    model.eval()
    tokens = en_tokenizer.encode(sentence, add_special_tokens=True, padding=True)

    max_index = max(tokens)
    if max_index >= len(en_tokenizer.vocab):
        raise ValueError(f"Token index {max_index} out of bounds for vocab size {len(en_tokenizer.vocab)}")

    src_tensor = torch.tensor(tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        encoder_output, hidden = model._encoder(src_tensor)
        model._decoder.init_hidden(encoder_output)

    input_token = torch.tensor([[de_tokenizer.vocab["<SOS>"]]]).to(device)
    mask = (src_tensor != en_tokenizer.vocab["<PAD>"]).unsqueeze(1)

    translated_tokens = []

    for _ in range(max_len):
        with torch.no_grad():
            output, hidden = model._decoder(input_token, hidden, mask)
        pred_token = output.argmax(2).item()

        if pred_token == de_tokenizer.vocab["<EOS>"]:
            break
        translated_tokens.append(pred_token)
        input_token = torch.tensor([[pred_token]]).to(device)

    return de_tokenizer.decode(translated_tokens)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
english_sentence = "The weather is very good today."
print("English:", english_sentence)
print("German:", translate_en_to_de(english_sentence, model, en_tokenizer, de_tokenizer, device))