# Introduction to Attention Mechanism

The Attention mechanism is a technique that allows neural networks to focus on specific parts of the input sequence when making predictions. It was introduced to address the limitations of traditional sequence-to-sequence models like LSTMs (Long Short-Term Memory) and GRUs (Gated Recurrent Units), which often struggle with long-range dependencies in sequences.

## How Attention Works

The core idea of the Attention mechanism is to compute a weighted sum of all input elements, where the weights are dynamically calculated based on the relevance of each input element to the current output element being generated. This allows the model to "attend" to different parts of the input sequence as needed.

### Steps Involved in Attention Mechanism:

1. **Score Calculation**: For each output time step, calculate a score for each input element. This score represents the relevance of the input element to the current output element.
2. **Softmax**: Apply the softmax function to the scores to obtain attention weights. These weights sum to 1 and indicate the importance of each input element.
3. **Context Vector**: Compute the context vector as the weighted sum of the input elements, using the attention weights.
4. **Output Generation**: Use the context vector to generate the output for the current time step.

### Types of Attention Mechanisms:

- **Global Attention**: Considers all input elements when computing the context vector.
- **Local Attention**: Focuses on a subset of input elements, typically around a specific position.

## Improvements Over LSTM/GRU

1. **Handling Long-Range Dependencies**: Attention mechanisms can effectively capture long-range dependencies in sequences, which LSTMs and GRUs often struggle with.
2. **Parallelization**: Attention mechanisms, especially in the Transformer architecture, allow for parallel processing of input sequences, leading to faster training times compared to the sequential nature of LSTMs and GRUs.
3. **Interpretability**: The attention weights provide insights into which parts of the input sequence the model is focusing on, making the model more interpretable.

Overall, the Attention mechanism has significantly improved the performance of sequence-to-sequence models in various tasks such as machine translation, text summarization, and more.

### Example implementation of Machine Translate using Bi-LSTM with Attention

In [1]:
source_sentences = [
    "The weather is nice today.",
    "I love programming.",
    "How old are you?",
    "Where do you live?",
    "What is your favorite color?",
    "I enjoy reading books.",
    "Do you like music?",
    "What time is it?",
    "Can you help me?",
    "I am going to the store.",
    "She is my best friend.",
    "We are having dinner.",
    "This is a beautiful place.",
    "I need to study for my exam.",
    "He is a great teacher.",
    "They are playing soccer.",
    "I want to learn Spanish.",
    "Do you speak English?",
    "I have a pet dog.",
    "She likes to dance.",
    "We are watching a movie.",
    "This is my favorite song.",
    "I am feeling happy today.",
    "He is very talented.",
    "They are traveling to Europe.",
    "I need to buy groceries.",
    "She is reading a novel.",
    "We are going to the beach.",
    "This is an interesting book.",
    "I am learning to cook.",
    "He is working on a project.",
    "They are visiting their grandparents.",
    "I want to go for a walk.",
    "Do you like to swim?",
    "I have a lot of homework.",
    "She is painting a picture.",
    "We are celebrating a birthday.",
    "This is a challenging task.",
    "I am practicing yoga."
]

