In [14]:
import numpy as np
import matplotlib.pyplot as plt
from adaptive_latents import CenteringTransformer, Pipeline, proSVD, sjPCA, mmICA, AnimationManager, KernelSmoother, proPLS, Bubblewrap
import adaptive_latents as al
from scipy.stats import special_ortho_group
from adaptive_latents import datasets
from tqdm.notebook import tqdm
import IPython.display as ipd


TostadoMarcos24Dataset = datasets.TostadoMarcos24Dataset
rng = np.random.default_rng(0)

## Runs proPLS

In [16]:
d = TostadoMarcos24Dataset(TostadoMarcos24Dataset.sub_datasets[0])

pls = proPLS(k=20)
p = Pipeline([
    CenteringTransformer(input_streams={0:'X'}),
    CenteringTransformer(input_streams={1:'X'}),
    pls,
])
streams = p.offline_run_on([d.neural_data, d.behavioral_data], convinient_return=False)
ts = {}
for stream in streams:
    while np.isnan(streams[stream][0]).any():
        streams[stream].pop(0)
    ts[stream] = np.array([x.t for x in streams[stream]])
    streams[stream] = np.squeeze(streams[stream])
    

## jPCA plane convergence in simulated data

In [17]:
%matplotlib qt

m= 50
n = 3
t = np.linspace(0, (m / 50) * np.pi * 2, m)
circle = np.column_stack([np.cos(t), np.sin(t)]) @ np.diag([10, 10])
C = special_ortho_group(dim=n, seed=rng).rvs()[:, :2]
orth = np.cross(*C.T)
X = (circle @ C.T) + rng.normal(size=(m, n)) + rng.normal(size=(m, 1)) * orth * 0

jpca = sjPCA()
p = Pipeline([
    jpca
])

with AnimationManager(projection='3d', figsize=(5,5)) as am:
    ax = am.axs[0,0]
    outputs = np.zeros_like(X)
    for i, output in enumerate(p.streaming_run_on(X)):
        outputs[i] = output
        
        ax.cla()
        ax.scatter(X[:i,0], X[:i,1], X[:i,2])

        U = jpca.get_U()[:,:2]
        mesh_a, mesh_b = np.meshgrid(np.linspace(-10,10,2), np.linspace(-10,10,2))

        mesh_X, mesh_Y, mesh_Z = (mesh_a[None,:,:].T * U[:,0] + mesh_b[None,:,:].T * U[:,1]).T

        ax.plot_surface(mesh_X, mesh_Y, mesh_Z, alpha=.1)
                
                
        U = U * 10
        ax.plot([0, U[0,0]], [0, U[1,0]], [0, U[2,0]], color='k')
        ax.plot([0, U[0,1]], [0, U[1,1]], [0, U[2,1]], color='r')

        ax.axis('equal')
        ax.set_xlim(-20,20)
        ax.set_ylim(-20,20)
        ax.set_zlim(-20,20)
        
        am.grab_frame()
        
        
        

## jPCA plane convergence in real data

In [32]:
%%capture
d = datsets.Leventhal24uDataset()

centerer = CenteringTransformer()
pro = proSVD(k=3, whiten=False)
jpca = sjPCA()

p = Pipeline([
    KernelSmoother(tau=8),
    centerer,
    pro,
    KernelSmoother(tau=8),
])
X = p.offline_run_on(d.neural_data.a[1000:2000])

In [33]:
%matplotlib qt
"""Shows off finding a plane in real data."""

sub_X = X

with AnimationManager(figsize=(10,5), make_axs=False, fps=40) as am:
    ax1 = am.fig.add_subplot(1,2,1, projection='3d')
    ax2 = am.fig.add_subplot(1,2,2)
    outputs = np.zeros_like(sub_X)
    for i, output in enumerate(tqdm(jpca.streaming_run_on(sub_X))):
        outputs[i] = output
        
        ax1.cla()
        ax1.scatter(sub_X[:i,0], sub_X[:i,1], sub_X[:i,2], s=1, edgecolors=None)
        
        tail_points = sub_X[max(i-5,0):i]
        ax1.plot(tail_points[:,0], tail_points[:i,1], tail_points[:i,2], color='C0')

        U = jpca.get_U()[:,:2]
        mesh_a, mesh_b = np.meshgrid(np.linspace(-1,1,2), np.linspace(-1,1,2))

        mesh_X, mesh_Y, mesh_Z = (mesh_a[None,:,:].T * U[:,0] + mesh_b[None,:,:].T * U[:,1]).T

        ax1.plot_surface(mesh_X, mesh_Y, mesh_Z, alpha=.1)


        U = U * .1
        ax1.plot([0, U[0,0]], [0, U[1,0]], [0, U[2,0]], color='k')
        ax1.plot([0, U[0,1]], [0, U[1,1]], [0, U[2,1]], color='r')

        ax1.axis('equal')
        ax1.set_xlim(-1,1)
        ax1.set_ylim(-1,1)
        ax1.set_zlim(-1,1)
        
        ax2.cla()
        reprojected = jpca.transform(sub_X[:i])
        ax2.scatter(reprojected[:i,0], reprojected[:i,1], s=2, edgecolors=None)

        tail_points = reprojected[max(i-5,0):]
        ax2.plot(tail_points[:,0], tail_points[:i,1], color='C0')

        ax2.plot([0, .1], [0, 0], color='k')
        ax2.plot([0, 0], [0, .1], color='r')
        
        ax2.set_xlim(-1,1)
        ax2.set_ylim(-1,1)
        ax2.set_xticks([])
        ax2.set_yticks([])

        am.grab_frame()
        
# ffmpeg -i movie_2024-08-19-13-33-52.mp4 -vf "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" output.gif      
        

0it [00:00, ?it/s]