In [1]:
from gptopt.optim.attn_utils import * 
from gptopt.optim.linop import * 
from gptopt.gpt_model import *
from einops import rearrange, einsum
from utils import *
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

from gptopt.utils import set_seed
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
maxit = 1000

# Linea operator and its adjoint
$$\mathcal{A}(Z) = 
\begin{bmatrix}
    (Z^1_1)^\top A^1_1 + (A^1_2)^\top Z^1_2 \\
    \vdots \\
    (Z^h_1)^\top A^h_1 + (A^h_2)^\top Z^h_2
\end{bmatrix},
\qquad 
\mathcal{A}^*(Y) = (A^1_1Y_1^\top, A^1_2Y_1, \ldots, A^h_1Y_h^\top, A^h_2Y_h)
$$

### Vectorization
$$
K = [(A_1^\top \otimes I_n)P, I_n \otimes A_2^\top],
\qquad 
K^\top = \begin{bmatrix}
P^\top(A_1 \otimes I_n) \\
I_n \otimes A_2
\end{bmatrix},
$$
where $P$ is a permutation matrix s.t $P\text{vec}(Z^\top) = \text{vec}(Z)$

In [4]:
n_head = 5
dtype = torch.float64
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    # print(f"{m}x{n}")
    for _ in range(5):  
        A1 = torch.randn((n_head * m, n), device=device).to(dtype)
        A2 = torch.randn((n_head * m, n), device=device).to(dtype)
        Z = torch.randn((2 * n_head * m, n), device=A2.device, dtype=A2.dtype)
        Y = torch.randn((n_head * n, n), device=A2.device, dtype=A2.dtype)

        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)

        Az = A_linop.matvec(Z)
        Aty = A_linop.rmatvec(Y) 
        tr1 = (Az * Y).sum()
        tr2 = (Z * Aty).sum()
        assert torch.allclose(tr1, tr2)

        A1_heads, A2_heads = A1_A2_unpack_heads(A_linop.A1, A_linop.A2, n_head)
        Z1_heads, Z2_heads = Z_unpack_Z1_Z2_heads(Z, n_head=n_head)
        Y_heads = rearrange(Y, "(n_head n_embd1) n_embd2 -> n_head n_embd1 n_embd2",
                             n_head=n_head)

        vec_Kz = torch.empty_like(Az.flatten())
        vec_Kty = torch.empty_like(Aty.flatten())

        Az_heads = rearrange(Az, "(n_head n_emb1) n_emb2 -> n_head n_emb1 n_emb2", n_head=n_head)
        Aty_heads = rearrange(Aty, "(n_head zs n_att) n_embd -> n_head zs n_att n_embd", n_head=n_head, zs=2)
        vec_Az = torch.empty_like(Az.flatten())
        vec_Aty = torch.empty_like(Aty.flatten())
        for h in range(n_head):
            K = matcal_A_to_kron_Kron(A1_heads[h], A2_heads[h]) 
            Kz = K @ torch.cat([Z1_heads[h].reshape(-1), Z2_heads[h].T.reshape(-1)], dim=0)
            vec_Kz[h*n**2 : (h+1)*n**2] = Kz
            vec_Az[h*n**2 : (h+1)*n**2] = Az_heads[h].T.reshape(-1)
 
            KTy = K.T @ Y_heads[h].T.reshape(-1)
            vec_Kty[2*h*m*n : 2*(h+1)*m*n] = KTy
            vec_Aty[2*h*m*n : 2*h*m*n + m*n] = Aty_heads[h, 0].reshape(-1)
            vec_Aty[2*h*m*n + m*n : 2*(h+1)*m*n] = Aty_heads[h, 1].T.reshape(-1)
        
        assert torch.allclose(vec_Kz, vec_Az, atol=1e-5)
        assert torch.allclose(vec_Kty, vec_Aty, atol=1e-5)


print("PASSED")

PASSED


# Kronecker matrix, and packing / unpacking

In [5]:
n_head = 3
dtype = torch.float64
for (m, n) in [(15, 30), (30, 15), (20, 20)]:
    # print(f"{m}x{n}")
    for _ in range(5):  
        A1 = torch.randn((n_head * m, n), device=device).to(dtype)
        A2 = torch.randn((n_head * m, n), device=device).to(dtype)
        Z = torch.randn((2 * n_head * m, n), device=A2.device, dtype=A2.dtype)
        Y = torch.randn((n_head * n, n), device=A2.device, dtype=A2.dtype)

        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head) 
        
        K = kron_mat_A_linop_heads(A1, A2, n_head)

        z = pack_Z(Z, m, n, n_head)
        Az = pack_Y(A_linop.mv(Z), n, n_head)
        assert torch.allclose(K @ z, Az, rtol=1e-5)

        y = pack_Y(Y, n, n_head)
        Aty = pack_Z(A_linop.rmv(Y), m, n, n_head)
        assert torch.allclose(K.T @ y, Aty, rtol=1e-5)
 

