 <!--
 Author : Harshitha Machiraju
 Date: 20/09/2024
   -->




**Author** : Harshitha Machiraju

 **Date**: 20/09/2024


# Transformers

In [1]:
import os
import time
from copy import deepcopy

import math
from typing import Tuple, Optional
import torch
from torch import Tensor, nn
from torch.utils.data import dataset
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from unittest import TestCase

import torch.nn.functional as F

import os
import tempfile
import random
import numpy as np

In [3]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention module.

    This module implements the Multi-Head Attention mechanism as described in 
    "Attention Is All You Need" (Vaswani et al., 2017). It projects the input into
    multiple heads, applies scaled dot-product attention, and then concatenates
    the results.

    Args:
        dim (int): The input and output dimension of the model.
        n_heads (int): The number of attention heads.
        dropout (Optional[float]): Dropout probability. Defaults to None (no dropout).

    Attributes:
        dim (int): The input and output dimension.
        n_heads (int): The number of attention heads.
        head_dim (int): The dimension of each attention head.
        dropout (nn.Dropout): Dropout layer.
        linear_query (nn.Linear): Linear projection for query.
        linear_key (nn.Linear): Linear projection for key.
        linear_value (nn.Linear): Linear projection for value.
        linear_cat_attn (nn.Linear): Final linear projection after attention.
    """
    def __init__(self, dim: int, n_heads: int, dropout: Optional[float] = None):
        super(MultiHeadSelfAttention, self).__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        assert n_heads * self.head_dim == dim, f"embedding dim={dim} not divisible by n_heads={n_heads}."
        self.dropout = nn.Dropout(p=dropout if dropout is not None else 0.0)
        self.linear_query = nn.Linear(dim, dim)
        self.linear_key = nn.Linear(dim, dim)
        self.linear_key = nn.Linear(dim, dim)
        self.linear_cat_attn = nn.Linear(dim, dim)
        # self.max_positions = 512
        # self.register_buffer("tril", torch.tril(torch.ones(1, 1, self.max_positions, self.max_positions)))
        # self.T = 0

    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """
        Applies multi-head attention to input tensors.

        Args:
            q (Tensor): Query tensor of shape (batch_size, seq_len, dim).
            k (Tensor): Key tensor of shape (batch_size, seq_len, dim).
            v (Tensor): Value tensor of shape (batch_size, seq_len, dim).
            mask (Optional[Tensor]): Attention mask tensor. Defaults to None.

        Returns:
            Tensor: Output tensor after applying multi-head attention.
        """

        # Linear projections
        batch_size, T, _ = q.size()
        self.T = T
        q = self.linear_query(q).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.linear_key(k).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.linear_key(v).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Apply attention
        if mask is not None:
            # print(mask.shape)
            mask = mask.repeat(self.n_heads, 1, 1)
            # mask = mask.unsqueeze(1).unsqueeze(2)

        attn_output = self.attention(q, k, v, mask)
        
        # Concatenate heads and apply final linear projection
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.dim)

        return self.linear_cat_attn(attn_output)

    def attention(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        d_k = query.size()[-1]
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        # Below line not needed due to the mask shape not matching the scores shape !!
        # if mask is not None:
            # scores = scores.masked_fill(mask == 0, float('-inf'))
            # scores = scores.masked_fill(self.tril[:,:,:self.T,:self.T] == 0, float('-inf')) # this works
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        return torch.matmul(attn, value)

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, attn, d_model: int, dim_feedforward: int, layer_norm_eps=1e-5, dropout=0.1):
        """
        Class for the Encoder block of the transformer model.
        :param attn: multi-head self-attention layer
        :param d_model: hidden dimension of the input tensor
        :param dim_feedforward: hidden dimension of the feedforward network
        :param layer_norm_eps: epsilon for layer normalization
        :param dropout: dropout rate
        
        """
        super(EncoderBlock, self).__init__()
        self.attn = MultiHeadSelfAttention(d_model, d_model, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.activation = nn.ReLU()

        self.linear_model = nn.Sequential(
            self.linear1,
            self.dropout,
            self.activation,
            self.linear2,
            self.dropout
        )


    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass of the EncoderBlock.
        :param x: input tensor
        :param mask: attention mask
        :return: output tensor of the encoder block
        """
        attn_x = self._self_attn(x, mask)
        x = self.norm1(x + attn_x)

        ff_x = self._feed_forward(x)
        return self.norm2(ff_x)

    def _self_attn(self, x: Tensor, mask: Optional[Tensor]) -> Tensor:
        """
        Apply self-attention to the input tensor.

        :param x: input tensor
        :param mask: attention mask
        :return: tensor after applying self-attention and dropout
        """
        x = self.attn(x, x, x, mask)
        return self.dropout(x)

    def _feed_forward(self, x: Tensor) -> Tensor:
        """
        Apply feed-forward network to the input tensor.

        :param x: input tensor
        :return: tensor after applying feed-forward network and dropout
        """
        x = self.dropout(self.activation(self.linear1(x)))
        return self.dropout(self.linear2(x))

