In [None]:
import copy
import numpy as np
import adaptive_latents as al
from adaptive_latents.prediction_regression_run import pred_reg_run
import matplotlib.pyplot as plt
import functools
from collections import namedtuple
from scipy.linalg import null_space
from adaptive_latents import proSVD
from tqdm.autonotebook import tqdm
import warnings
import matplotlib as mpl
import cvxpy as cp
rng = np.random.default_rng()


In [None]:
dataset = al.datasets.Odoherty21Dataset()

In [None]:
np.sort([1,2,3])[-1]

In [None]:
stim_times = [100, 110, 200]

def normalize(x):
    return x/np.linalg.norm(x)

def get_pipeline_proj_matrix(p):
    for step in reversed(p.steps):
        if isinstance(step, al.proSVD):
            return step.Q.T
        if isinstance(step, al.sjPCA):
            warnings.warn('check that this order of ops is correct')
        if isinstance(step, al.mmICA):
            warnings.warn('check that this order of ops is correct')

    warnings.warn('check that this shape works')
    return 1

def give_decay(x, decay_time=6, divisor=1.5):
    response_time = np.arange(decay_time)
    response_decay = np.exp(-response_time/divisor)
    return response_decay[:,None] @ x.flatten()[None,:]

def in_out_ratio(pipeline, ratio=.5, magnitude=10):
    proj_mat = get_pipeline_proj_matrix(p=pipeline)
    x = np.ones(shape=proj_mat.shape[1])
    x = x - proj_mat.T @ proj_mat @ x
    x_orth = x.flatten() / np.linalg.norm(x)
    x_in = proj_mat[0].flatten()
    response_direction = (ratio*x_in + (1-ratio)*x_orth) * magnitude
    return give_decay(response_direction)
    
def single_neuron_null(pipeline, magnitude=10):
    proj_mat = get_pipeline_proj_matrix(p=pipeline)
    idx = np.argmin(np.abs(proj_mat).sum(axis=0))
    response_direction = np.zeros(proj_mat.shape[1])
    response_direction[idx] = magnitude
    return give_decay(response_direction)

def towards_null_direction(pipeline, magnitude=10):
    proj_mat = get_pipeline_proj_matrix(p=pipeline)
    x = np.ones(shape=proj_mat.shape[1])
    x = x - proj_mat.T @ proj_mat @ x
    x = x / np.linalg.norm(x)
    response_direction = x * magnitude
    
    return give_decay(response_direction)

def towards_prosvd_direction(pipeline, magnitude=10, component=0):
    mat = get_pipeline_proj_matrix(p=pipeline)
    response_direction = normalize(mat[component]) * magnitude
    
    return give_decay(response_direction)

def towards_prosvd_direction_keep_top(pipeline, magnitude=10, component=0, zero_negative=False, keep_top=30):
    mat = get_pipeline_proj_matrix(p=pipeline)
    response_direction = mat[component]
    threshold = np.sort(response_direction**2)[-keep_top]
    response_direction[response_direction**2 < threshold] = 0
    if zero_negative:
        response_direction[response_direction < 0] = 0
    response_direction = normalize(response_direction) * magnitude

    return give_decay(response_direction)

def exp_resp(pipeline, magnitude=10):
    response_direction = np.mean(dataset.neural_data, axis=0) * magnitude
    response_direction = normalize(response_direction)
    
    return give_decay(response_direction)

def zero_resp(pipeline):
    response_direction = np.zeros(dataset.neural_data.shape[1])
    
    return give_decay(response_direction)


In [None]:
%matplotlib qt
p = al.proSVD(k=2)

points = rng.normal(size=(300,3)) * np.array([.1,1,1])
for x in points.reshape(3,-1,3):
    p.partial_fit_transform(x)

x = lambda: None
x.steps = [p]

response = towards_null_direction(x, magnitude=5)

fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.scatter(*points.T)


ax.scatter(*response.T, color='r')
ax.set_xlabel('null axis')
ax.set_ylabel('manifold axis 1')
ax.set_zlabel('manifold axis 2')