target_sentences = [
    "Thời tiết hôm nay thật đẹp.",
    "Tôi yêu lập trình.",
    "Bạn bao nhiêu tuổi?",
    "Bạn sống ở đâu?",
    "Màu sắc yêu thích của bạn là gì?",
    "Tôi thích đọc sách.",
    "Bạn có thích âm nhạc không?",
    "Mấy giờ rồi?",
    "Bạn có thể giúp tôi không?",
    "Tôi đang đi đến cửa hàng.",
    "Cô ấy là bạn thân nhất của tôi.",
    "Chúng tôi đang ăn tối.",
    "Đây là một nơi đẹp.",
    "Tôi cần học cho kỳ thi của mình.",
    "Anh ấy là một giáo viên tuyệt vời.",
    "Họ đang chơi bóng đá.",
    "Tôi muốn học tiếng Tây Ban Nha.",
    "Bạn có nói tiếng Anh không?",
    "Tôi có một con chó cưng.",
    "Cô ấy thích nhảy múa.",
    "Chúng tôi đang xem phim.",
    "Đây là bài hát yêu thích của tôi.",
    "Hôm nay tôi cảm thấy hạnh phúc.",
    "Anh ấy rất tài năng.",
    "Họ đang du lịch đến Châu Âu.",
    "Tôi cần mua hàng tạp hóa.",
    "Cô ấy đang đọc một cuốn tiểu thuyết.",
    "Chúng tôi đang đi đến bãi biển.",
    "Đây là một cuốn sách thú vị.",
    "Tôi đang học nấu ăn.",
    "Anh ấy đang làm việc trên một dự án.",
    "Họ đang thăm ông bà của họ.",
    "Tôi muốn đi dạo.",
    "Bạn có thích bơi không?",
    "Tôi có rất nhiều bài tập về nhà.",
    "Cô ấy đang vẽ một bức tranh.",
    "Chúng tôi đang tổ chức sinh nhật.",
    "Đây là một nhiệm vụ khó khăn.",
    "Tôi đang tập yoga."
]

# Create random translation data
translation_data = [{"source": src, "target": tgt} for src, tgt in zip(source_sentences, target_sentences)]

# Save into a list
train_data = list(translation_data)

# Print example data
for example in train_data[:5]:  # Show first 5 examples
    print(example)

{'source': 'The weather is nice today.', 'target': 'Thời tiết hôm nay thật đẹp.'}
{'source': 'I love programming.', 'target': 'Tôi yêu lập trình.'}
{'source': 'How old are you?', 'target': 'Bạn bao nhiêu tuổi?'}
{'source': 'Where do you live?', 'target': 'Bạn sống ở đâu?'}
{'source': 'What is your favorite color?', 'target': 'Màu sắc yêu thích của bạn là gì?'}


In [3]:
import torch
import numpy as np 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [4]:
# Pre-process data 
word2index = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}

# add source_sentences & target_sentences to word2index
for example in train_data:
    for word in example["source"].split():
        if word not in word2index:
            word2index[word] = len(word2index)
    for word in example["target"].split():
        if word not in word2index:
            word2index[word] = len(word2index)

index2word = {index: word for word, index in word2index.items()}
vocab_size = len(word2index)

class MachineTranslateDataset:

    def __init__(self, data, word2index, max_length=50):
        self.data = data
        self.word2index = word2index
        self.max_length = max_length

    def __len__(self):
        return len(self.data)
    
    def pad_sequence(self, sequence, max_len):
        # Truncate if sequence is longer than max_len
        if len(sequence) > max_len:
            return sequence[:max_len]
        # Pad with <pad> token if sequence is shorter
        else:
            return sequence + [self.word2index["<pad>"]] * (max_len - len(sequence))
    
    def __getitem__(self, index):
        src = self.data[index]["source"]
        trg = self.data[index]["target"]
        
        # Convert words to indices
        src_indexes = [self.word2index.get(word, self.word2index["<unk>"]) 
                      for word in src.split()]
        trg_indexes = [self.word2index.get(word, self.word2index["<unk>"]) 
                      for word in trg.split()]
        
        # Add <sos> and <eos> tokens
        src_indexes = [self.word2index["<sos>"]] + src_indexes + [self.word2index["<eos>"]]
        trg_indexes = [self.word2index["<sos>"]] + trg_indexes + [self.word2index["<eos>"]]
        
        # Pad sequences
        src_indexes = self.pad_sequence(src_indexes, self.max_length)
        trg_indexes = self.pad_sequence(trg_indexes, self.max_length)
        
        return (torch.tensor(src_indexes, dtype=torch.long), 
                torch.tensor(trg_indexes, dtype=torch.long))


