In [1]:
from gptopt.optim.attn_utils import *
from gptopt.optim.fast_pdhg import *
from gptopt.optim.least_squares import *
from gptopt.optim.linop import *
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

from gptopt.utils import set_seed
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
maxit = 1000

In [4]:
# from torchmin import TorchLinearOperator

# Vectorization of $\mathcal{A}Z = Z_1^TA_1 + A_2^TZ_2$
$$
K = [(A_1^\top \otimes I_n)P, I_n \otimes A_2^\top],
\qquad 
K^\top = \begin{bmatrix}
P^\top(A_1 \otimes I_n) \\
I_n \otimes A_2
\end{bmatrix},
$$
where $P$ is a permutation matrix s.t $P\text{vec}(Z^\top) = \text{vec}(Z)$

In [5]:
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    # print(f"{m}x{n}")
    for _ in range(5): 
        A1 = torch.randn((m, n), device=device)
        A2 = torch.randn((m, n), device=device)
        Z = torch.randn((2 * m, n), device=A2.device, dtype=A2.dtype)
        Z1, Z2 = Z[:m, :], Z[m:, :]

        A_linop = attn_linop_from_matrices(A1, A2)

        K = matcal_A_to_kron_Kron(A1, A2) 
        Kz = K @ torch.cat([Z1.reshape(-1), Z2.T.reshape(-1)], dim=0)
        vecAZ = mathcal_A_linop(A1=A1, A2=A2, Z=Z).T.reshape(-1)
        vecAZ2 = A_linop.matvec(Z).T.reshape(-1)
        assert torch.allclose(Kz, vecAZ2, atol=1e-5)
        assert torch.allclose(Kz, vecAZ, atol=1e-5), print(torch.max(torch.abs(Kz - vecAZ)).item())

        Y = torch.randn((n, n), device=A2.device, dtype=A2.dtype)
        KTy = K.T @ Y.T.reshape(-1)
        AadjY = mathcal_A_adj_linop(A1=A1, A2=A2, Y=Y)
        AadjY2 = A_linop.rmatvec(Y)
        vecAadjY2 = torch.cat([AadjY2[:m].reshape(-1), AadjY2[m:].T.reshape(-1)], dim=0)
        vecAadjY = torch.cat([AadjY[:m].reshape(-1), AadjY[m:].T.reshape(-1)], dim=0)
        assert torch.allclose(KTy, vecAadjY2, atol=1e-5)
        assert torch.allclose(KTy, vecAadjY, atol=1e-5), print(torch.max(torch.abs(KTy - vecAadjY)).item())

        Az = A_linop.matvec(Z)
        Aty = A_linop.rmatvec(Y) 
        tr1 = (Az * Y).sum()
        tr2 = (Z * Aty).sum()
        assert torch.allclose(tr1, tr2)

print("PASSED")

PASSED


# Ruiz equilibration

In [6]:
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 0.01
        rank_ratio = 1
        A2, A1, G1, G2, A2_np, A1_np, G1_np, G2_np, lamb_max = gaussian_data(m, n, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=True)

        K = matcal_A_to_kron_Kron(A1, A2)
        print(f"maxcol={K.abs().max(dim=0).values.max().item():.4e}, maxrow={K.abs().max(dim=1).values.max().item():.4e}")
        R, Gamma1, Gamma2= ruiz_equilibration(A1=A1, A2=A2, num_iters=10)
        tildeK = R.T.reshape(-1)[:, None] * K
        tildeK[:, :m*n] *= Gamma1.T.reshape(-1)[None, :]
        tildeK[:, m*n:] *= Gamma2.T.reshape(-1)[None, :]

        print("maxcol", tildeK.abs().max(dim=0).values.max().item(),
            "maxrow", tildeK.abs().max(dim=1).values.max().item())
 

