# Setup

In [None]:
import numpy as np
import pandas as pd
from soft_aggregation_alg import estimate_decomposition
import time
import copy

# Helper functions

In [None]:
def generate_v(p: int, r: int, anchor_prob: int, num_anch: int, rng: np.random.Generator):
    V = rng.random((p, r))
    V /= V.sum(axis=0)[np.newaxis, :]
    # Assumption: there is at least one disaggregation anchor state, 
    # for each meta-state
    anchor_indices = rng.permutation(p)
    anchor_states = [[] for i in range(r)]
    for meta in range(r):
        V[anchor_indices[num_anch * meta: num_anch * (meta + 1)], :] *= (1-anchor_prob)
        V[anchor_indices[num_anch * meta: num_anch * (meta + 1)], meta] += anchor_prob
        anchor_states[meta] += list(anchor_indices[num_anch * meta: num_anch * (meta + 1)])
    V /= V.sum(axis=0)
    # print(f"Anchor states: {anchor_indices}")
    return anchor_states, V

def generate_n(p: int, r: int, anchor_prob: int, num_anch: int, rng: np.random.Generator, T: int):
    start_time = time.time()
    U = rng.random((p, r))
    U /= U.sum(axis=1)[:, np.newaxis]
    assert np.allclose(U.sum(axis=1), 1)
    anchor_states, V = generate_v(p, r, anchor_prob, num_anch, rng)
    assert np.allclose(V.sum(axis=0), 1)
    P = U @ V.T
    assert np.allclose(P.sum(axis=1), 1)

    # Draw T iterations from the Markov chain with transition matrix P
    X = np.random.randint(p)
    N = [[0] * p for i in range(p)]
    for t in range(T):
        next_X = rng.choice(a=p, p=P[X, :])
        N[X][next_X] += 1
        X = next_X
    N = np.array(N)
    end_time = time.time()

    return U, V, P, N, anchor_states, end_time - start_time

In [None]:
def TV(a: np.ndarray, b: np.ndarray):
    # Requires a, b to be the same shape
    return np.abs(a[:, np.newaxis] - b).sum(axis=2).min(axis=1).sum()

def L2(a: np.ndarray, b: np.ndarray, typ='V'):
    # Requires a, b to be the same shape
    return ((a[:, np.newaxis] - b)**2).sum(axis=2).min(axis=1).sum()

# Experiment

In [None]:
SEED = 623
rng = np.random.default_rng(SEED)
r = 6

res = {
    'n': [],
    'p': [],
    'r': [],
    'anchor_prob': [],
    'num_anchors': [],
    'TV_err_V': [],
    'TV_err_U': [],
    'TV_err_P_hat': [],
    'TV_err_P': [],
    'L2_err_V': [],
    'L2_err_U': [],
    'L2_err_P_hat': [],
    'L2_err_P': [],
    'sim_time': [],
    'decomp_time': [],
    'trial': []
}

err_func_to_name = {
    TV: 'TV_err',
    L2: 'L2_err',
}
t5, t6, t7 = 10**5, 10**6, 10**7
for T in [8*t5, t6, 2*t6, 5*t6, t7, 3*t7]:
    print(f"Starting T={T}")
    for anch in [1, 2, 5, 10]:
        print(f"\tStarting num_anch={anch}")
        for p in [1000]:
            for anchor_prob in [1]:
                for trial in range(5):
                    print(f"\t\tStarting trial={trial}")
                    U, V, P, N, anchor_states, sim_time = generate_n(p, r, anchor_prob, anch, rng, T)

                    decomp_start = time.time()
                    ret = estimate_decomposition(N, r)
                    decomp_end = time.time()

                    res['n'].append(T)
                    res['p'].append(p)
                    res['r'].append(r)
                    res['anchor_prob'].append(anchor_prob)
                    res['num_anchors'].append(anch)
                    res['trial'].append(trial)

                    for err_func in [TV, L2]:
                        res[f'{err_func_to_name[err_func]}_V'].append(err_func(V.T, ret['V_hat'].T))
                        res[f'{err_func_to_name[err_func]}_U'].append(err_func(U, ret['U_hat']))
                        res[f'{err_func_to_name[err_func]}_P_hat'].append(err_func(P, ret['P_hat']))
                        res[f'{err_func_to_name[err_func]}_P'].append(err_func(P, ret['U_hat'] @ ret['V_hat'].T))

                    res['sim_time'].append(sim_time)
                    res['decomp_time'].append(decomp_end - decomp_start)
res = pd.DataFrame(res)