In [None]:
import matplotlib.pyplot as plt
import numpy as np

import frame_whitening as fw
import frame_whitening.simulation_new as sim

In [None]:
def simulate_one(n, eta_g, n_batch, batch_size, frame):
    # setup
    k = n * (n + 1) // 2
    print(f"n = {n}, k = {k}")
    V, _ = np.linalg.qr(np.random.randn(n, n))
    s = np.linspace(1, 5, n) + np.random.randn(n) * 0.1
    Cxx = V @ np.diag(s) @ V.T
    Lxx = np.linalg.cholesky(Cxx)
    kappa0 = np.linalg.cond(Cxx)
    if frame == "GRASSMANN":
        W, G, res = fw.get_grassmannian(n, k, niter=400)
    elif frame == "RANDN":
        W = np.random.randn(n, k)
        W = fw.normalize_frame(W)

    # run whitening
    error = []
    g = np.ones(k)
    Inn = np.eye(n)
    for _ in range(n_batch):
        X = fw.sample_x(Lxx, batch_size)
        Y = np.linalg.solve(W @ np.diag(g) @ W.T, X)
        Z = W.T @ Y
        dg = np.mean(Z**2, -1) - 1.0
        g += eta_g * dg
        Cyy = np.cov(Y)
        err_sq = np.linalg.norm(Inn - Cyy) ** 2
        error.append(err_sq)
    err = np.array(error) / n**2

    X_test = fw.sample_x(Lxx, 4096)
    Y_test = np.linalg.solve(W @ np.diag(g) @ W.T, X_test)
    Cyy = np.cov(Y_test)
    error_fro = (np.linalg.norm(Inn - Cyy) ** 2) / n**2
    error_trace = (np.trace(Inn - Cyy)) / n
    error_bures = fw.bures_dist(Inn, Cyy) ** 2
    df_sim = pd.DataFrame(
        [
            {
                "n": n,
                "k": k,
                "n_batch": n_batch,
                "batch_size": batch_size,
                "eta_g": float(eta_g),
                "kappa0": float(kappa0),
                "error": err,
                "error_fro": float(error_fro),
                "error_trace": float(error_trace),
                "error_bures": float(error_bures),
                "frame": frame,
            }
        ],
    )
    return df_sim