In [None]:
import numpy as np
import matplotlib.pyplot as plt
import adaptive_latents as al
from naumann_utility_functions import make_responses, find_decompositions

rng = np.random.default_rng()


In [None]:
d = al.datasets.Naumann24uDataset(sub_dataset_identifier=2)
model = find_decompositions(make_responses(d, non_nan=True), n_restarts=200)[0]

d.neural_data[np.isnan(d.neural_data)] = 0

In [None]:
d = al.datasets.Naumann24uDataset(sub_dataset_identifier=2)

d.neural_data[np.isnan(d.neural_data)] = 0
responses = make_responses(d, non_nan=True)

In [None]:
pro = None
jpca = None
p = al.Pipeline([
    al.KernelSmoother(tau=2),
    # al.CenteringTransformer(init_size=100),
    pro:=al.proSVD(k=3),
    # jpca:=al.sjPCA()
])

visual_stimuli_data = d.neural_data.slice(slice(0, d.last_visual_sample))
opto_stimuli_data = d.neural_data.slice(slice(d.last_visual_sample, -1))

online_output = p.offline_run_on(opto_stimuli_data)

for i, step in enumerate(p.steps):
    if isinstance(step, al.KernelSmoother):
        p.steps[i] = al.KernelSmoother(**step.get_params())
    else:
        step.freeze()
offline_output = p.offline_run_on(opto_stimuli_data)


In [None]:
%matplotlib qt

output = offline_output

fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.plot(output[:,0], output[:,1], output[:,2])
o = output.as_array()
for i, stim_t in enumerate(d.opto_stimulations.time[:-1]):
    stim = np.nonzero(output.t > stim_t)[0][0]
    s = slice(stim + 8, stim + 20)
    ax.plot(o[s, 0], o[s, 1], o[s, 2], 'C1')
    # ax.text(o[s.start,0], o[s.start,1], o[s.start,2], s=d.opto_stimulations.loc[i,'stim_name'])


ax.scatter(0,0,0, color='k')


arrow = np.vstack([np.zeros(3), np.ones(3) * 20]).T
ax.plot(arrow[0], arrow[1], arrow[2], color='C6', label='1s vector')


loading = model.factors[0]
v = loading @ pro.Q[:loading.size]
if jpca is not None:
    U = jpca.get_U()
    if pro.Q.shape[1] == 3:
        U_perp = np.linalg.cross(U[:,0], U[:,1]).reshape(-1,1)
        U = np.hstack([U[:,:2], U_perp])
    v = v @ U
 
arrow = np.vstack([np.zeros(v.shape[1]), v] ).T
ax.plot(arrow[0], arrow[1], arrow[2], color='C3', label='direction from TCA')


ax.legend()
ax.axis('equal')



In [None]:
loading = model.factors[0].T
loading = loading / np.linalg.norm(loading)

high_d_plane = (pro.Q @ jpca.get_U())[:,:2]
high_d_plane = high_d_plane[:loading.size,:]


angles = []
for _ in range(10_000):
    random_direction = rng.normal(size=(high_d_plane.shape[0], 1))
    random_direction = random_direction / np.linalg.norm(random_direction)
    angle = al.utils.column_space_distance(random_direction, high_d_plane)
    angles.append(angle * 180/np.pi)

angle = al.utils.column_space_distance(loading, high_d_plane) * 180 / np.pi

fig, ax = plt.subplots()
ax.axvline(angle, color='r')
ax.hist(angles, 100);

In [None]:
%matplotlib inline
fig, ax = plt.subplots()
ax.plot(model.factors[-1].T)
x = np.arange(24)
ax.plot(x,np.maximum(np.sin(x/6 - 1.15)/8.9,0))
ax.plot(x,np.maximum(-(x-16.2)**2/500 + .115,0))