30x60
maxcol=3.9262e-02, maxrow=3.9262e-02
maxcol 1.7660601139068604 maxrow 1.7660601139068604
maxcol=3.9224e-02, maxrow=3.9224e-02
maxcol 1.7172335386276245 maxrow 1.7172335386276245
maxcol=4.4791e-02, maxrow=4.4791e-02
maxcol 1.9314849376678467 maxrow 1.9314849376678467
maxcol=3.6551e-02, maxrow=3.6551e-02
maxcol 1.7974989414215088 maxrow 1.7974989414215088
maxcol=3.5328e-02, maxrow=3.5328e-02
maxcol 1.6759533882141113 maxrow 1.6759533882141113
60x30
maxcol=3.4277e-02, maxrow=3.4277e-02
maxcol 2.0323734283447266 maxrow 2.0323734283447266
maxcol=4.2954e-02, maxrow=4.2954e-02
maxcol 2.184307813644409 maxrow 2.184307813644409
maxcol=3.4019e-02, maxrow=3.4019e-02
maxcol 1.8383127450942993 maxrow 1.8383127450942993
maxcol=3.7454e-02, maxrow=3.7454e-02
maxcol 1.9406521320343018 maxrow 1.9406521320343018
maxcol=4.1577e-02, maxrow=4.1577e-02
maxcol 2.243189573287964 maxrow 2.243189573287964
60x60
maxcol=3.7261e-02, maxrow=3.7261e-02
maxcol 1.78279709815979 maxrow 1.78279709815979
maxcol=3.60

# Basic tests

In [7]:
def cvxpy_prox_l1(x0, rho, R=None):
    x = cp.Variable(x0.shape)
    obj = rho * cp.sum(cp.abs(x)) 
    if R is not None:
        W = 1 / R**0.5
        obj += (1/2) * cp.sum_squares(cp.multiply(W, x - x0))
    else:
        obj += (1/2) * cp.sum_squares(x - x0)
    objective = cp.Minimize(obj)
    prob = cp.Problem(objective, [])
    prob.solve(solver=cp.CLARABEL, max_iter=10000, tol_gap_abs=1e-12, tol_gap_rel=1e-12)
    assert prob.status in ["optimal"], print(prob.status)
    return x.value

In [8]:
for _ in range(50):
    n  = torch.randint(25, 80, ()).item()
    m  = torch.randint(25, 80, ()).item()
    R = torch.rand(m, n, dtype=torch.float64) + 0.1
    X0 = torch.randn(m, n, dtype=torch.float64) * 100
    rho = torch.rand(1).item() * 0.5 + 0.1
    X1 = prox_l1(X0, rho, R=R).cpu().numpy()
    X2 = cvxpy_prox_l1(X0.cpu().numpy(), rho, R=R.cpu().numpy()) 
    assert np.allclose(X1, X2, rtol=1e-4, atol=1e-4), "prox_l1 mismatch!"
print("PASSED")

PASSED


## Diagonal scaling for linear operator to shrink operator number

In [9]:
def A_op(Z1, Z2, A, B):
    # Y = Z1^T B + A^T Z2
    return Z1.T @ B + A.T @ Z2

def A_adj(Y, A, B):
    # (B Y^T, A Y)
    return (B @ Y.T, A @ Y)

def spec_norm_implicit(A, B, R, g1, g2, it=120):
    # power iteration for preconditioned operator G^{1/2} A^* R A G^{1/2} 
    st = R.sqrt()        # (n,n)
    ss1 = g1.sqrt()      # (p1,1)
    ss2 = g2.sqrt()      # (p2,1)
    p1, n = B.shape
    p2, _ = A.shape
    Z1 = torch.randn(p1, n); Z2 = torch.randn(p2, n)
    Z1 /= Z1.norm(); Z2 /= Z2.norm() 
    for _ in range(it): 
        Y  = st * A_op(ss1 * Z1, ss2 * Z2, A, B)   # R^{1/2} A G^{1/2} Z
        U  = st * Y                    # R A G^{1/2} Z
        G1, G2 = A_adj(U, A, B)        # A^* R A G^{1/2} Z
        Z1, Z2 = ss1 * G1, ss2 * G2     # G^{1/2} A^* R A G^{1/2} Z
        s = (Z1.norm()**2 + Z2.norm()**2).sqrt()
        Z1 /= s; Z2 /= s
    return (st * A_op(ss1*Z1, ss2*Z2, A, B)).norm().item()


