# Imports

In [None]:
import numpy as np
from adaptive_latents.ica import BaseMMICA
from adaptive_latents.prosvd import  BaseProSVD
import adaptive_latents
import matplotlib.pyplot as plt
import timeit
from scipy.stats import kurtosis, ortho_group
from picard import permute

from sklearn.decomposition import FastICA, PCA

rng = np.random.default_rng()

# Define a Dataset

In [None]:
performances = []

In [None]:
n_blocks, n_points_per_block, n_dimensions = 700, 10, 10

size = (n_blocks, n_points_per_block, n_dimensions)
# A_true = rng.laplace(size=size)
# A_true = rng.standard_cauchy(size=size)
A_true = rng.poisson(.5, size=size)

t = np.arange(np.prod(size[:-1])).reshape(size[:-1]) * 0.01
block_t = t[:,-1]

mixing_matrix = rng.normal(size=[A_true.shape[2]]*2)
# mixing_matrix = ortho_group.rvs(A_true.shape[2])


A = A_true @ mixing_matrix.T

# Run mmICA

In [None]:
pro = BaseProSVD(k=n_dimensions)
# ica = mmICA(density='huber', track_extra_info=True, maxiter_cg=20, tol=1e-15)
ica = BaseMMICA(alpha=.7)

start_index = 5
pro.initialize(A[:start_index].reshape((-1,A.shape[-1])).T)
t, block_t = t[-(A.shape[0]-start_index):], block_t[-(A.shape[0]-start_index):]

ica.set_p(pro.k)

Ws = []
Qs = []
ica_times = []
pro_times = []

for block in A[start_index:]:
    start = timeit.default_timer()
    # pro.updateSVD(block.T)
    block = pro.project_down(block.T).T
    block = block.T
    pro_times.append(timeit.default_timer() - start)


    start = timeit.default_timer()
    ica.observe_new_batch(block)
    ica_times.append(timeit.default_timer() - start)

    Ws.append(ica.W.copy())
    Qs.append(pro.Q.copy())

Ws = np.array(Ws)
Qs = np.array(Qs)
ica_times = np.array(ica_times)
pro_times = np.array(pro_times)

In [None]:
performance = [np.linalg.norm(permute(d @ q.T @ mixing_matrix) - np.eye(n_dimensions)) for d,q in zip(Ws,Qs)]
performances.insert(0,performance)

offline_performance = np.linalg.norm(
    permute(FastICA(max_iter=5000).fit(A.reshape((-1, n_dimensions))).components_ @ mixing_matrix) 
    - np.eye(n_dimensions)
)

for p in performances[1:]:
    plt.plot(block_t, p, color='k', alpha=.5)
plt.plot(block_t, performances[0])
plt.axhline(offline_performance, color='r')
plt.semilogy();
plt.ylabel('error')
plt.xlabel('time (s)');
# plt.ylim([0,10])

In [None]:
assert np.all(~np.isnan(Ws))

In [None]:
if ica.hit_iter_history:
    plt.figure()
    plt.plot(np.array(ica.hit_iter_history), '.', alpha=.1, color='k');
    
    fig, ax = plt.subplots(layout='tight', figsize=(9,3))
    ica.hit_norm_history = np.array(ica.hit_norm_history)
    # m = ica.hit_norm_history[:200,7,:]
    m = np.nanmax(ica.hit_norm_history[:,:,:], axis=1)
    c = ax.matshow(np.log(np.squeeze(m)).T, vmin=np.log(ica.tol), origin='lower', aspect='auto')
    ax.set_xlabel("data block #")
    ax.set_ylabel("inner loop iteration #")
    ax.set_title("max convergence metric across all rows")
    cbar = fig.colorbar(c)
    
    cbar.ax.get_yaxis().labelpad=15
    cbar.set_label("log convergence metric", rotation=270)



# Asess performance

## Time to run

In [None]:
plt.plot(pro_times * 1_000, label=f"proSVD {pro.Q.shape[0]} -> {pro.Q.shape[1]}")
plt.plot(ica_times * 1_000, label=f"mmICA {ica.W.shape[0]} -> {ica.W.shape[0]}")

plt.xlabel("block number")
plt.ylabel("iteration time (ms)")
plt.legend()

## Kurtosis

In [None]:
fig, ax = plt.subplots()


datasets = {}

dim_reduced_data = pro.project_down(A.reshape((-1, A.shape[-1])).T)
datasets['input'] = dim_reduced_data
datasets['mmICA'] = ica.W @ dim_reduced_data

sk_ica = FastICA(max_iter=800, whiten='unit-variance').fit(dim_reduced_data.T)
datasets['FastICA'] = sk_ica.transform(dim_reduced_data.T).T

sk_pca = PCA().fit(dim_reduced_data.T)
datasets['pca'] = sk_pca.transform(dim_reduced_data.T).T


