In [None]:
import numpy as np
from scipy.linalg import block_diag
import scipy
import matplotlib.pyplot as plt
from adaptive_latents.transforms import proSVD
from adaptive_latents.transforms.jpca import align_column_spaces
from tqdm import tqdm

from functools import reduce

rng = np.random.default_rng()

In [None]:
def column_space_distance(Q1, Q2):
    return (((Q1 @ Q1.T) - (Q2 @ Q2.T))**2).sum()

In [None]:
diffs = []
for _ in tqdm(range(5000)):
    d = np.ones(8)
    d[0] = 1.3
    d[-1] = 10
    d = np.diag(d)
    
    data = [
        d[:-3,:-3] @ rng.normal(size=(5,20)),
        d @ rng.normal(size=(8,8)),
        d @ rng.normal(size=(8,20)),
        ]
    
    
    psvd1 = proSVD(2)
    psvd2 = proSVD(2)
    
    psvd1.initialize(data[1])
    for i in np.arange(data[2].shape[1]):
        psvd1.updateSVD(data[2][:,i:i+1])
    
    psvd2.initialize(data[0])
    psvd2.add_new_input_channels(3)
    for j in [1,2]:
        for i in np.arange(data[j].shape[1]):
            psvd2.updateSVD(data[j][:, i:i + 1])
    
    ideal_basis = np.zeros((8,2))
    ideal_basis[0,0] = 1
    ideal_basis[-1, 1] = 1

    diff = column_space_distance(psvd1.Q, ideal_basis) - column_space_distance(psvd2.Q, ideal_basis)
    diffs.append(diff)
diffs = np.array(diffs)

In [None]:
fig, ax = plt.subplots()
ax.hist(diffs, 100);
ax.axvline(0,color='k')
ax.set_xlabel("old method error - new method error")
ax.set_ylabel("counts")
ax.set_title(f"improvement in {np.mean(diffs > 0)*100 :.1f}% of cases");

In [None]:
fig, axs = plt.subplots(ncols=1)
axs.imshow(psvd2.Q)

In [None]:
np.std(np.hstack(data[1:]),1)

In [None]:
np.std(np.hstack(data[0:1]),1)

In [None]:
np.linalg.norm(Q2[:,0])

# old

In [None]:
from scipy.stats import special_ortho_group



def random_mat_leq_theta(theta, d=3):
    rmm = special_ortho_group(dim=d, seed=rng)
    Q = rmm.rvs()
    while np.linalg.det(Q) > 0 and np.arccos((np.trace(Q) - 1)/2) > theta:
        Q = rmm.rvs()
    return Q

In [None]:
def trivial_make_G(B, k):
    u, s, vt = np.linalg.svd(B)
    G_u = u[:,:k]
    G_v = (vt.T)[:,:k]
    return G_u, G_v

In [None]:
def prosvd_make_G(B,k):
    alpha = 1
    
    u, sigma, vt = np.linalg.svd(B)

    u_tilde, sigma_tilde, vt_tilde = np.linalg.svd(u[:k, :k])
    
    T = (u_tilde @ vt_tilde).T

    G_u = u[:,:k] @ T
    G_v = vt.T
    
    return G_u, G_v

In [None]:
def general_incremental_svd(A_sequence, make_G_matrices=trivial_make_G, max_k=10):
    history = []
    
    A =  A_sequence[0]
    l = A.shape[1]
    Q, B = np.linalg.qr(A)
    W = np.eye(l)
    k = l
    history.append(dict(Q=Q, B=B, W=W))
    
    for A in A_sequence[1:]:
        l = A.shape[1]

        B_in = Q.T @ A
        A_orth = A - Q @ B_in
        Q_orth, B_orth = np.linalg.qr(A_orth)
        
        Q_bigger = np.hstack([Q, Q_orth]) # Q hat
        zero_block = np.zeros((B_orth.shape[0],B.shape[1]))
        B_bigger =  np.block([
            [B,          B_in],
            [zero_block, B_orth]
        ])
        W = block_diag(W, np.eye(l))
        
        G_u, G_v = make_G_matrices(B_bigger, max_k)

        k = min(k + l, max_k)
        G_u, G_v = G_u[:,:k], G_v[:,:k]
        
        B = G_u.T @ B_bigger @ G_v
        Q = Q_bigger @ G_u
        W = W @ G_v

        history.append(dict(Q=Q, B=B, W=W, B_bigger=B_bigger, Q_bigger=Q_bigger))
        
    return Q, B, W, history

