In [None]:
class MultiHeadLatentAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_latents=8):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.num_latents = num_latents

        # Learnable latent queries
        self.latents = nn.Parameter(torch.randn(1, num_latents, d_model))

        # Projection layers
        self.q_proj = nn.Linear(d_model, d_model)
        self.kv_proj = nn.Linear(d_model, d_model * 2)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape

        # Project keys and values from input
        kv = self.kv_proj(x)  # (B, T, 2*C)
        k, v = kv.chunk(2, dim=-1)

        # Expand latents for each batch
        latents = self.latents.expand(B, -1, -1)  # (B, L, C)
        q = self.q_proj(latents)

        # Reshape for multi-head attention
        q = q.view(B, self.num_latents, self.num_heads, self.d_k).transpose(1, 2)  # (B, H, L, d_k)
        k = k.view(B, T, self.num_heads, self.d_k).transpose(1, 2)                # (B, H, T, d_k)
        v = v.view(B, T, self.num_heads, self.d_k).transpose(1, 2)                # (B, H, T, d_k)

        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)             # (B, H, L, T)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = attn_weights @ v                                            # (B, H, L, d_k)

        out = attn_output.transpose(1, 2).contiguous().view(B, self.num_latents, C)
        return self.out_proj(out)  # (B, L, C)


In [None]:
class TransformerBlockWithLatents(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, num_latents=8):
        super().__init__()
        self.latent_attn = MultiHeadLatentAttention(d_model, num_heads, num_latents)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        latents = self.latent_attn(self.ln1(x))  # (B, L, C)
        # Optionally return just latents, or combine:
        combined = x.mean(dim=1, keepdim=True).repeat(1, latents.size(1), 1) + latents
        return combined + self.ff(self.ln2(combined))


In [None]:
# Fusion of latent + causal self-attention
class HybridAttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_latents=8):
        super().__init__()
        self.token_attn = CausalSelfAttention(d_model, num_heads)
        self.latent_attn = MultiHeadLatentAttention(d_model, num_heads, num_latents)
        self.ff = FeedForward(d_model, d_ff)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, x):
        tokens = self.token_attn(self.ln(x))
        latents = self.latent_attn(x)
        merged = tokens + latents.mean(dim=1, keepdim=True)  # Add a global latent token
        return merged + self.ff(self.ln(merged))
