# Sliding Window Attention (SWA)

## Introduction

![SWA](Mistral%207B%20SWA.png)

Sliding Window Attention (SWA) is an efficient attention mechanism that restricts each token to attend to only a fixed-size window of neighboring tokens (previous tokens in the case of causal LMs), rather than all tokens. This reduces attention's time and memory complexity from $O(n^2)$ to $O(n \cdot w)$, where $w$ is the window size ($w \ll n$). Since the attention layers are stacked, tokens outside the sliding window still influence next word prediction. At each attention layer, information can move forward by W tokens. Hence, after k attention layers, information can move forward by up to k × W tokens.

In 2020, Allen Institute for AI published Longformer, which combined dilated sliding windows with optional global tokens to achieve linear scaling on long documents.

In 2021, Google Research published BigBird, which used a sliding window + random + global sparsity schema and showed expressivity close to full attention

Mistral 7B, released in 2023, adopted SWA. It uses a window size $W=4096$ and model depth $L=32$, $L \times W = 131k$, effectively covered its 32K max context length.

The formula:

$$
\text{SWA}(Q, K, V)_i = \text{softmax}\left(\frac{Q_i K_{[i-w+1:i]}^T}{\sqrt{d_k}}\right) V_{[i-w+1:i]}
$$


Where:
- $w$ is the window size
- $Q_i$ is the query at position i
- $K_{[i-w+1:i]}$ are the keys within the window
- $V_{[i-w+1:i]}$ are the values within the window

## Implementation

In [1]:
import torch
import torch.nn.functional as F

def sliding_window_attention(Q, K, V, window_size):
    """
    Causal Sliding Window Attention (no dilation).
    Q, K, V: [B, T, H, D]
    window_size: int W (tokens to the left, inclusive of current position)
    Returns: [B, T, H, D]
    """
    B, T, H, D = Q.shape
    device = Q.device

    # We will compute attention in banded blocks using an index trick:
    # For each i, gather keys/values in [i-W+1, i] clipped to [0, T-1].

    # Build indices: idx[i, k] = j in window k for position i
    # k runs from 0..W-1 mapping to j = i - (W-1 - k)
    W = min(window_size, T)
    base = torch.arange(T, device=device).unsqueeze(1).expand(T, W)            # [T, W]
    offsets = torch.arange(-(W-1), 1, device=device).unsqueeze(0).expand(T, W) # [T, W]
    idx = (base + offsets).clamp(min=0)                                        # [T, W], j indices

    # Gather K, V windows for each position
    # K_win, V_win: [B, T, H, W, D]
    K_win = K.gather(dim=1, index=idx.view(1, T, 1, W, 1).expand(B, T, H, W, D))
    V_win = V.gather(dim=1, index=idx.view(1, T, 1, W, 1).expand(B, T, H, W, D))

    # Compute scores: [B, T, H, W]
    # scores_{i,k} = <Q_{i}, K_{idx[i,k]}>
    Q_exp = Q.unsqueeze(3)                      # [B, T, H, 1, D]
    scores = (Q_exp * K_win).sum(dim=-1)        # dot product -> [B, T, H, W]
    scores = scores / (D ** 0.5)

    # Mask out positions > i (causality) that slipped in via clamp at seq start
    # Valid columns per row i are those where idx[i,k] <= i
    arange_T = torch.arange(T, device=device).unsqueeze(1).expand(T, W)  # [T, W]
    causal_mask = (idx <= arange_T)                                      # [T, W]
    scores = scores.masked_fill(~causal_mask.view(1, T, 1, W), float('-inf'))

    # Softmax over the W window entries and weighted sum
    attn = F.softmax(scores, dim=-1)             # [B, T, H, W]
    out = torch.einsum('bthw,bthwd->bthd', attn, V_win)  # [B, T, H, D]
    return out