# --- slow, **correct** K build by explicit indexing (column-major y, mixed z) ---
def build_K_slow(A, B):
    """
    y = vec_col(Y) with Y = Z1^T B + A^T Z2
    z = [ vec_col(Z1^T) ; vec_col(Z2) ] = [ Z1.reshape(-1) ; Z2.T.reshape(-1) ]
    K shape: (n^2) x (n*p1 + n*p2)
    """
    p1, n = B.shape
    p2, _ = A.shape
    K = torch.zeros(n*n, n*(p1+p2), dtype=A.dtype, device=A.device)

    # helper: row index for y(i,k) in vec_col(Y) = i + k*n
    def ridx(i,k): return i + k*n

    # left block: Z1 part (coeffs B[j,k] for var Z1[j,i])
    # z1 index: vec_col(Z1^T) == row-major(Z1): idx1 = j*n + i
    for i in range(n):       # row in Y
        for k in range(n):   # col in Y
            r = ridx(i,k)
            for j in range(p1):
                c = j*n + i
                K[r, c] = B[j, k]

    # right block: Z2 part (coeffs A[j,i] for var Z2[j,k])
    # z2 index: vec_col(Z2) == column-major: idx2 = k*p2 + j
    base = n*p1
    for i in range(n):
        for k in range(n):
            r = ridx(i,k)
            for j in range(p2):
                c = base + k*p2 + j
                K[r, c] = A[j, i]

    return K



# --- full test on small sizes (slow but reliable) ---
def test_scaling_with_explicit_K(num_cases=30, eta=0.99, seed=0, verbose=True):
    torch.manual_seed(seed)
    fails = 0
    for t in range(num_cases):
        n  = torch.randint(3, 10, ()).item()    # keep small (explicit K is O(n^3))
        p1 = torch.randint(2, 10, ()).item()
        p2 = torch.randint(2, 10, ()).item()
        A = torch.randn(p2, n)
        B = torch.randn(p1, n)

        # scaling
        Rm, g1, g2 = pdhg_diagonal_scaling(A, B, eta=eta, debug=True)

        # explicit K, R, G
        K = build_K_slow(A, B)                        # (n^2, n(p1+p2))
        I_n = torch.eye(n, dtype=A.dtype, device=A.device)
        K2 = torch.cat([torch.kron(B.T.contiguous(), I_n), torch.kron(I_n, A.T.contiguous())], dim=1)
        assert torch.allclose(K, K2), "K build mismatch!"

        R_diag = Rm.T.reshape(-1)                     # vec_col(R) 
        R_half = torch.diag(torch.sqrt(R_diag))

        # G diag for z = [ vec_col(Z1^T) ; vec_col(Z2) ]
        # G1 = diag(s1) x I_n  (Z1 row j repeated across n columns)
        # G2 = I_n x diag(s2)  (Z2 column-major)
        s1 = g1.squeeze(-1); s2 = g2.squeeze(-1) 
        G_half = torch.block_diag(torch.kron(torch.diag(torch.sqrt(s1)), torch.eye(n)),
                                  torch.kron(torch.eye(n), torch.diag(torch.sqrt(s2))))

        # operator wiring check vs direct computation
        Z1 = torch.randn(p1, n); Z2 = torch.randn(p2, n)
        Y  = Z1.T @ B + A.T @ Z2
        y  = Y.T.reshape(-1)                            # vec_col(Y)
        z  = torch.cat([Z1.reshape(-1), Z2.T.reshape(-1)])
        K_err = (y - K @ z).abs().max().item()

        # bounds and spectral norm 
        smax = torch.linalg.svdvals(R_half @ K @ G_half)[0].item()

        # implicit (should match SVD)
        smax_impl = spec_norm_implicit(A, B, Rm, g1, g2, it=5000)

        ok = (K_err <= 1e-6 + 1e-8 * torch.linalg.vector_norm(y).item()) and \
            (abs(smax - smax_impl) <= 1e-7 + 1e-6 * max(1.0, smax, smax_impl)) and \
                (smax <= eta*(1+1e-8))

        if verbose:
            print(f"[{t:02d}] n={n} p1={p1} p2={p2}  "
                  f"K_err={K_err:.2e}  "
                  f"||R^1/2 K G^1/2||_2={smax:.6f}  (impl={smax_impl:.6f})  -> {'OK' if ok else 'FAIL'}")
        fails += 0 if ok else 1
    print(f"\nSummary: {num_cases - fails} / {num_cases} passed (eta={eta}).")
    return fails