In [5]:
machine_translate_dataset =  MachineTranslateDataset(translation_data, word2index)
machine_translate_dataset.__getitem__(0)

(tensor([1, 4, 5, 6, 7, 8, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]),
 tensor([ 1,  9, 10, 11, 12, 13, 14,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]))

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.LSTM(emb_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x shape: (batch_size, seq_length)
    
        embedded = self.dropout(self.embedding(x))
        # embedded = [ batch size, seq_length, emb dim]
        
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = [batch size, seq_length, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        
        # n directions =  1 if LSTM, n directions = 2 if bidirectional LSTM
        return outputs, hidden, cell

encoder = Encoder(vocab_size, 256, 512, 2, 0.5)
outputs, hidden, cell = encoder(torch.tensor([[1, 2, 3, 4, 5]]))
print(outputs.shape, hidden.shape, cell.shape)


torch.Size([1, 5, 512]) torch.Size([2, 1, 512]) torch.Size([2, 1, 512])


# Luong Attention Mechanism - Quick Overview

## Core Concept
Luong attention calculates attention weights between encoder and decoder hidden states to help the decoder focus on relevant input parts during sequence generation.

## Key Components

### 1. Three Score Functions
$$
\text{Dot:} \quad score(h_t, \bar{h}_s) = h_t^T\bar{h}_s
$$
$$
\text{General:} \quad score(h_t, \bar{h}_s) = h_t^TW_a\bar{h}_s
$$
$$
\text{Concat:} \quad score(h_t, \bar{h}_s) = v_a^T\tanh(W_a[h_t;\bar{h}_s])
$$

### 2. Attention Weight Calculation
$$
\alpha_{ts} = \text{softmax}(score(h_t, \bar{h}_s))
$$

### 3. Context Vector
$$
c_t = \sum_s \alpha_{ts}\bar{h}_s
$$

## Key Features
1. Offers both global and local attention variants
2. Uses input feeding for richer representations
3. Simpler than Bahdanau attention
4. Computationally efficient

## Common Use Cases
- Neural Machine Translation
- Text Summarization
- Question Answering
- Speech Recognition

## Implementation Tips
1. Start with dot product scoring
2. Use global attention for short sequences
3. Consider local attention for long sequences
4. Monitor attention weights during training

In [32]:
class LuongAttention(nn.Module):
    
    def __init__(self, hidden_dim):
        super().__init__()
        
        self.hidden_dim = hidden_dim

    def forward(self, decoder_hidden, encoder_outputs):

        """
        Compute dot attention weights and weighted sum of encoder outputs

        Args:
            decoder_hidden: Current decoder hidden state
                Shape: [n layers * n directions, batch size, hidden dim]
            encoder_outputs: All encoder hidden states
                Shape: [batch_size, seq_len, hidden_dim]
        
        Returns:
            attention_weights: Attention weights for each encoder state
                Shape: [batch_size, seq_len]
        """

        latest_decode_hidden = decoder_hidden[-1].unsqueeze(1)
        # latest_decode_hidden = [batch_size, 1, hidden_dim] x [batch_size, hidden_dim, seq_len] = [batch_size, 1, seq_len]

        scores = torch.bmm(latest_decode_hidden, encoder_outputs.transpose(1, 2)).squeeze(1)

        attention_weights = F.softmax(scores, dim=1)

        return attention_weights
        

In [33]:
# Example test dot LuongAttention
attention = LuongAttention(512)
example_hidden = torch.randn(3, 8, 256)
example_encoder_outputs = torch.randn(8, 10, 256)
attention_weights = attention(example_hidden, example_encoder_outputs)
print(attention_weights.shape)


torch.Size([8, 10])


In [38]:
import random
class Decoder(nn.Module):
    
    def __init__(self, output_dim, hidden_dim, n_layers, dropout, attention):
        super().__init__()
        
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = dropout
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, hidden_dim)
        
        self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, n_layers, dropout = dropout)
        
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
        # input = [batch size]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # encoder_outputs = [batch size, seq length, hidden dim]
        input = input.unsqueeze(0)
        
        embedded = self.dropout(self.embedding(input))
        # embedded = [1, batch size, hidden dim]
        
        attention_weights = self.attention(hidden, encoder_outputs)
        # attention_weights = [batch size, seq length]
        
        attention_weights = attention_weights.unsqueeze(1)
        # attention_weights = [batch size, 1, seq length]
        
        weighted = torch.bmm(attention_weights, encoder_outputs)
        # weighted = [batch size, 1, hidden dim]
        
        weighted = weighted.permute(1, 0, 2)
        # weighted = [1, batch size, hidden dim]
        
        rnn_input = torch.cat((embedded, weighted), dim=2)
        # rnn_input = [1, batch size, hidden dim * 2]
        
        output, hidden = self.rnn(rnn_input, hidden)
        # output = [seq length, batch size, hidden dim]
        
        prediction = self.fc_out(output.squeeze(0))
        # prediction = [batch size, output dim]
        
        return prediction, hidden