In [None]:
A_sequence = [rng.normal(size=(2,1)) for _ in range(10)]
A_full = np.column_stack(A_sequence)

for f in [trivial_make_G, prosvd_make_G]:
    Q, B, W, _ = general_incremental_svd(A_sequence, f)
    assert np.allclose(Q @ B @ W.T, A_full)

In [None]:
A_sequence = [rng.normal(size=(3,4)) for _ in range(10)]
A_full = np.column_stack(A_sequence)

d=4
Q, B, W, _ = general_incremental_svd(A_sequence, prosvd_make_G, max_k=d)



p = proSVD(d)

p.initialize(A_sequence[0])
for A in A_sequence[1:]:
    p.preupdate()
    p.updateSVD(A)
    p.postupdate()

assert np.allclose(p.Q,  Q)

$$
\begin{bmatrix}A_{t-1} & a_t\end{bmatrix} = U_t \Sigma_t V_t^\top
$$

## Visualization

In [None]:
direction = np.array([-3,1])
direction = direction/ np.linalg.norm(direction)
cov = np.column_stack([direction, [-direction[1], direction[0]]])
cov = cov @ np.diag([5,.11]) * .25 @ cov.T

rng = np.random.default_rng(24)
A_sequence = [rng.multivariate_normal(mean=[2,2], cov=cov)[:,None] for _ in range(15)]
A_sequence.pop(2)
A_sequence[1] -= np.array([[.1],[.1]])
A_full = np.column_stack(A_sequence)

Q, B, W, history = general_incremental_svd(A_sequence, max_k=1)

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(9,4))
ax[0].plot(*A_full, '.')
ax[0].axis('scaled')
ax[0].set_xlim([-4,4])
ax[0].set_ylim([-4,4])
ax[0].plot(0,0, '.k')

ax[1].imshow(A_full)

In [None]:
-history[1]["B_bigger"], -history[1]["Q_bigger"]

In [None]:
fig, axs = plt.subplots(ncols=4, nrows=5, squeeze=False, layout='tight', figsize=(10,10))

i = 0
for j in range(2):
    axs[i,j].plot(A_full[0,:i], A_full[1,:i], '.', color="C0", alpha=.5)
    axs[i,j].plot(A_full[0,i], A_full[1,i], '.', color="C0")
axs[i,1].axline((0,0), history[i]['Q'][:,0], color='k', alpha=.5)

axs[i,1].plot([0, history[i]['Q'][0,0] * history[i]['B'][0,0]], [0,history[i]['Q'][1,0] * history[i]['B'][0,0]], color='C3')

t = np.linspace(0,2*np.pi, 50)