# run
_ = test_scaling_with_explicit_K(num_cases=20, eta=0.99, seed=1234, verbose=True)

Diagonal PDHG scaling computed.
1.0766e-01 +- 2.0185e-02, 1.5832e-01 +- 1.9312e-02, 1.6263e-01 +- 6.0898e-02
A.shape=torch.Size([8, 9]), rank_tol=8.0000e+00, sigma_max=4.8292e+00, fro_norm=7.9894e+00
A.shape=torch.Size([5, 9]), rank_tol=5.0000e+00, sigma_max=4.0618e+00, fro_norm=6.1722e+00
[00] n=9 p1=5 p2=8  K_err=9.54e-07  ||R^1/2 K G^1/2||_2=0.704098  (impl=0.704098)  -> OK
Diagonal PDHG scaling computed.
8.6957e-02 +- 2.0812e-02, 1.7292e-01 +- 4.5533e-02, 1.6880e-01 +- 7.4941e-02
A.shape=torch.Size([8, 8]), rank_tol=8.0000e+00, sigma_max=5.1265e+00, fro_norm=8.4192e+00
A.shape=torch.Size([7, 8]), rank_tol=7.0000e+00, sigma_max=4.2295e+00, fro_norm=7.1814e+00
[01] n=8 p1=7 p2=8  K_err=9.54e-07  ||R^1/2 K G^1/2||_2=0.664996  (impl=0.664996)  -> OK
Diagonal PDHG scaling computed.
1.5254e-01 +- 3.8808e-02, 1.5136e-01 +- 2.4287e-02, 1.4593e-01 +- 4.1929e-02
A.shape=torch.Size([4, 8]), rank_tol=4.0000e+00, sigma_max=4.5893e+00, fro_norm=6.5231e+00
A.shape=torch.Size([4, 8]), rank_tol=4.0

In [10]:
def cvxpy_proj_subgrad_l1(AZ_np, Y_np):
    m, n = AZ_np.shape
    S = cp.Variable((m, n))
    obj = cp.sum_squares(S - AZ_np)
    constraints = []
    for i in range(m):
        for j in range(n):
            if Y_np[i, j] > 0:
                constraints.append(S[i, j] == 1)
            elif Y_np[i, j] < 0:
                constraints.append(S[i, j] == -1)
            else:
                constraints.append(S[i, j] <= 1)
                constraints.append(S[i, j] >= -1)
    objective = cp.Minimize(obj)
    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.CLARABEL, max_iter=10000, tol_gap_abs=1e-12, tol_gap_rel=1e-12)
    assert prob.status in ["optimal"], print(prob.status)
    return obj.value ** 0.5


In [11]:
for _ in range(20):
    Y = torch.randn(50, 50)
    AZ = torch.randn(50, 50)
    Y[Y.abs() < 0.1] = 0.0
    val1, val1_rel = proj_subgrad_l1(AZ, Y)
    val2 = cvxpy_proj_subgrad_l1(AZ.cpu().numpy(), Y.cpu().numpy())
    assert np.allclose(val1, val2, rtol=1e-4, atol=1e-4), print(val1, val2)
print("PASSED")

PASSED