class Seq2SeqAtt(nn.Module):

    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src = [batch size, src len]
        # trg = [batch size, src len]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        # tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        # encoder_outputs is all hidden states of the input sequence, back and forwards
        # hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden, cell = self.encoder(src)

        # first input to the decoder is the <sos> tokens
        input = trg[:, 0]

        for t in range(1, trg_len):
            
            output, hidden = self.decoder(input, hidden, encoder_outputs)

            outputs[:, t] = output

            teacher_force = random.random() < teacher_forcing_ratio

            top1 = output.argmax(1) 

            input = trg[:, t] if teacher_force else top1

        return outputs


In [41]:
# Define the model
VOCAB_SIZE = len(word2index)
EMB_DIM = 100
HIDDEN_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_DROP = 0.5
DEC_DROP = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc = Encoder(VOCAB_SIZE, EMB_DIM, HIDDEN_DIM, ENC_LAYERS, ENC_DROP).to(device)
attn = LuongAttention(HIDDEN_DIM).to(device)
dec = Decoder(VOCAB_SIZE, HIDDEN_DIM, DEC_LAYERS, DEC_DROP, attn).to(device)
model = Seq2SeqAtt(enc, dec, device).to(device)
print(model)

# Create Dataloader
from torch.utils.data import DataLoader
data_loader = DataLoader(machine_translate_dataset, batch_size=8, shuffle=True)

# define hypermeters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
num_epochs = 200
learning_rate = 0.001

