# Imports and settings

In [None]:
import adaptive_latents as al
from adaptive_latents import mmICA, sjPCA, proSVD, NumpyTimedDataSource, proPLS
import matplotlib.pyplot as plt
import numpy as np
from picard import permute
import sklearn.cross_decomposition
import matlab.engine
import scipy.stats as stats
from adaptive_latents.utils import column_space_distance

rng = np.random.default_rng(1)

# Simulated data

In [None]:
plt.rc('lines', linewidth=2)
plt.rc('font', family='sans')

dt = 0.03 # seconds
T = 60 # seconds
iterations_to_run = 3
calculate_intra_run_errors = False

colors = ["#F94B00", "#2FA194", "#EC3C8E"]
common_ylim = [.005, 4]
common_xticks = [0, 30, 60] 


In [None]:
def plot(offline_errors, trajectories, t, i, name, offline_error_t=None):
    fig, ax = plt.subplots()
    fig: plt.Figure
    
    if not hasattr(offline_errors[0], "__len__"):
        lower = np.quantile(offline_errors, stats.norm(0,1).cdf(-1))
        upper = np.quantile(offline_errors, stats.norm(0,1).cdf(1))
        ax.axhline(np.quantile(offline_errors, .5), color='k', alpha=.75)
        ax.fill_between([t[0], t[-1]], upper, lower, color='k', edgecolor=None,  alpha=.25)
    else:
        lower = np.nanquantile(offline_errors, stats.norm(0,1).cdf(-1), axis=0)
        upper = np.nanquantile(offline_errors, stats.norm(0,1).cdf(1), axis=0)
        trajectory = np.nanmedian(offline_errors, axis=0)
        ax.plot(offline_error_t, trajectory, color='k', alpha=.75)
        ax.fill_between(offline_error_t, upper, lower,  color='k', edgecolor=None,  alpha=.25)
        
    ax.autoscale(False)

    lower = np.quantile(trajectories, stats.norm(0,1).cdf(-1), axis=0)
    upper = np.quantile(trajectories, stats.norm(0,1).cdf(1), axis=0)
    trajectory = np.median(trajectories, axis=0)
    ax.plot(t, trajectory, color=colors[i])
    ax.fill_between(t, upper, lower, color=colors[i], edgecolor=None,  alpha=.5)

    ax.set_xlabel('time (s)')
    ax.set_ylabel('error')
    ax.set_xticks(common_xticks)

    ax.semilogy()
    ax.set_yticks([1, .1, .01])
    ax.minorticks_off()
    ax.set_ylim(common_ylim)
    ax.set_xticks(common_xticks)

    ax.set_title(f"{name} error")
    fig.savefig(al.CONFIG.plot_save_path/ f"{name}.svg")

    return fig, ax

## sjPCA

In [None]:
eng = matlab.engine.start_matlab()
trajectories = []
offline_errors = []
for _ in range(iterations_to_run):
    X, _, true_variables = al.jpca.generate_circle_embedded_in_high_d(rng, m=int(T/dt), n=6, stddev=1)

    # matlab section
    params = dict(
        meanSubtract=False,
        normalize=False,
        suppressBWrosettes=True,
        suppressHistograms=True,
        suppressText=True,
    )
    proj, summary = eng.jPCA({'A':X}, [], params, nargout=2)
    offline_U = np.array(summary['jPCs_highD'])
    offline_error = np.abs(al.utils.principle_angles(offline_U[:,:2], true_variables['C'])).sum()
    offline_errors.append(offline_error)

    # my section
    jp = sjPCA(log_level=2)
    jp.offline_run_on(NumpyTimedDataSource(X, timepoints = np.arange(X.shape[0]) * dt))
    distances, t = jp.get_distance_from_subspace_over_time(true_variables['C'])
    trajectories.append(distances[:,0])

fig, ax = plot(offline_errors, trajectories, t, 0, 'sjPCA')

## mmICA

