Implement MuonClip (qk-clip) for Stabilizing Attention
Medium
Deep Learning

Write a function that applies the MuonClip qk-clip step to attention projection weights W_q and W_k. Given input features x, a threshold t, and a mixing parameter alpha, your function must: (1) compute the maximum pre-clip QK score (scaled by 1/sqrt(d_head)); (2) if this max score exceeds t, rescale W_q and W_k by eta^alpha and eta^(1-alpha), where eta = t / max_score; (3) return the (possibly) rescaled weights, a boolean indicating whether clipping occurred, and the post-clip max score. Round all returned floating values to 4 decimals for reproducibility in tests.

Example:
Input:
W_q = [[2.0, 0.0],[0.0, 2.0]]
W_k = [[2.0, 0.0],[0.0, 2.0]]
x   = [[[1.0, 0.0],[0.0, 1.0]]]  # (batch=1, seq=2, d_model=2)
print(muonclip_qk_clip(W_q, W_k, x, t=1.0, alpha=0.5))
Output:
([[1.1892, 0.0], [0.0, 1.1892]], [[1.1892, 0.0], [0.0, 1.1892]], True, 1.0)
Reasoning:
Pre-clip max score is 4/√2 ≈ 2.8284 (> 1.0). With η = 1/2.8284 and α = 0.5, both W_q and W_k are scaled by √η, bringing the max score to t.



In [None]:
import numpy as np

def muonclip_qk_clip(W_q: np.ndarray, W_k: np.ndarray, x: np.ndarray, t: float, alpha: float = 0.5, eps: float = 1e-7):
    """
    Apply MuonClip qk-clip to (W_q, W_k).

    Args:
        W_q: (d_head, d_model) query projection weights
        W_k: (d_head, d_model) key  projection weights
        x:   (batch, seq, d_model) input features
        t: threshold for max QK score (after 1/sqrt(d_head) scaling)
        alpha: fraction of rescaling applied to W_q (remainder to W_k)
        eps: small epsilon to avoid division by zero

    Returns:
        (W_q_new, W_k_new, clipped, max_post)
        where W_q_new/W_k_new are lists of lists with 4-decimal rounding,
        clipped is a bool, and max_post is a float rounded to 4 decimals.
    """
    d_head, d_model = W_q.shape
    assert W_k.shape == (d_head, d_model), "W_k must match W_q shape"
    assert x.shape[-1] == d_model, "x last dim must equal d_model"

    def max_qk_score(Wq, Wk):
        # q,k: (B, S, d_head)
        q = x @ Wq.T
        k = x @ Wk.T
        # scores per batch: (B, S, S) = q @ k^T
        # scale by 1/sqrt(d_head)
        scale = np.sqrt(max(d_head, eps))
        scores = np.matmul(q, np.transpose(k, (0,1,2)).swapaxes(1,2)) / scale
        return float(np.max(scores)) if scores.size else 0.0

    # 1) pre-clip max score
    max_pre = max_qk_score(W_q, W_k)

    # 2) decide clipping and rescale if needed
    if max_pre > t + eps:
        eta = t / (max_pre + eps)  # shrink factor (<1)
        # rescale W_q, W_k with mixing alpha
        W_q_new = W_q * (eta ** alpha)
        W_k_new = W_k * (eta ** (1.0 - alpha))
        clipped = True
    else:
        W_q_new = W_q
        W_k_new = W_k
        clipped = False

    # 3) post-clip max score
    max_post = max_qk_score(W_q_new, W_k_new)

    # round outputs to 4 decimals
    W_q_out = [[round(float(v), 4) for v in row] for row in W_q_new]
    W_k_out = [[round(float(v), 4) for v in row] for row in W_k_new]
    return W_q_out, W_k_out, clipped, round(float(max_post), 4)


# Example
if __name__ == "__main__":
    W_q = np.array([[2.0, 0.0],[0.0, 2.0]])
    W_k = np.array([[2.0, 0.0],[0.0, 2.0]])
    x   = np.array([[[1.0, 0.0],[0.0, 1.0]]])  # (B=1,S=2,D=2)
    print(muonclip_qk_clip(W_q, W_k, x, t=1.0, alpha=0.5))
    # ([[1.1892, 0.0], [0.0, 1.1892]], [[1.1892, 0.0], [0.0, 1.1892]], True, 1.0)