In [None]:
import adaptive_latents as al
from adaptive_latents import ArrayWithTime
from adaptive_latents.input_sources import LDS
import matplotlib.pyplot as plt
from adaptive_latents import VJF, Bubblewrap, KalmanFilter
import numpy as np

rng = np.random.default_rng()


In [None]:
b = Bubblewrap()

b.input_streams

In [None]:
v = VJF()

In [None]:
_, X, stim = LDS.nest_dynamical_system(100, transitions_per_rotation=60 + 1/np.pi)
stim.t = stim.t

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, subplot_kw={'projection': '3d'})
for ax in axs:
    ax.plot(X[:,0], X[:,1], X[:,2])
    ax.axis('equal')
axs[0].view_init(elev=90, azim=0, roll=0)
axs[1].view_init(elev=0, azim=0, roll=0)
plt.show()


In [None]:
from adaptive_latents.transformer import DecoupledTransformer
from adaptive_latents.regressions import BaseVanillaOnlineRegressor, BaseKNearestNeighborRegressor

In [None]:
class StimRegressor(DecoupledTransformer):
    def __init__(self, input_streams=None, spatial_stim_response=True, *args, **kwargs):
        input_streams = input_streams or {}
        super().__init__(input_streams=input_streams, *args, **kwargs)
        self.reg = BaseVanillaOnlineRegressor()
        self.stim_reg = BaseKNearestNeighborRegressor(k=2)
        self.spatial_stim_response = spatial_stim_response
        self.last_seen = None
        self.last_seen_stim = None
        self.predictions = []
        self.auto_pred = []
        self.stim_pred = []

    def _partial_fit(self, data, stream):
        if self.input_streams[stream] == 'X':
            if self.last_seen is not None:
                auto_pred = ArrayWithTime(self.reg.predict(self.last_seen), data.t)

                stim_pred = np.zeros(shape=data.shape)
                if self.last_seen_stim.any() and self.spatial_stim_response:
                    stim_pred += self.stim_reg.predict(self.last_seen)
                stim_pred = ArrayWithTime(stim_pred, data.t)

                prediction = auto_pred + stim_pred
                self.predictions.append(prediction)
                self.auto_pred.append(auto_pred)
                self.stim_pred.append(stim_pred)

                if not self.last_seen_stim.any():
                    self.reg.observe(self.last_seen, data)
                else:
                    prediction = self.reg.predict(self.last_seen)
                    self.stim_reg.observe(self.last_seen, data-prediction)

                    prediction = ArrayWithTime(prediction, data.t)

            self.last_seen = data

        if self.input_streams[stream] == 'stim':
            self.last_seen_stim = data


    def transform(self, data, stream=0, return_output_stream=False):
        return data, stream
        


In [None]:
s1 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=True)
s1.offline_run_on([X, stim], show_tqdm=True)

s2 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=False)
s2.offline_run_on([X, stim], show_tqdm=True);


In [None]:
%matplotlib inline
_, axs = plt.subplots(squeeze=False)


def operation(x):
    return al.ArrayWithTime.from_list(x, drop_early_nans=True, squeeze_type='to_2d')[:, slice(-1,None)]

for s in [s2,s1]:
    pred = operation(s.predictions)
    sl = stim[-pred.shape[0]:]==1
    error = (pred - X[-pred.shape[0]:])
    axs[0,0].plot(error.t[sl], error[sl,-1], '.')
axs[0,0].legend(['no stim regression', 'with stim regression'])
axs[0,0].set_title('Error comparison for timepoints with stimulation')