# Define optimizer and loss function
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for i, (source, target) in enumerate(data_loader):
        optimizer.zero_grad()
        source = source.to(device)
        target = target.to(device)
        output = model(source, target)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        target = target[1:].view(-1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % 100 == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
    
    # Print epoch statistics
    avg_loss = total_loss / len(data_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

Seq2SeqAtt(
  (encoder): Encoder(
    (embedding): Embedding(243, 100)
    (rnn): LSTM(100, 256, num_layers=3, batch_first=True, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): LuongAttention()
    (embedding): Embedding(243, 256)
    (rnn): GRU(512, 256, num_layers=3, dropout=0.5)
    (fc_out): Linear(in_features=256, out_features=243, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)
Epoch [1/200], Average Loss: 3.4211
Epoch [2/200], Average Loss: 1.0490
Epoch [3/200], Average Loss: 0.9375
Epoch [4/200], Average Loss: 0.8560
Epoch [5/200], Average Loss: 0.8272
Epoch [6/200], Average Loss: 0.7661
Epoch [7/200], Average Loss: 0.7305
Epoch [8/200], Average Loss: 0.7366
Epoch [9/200], Average Loss: 0.6932
Epoch [10/200], Average Loss: 0.6986
Epoch [11/200], Average Loss: 0.6657
Epoch [12/200], Average Loss: 0.6929
Epoch [13/200], Average Loss: 0.6838
Epoch [14/200], Average Loss: 0.6788
Epoch [15/200], Average Loss: 0.6586


In [62]:

# Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for i, (source, target) in enumerate(data_loader):
        optimizer.zero_grad()
        source = source.to(device)
        target = target.to(device)
        output = model(source, target)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        target = target[1:].view(-1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % 100 == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
    
    # Print epoch statistics
    avg_loss = total_loss / len(data_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

Epoch [1/200], Average Loss: 0.1909
Epoch [2/200], Average Loss: 0.1747
Epoch [3/200], Average Loss: 0.1881
Epoch [4/200], Average Loss: 0.1881
Epoch [5/200], Average Loss: 0.1856
Epoch [6/200], Average Loss: 0.2001
Epoch [7/200], Average Loss: 0.2165
Epoch [8/200], Average Loss: 0.1817
Epoch [9/200], Average Loss: 0.1991
Epoch [10/200], Average Loss: 0.1782
Epoch [11/200], Average Loss: 0.1945
Epoch [12/200], Average Loss: 0.2172
Epoch [13/200], Average Loss: 0.1879
Epoch [14/200], Average Loss: 0.2111
Epoch [15/200], Average Loss: 0.1766
Epoch [16/200], Average Loss: 0.1734
Epoch [17/200], Average Loss: 0.1690
Epoch [18/200], Average Loss: 0.1674
Epoch [19/200], Average Loss: 0.2014
Epoch [20/200], Average Loss: 0.1994
Epoch [21/200], Average Loss: 0.1878
Epoch [22/200], Average Loss: 0.1956
Epoch [23/200], Average Loss: 0.2017
Epoch [24/200], Average Loss: 0.1917
Epoch [25/200], Average Loss: 0.1895
Epoch [26/200], Average Loss: 0.1788
Epoch [27/200], Average Loss: 0.1843
Epoch [28/

In [63]:
def pad_sequence(sequence, max_len, word2index):
    if len(sequence) > max_len:
        return sequence[:max_len]
    else:
        return sequence + [word2index["<pad>"]] * (max_len - len(sequence))

def translate(model, src_sentence, word2index, device, max_length=50):

    model.eval()
    # Convert sentence to indices
    src_indexes = [word2index.get(token, word2index["<unk>"]) for token in src_sentence]
    src_indexes = [word2index["<sos>"]] + src_indexes + [word2index["<eos>"]]
    src_indexes = pad_sequence(src_indexes, max_length, word2index)

    # Convert to tensor
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)

    with torch.no_grad():
        # Get encoder outputs
        encoder_outputs, hidden, cell = model.encoder(src_tensor)
        
        # Initialize decoder input with <sos> token
        decoder_input = torch.LongTensor([word2index["<sos>"]]).to(device)
        
        # Store all decoder outputs
        decoded_words = []
        
        for _ in range(max_length):
            # Run decoder for one step
            decoder_output, hidden = model.decoder(decoder_input, hidden, encoder_outputs)
            
            # Get the most likely word
            topv, topi = decoder_output.squeeze().data.topk(1)
            decoded_token = topi.item()
            
            # Add the token to results
            if decoded_token == word2index["<eos>"]:
                break
            elif decoded_token == word2index["<pad>"]:
                continue
            else:
                # Convert index back to word
                decoded_words.append(
                    next(word for word, index in word2index.items() 
                        if index == decoded_token)
                )
            
            # Next input is the decoded token
            decoder_input = torch.LongTensor([decoded_token]).to(device)
    
    return decoded_words

In [64]:
# Example usage:
translated = translate(
    model=model,
    src_sentence="I am going to the store.".split(),
    word2index=word2index,
    device=device,
    max_length=50
)
print(" ".join(translated))

Tôi đang đi đến cửa hàng.