In [5]:
class TransformerEncoder(nn.Module):
    def __init__(self, encoder: EncoderBlock, n_blocks: int):
        super(TransformerEncoder, self).__init__()
        self.encoder_blocks = nn.ModuleList([deepcopy(encoder) for _ in range(n_blocks)])

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        for encoder in self.encoder_blocks:
            x = encoder(x, mask)
        return x

In [6]:
def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

The code below is given as-is, please contact the examinator if there is any issue running this, unrelated to your code above.

In [7]:
class PositionalEncodingTorch(nn.Module):
    """
    Positional encoding module for the Transformer model.

    This module adds positional encoding to the input tensor. The positional encoding
    is learned as a part of the model.

    Args:
        d_model (int): The input and output dimension of the model.
        dropout (float): Dropout probability. Defaults to 0.1.
        max_len (int): Maximum length of the input sequence. Defaults to 5000.

    Attributes:
        dropout (nn.Dropout): Dropout layer.
        pe (Tensor): Positional encoding tensor.
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [8]:
class TransformerModelManualAttn(nn.Module):
    """
    Transformer model for language modeling. With EncodingBlock, positional encoding and Attention implemented manually.
    Args:
        emsize (int): The embedding dimension.
        ntoken (int): The size of the vocabulary.
        d_model (int): The hidden dimension.
        nhead (int): The number of attention heads.
        d_hid (int): The hidden dimension of the feedforward network.
        nlayers (int): The number of encoder layers.
        dropout (float): Dropout probability. Defaults to 0.5.

    """

    def __init__(self, emsize:int, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.pos_encoder = PositionalEncodingTorch(d_model, dropout)
        encoder_layers = EncoderBlock(attn=MultiHeadSelfAttention(dim=emsize, n_heads=nhead, dropout=dropout),
                                      d_model=d_model, dim_feedforward=d_hid, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        """
        Perform forward pass through the transformer model.
        
        :param src: Source tensor
        :param src_mask: Source mask tensor
        :return: Model output
        """
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        if src_mask is None:
            # print("Generating mask", src.size(0))
            src_mask = generate_square_subsequent_mask(src.size(0)).to(src.device)
        output = self.transformer_encoder(src, src_mask)
        output = self.linear(output)
        return output

In [9]:
def data_process(raw_text_iter: dataset.IterableDataset, vocab, tokenizer) -> Tensor:
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def batchify(data: Tensor, bsz: int, device: torch.device = None) -> Tensor:
    seq_len = data.shape[0] // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)


def get_batch(source: Tensor, i: int, bptt: int) -> Tuple[Tensor, Tensor]:
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i + seq_len]
    target = source[i + 1:i + 1 + seq_len].reshape(-1)
    return data, target

In [10]:
def train(
        model,
        train_data: Tensor,
        bptt: int,
        criterion,
        ntokens: int,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        epoch: int = 0,
        device: torch.device = None,
        use_causal_mask: bool = True
) -> None:
    model.train()
    total_loss = 0.
    log_interval = 1
    start_time = time.time()
    src_mask = None  # We'll generate the mask for each batch
    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.shape[0] - 1, bptt)):
        data, targets = get_batch(train_data, i, bptt=bptt)
        # print(f"Batch {batch}:, data shape: {data.shape}, target shape: {targets.shape}")
        if use_causal_mask:
            src_mask = generate_square_subsequent_mask(data.size(0)).to(device)
        
        # print(f"Batch {batch}:")
        # print(f"  Data shape: {data.shape}")
        # print(f"  Mask shape: {src_mask.shape if src_mask is not None else 'None'}")
        
        output = model(src=data, src_mask=src_mask)
        loss = criterion(output.view(-1, ntokens), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f}'
                  f' | ppl {ppl:8.2f}'
                  )
            total_loss = 0
            start_time = time.time()

In [11]:
def evaluate(
        model,
        eval_data: Tensor,
        bptt: int,
        ntokens: int,
        criterion,
        device: torch.device = None,
        use_causal_mask: bool = True,
) -> float:
    model.eval()
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device) if use_causal_mask else None
    with torch.no_grad():
        for i in range(0, eval_data.shape[0] - 1, bptt):
            data, targets = get_batch(eval_data, i, bptt=bptt)
            if data.shape[0] < bptt and src_mask is not None:
                src_mask = generate_square_subsequent_mask(data.shape[0]).to(device)
            # print(data, src_mask)
            output = model(data, src_mask=src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += data.shape[0] * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

In [12]:
## Load and batch data
train_iter = PennTreebank(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

train_iter, val_iter, test_iter = PennTreebank()
train_data = data_process(train_iter, vocab=vocab, tokenizer=tokenizer)
val_data = data_process(val_iter, vocab=vocab, tokenizer=tokenizer)
test_data = data_process(test_iter, vocab=vocab, tokenizer=tokenizer)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

bptt = 35
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size, device=device)
val_data = batchify(val_data, eval_batch_size, device=device)
test_data = batchify(test_data, eval_batch_size, device=device)

# Model Parameters
ntokens = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.1
use_causal_mask = True

# Create Model
model = TransformerModelManualAttn(emsize=emsize, ntoken= ntokens, d_model= d_hid, nhead=nhead, d_hid=d_hid, nlayers=nlayers, dropout=dropout).to(device)
### Run model
criterion = nn.CrossEntropyLoss()
lr = 5.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr, )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

best_val_loss = float('inf')
epochs = 10

# Train
home = os.path.join(os.path.expanduser("~"), "transformer_test")
save_dir = os.path.join(os.path.join(home, 'pytorch_example')) # feel free to change path
os.makedirs(save_dir, exist_ok=True)
best_model_params_path = os.path.join(save_dir, "best_model_params.pt")

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model, train_data, bptt, criterion, ntokens, optimizer, scheduler, epoch, device, use_causal_mask)
    val_loss = evaluate(model, val_data, bptt, ntokens, criterion, device, use_causal_mask)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_params_path)

    scheduler.step()
model.load_state_dict(torch.load(best_model_params_path))  # load best model states

# Test
test_loss = evaluate(model, test_data, bptt, ntokens, criterion, device)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

| epoch   1 |     1/ 1320 batches | lr 5.00 | ms/batch 453.86 | loss 18.08 | ppl 71114753.31
| epoch   1 |     2/ 1320 batches | lr 5.00 | ms/batch 147.39 | loss  9.28 | ppl 10755.84
| epoch   1 |     3/ 1320 batches | lr 5.00 | ms/batch 159.54 | loss  9.74 | ppl 16946.52
| epoch   1 |     4/ 1320 batches | lr 5.00 | ms/batch 209.50 | loss 11.02 | ppl 60844.06
| epoch   1 |     5/ 1320 batches | lr 5.00 | ms/batch 160.96 | loss 10.67 | ppl 42830.99
| epoch   1 |     6/ 1320 batches | lr 5.00 | ms/batch 154.59 | loss  9.65 | ppl 15476.97
| epoch   1 |     7/ 1320 batches | lr 5.00 | ms/batch 149.53 | loss  9.09 | ppl  8887.89
| epoch   1 |     8/ 1320 batches | lr 5.00 | ms/batch 148.10 | loss  9.55 | ppl 13978.37
| epoch   1 |     9/ 1320 batches | lr 5.00 | ms/batch 150.11 | loss  9.33 | ppl 11217.22
| epoch   1 |    10/ 1320 batches | lr 5.00 | ms/batch 150.93 | loss  9.88 | ppl 19500.56
| epoch   1 |    11/ 1320 batches | lr 5.00 | ms/batch 149.12 | loss  9.50 | ppl 13415.97
| epoch

KeyboardInterrupt: 

# Notes

### Calculating Attention Complexity
$Q, K, V ∈ \mathbb{R}^{n×d}$ are the query, key, and value matrices, where $n$ is the sequence length and $d$ is the hidden dimension.

$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d}})V$

#### Computational Complexity calculation: 
1. *Matrix Multiplication:* $QK^T$ is a matrix multiplication of size $n×n$. --> Cost of $O(n^2d)$
2. *Scaling:* Divide by $\sqrt{d}$. --> Cost of $O(n^2)$
3. *Softmax:* Compute softmax along rows. --> Cost of $O(n^2)$
4. *Weighted Sum:* Multiply softmax scores with $V$. --> Cost of $O(n^2d)$

Assuming $d$ is a constant and $d <<< n$, the total cost $O(n^2)$.

#### Memory Cost
- Is also $O(n^2)$ memory complexity
- Since this the largest storage is for the attention matrix, whose size is $n×n$.




### Ideas to Reduce Attention Costs

1. **Sparsity:**
   - Sparse Attention: $\text{Attention}(Q, K, V) = \text{softmax}(\frac{(Q \odot M)K^T}{\sqrt{d}})V, \quad M \in \{0,1\}^{n \times n}, |M| = O(n)$ (from Sparse Transformer like BigBird)
   - Element wise product only has cost $O(n)$ instead of $O(n^2)$ for matrix multiplication.


2. **Lower rank approximations:**
   If the attention matrix is low-rank, it can be approximated using a low-rank factorization. This reduces the complexity to **O(n)**.
   As follows $\text{Attention}(Q, K, V) \approx \text{softmax}(\frac{E^T(FV)}{\sqrt{d}}), \quad E,F \in \mathbb{R}^{k \times n}, k \ll n$ (from Linformer)


3. **Memory-Efficient Attention:**
   Similar to how DCT can be perfromed in blocks to speeden up computation, attention can also be performed in blocks.
   
   $\text{Attention}(Q, K, V) = \text{BlockSoftmax}(\frac{QK^T}{\sqrt{d}})V, \quad \text{computed in blocks } B_{ij} \in \mathbb{R}^{b \times b}$ (from FlashAttention)


4. **Hierarchical Approaches**
   - Similar to how images can be processed at different scales and then combined across scales of the input, attention can also be processed at different heirachies. In this case, the heirachies can be processed in parallel and then combined.Thus, the complexity is reduced. The heirachies would be different levels of abstraction of the input/local-global features etc. 



here we used causal masking for the attention as follows : $$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}} \odot M)V$$

Where $$M_{ij} = \begin{cases} 0 & \text{if } i > j \\ -\infty & \text{otherwise} \end{cases}$$

Mathematically, this ensures that the attention mechanism only looks at the past tokens and not the future tokens. This is useful for autoregressive models like LSTMs, Transformers etc. where the model should not have access to future tokens.
However, we can have other versions of masking like:

1. **Bidirectional Masking:**
   This is the opposite of causal masking. Here, the model can look at both past and future tokens. This is useful for models like BERT where the model should have access to both past and future tokens.
    $$M_{ij} = 1 \quad \forall i,j$$

2. **Block Masking:**
   Where b is the block size. This groups tokens into blocks, allowing for efficient processing of long sequences.
   $$M_{ij} = \begin{cases} 0 & \text{if } \lfloor i/b \rfloor = \lfloor j/b \rfloor \\ -\infty & \text{otherwise} \end{cases}$$
3. **Random Masking:**
   Randomly mask some tokens. This is useful for training robust models that can handle missing tokens.
   $$M_{ij} = \begin{cases} 0 & \text{with probability } p \\ -\infty & \text{otherwise} \end{cases}$$

4. **Window Masking:**
   This restricts attention to a local context, useful for tasks with primarily local dependencies.
   $$M_{ij} = \begin{cases} 0 & \text{if } |i-j| \leq w \\ -\infty & \text{otherwise} \end{cases}$$
Where w is the window size. 