for i in range(1, axs.shape[0]):
    current_point = A_full[:,i:i+1]
    old_Q = history[i-1]['Q']
    weighted_col_space = history[i]['Q_bigger'] @ history[i]['B_bigger']
    for j in range(axs.shape[1]):
        axs[i,j].plot(A_full[0,:i], A_full[1,:i], '.', color="C0", alpha=.5)
        axs[i,j].plot(current_point[0], current_point[1], '.', color="C0")
        axs[i,j].axline((0,0), old_Q[:,0], color='k', alpha=.25)
        

    along = old_Q * (old_Q.T @ current_point)
    direction_of_along = np.sign(old_Q.T @ current_point)
    residual = current_point - along
    axs[i,1].plot([0, along[0,0]], [0,along[1,0]], color='C1')
    axs[i,1].plot([along[0,0], (along+residual)[0,0] ], [along[1,0], (along+residual)[1,0]], color='C2')

    circle_points = np.column_stack([np.cos(t), np.sin(t)])
    ellipse_points = weighted_col_space @ circle_points.T
    axs[i,2].plot(ellipse_points[0,:], ellipse_points[1,:])

    saved_variance_vector = np.abs(history[i]['B_bigger'][0,0]) * -direction_of_along * old_Q
    axs[i,2].plot([0, saved_variance_vector[0,0] ], [0, saved_variance_vector[1,0]], color='C3')
    axs[i,2].plot([0, along[0,0]], [0,along[1,0]], color='C1')
    axs[i,2].plot([0, residual[0,0] ], [0, residual[1,0]], color='C2')

    axs[i,-1].plot(ellipse_points[0,:], ellipse_points[1,:])
    axs[i,-1].axline((0,0), history[i]['Q'][:,0], color='k', alpha=.5)

for ax in axs.flatten():
    ax.plot(0,0, '.k')
    
    ax.set_xticks([])
    ax.set_yticks([])
    
    ax.axis('scaled')
    ax.set_xlim([-4,4])
    ax.set_ylim([-4,4])

In [None]:
plt.plot(np.linalg.norm(np.diff(np.squeeze([h['Q'] for h in history]), axis=0), axis=1));

In [None]:
def tex_for_history(history, step):
    Q_old = history[step-1]["Q"]
    R_old = history[step-1]["B"]
    Q_bigger = history[step]["Q_bigger"]
    R_bigger = history[step]["B_bigger"]
    Q = history[step]["Q"]
    R = history[step]["B"]
    s = fr"""\begin{{align*}}
    Q_{{{step-1}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{gray}}{{ {Q_old[0,0]:.2f} }}}} \\
        \mathtt{{\mathcolor{{gray}}{{ {Q_old[1,0]:.2f} }}}} \\
    \end{{bmatrix}}
    \quad&
    R_{{{step-1}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{C3}}{{ {R_old[0,0]:.2f} }}}}
    \end{{bmatrix}}
    \\
    \hat Q_{{{step}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{gray}}{{ {Q_bigger[0,0]:.2f}  }}}} & \mathtt{{\mathcolor{{kk}}{{ {Q_bigger[0,1]:.2f} }}}} \\
        \mathtt{{\mathcolor{{gray}}{{ {Q_bigger[1,0]:.2f} }}}} & \mathtt{{\mathcolor{{kk}}{{ {Q_bigger[1,1]:.2f} }}}} \\
    \end{{bmatrix}}
    \quad& 
    \hat R_{{{step}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{C3}}{{ {R_bigger[0,0]:.2f}  }}}} & \mathtt{{\mathcolor{{C1}}{{ {R_bigger[0,1]:.2f} }}}} \\
        \mathtt{{\mathcolor{{kk}}{{0}}}} & \mathtt{{\mathcolor{{C2}}{{ {R_bigger[1,1]:.2f} }}}} \\
    \end{{bmatrix}}
    \\
    Q_{{{step}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{gray}}{{ {Q[0,0]:.2f}  }}}} \\
        \mathtt{{\mathcolor{{gray}}{{ {Q[1,0]:.2f} }}}} \\
    \end{{bmatrix}}
    \quad& 
    R_{{{step}}} =
    \begin{{bmatrix}}
        \mathtt{{\mathcolor{{C3}}{{ {R[0,0]:.2f}  }}}}
    \end{{bmatrix}}
\end{{align*}}
"""
    return s
print(tex_for_history(history, 3))

# Biased proSVD

