In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, einsum

#GQA Implementation

(MQA can be implemented by setting num_head_groups=1 in GQA, so that all query heads use a single key and value head)

In [None]:
class GroupedQueryAttention(nn.Module):
  '''
  GQA divides query heads into G groups; each group shares a single key value and head
  '''
  def __init__(self, hidden_size, num_heads, num_head_groups):
    super().__init__()
    assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
    assert num_heads % num_head_groups == 0, "num_heads must be divisible by num_head_groups"

    self.hidden_size=hidden_size
    self.num_heads=num_heads
    self.num_head_groups=num_head_groups

    self.head_dim=hidden_size//num_heads
    self.kv_heads=num_heads//num_head_groups #each query group (num_heads//num_head_groups) only needs one kv head

    self.query = nn.Linear(hidden_size, hidden_size)
    self.key = nn.Linear(hidden_size, hidden_size)
    self.value = nn.Linear(hidden_size, hidden_size)

    self.fc_c = nn.Linear(hidden_size, hidden_size)

    self.scale=(self.head_dim ** 0.5) #done to keep the weight variance = 1, so that the softmax distribution is flatter

  def forward(self, query, key, value, mask=None):
    batch_size, seq_len, _ = query.shape

    Q = self.query(query)
    K = self.key(key)
    V = self.value(value)

    #(batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_q_heads, head_dim)
    Q = rearrange(Q, "b s (h d) -> b s h d", h=self.num_heads)

    #(batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_kv_heads, head_dim)
    K = rearrange(K, "b s (h d) -> b s h d", h=self.kv_heads)
    V = rearrange(V, "b s (h d) -> b s h d", h=self.kv_heads)

    #split heads into groups
    Q = rearrange(Q, "b s (h g) d -> b (h g) s d", g=self.num_head_groups)
    K = rearrange(K, "b s (h g) d -> b (h g) s d", g=self.num_head_groups)
    V = rearrange(V, "b s (h g) d -> b (h g) s d", g=self.num_head_groups)

    scores = einsum(Q, K, "b h g s d, b h g d s -> b h g s s") * self.scale #[b h g n s d] * [b h g n d s] -> [b h g n s s]

    if mask is not None:
            scores = scores.masked_fill(mask == False, float('-inf'))

    attn_weights = F.softmax(scores, dim=-1)
    attn_output = einsum(attn_weights, V, "b h g s s, b h g s d -> b h g s d") #[b h g s s] * [b h g s d] -> [b h g s d]
    attn_output = rearrange(attn_output, "b (h g) s d -> b s (h d)", h=self.num_heads) #[b h g s d] -> [b s h d]

    output=self.fc_out(attn_output)
    return output

#SWA Impmenetation

In [None]:
class SlidingWindowAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, window_size):
        super().__init__()
        assert hidden_size % num_heads == 0
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.window_size = window_size  #Number of past tokens to attend to (window_overlap in https://amaarora.github.io/posts/2024-07-04%20SWA.html)

        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        self.fc_out = nn.Linear(hidden_size, hidden_size)

        self.scale = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

    def _chunk(self, hidden_states, window_overlap):
        """Convert into overlapping chunks. Chunk size = 2w, overlap = w, but only for past tokens."""
        batch_size, seq_len, _ = hidden_states.size()
        chunk_size = 2 * window_overlap
        num_chunks = (seq_len + window_overlap - 1) // window_overlap

        chunked = torch.zeros(batch_size, num_chunks, chunk_size, hidden_states.size(-1), device=hidden_states.device) #[batch_size, num_chunks, chunk_size, hidden_size]

        for chunk_idx in range(num_chunks):
            start = max(0, chunk_idx * window_overlap)
            end = min(seq_len, start + chunk_size)
            chunk_len = end - start
            chunked[:, chunk_idx, :chunk_len, :] = hidden_states[:, start:end, :]

        return chunked

    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.size()

        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        window_overlap = self.window_size
        chunk_size = 2 * window_overlap
        Q_chunked = self._chunk(Q, window_overlap)
        K_chunked = self._chunk(K, window_overlap)
        V_chunked = self._chunk(V, window_overlap)

        #[batch_size, num_chunks, chunk_size, num_heads, head_dim] * [batch_size, num_chunks, chunk_size, num_heads, head_dim] -> [batch_size, num_chunks, num_heads, chunk_size, chunk_size]
        scores = torch.einsum("bnchd,bncyhd->bnchxy", Q_chunked, K_chunked.transpose(-2, -1)) / self.scale

        # Create causal mask for chunks
        num_chunks = scores.size(1)
        chunk_mask = torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=scores.device)
        for i in range(chunk_size):
            chunk_mask[i, i+1:] = False

        #[chunk_size, chunk_size] -> [batch_size, num_chunks, num_heads, chunk_size, chunk_size]
        chunk_mask = chunk_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(batch_size, num_chunks, self.num_heads, -1, -1)
        scores = scores.masked_fill(chunk_mask == False, float('-inf'))

        global_causal_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=scores.device)
        for i in range(seq_len):
            global_causal_mask[i, i+1:] = False

        full_scores = torch.full((batch_size, self.num_heads, seq_len, seq_len), float('-inf'), device=scores.device)
        for chunk_idx in range(num_chunks):
            start = chunk_idx * window_overlap
            end = min(seq_len, start + chunk_size)
            chunk_len = end - start
            full_scores[:, :, start:end, start:end] = scores[:, chunk_idx, :, :chunk_len, :chunk_len]

        global_causal_mask = global_causal_mask.unsqueeze(0).unsqueeze(0).expand(
            batch_size, self.num_heads, -1, -1)
        full_scores = full_scores.masked_fill(global_causal_mask == False, float('-inf'))

        if mask is not None:
            full_scores = full_scores.masked_fill(mask == False, float('-inf'))

        attn_weights = F.softmax(full_scores, dim=-1)  #[batch_size, num_heads, seq_len, seq_len]

        V = V.transpose(1, 2)  #[batch_size, num_heads, seq_len, head_dim]
        attn_output = torch.matmul(attn_weights, V)  #[batch_size, num_heads, seq_len, head_dim]

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size) #[batch_size, num_heads, seq_len, head_dim] -> [batch_size, seq_len, hidden_size]

        output = self.fc_out(attn_output)  #[batch_size, seq_len, hidden_size]
        return output