print("PASSED")

PASSED


In [6]:
n_head = 7
for (m, n) in [(35, 60), (20, 45), (60, 20), (50, 50)]:
    for _ in range(5):
        Z = torch.randn(2*m*n_head, n) 
        Z1, Z2 = rearrange(Z_unpack_Z1_Z2_heads(Z, n_head=n_head), 
                           "zs n_head n_att n_embd -> zs (n_head n_att) n_embd")
        assert torch.allclose(Z, Z1_Z2_pack_Z_heads(Z1, Z2, n_head=n_head))

print("PASSED")

PASSED


## Diagonal scaling s.t. $\|R^{1/2} A \Gamma^{1/2}\|_{op} < 1$

In [7]:
n_head = 4

In [16]:
num_cases=20; seed=1234; verbose=True
torch.manual_seed(seed)
fails = 0
for t in range(num_cases):
    n  = torch.randint(3, 10, ()).item()    # keep small (explicit K is O(n^3))
    m = torch.randint(2, 10, ()).item() 
    A2 = torch.randn(n_head * m, n)
    A1 = torch.randn(n_head * m, n)
    A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)

    # scaling
    eta=0.99; eps=1e-8
    Rm, Gamma = diagonal_scaling_heads_op_norm(A_linop, eta=eta, eps=eps)
    G1, G2 = rearrange(Gamma, "(n_head zs n_att) n_embed -> zs n_head n_att n_embed", n_head=n_head, zs=2)

    # explicit K
    K = kron_mat_A_linop_heads(A_linop.A1, A_linop.A2, n_head)
    R_diag_sqrt = torch.diag(pack_Y(torch.sqrt(Rm), n, n_head))
    G_diag_sqrt = torch.diag(pack_Z(torch.sqrt(Gamma), m, n, n_head))
    # R and G matrices from K
    rows = K.norm(dim=1, p=1)   # (h*n*n,)
    cols = K.norm(dim=0, p=1)   # (2*h*m*n,)
    R2 = torch.where(rows > 0, eta**0.5 / torch.sqrt(rows + eps), torch.zeros_like(rows))
    G2 = torch.where(cols > 0, eta**0.5 / torch.sqrt(cols + eps), torch.zeros_like(cols))
    assert torch.allclose(R2, R_diag_sqrt.diagonal(), rtol=1e-3) \
        and torch.allclose(G2, G_diag_sqrt.diagonal(), rtol=1e-3)
    # bounds and spectral norm 
    sigma_max = torch.linalg.svdvals(R_diag_sqrt @ K @ G_diag_sqrt)[0].item() 
    sigma_powit = op_norm_power_iteration(A_linop, Rm.sqrt(), Gamma.sqrt(), num_iters=500)
    ok = (sigma_max <= eta*(1+1e-8))
    if verbose:
        print(f"[{t:02d}] n={n} m={m}  "
                f"{sigma_max=:.6f} {sigma_powit=:.6f}  -> {'OK' if ok else 'FAIL'}")
    fails += 0 if ok else 1
print(f"\nSummary: {num_cases - fails} / {num_cases} passed (eta={eta*(1+1e-8)}).")

[00] n=9 m=5  sigma_max=0.774346 sigma_powit=0.774346  -> OK
[01] n=4 m=6  sigma_max=0.866952 sigma_powit=0.866952  -> OK
[02] n=4 m=5  sigma_max=0.868909 sigma_powit=0.868909  -> OK
[03] n=8 m=2  sigma_max=0.935572 sigma_powit=0.935506  -> OK
[04] n=9 m=7  sigma_max=0.762751 sigma_powit=0.762751  -> OK
[05] n=9 m=7  sigma_max=0.756849 sigma_powit=0.756849  -> OK
[06] n=4 m=7  sigma_max=0.825347 sigma_powit=0.825347  -> OK
[07] n=7 m=2  sigma_max=0.917056 sigma_powit=0.917056  -> OK
[08] n=3 m=6  sigma_max=0.876481 sigma_powit=0.876481  -> OK
[09] n=3 m=9  sigma_max=0.835567 sigma_powit=0.835567  -> OK
[10] n=8 m=8  sigma_max=0.729278 sigma_powit=0.729278  -> OK
[11] n=3 m=8  sigma_max=0.845582 sigma_powit=0.845582  -> OK
[12] n=5 m=3  sigma_max=0.891582 sigma_powit=0.891582  -> OK
[13] n=6 m=8  sigma_max=0.757297 sigma_powit=0.757297  -> OK
[14] n=4 m=9  sigma_max=0.805240 sigma_powit=0.805240  -> OK
[15] n=3 m=5  sigma_max=0.911359 sigma_powit=0.910819  -> OK
[16] n=3 m=6  sigma_max=