In [None]:
def biased_pro_svd(A_sequence, make_G_matrices=trivial_make_G, max_k=10):
    A =  A_sequence[0]
    l = A.shape[1]
    Q, B = np.linalg.qr(A)
    W = np.eye(l)
    k = l
    QBWs = [(Q, B, W)]
    for A in A_sequence[1:]:
        l = A.shape[1]

        B_in = Q.T @ A
        A_orth = A - Q @ B_in
        Q_orth, B_orth = np.linalg.qr(A_orth)
        
        Q = np.hstack([Q, Q_orth]) # Q hat
        zero_block = np.zeros((B_orth.shape[0],B.shape[1]))
        B =  np.block([
            [B,          B_in],
            [zero_block, B_orth]
        ])
        W = block_diag(W, np.eye(l))
        
        u, sigma, vt = np.linalg.svd(B)
        u_tilde, sigma_tilde, vt_tilde = np.linalg.svd(u[:max_k, :max_k])
        T = (u_tilde @ vt_tilde).T
        G_u = u[:,:max_k] @ T
        G_v = vt.T

        k = min(k + l, max_k)
        G_u, G_v = G_u[:,:k], G_v[:,:k]
        
        B = G_u.T @ B @ G_v
        Q = Q @ G_u
        W = W @ G_v
        QBWs.append((Q, B, W))
    return QBWs

In [None]:
n_points = 50
d = 10
max_k = 3
A_sequence = []

for i in range(n_points):
    cov = [10, 10, 10]
    cov = np.diag(cov + [1 for _ in range(d - len(cov))])
    mean = [i/n_points * 100,0,0]
    mean = np.array(mean + [0 for _ in range(d - len(mean))])
    A = rng.multivariate_normal(mean=mean, cov=cov, size=max_k)
    A_sequence.append(A.T)

for i in range(n_points):
    cov = [10, 10, 10]
    cov = np.diag(cov + [1 for _ in range(d - len(cov))])
    mean = [np.cos(i / n_points) * 100, np.sin(i/n_points) * 100, 0]
    mean = np.array(mean + [0 for _ in range(d - len(mean))])
    A = rng.multivariate_normal(mean=mean, cov=cov, size=max_k)
    A_sequence.append(A.T)
    
A_full = np.column_stack(A_sequence)

QBWs = biased_pro_svd(A_sequence, max_k=max_k)
Qs, Bs, Ws = zip(*QBWs)

print(f"{np.allclose(Qs[-1] @ Bs[-1] @ Ws[-1].T, A_full)}")

In [None]:
n_points = 50
d = 10
max_k = 3
A_sequence = []

for i in range(n_points):
    cov = [10, 10, 10, 1]
    cov = np.diag(cov + [1 for _ in range(d - len(cov))])
    mean = []
    mean = np.array(mean + [0 for _ in range(d - len(mean))])
    A = rng.multivariate_normal(mean=mean, cov=cov, size=max_k)
    A_sequence.append(A.T)

for i in range(n_points):
    cov = [10, 10, 10, i]
    cov = np.diag(cov + [1 for _ in range(d - len(cov))])
    mean = []
    mean = np.array(mean + [0 for _ in range(d - len(mean))])
    A = rng.multivariate_normal(mean=mean, cov=cov, size=max_k)
    A_sequence.append(A.T)

A_full = np.column_stack(A_sequence)

QBWs = biased_pro_svd(A_sequence, max_k=max_k)
Qs, Bs, Ws = zip(*QBWs)

print(f"{np.allclose(Qs[-1] @ Bs[-1] @ Ws[-1].T, A_full)}")

In [None]:
plt.plot(np.arange(len(Qs)-1)+1,np.linalg.norm(np.diff(np.array(Qs), axis=0), axis=1))
plt.xlabel("step")
plt.ylabel("norm stepwise change in the nth column of Q")

In [None]:
plt.imshow(A_full, aspect='auto', interpolation="nearest")

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, squeeze=False, figsize=(10,5))

im = ax[0,1].imshow(Qs[-1], interpolation='nearest')
fig.colorbar(im);
# plt.subplot(1,2,2)
ax[0,0].plot(Qs[-1])

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, squeeze=False, figsize=(13,5))
H = Bs[-1] @ Ws[-1].T
im = ax[0,0].imshow(H, aspect='auto', interpolation='nearest')
fig.colorbar(im);
# plt.subplot(1,2,2)
ax[0,1].plot(H.T)

