In [1]:
from gptopt.optim.attn_utils import *
from gptopt.optim.least_squares import *
from gptopt.optim.linop import * 
from gptopt.optim.least_squares_torchmin import * 

from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

from gptopt.utils import set_seed
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
maxit = 1000

# Pack and unpack functions

In [4]:
n_head = 7
for (m, n) in [(35, 60), (20, 45), (60, 20), (50, 50)]:
    for _ in range(5):
        Z = torch.randn(2*m*n_head, n)
        z = pack_Z(Z, m, n, n_head=n_head)
        assert torch.allclose(pack_Z(unpack_Z(z, m, n, n_head=n_head), m, n, n_head=n_head), z)

        Z1, Z2 = rearrange(Z_unpack_Z1_Z2_heads(Z, n_head=n_head), 
                           "zs n_head n_att n_embd -> zs (n_head n_att) n_embd")
        assert torch.allclose(Z, Z1_Z2_pack_Z_heads(Z1, Z2, n_head=n_head))

        Y = torch.randn(n_head*n, n)
        y = pack_Y(Y, n, n_head=n_head)
        assert torch.allclose(pack_Y(unpack_Y(y, n, n_head=n_head), n, n_head=n_head), y)

print("PASSED")

PASSED


# Minimize $\|\mathcal{A}^*(Y) + G\|_F^2$

In [5]:
maxit = 1000
n_head = 7
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 1
        rank_ratio = 0.5
        A2, A1, G1, G2 = gaussian_data_heads(m, n, n_head=n_head, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=True)


        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)
        Grad = Z1_Z2_pack_Z_heads(G1, G2, n_head=n_head)
        
        Y_hat, res_lsmr, itn = solve_lsmr_Y_lstsq(A_linop, Grad, maxiter=maxit, n_head=n_head, atol=1e-12, btol=1e-12)
        print(f"LSMR: res={res_lsmr:.4e}, {itn=}") 

30x60
LSMR: res=4.4960e-07, itn=60
LSMR: res=6.3945e-07, itn=60
LSMR: res=5.3619e-07, itn=60
LSMR: res=5.6447e-07, itn=60
LSMR: res=7.0572e-07, itn=60
60x30
LSMR: res=2.4606e-07, itn=40
LSMR: res=3.8404e-07, itn=40
LSMR: res=2.3281e-07, itn=40
LSMR: res=3.6399e-07, itn=40
LSMR: res=3.6958e-07, itn=40
60x60
LSMR: res=2.5039e-06, itn=820
LSMR: res=2.0960e-06, itn=850
LSMR: res=2.5452e-06, itn=830
LSMR: res=1.9902e-06, itn=560
LSMR: res=2.5514e-06, itn=890


In [6]:
n_head = 5
for (m, n) in [(50, 50)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 1
        rank_ratio = 0.5
        A2, A1, G1, G2 = gaussian_data_heads(m, n, n_head=n_head, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=False)

        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)
        Grad = Z1_Z2_pack_Z_heads(G1, G2, n_head=n_head) 
        Y_hat, res_lsmr, itn = solve_lsmr_Y_lstsq(A_linop, Grad, n_head=n_head, maxiter=maxit)
        print(f"LSMR: res={res_lsmr:.4e}, {itn=}")

50x50
LSMR: res=7.1175e-01, itn=270
LSMR: res=7.1125e-01, itn=360
LSMR: res=7.0737e-01, itn=340
LSMR: res=7.0771e-01, itn=460
LSMR: res=7.1287e-01, itn=550


# Minimize $\|\mathcal{A}(Z) +\beta \mathbf{sign}(Y^0)\|_F^2$, where  $\mathcal{A}(Z) = Z_1^\top A_1 + A_2^\top Z_2$

In [7]:
beta = 0.1
n_head = 7
for (m, n) in [(10, 10), (30, 30), (40, 40)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 1
        rank_ratio = 0.5
        A2, A1, G1, G2 = gaussian_data_heads(m, n, n_head=n_head, std1=std1, std2=std2, 
                                             rank_ratio=rank_ratio, G_in_range=True)
        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)
        Grad = Z1_Z2_pack_Z_heads(G1, G2, n_head=n_head)
        
        Y0, res_lsmr_y, itn_y = solve_lsmr_Y_lstsq(A_linop, Grad, n_head=n_head, maxiter=maxit) 
        Z_hat, res_lsmr, itn = solve_lsmr_Z_lstsq(A_linop, beta, Y0, n_head=n_head, maxiter=maxit)
        print(f"LSMR: res={res_lsmr:.4e}, {itn=}; {res_lsmr_y=:.4e}, {itn_y=}")