ax.set_title('null space stimulation control')

ax.legend(['observed points', 'stimulation'])

ax.axis('equal');



In [None]:
def evaluate_in_series(stim_times, make_response):

    stim_passed = [False for _ in stim_times]

    evaluation = pred_reg_run(
        neural_data=dataset.neural_data, 
        behavioral_data=dataset.behavioral_data[:,:0], 
        target_data=dataset.behavioral_data[:,:0],
        dim_red_method='pro',
        predict=False, evaluate=False)

    pipeline = evaluation.pipeline

    pending_responses = []

    outputs = []
    for d, s in al.Pipeline().streaming_run_on(evaluation.sources, return_output_stream=True):
        if s == 0:
            for i in range(len(stim_times)):
                if not stim_passed[i] and d.t > stim_times[i]:
                    response = make_response(pipeline)
                    pending_responses.append(list(response))
                    stim_passed[i] = True
                    
            for r in pending_responses:
                if r:
                    d = d + r.pop(0)
        
        d, s = evaluation.pipeline.partial_fit_transform(d, s, return_output_stream=True)
        if s == 0:
            outputs.append(d)
        
        if d.t > stim_times[-1] + 10:
            break
            
    latents = al.ArrayWithTime.from_list(outputs, squeeze_type="to_2d")
    
    return latents, evaluation
        

In [None]:
def subtract_aligned_indices(a,b):
    a, b = (a, b) if a.t[0] < b.t[0] else (b, a)
    sorter = np.argsort(a.t)
    idx_a = sorter[np.searchsorted(a.t, b.t, sorter=sorter)]
    return al.ArrayWithTime(b - a[idx_a], a.t[idx_a])

In [None]:
Sim = namedtuple('Sim', ['pipeline', 'response', 'outputs', 'expire_time'])

def evaluate_in_parallel(stim_times, make_response, end_t=200, response_cutoff=5, freeze_prosvd=False):

    stim_times = list(sorted(stim_times))
    assert len(stim_times) == len(set(stim_times))

    evaluation = pred_reg_run(
        neural_data=dataset.neural_data,
        behavioral_data=dataset.behavioral_data[:,:0],
        target_data=dataset.behavioral_data[:,:0],
        dim_red_method='pro',
        predict=False, evaluate=False)

    sims = [Sim(evaluation.pipeline, [], [], end_t)]


    for d, s in al.Pipeline().streaming_run_on(evaluation.sources, return_output_stream=True):
        if s == 0 and len(stim_times) and d.t > stim_times[0]:
            stim_times.pop(0)
            response = make_response(sims[0].pipeline)
            sims.append(Sim(copy.deepcopy(sims[0].pipeline), list(response), [], response_cutoff))
            if freeze_prosvd:
                sims[-1].pipeline.steps[-3].freeze()
            

        for sim in sims:
            if d.t - (sim.outputs[0].t if sim.outputs else d.t) > sim.expire_time:
                continue
            inner_d = copy.deepcopy(d)
            inner_s = copy.deepcopy(s)
            if inner_s == 0 and len(sim.response):
                inner_d += sim.response.pop(0)
            
            inner_d, inner_s = sim.pipeline.partial_fit_transform(inner_d, inner_s, return_output_stream=True)
            
            if inner_s == 0:
                sim.outputs.append(inner_d)
        

        if not len(sims):
            break

    latents = [al.ArrayWithTime.from_list(sim.outputs, squeeze_type="to_2d") for sim in sims]

    return latents, evaluation
    

In [None]:
ys = np.linspace(.1,50,10)
xs = np.linspace(0,1,7)
rrs = []
for freeze in [False, True]:
    rr = []
    for y in tqdm(ys):
        responses = []
        for x in xs:
            latents, _ = evaluate_in_parallel([100], functools.partial(in_out_ratio, ratio=x, magnitude=y), end_t=120, freeze_prosvd=freeze)
            response = np.sqrt((subtract_aligned_indices(latents[1].slice(None,60), latents[0])**2).mean())
            responses.append(response)
        responses = np.array(responses)/responses[-1]
        rr.append(responses)
    rr = np.array(rr)
    rrs.append(rr)
