In [3]:
import torch
from torch.nn import functional as F
from torch import nn
import math
import os
import shutil
from sklearn.model_selection import train_test_split
import os


In [None]:
class SelfAttention(nn.Module):
  """
  Multi-head Self-Attention mechanism that computes attention weights
  between all positions in a sequence and produces weighted outputs.
  """
  def __init__(self, n_heads, embed_dim, in_proj_bias = True, out_proj_bias = True ):
    super().__init__()
    self.n_heads = n_heads  # Number of attention heads
    # Linear layer to project input to query, key, value (3 * embed_dim for all three)
    self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias = in_proj_bias)
    # Linear layer to project output back to original embedding dimension
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias = out_proj_bias)

    # Dimension of each attention head
    self.d_head = embed_dim // n_heads

  def forward(self, x, casual_mask = False):
    """
    Forward pass for multi-head self-attention.
    
    Args:
        x: Input tensor of shape (batch_size, seq_len, embed_dim)
        casual_mask: Boolean flag to apply causal masking for autoregressive models
    
    Returns:
        output: Attention output of shape (batch_size, seq_len, embed_dim)
    """
    batch_size, seq_len, d_emed = x.shape

    # Shape to reshape into multiple heads: (batch_size, seq_len, n_heads, d_head)
    interim_shape = (batch_size, seq_len, self.n_heads, self.d_head)
    
    # Project input to query, key, value and split into 3 tensors
    q,k,v = self.in_proj(x).chunk(3, dim = -1)
    
    # Reshape Q, K, V to separate the attention heads
    q= q.view(interim_shape)
    k= k.view(interim_shape)
    v= v.view(interim_shape)

    # Transpose to move head dimension next to batch dimension: (batch_size, n_heads, seq_len, d_head)
    q = q.transpose(1,2)
    k = k.transpose(1,2)
    v = v.transpose(1,2)

    # Compute attention scores: (batch_size, n_heads, seq_len, seq_len)
    weight = q @ k.transpose(-1, -2)

    # Apply causal mask to prevent attention to future tokens in autoregressive models
    if casual_mask:
      mask = torch.ones_like(weight, dtype = torch.bool).triu(1)
      weight.masked_fill_(mask, -torch.inf)

    # Scale attention scores by square root of head dimension
    weight = weight / math.sqrt(self.d_head)

    # Apply softmax to get attention weights
    weight = F.softmax(weight, dim = -1)

    # Apply attention weights to values: (batch_size, n_heads, seq_len, d_head)
    output = weight @ v
    
    # Transpose back: (batch_size, seq_len, n_heads, d_head)
    output = output.transpose(1,2)
    
    # Reshape to concatenate heads: (batch_size, seq_len, embed_dim)
    output = output.reshape(batch_size, seq_len, d_emed)

    # Final linear projection
    output = self.out_proj(output)

    return output