10x10
LSMR: res=2.5851e-05, itn=210; res_lsmr_y=1.3494e-05, itn_y=160
LSMR: res=3.8559e-05, itn=210; res_lsmr_y=1.4018e-05, itn_y=180
LSMR: res=5.5704e-05, itn=250; res_lsmr_y=1.6631e-05, itn_y=260
LSMR: res=7.3756e-06, itn=210; res_lsmr_y=1.7648e-05, itn_y=190
LSMR: res=2.2676e-05, itn=180; res_lsmr_y=1.3461e-05, itn_y=150
30x30
LSMR: res=9.9417e-05, itn=510; res_lsmr_y=2.4625e-05, itn_y=330
LSMR: res=6.6463e-05, itn=490; res_lsmr_y=2.2664e-05, itn_y=360
LSMR: res=7.2364e-05, itn=490; res_lsmr_y=2.1890e-05, itn_y=260
LSMR: res=6.6709e-05, itn=490; res_lsmr_y=2.6328e-05, itn_y=370
LSMR: res=1.3953e-04, itn=560; res_lsmr_y=2.9557e-05, itn_y=460
40x40
LSMR: res=6.3285e-05, itn=470; res_lsmr_y=2.3335e-05, itn_y=290
LSMR: res=1.0972e-04, itn=680; res_lsmr_y=2.9221e-05, itn_y=470
LSMR: res=7.7659e-05, itn=560; res_lsmr_y=2.5834e-05, itn_y=350
LSMR: res=5.1405e-04, itn=1000; res_lsmr_y=2.8799e-05, itn_y=470
LSMR: res=2.2618e-04, itn=800; res_lsmr_y=3.3985e-05, itn_y=610


In [8]:
beta = 0.1
for (m, n) in [(10, 10), (30, 30), (40, 40)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 1
        rank_ratio = 0.5
        A2, A1, G1, G2 = gaussian_data_heads(m, n, n_head=n_head, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=False)
        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)
        Grad = Z1_Z2_pack_Z_heads(G1, G2, n_head=n_head)
        
        Y0, res_lsmr, itn = solve_lsmr_Y_lstsq(A_linop, Grad, n_head=n_head, maxiter=maxit) 
        Z_hat, res_lsmr, itn = solve_lsmr_Z_lstsq(A_linop, beta, Y0, n_head=n_head, maxiter=maxit)
        print(f"LSMR: res={res_lsmr:.4e}, {itn=}") 

10x10
LSMR: res=8.6768e-05, itn=210
LSMR: res=4.1127e-05, itn=220
LSMR: res=3.5852e-05, itn=210
LSMR: res=8.8768e-05, itn=210
LSMR: res=1.2673e-04, itn=180
30x30
LSMR: res=3.9184e-04, itn=560
LSMR: res=7.0624e-04, itn=480
LSMR: res=1.9948e-03, itn=340
LSMR: res=4.6140e-04, itn=400
LSMR: res=2.3296e-03, itn=660
40x40
LSMR: res=7.9479e-04, itn=670
LSMR: res=1.3673e-03, itn=750
LSMR: res=6.6236e-04, itn=750
LSMR: res=4.7189e-04, itn=570
LSMR: res=7.2151e-04, itn=810


In [9]:
beta = 0.1
for (m, n, G_in_range) in [(500, 500, True), (500, 500, False)]:
    print(f"{m}x{n}, {G_in_range=}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 1
        rank_ratio = 1
        A2, A1, G1, G2 = gaussian_data_heads(m, n, n_head=n_head, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=False)

        A_linop = attn_linop_from_matrices_heads(A1, A2, n_head=n_head)
        Grad = Z1_Z2_pack_Z_heads(G1, G2, n_head=n_head)

        Y0, res_lsmr, itn = solve_lsmr_Y_lstsq(A_linop, Grad, n_head=n_head, maxiter=maxit)
        Z_hat, res_lsmr, itn = solve_lsmr_Z_lstsq(A_linop, beta, Y0, n_head=n_head, maxiter=maxit)
        print(f"LSMR: res={res_lsmr:.4e}, {itn=}") 

500x500, G_in_range=True
LSMR: res=1.5792e-02, itn=1000
LSMR: res=1.7599e-02, itn=1000
LSMR: res=1.6666e-02, itn=1000
LSMR: res=2.5672e-02, itn=1000
LSMR: res=1.6349e-02, itn=1000
500x500, G_in_range=False
LSMR: res=1.8495e-02, itn=1000
LSMR: res=1.6003e-02, itn=1000
LSMR: res=1.6839e-02, itn=1000
LSMR: res=1.6437e-02, itn=1000
LSMR: res=2.0513e-02, itn=1000