In [None]:
idx = 3
plt.plot(A_full[idx,:])
plt.plot((Qs[-1]@Qs[-1].T@A_full)[idx,:])

# real data

In [None]:
import timeit
import numpy as np
import adaptive_latents.input_sources as ins
from adaptive_latents import default_rwd_parameters, Bubblewrap, SymmetricNoisyRegressor
from proSVD import proSVD

identifier = ins.datasets.individual_identifiers["buzaki"][0]
bin_width = 0.03
obs, position_data, obs_t, position_data_t = ins.datasets.construct_buzaki_data(individual_identifier=identifier, bin_width=bin_width)

position_data = ins.functional.resample_behavior(raw_behavior=position_data, bin_centers=obs_t, t=position_data_t)
position_data = position_data[:,:2]


In [None]:
plt.plot(position_data[:,0], position_data[:,1])

In [None]:
max_k = 6
stride = 5
A_sequence = []

to_partition = scipy.stats.zscore(np.hstack([obs, position_data]))
to_partition = obs
for i in range(100):
    new = to_partition[stride*i:stride*i+stride].T
    A_sequence.append(new)
    
A_full = np.column_stack(A_sequence)
QBWs = biased_pro_svd(A_sequence, max_k=max_k)
Qs, Bs, Ws = zip(*QBWs)

In [None]:
reconstruction = Qs[-1] @ Bs[-1] @ Ws[-1].T
fig, ax = plt.subplots(figsize=(20,5))
im = ax.imshow(reconstruction, aspect='auto', interpolation='nearest')
print(np.linalg.norm(reconstruction - A_full))
fig.colorbar(im);

In [None]:
plt.subplot(1,2,1)
plt.imshow(Qs[-1])
plt.subplot(1,2,2)
plt.plot(Qs[-1]);

In [None]:
plt.plot((Bs[-1] @ Ws[-1].T).T);

In [None]:
Qs_slice = Qs[1:]
plt.plot(np.arange(len(Qs_slice)-1)+1,np.linalg.norm(np.diff(np.array(Qs_slice), axis=0), axis=1));

In [None]:
from sklearn.decomposition import non_negative_factorization

In [None]:
W, H, _ = non_negative_factorization(A_full,n_components=max_k)

In [None]:
reconstruction = W @ H
fig, ax = plt.subplots(figsize=(20,5))
im = ax.imshow(reconstruction, aspect='auto', interpolation='nearest')
print(np.linalg.norm(reconstruction - A_full))
fig.colorbar(im);

In [None]:
plt.subplot(1,2,1)
plt.imshow(W)
plt.subplot(1,2,2)
plt.plot(W);

In [None]:
plt.plot(H.T)

### Procrustes

In [None]:
A = rng.random(size=(3,3))

In [None]:
fig, ax = plt.subplots()
ax.imshow(A);

In [None]:
magnitudes = []
points = []

random_mat_maker = special_ortho_group(dim=A.shape[0], seed=rng).rvs


for _ in range(10_000):
    # random_mat = rng.random(size=(A.shape[0],A.shape[0]))
    # q_qr,_ = np.linalg.qr(random_mat)
    # q_svd,_,_ = np.linalg.svd(random_mat)
    # q = q_qr
    # q = q[rng.permutation(q.shape[0])]
    
    q = random_mat_maker()
    m = ((A - q)**2).sum()
    
    
    
    assert np.allclose(q.T@q, np.eye(A.shape[0])) and np.allclose(q@q.T, np.eye(A.shape[0]))
    points.append(q[:,0])
    magnitudes.append(m)

In [None]:
points = np.array(points)
plt.plot(points[:,0], points[:,1], '.')
plt.axis("equal");

In [None]:
u, s, vh = np.linalg.svd(A)
q = u @ vh
m = ((A - q)**2).sum()
assert np.allclose(q.T@q, np.eye(A.shape[0])) and np.allclose(q@q.T, np.eye(A.shape[0]))

plt.axvline(m, color='red')
plt.hist(magnitudes, bins=100);