In [None]:
trajectories = []
offline_errors = []
for _ in range(iterations_to_run):
    n = 6
    m = int((T/dt)//n) # number of blocks
    X = rng.laplace(size=(m, n, n))

    ica = mmICA(alpha=.5, maxiter_cg=20, tol=1e-20, log_level=2)
    input_data = NumpyTimedDataSource(X, timepoints = np.arange(m) * n * dt )
    ica.offline_run_on(input_data)
    ts = []
    errors = []
    for W, t in zip(ica.log['W'], ica.log['t']):
        ts.append(t)
        error = permute(W) - np.eye(W.shape[0])
        errors.append(np.linalg.norm(error))
    t = np.squeeze(ts)
    trajectories.append(np.squeeze(errors))
    

    # intra_run_offline_errors = []
    # for i in range(X.shape[0]):
    #     pass
    W = sklearn.decomposition.FastICA(max_iter=5000).fit(X.transpose([0,2,1]).reshape(-1,n)).components_
    offline_error = np.linalg.norm(permute(W) - np.eye(W.shape[0]))
    offline_errors.append(offline_error)

fig, ax = plot(offline_errors, trajectories, t, 1, 'mmICA')


## proSVD

In [None]:
trajectories = []
offline_errors = []
for _ in range(10):
    X, _, true_variables = al.jpca.generate_circle_embedded_in_high_d(rng, m=int(T/dt), n=8, stddev=1)
    pro = proSVD(k=4, log_level=2)
    X_with_time = NumpyTimedDataSource(X, timepoints=np.arange(X.shape[0]) * dt)
    pro.offline_run_on(X_with_time)
    Q_error, t = pro.get_distance_from_subspace_over_time(true_variables['C'])
    trajectories.append(Q_error)


    if calculate_intra_run_errors:
        intra_run_offline_errors = []
        for i in range(X.shape[0]):
            _, s, Vt = np.linalg.svd(X[:i])
            V = Vt[np.argsort(s)[::-1], :].T[:,:pro.k]
            offline_error = column_space_distance(V, true_variables['C'])
            intra_run_offline_errors.append(offline_error)
        offline_errors.append(intra_run_offline_errors)
    else:
        _, s, Vt = np.linalg.svd(X)
        V = Vt[np.argsort(s)[::-1], :].T[:,:pro.k]
        offline_error = column_space_distance(V, true_variables['C'])
        offline_errors.append(offline_error)
        

fig, ax = plot(offline_errors, trajectories, t, 2, 'proSVD', X_with_time.t)

## proPLS


In [None]:
trajectories = []
offline_errors = []
for _ in range(iterations_to_run):
    high_d = (10, 9)
    n_points = int(T/dt)
    common_d = 2

    X = rng.normal(size=(n_points, high_d[0]))
    Y = rng.normal(size=(n_points, high_d[1]))

    snr = 100
    common = rng.normal(size=(n_points, common_d))
    Y[:,:common_d] = (snr * common + rng.normal(size=(n_points, common_d)))/np.sqrt(1 + snr**2)
    X[:,:common_d] = (snr * common + rng.normal(size=(n_points, common_d)))/np.sqrt(1 + snr**2)
    
    x_mixing_matrix = stats.special_ortho_group(dim=X.shape[1]).rvs()
    X = X @ x_mixing_matrix
    x_common_basis = np.eye(high_d[0])[:,:common_d]
    x_common_basis = x_mixing_matrix.T @ x_common_basis
    
    Y = Y @ stats.special_ortho_group(dim=Y.shape[1]).rvs()
    
    X_with_time = NumpyTimedDataSource(X, timepoints=np.arange(X.shape[0]) * dt)
    Y_with_time = NumpyTimedDataSource(Y, timepoints=np.arange(Y.shape[0]) * dt)


    pls = proPLS(k=3, log_level=2)
    output = pls.offline_run_on([X_with_time, Y_with_time])
    U_error, t = pls.get_distance_from_subspace_over_time(x_common_basis, variable='X')
    trajectories.append(U_error)

    if calculate_intra_run_errors:
        intra_run_offline_errors = []
        for i in range(n_points):
            try:
                sk_weights = sklearn.cross_decomposition.PLSRegression(n_components=pls.k).fit(X[:i], Y[:i]).x_weights_
                intra_run_offline_errors.append(column_space_distance(sk_weights, x_common_basis))
            except:
                intra_run_offline_errors.append(np.nan)
        offline_errors.append(intra_run_offline_errors)
    else:
        sk_weights = sklearn.cross_decomposition.PLSRegression(n_components=pls.k).fit(X, Y).x_weights_
        offline_errors.append(column_space_distance(sk_weights, x_common_basis))
        


fig, ax = plot(offline_errors, trajectories, t, 2, 'proPLS', offline_error_t=X_with_time.t)