# Test slicing in `CausalSelfAttention` for $\mathcal{A}$

In [9]:
C = n_embed = 768   
n_head = 12 
B = batch_size = 32
T = context_length = 1024

attn = CausalSelfAttention(GPTConfig(n_embd=n_embed, n_head=n_head))

In [10]:
p = attn.c_attn.weight
# A1 = W_q, A2 = W_k
A1, A2   = p[:n_embed, :], p[n_embed:2 * n_embed, :]
x = torch.randn((batch_size, context_length, n_embed), device=p.device, dtype=p.dtype)

In [11]:
qkv = attn.c_attn(x)
q, k, v = rearrange(qkv, "batch seqlen (size n) -> size batch seqlen n", size=3)

q2, k2, v2 = qkv.split(n_embed, dim=2)
assert torch.allclose(q, q2) and torch.allclose(k, k2) and torch.allclose(v, v2)

# split over n_head
q_heads = rearrange(q, "batch seqlen (n_head n_att) -> batch n_head seqlen n_att", n_head=n_head)
k_heads = rearrange(k, "batch seqlen (n_head n_att) -> batch n_head seqlen n_att", n_head=n_head)
v_heads = rearrange(v, "batch seqlen (n_head n_att) -> batch n_head seqlen n_att", n_head=n_head)

k2 = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
q2 = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
v2 = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) 
assert torch.allclose(q_heads, q2) and torch.allclose(k_heads, k2) and torch.allclose(v_heads, v2)

In [12]:
qk2 = (q2 @ k2.transpose(-2, -1)) 
qk = einsum(q_heads, k_heads, "batch n_head seqlen1 n_att, batch n_head seqlen2 n_att\
            -> batch n_head seqlen1 seqlen2")
assert torch.allclose(qk, qk2)

In [13]:
# (n_head, n_att, n_embd) 
A1_heads, A2_heads = A1_A2_unpack_heads(A1, A2, n_head)
Wq_heads, Wk_heads = A1_heads, A2_heads 
XWq = einsum(x, Wq_heads, "batch seqlen n_embd, n_head n_att n_embd \
                           -> batch n_head seqlen n_att")
XWk = einsum(x, Wk_heads, "batch seqlen n_embd, n_head n_att n_embd \
                           -> batch n_head seqlen n_att")
QK = einsum(XWq, XWk, "batch n_head seqlen1 n_att, batch n_head seqlen2 n_att\
            -> batch n_head seqlen1 seqlen2")
assert torch.allclose(QK, qk)

In [14]:
for _ in range(5):
    attn = CausalSelfAttention(GPTConfig(n_embd=n_embed, n_head=n_head))
    x = torch.randn((batch_size, context_length, n_embed), device=p.device, dtype=p.dtype)

    res2 = attn(x)

    p = attn.c_attn.weight
    # A1 = W_q, A2 = W_k
    A1, A2, Wv   = p[:n_embed, :], p[n_embed:2 * n_embed, :], p[2 * n_embed:3 * n_embed, :]
    # (n_head, n_att, n_embd) 
    A1_heads, A2_heads = A1_A2_unpack_heads(A1, A2, n_head)
    Wq_heads, Wk_heads = A1_heads, A2_heads
    Wv_heads = rearrange(Wv, "(n_head n_att) n_embd -> n_head n_att n_embd",
                    n_head=n_head) 
    q = einsum(x, Wq_heads, "batch seqlen n_embd, n_head n_att n_embd \
                            -> batch n_head seqlen n_att")
    k = einsum(x, Wk_heads, "batch seqlen n_embd, n_head n_att n_embd \
                            -> batch n_head seqlen n_att")
    v = einsum(x, Wv_heads, "batch seqlen n_embd, n_head n_att n_embd \
                            -> batch n_head seqlen n_att")
    cos, sin = attn.rope.get_embed(T, x.device, x.dtype)
    q = apply_rotary_pos_emb(q, cos, sin)
    k = apply_rotary_pos_emb(k, cos, sin)

    res = einsum(q, k, "batch n_head seqlen1 n_att, batch n_head seqlen2 n_att\
                -> batch n_head seqlen1 seqlen2") / (k.shape[-1])**0.5
    res = res.masked_fill(attn.causal_mask[:,:,:T,:T] == 0, float('-inf'))
    res = F.softmax(res, dim=-1)       
    res = einsum(res, v, "batch n_head seqlen1 seqlen2, batch n_head seqlen2 n_embd \
                        -> batch n_head seqlen1 n_embd")     
    res = rearrange(res, "batch n_head seqlen n_embd -> batch seqlen (n_head n_embd)")
    res = attn.c_proj(res)

    assert torch.allclose(res, res2)

print("PASSED")

PASSED