rrs = np.array(rrs)

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, figsize=(10,4), sharex=True, sharey=True)

colors = plt.cm.jet(np.linspace(0,1,rr.shape[0]))
for ax, inner_rr in zip(axs, rrs):
    for idx, rrr in enumerate(inner_rr):
        ax.plot(xs, rrr, color=colors[idx])
        
    ax.set_xlabel('proportion of stimulation in-space')
    ax.set_ylabel('magnitude of total response (compared to in-space stim)')
ax.legend([f'{y:.1f}'for y in ys], title='stim. magnitude')
ax.set_title('large stimulations scale nonlinearly')


In [None]:
difference = subtract_aligned_indices(latents[1].slice(None,60), latents[0])



fig, ax = plt.subplots()
ax.plot(difference.t, difference)
ax.set_xlabel('experiment time')
ax.set_ylabel('difference magnitude')
ax.set_title('Difference between stimulated and unstimulated latents')



In [None]:
difference = subtract_aligned_indices(latents[1].slice(None,60), latents[0])

fig, axs = plt.subplots(ncols=2, figsize=(10,5), sharex=True)
axs[1].plot(latents[1].t, latents[1])
axs[0].plot(latents[0].t, latents[0])
axs[0].set_xlim([99.9, 105.1])
axs[0].set_xlabel('experiment time')
axs[1].set_xlabel('experiment time')

axs[0].set_title('unstimulated latents')
axs[1].set_title('stimulated latents')


axs[0].set_ylabel('latent magnitude (a.u.)')



In [None]:

n_components = [1,5,10,20, 50, 130]
latents = []
for n in tqdm(n_components):
    l, _ = evaluate_in_parallel([100], functools.partial(towards_prosvd_direction_keep_top, magnitude=20, keep_top=n, zero_negative=True and (not n==130)), end_t=120, freeze_prosvd=True)
    latents.append(l[1])


In [None]:
fig, axs = plt.subplots(nrows=3, ncols=len(n_components), sharex=True, sharey='row', figsize=(15,5))

for i in range(len(latents)):
    axs[0, i].plot(latents[i])
    axs[1, i].plot(latents[i] - latents[-1])
    axs[2, i].plot(subtract_aligned_indices(latents[i],  l[0]))
    axs[0, i].set_title(f'kept top {n_components[i]}')
axs[0, -1].set_title(f'full vector')
axs[0,0].set_ylabel('resp.')
axs[1,0].set_ylabel('resp. - resp.[-1]')
axs[2,0].set_ylabel('resp. - no stim')



In [None]:
%matplotlib qt
fig, ax = plt.subplots()
l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=10))

ax.plot(l1.t, l1, '.-')
ax.set_xlim([99, 104])
ax.set_title("Stimulated latents - unstimulated latents")

In [None]:
%matplotlib qt
fig, ax = plt.subplots()

l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=100))
l2, _ = evaluate_in_series(stim_times, zero_resp)

ax.plot(l1.t, l1 - l2, '.-');
ax.set_xlim([99, 104])
ax.set_title("Stimulated latents - unstimulated latents")

In [None]:
np.nanstd(l2, axis=0)

In [None]:
%matplotlib inline

differences = []
magnitudes = [.5,1,2,4,8, 16, 32, 64, 128]
for m in tqdm(magnitudes):
    l1, _ = evaluate_in_series(stim_times, functools.partial(towards_null_direction, magnitude=m))
    l2, _ = evaluate_in_series(stim_times, zero_resp)
    differences.append(np.nanmean(np.abs(l1 - l2)))

fig, ax = plt.subplots()
ax.plot(magnitudes, differences, '.')
ax.set_xlabel('stimulation magnitude')
ax.set_ylabel('perturbation magnitude')

In [None]:
%matplotlib qt
plt.plot(l1.t, l1);