In [1]:
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

import copy, time
import random
import pickle
import scipy

import timeit

import mlrfit as mf
import mfmodel as mfm
import numba as nb

from scipy.sparse import coo_matrix
from scipy.linalg import block_diag, pinvh

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
np.random.seed(1001)
random.seed(1001)

# True model is MLR factor model with SNR = 4 

Let $\Sigma = FF^T + D$ be MLR. We generate samples using 
$$
y = Fz + e, \qquad z \sim N(0, I), \qquad N(0, D).
$$

In [3]:
mtype = "small_mlr_hier"
n = 100000
signal_to_noise = 4


nsamples = 20
L = 5

# ranks = np.array([5, 4, 3, 2, 1])
ranks = np.array([30, 20, 10, 5, 1])
rank = ranks.sum()

In [4]:
pi_rows = np.random.permutation(n)
hpart = {'rows':{'pi':pi_rows, 'lk':[]}, 'cols':{'pi':pi_rows, 'lk':[]}} 
for ngroups in [2, 5, 9, 17, n+1]:
       hpart['rows']['lk'] += [ np.linspace(0, n, ngroups, endpoint=True, dtype=int)]
hpart['rows']['lk'][1] = np.delete(hpart['rows']['lk'][1], -2)
hpart['rows']['lk'][2] = np.delete(hpart['rows']['lk'][2], -4)
hpart['cols']['lk'] = hpart['rows']['lk']
part_sizes = mfm.print_hpart_numgroups(hpart)
mfm.valid_hpart(hpart)

level=0, num_groups=1, mean_size=100000.0
level=1, num_groups=3, mean_size=33333.3
level=2, num_groups=7, mean_size=14285.7
level=3, num_groups=16, mean_size=6250.0
level=4, num_groups=100000, mean_size=1.0


In [5]:
F_hpart = {"pi": hpart['rows']["pi"], "lk": hpart['rows']["lk"][:-1]}
true_mfm = mfm.MFModel()
true_mfm = mfm.generate_mfmodel(true_mfm, n, F_hpart, ranks, signal_to_noise, debug=False)
F_hpart["pi_inv"] = true_mfm.pi_inv

signal_var=81.15506713197232, noise_var=16.242432778950498
SNR=3.9964847160793444, signal_to_noise=4


In [6]:
n, true_mfm.num_factors(), L, ranks.sum()

(100000, 240, 5, 66)

In [7]:
v = np.random.randn(n, 1)


hat_x = true_mfm.solve(v, eps=1e-12, max_iter=20, printing=False)
reldiff = np.linalg.norm(true_mfm.matvec(hat_x) - v) / np.linalg.norm(v)
print(f"solve {reldiff=}")

solve reldiff=1.5558171315505062e-12


In [8]:
%timeit true_mfm.solve(v, eps=1e-12, max_iter=20, printing=False)

2.96 s ± 75.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
