In [None]:
from gptopt.optim.pdhg import *
from gptopt.optim.fast_pdhg import *
from gptopt.optim.least_squares import *
from utils_pdhg import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

from gptopt.utils import set_seed
set_seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
maxit = 1000

# Vectorization of $\mathcal{A}Z = Z_1^TA_1 + A_2^TZ_2$

In [4]:
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    # print(f"{m}x{n}")
    for _ in range(5): 
        A1 = torch.randn((m, n), device=device)
        A2 = torch.randn((m, n), device=device)
        Z = torch.randn((2 * m, n), device=A2.device, dtype=A2.dtype)
        Z1, Z2 = Z[:m, :], Z[m:, :]

        K = matcal_A_to_kron_Kron(A1, A2) 
        Kz = K @ torch.cat([Z1.reshape(-1), Z2.T.reshape(-1)], dim=0)
        vecAZ = mathcal_A_linop(A1=A1, A2=A2, Z=Z).T.reshape(-1)
        assert torch.allclose(Kz, vecAZ, atol=1e-5), print(torch.max(torch.abs(Kz - vecAZ)).item())

        Y = torch.randn((n, n), device=A2.device, dtype=A2.dtype)
        KTy = K.T @ Y.T.reshape(-1)
        AadjY = mathcal_A_adj_linop(A1=A1, A2=A2, Y=Y)
        vecAadjY = torch.cat([AadjY[:m].reshape(-1), AadjY[m:].T.reshape(-1)], dim=0)
        assert torch.allclose(KTy, vecAadjY, atol=1e-5), print(torch.max(torch.abs(KTy - vecAadjY)).item())

print("PASSED")

PASSED


# Ruiz equilibration

In [None]:
for (m, n) in [(30, 60), (60, 30), (60, 60)]:
    print(f"{m}x{n}")
    for _ in range(5): 
        std2 = 0.1
        std1 = 0.01
        rank_ratio = 1
        A2, A1, G1, G2, A2_np, A1_np, G1_np, G2_np, lamb_max = gaussian_data(m, n, std1=std1, std2=std2, 
                                                                 rank_ratio=rank_ratio, G_in_range=True)

        K = matcal_A_to_kron_Kron(A1, A2)
        print(f"maxcol={K.abs().max(dim=0).values.max().item():.4e}, maxrow={K.abs().max(dim=1).values.max().item():.4e}")
        R, Gamma1, Gamma2= ruiz_equilibration(A1=A1, A2=A2, num_iters=10)
        tildeK = R.T.reshape(-1)[:, None] * K
        tildeK[:, :m*n] *= Gamma1.T.reshape(-1)[None, :]
        tildeK[:, m*n:] *= Gamma2.T.reshape(-1)[None, :]

        print("maxcol", tildeK.abs().max(dim=0).values.max().item(),
            "maxrow", tildeK.abs().max(dim=1).values.max().item())
 

30x60
maxcol=3.9490e-02, maxrow=3.9490e-02
maxcol 1.6630417108535767 maxrow 1.6630417108535767
maxcol=3.5600e-02, maxrow=3.5600e-02
maxcol 1.7202492952346802 maxrow 1.7202492952346802
maxcol=3.7595e-02, maxrow=3.7595e-02
maxcol 1.7226592302322388 maxrow 1.7226592302322388
maxcol=4.6105e-02, maxrow=4.6105e-02
maxcol 1.9020411968231201 maxrow 1.9020411968231201
maxcol=3.6167e-02, maxrow=3.6167e-02
maxcol 1.63478684425354 maxrow 1.63478684425354
60x30
maxcol=3.8961e-02, maxrow=3.8961e-02
maxcol 2.3841490745544434 maxrow 2.3841490745544434
maxcol=4.4216e-02, maxrow=4.4216e-02
maxcol 2.2840683460235596 maxrow 2.2840683460235596
maxcol=3.3483e-02, maxrow=3.3483e-02
maxcol 1.9658482074737549 maxrow 1.9658482074737549
maxcol=3.3683e-02, maxrow=3.3683e-02
maxcol 2.0430147647857666 maxrow 2.0430147647857666
maxcol=3.6487e-02, maxrow=3.6487e-02
maxcol 2.3817639350891113 maxrow 2.3817639350891113
60x60
maxcol=4.2152e-02, maxrow=4.2152e-02
maxcol 2.0316972732543945 maxrow 2.0316972732543945
maxcol=