def whiten(x):
    whitened_data = (x - x.mean(axis=1)[:,None])
    u, s, v_h = np.linalg.svd(whitened_data, full_matrices=False)
    s = s/np.sqrt(whitened_data.shape[1] - 1)
    return u @ np.diag(1/s) @ u.T @ whitened_data

if A_true is not None:
    datasets['ground truth'] = A_true.reshape((-1, A.shape[-1])).T
    datasets['whitened ground truth'] = whiten(A_true.reshape((-1, A.shape[-1])).T)



whitened_data = whiten(dim_reduced_data)

datasets['random whitened'] = lambda: ortho_group.rvs(ica.W.shape[0]) @ whitened_data
datasets['random'] = lambda: rng.normal(size=[ica.W.shape[0]]*2) @ whitened_data


single_line_styles = {
    'mmICA': dict(color='C0', linestyle='-', marker='.'),
    'input': dict(color='orange'),
    'ground truth': dict(color='#D4AF37'),
    'whitened ground truth': dict(color='#D4AF37', marker='.'),
    'FastICA': dict(color='red'),
    'pca': dict(color='green'),
}

evaluation_metrics = {
    'kurtosis': lambda x: kurtosis(x.T),
    '4th power': lambda x: ((x.T ** 4 - 3)/4).mean(axis=0),
    'log cosh': lambda x: (np.log(np.cosh(x.T)) - .375).mean(axis=0),
    'exp': lambda x: -np.exp((-.5 * x.T**2) - .707).sum(axis=0),
    'var': lambda x: np.var(x.T, axis=0),
}

metric_name = 'kurtosis'
evaluation_metric = evaluation_metrics[metric_name]

for key in set(single_line_styles).intersection(set(datasets)):
    dataset = datasets[key]
    idx = np.argsort(evaluation_metric(dataset))
    sorted_dataset = dataset[idx, :]
    datasets[key] = sorted_dataset
    ax.plot(evaluation_metric(sorted_dataset), label=key, **single_line_styles[key])

multi_line_styles = {
    'random whitened': dict(alpha=.1, color='k'),
    # 'random': dict(alpha=.1, linestyle='--', color='k'),
}
for key in set(multi_line_styles).intersection(set(datasets)):
    for i in range(10):
        dataset = datasets[key]()
        idx = np.argsort(evaluation_metric(dataset))
        sorted_dataset = dataset[idx, :]
        label = key if i == 0 else None
        ax.plot(evaluation_metric(sorted_dataset), label=label, **multi_line_styles[key])

ax.set_title(f"Sorted {metric_name} per component")
ax.legend();



## Compare offline and online

In [None]:
fig, ax = plt.subplots()
k1, k2 = 'mmICA', 'FastICA'
cov = np.abs(np.corrcoef(datasets[k1], datasets[k2])[:-datasets[k2].shape[0],datasets[k1].shape[0]:])
cov = permute(cov, scale=False)
c = ax.matshow(cov, vmin=0.1, vmax=1)
fig.colorbar(c)
ax.set_title("Offline vs Online latent correlations");

## Correlation with true latents

In [None]:
if A_true is not None:
    fig, ax = plt.subplots(ncols=2)
    a = datasets['ground truth']
    ax[0].set_ylabel('ground truth')

    for i, name in enumerate(['mmICA', 'FastICA']):
        b=datasets[name]
        cov = np.corrcoef(a, b)[:a.shape[0], a.shape[0]:]
        if cov.shape[0] == cov.shape[1]:
            cov = permute(cov, scale=False)
        ax[i].matshow(np.abs(cov), vmin=.1, vmax=1)
        ax[i].set_title(name)




## Convergence

In [None]:
fig, ax = plt.subplots(nrows=1, layout='tight')
ax = [ax]

changes = np.linalg.norm(np.diff(Ws, axis=0), axis=2) / np.linalg.norm(Ws, axis=2)[:-1]
ax[0].plot((np.arange(changes.shape[0]) + 1)*n_points_per_block * 0.03,changes)
ax[0].set_title("Convergence of the rows of W (mmICA)")
ax[0].set_xlabel("time (s)")
ax[0].set_ylabel(r"$\frac{\Delta r}{\Vert r \Vert}$")
# ax[0].semilogy()
# ax[0].set_ylim([0,.1])

# changes = np.linalg.norm(np.diff(Qs, axis=0), axis=1)/ np.linalg.norm(Qs, axis=1)[:-1]
# ax[1].plot((np.arange(changes.shape[0]) + 1)*n_points_per_block,changes)
# ax[1].set_title("Convergence of the columns of Q")
# ax[1].set_xlabel("points seen")
# ax[1].set_ylabel(r"$\frac{\Delta r}{\Vert r \Vert}$");



## The distribution metric

In [None]:
demixed_data = ica.W @ pro.project_down(A.reshape((-1, n_dimensions)).T)
statistics = ica.density.logp(demixed_data).mean(axis=1)
plt.bar(x=np.arange(len(statistics)), height=statistics)

In [None]:
plt.scatter(statistics, ica.cumulants)