In [None]:
import torch
import torch.nn.functional as F

# ----------
# Hyperparams
# ----------
batch_size = 1
seq_len    = 2048    # Large "context window"
d_model    = 32    # Embedding dimension
intrinsic_rank = 128 # We'll artificially force Q, K, V to have rank <= 8
proj_k     = seq_len // 4     # Linformer projection dimension

torch.manual_seed(0)

# ---------------------------------------------------------------------
# 1) Create Low-Rank Q, K, V
#    We'll do Q = (Q_base) x (Q_proj) so that rank(Q) <= intrinsic_rank
# ---------------------------------------------------------------------
def make_low_rank_matrix(seq_len, d_model, intrinsic_rank):
    """
    Create a matrix of shape (seq_len, d_model) with max rank = intrinsic_rank.
    We'll create it by multiplying:
       [seq_len x intrinsic_rank] x [intrinsic_rank x d_model]
    """
    base  = torch.randn(seq_len, intrinsic_rank)
    proj  = torch.randn(intrinsic_rank, d_model)
    full  = base @ proj  # shape: (seq_len, d_model)
    return full

# For demonstration, keep it batch_size=1
Q = make_low_rank_matrix(seq_len, d_model, intrinsic_rank).unsqueeze(0)  # (1, seq_len, d_model)
K = make_low_rank_matrix(seq_len, d_model, intrinsic_rank).unsqueeze(0)  # (1, seq_len, d_model)
V = make_low_rank_matrix(seq_len, d_model, intrinsic_rank).unsqueeze(0)  # (1, seq_len, d_model)

# ---------------------------------------------
# 2) Vanilla Attention (Full Softmax Attention)
# ---------------------------------------------
# QK^T has shape [batch, seq_len, seq_len]
scores_vanilla = torch.bmm(Q, K.transpose(1,2)) / (d_model**0.5)
attn_weights_vanilla = F.softmax(scores_vanilla, dim=-1)  # [1, seq_len, seq_len]
output_vanilla = torch.bmm(attn_weights_vanilla, V)       # [1, seq_len, d_model]

print("=== Vanilla Attention ===")
print("Attention matrix shape:", attn_weights_vanilla.shape, "=> (batch, seq_len, seq_len)")
print("Output shape:", output_vanilla.shape, "=> (batch, seq_len, d_model)")

# --------------------------------------------------------
# 3) Demonstrate (Approx.) Low Rank via Singular Values
#    We'll do SVD on a single attention matrix instance
# --------------------------------------------------------
# Extract the attention matrix for the single batch dimension
A = attn_weights_vanilla[0]  # shape (seq_len, seq_len)
# Torch has torch.linalg.svd or torch.svd. We'll use torch.linalg.svd:
U, S, Vh = torch.linalg.svd(A)

print("\nSingular Values of the Vanilla Attention Matrix:")
print(S)

# Print how many singular values are "significant"
threshold = 1e-3
significant_sv = (S > threshold).sum().item()
print(f"Number of singular values > {threshold}: {significant_sv} out of {seq_len}")

# ----------------------------------------------------------------
# 4) Linformer Attention: Project K, V along sequence dimension
# ----------------------------------------------------------------
# E_K, E_V are trainable or fixed in practice; here, random for demo
E_K = torch.randn(seq_len, proj_k)
E_V = torch.randn(seq_len, proj_k)

# Project K and V to shape: [batch, seq_len, proj_k] (conceptually)
# But because we have them as (b, seq_len, d_model), we do an einsum:
K_linformer = torch.einsum('bsd,sp->bpd', K, E_K)  # => (b, proj_k, d_model)
V_linformer = torch.einsum('bsd,sp->bpd', V, E_V)  # => (b, proj_k, d_model)

# Then compute Q x K_linformer^T => [b, seq_len, d_model] x [b, d_model, proj_k] => [b, seq_len, proj_k]
scores_linformer = torch.bmm(Q, K_linformer.transpose(1,2)) / (d_model**0.5)
attn_weights_linformer = F.softmax(scores_linformer, dim=-1)  # [b, seq_len, proj_k]
output_linformer = torch.bmm(attn_weights_linformer, V_linformer)  # [b, seq_len, d_model]

print("\n=== Linformer Attention ===")
print("Projected K shape:", K_linformer.shape, "-> (batch, proj_k, d_model)")
print("Projected V shape:", V_linformer.shape, "-> (batch, proj_k, d_model)")
print("QK'^T shape:", scores_linformer.shape, "-> (batch, seq_len, proj_k)")
print("Linformer attention matrix shape:", attn_weights_linformer.shape, "-> (batch, seq_len, proj_k)")
print("Linformer final output shape:", output_linformer.shape, "-> (batch, seq_len, d_model)")

# ----------------------------------------------
# 5) Demonstrate Reduced Memory / Compare Shapes
# ----------------------------------------------
print("\nMemory (Full attn matrix): seq_len x seq_len =", seq_len * seq_len)
print("Memory (Linformer attn matrix): seq_len x proj_k =", seq_len * proj_k)
