In [None]:
import adaptive_latents as al
from adaptive_latents import ArrayWithTime
from collections import deque
from adaptive_latents.input_sources import LDS, KalmanFilter
from adaptive_latents.transformer import StreamingTransformer
import matplotlib.pyplot as plt
from adaptive_latents import VJF, Bubblewrap, StreamingKalmanFilter
from adaptive_latents.regressions import BaseKNearestNeighborRegressor
import numpy as np

rng = np.random.default_rng()


In [None]:
_, X, stim = LDS.nest_dynamical_system(100, transitions_per_rotation=60 + 1/np.pi, u_function='constant')
# _, X, stim = LDS.nest_dynamical_system(100, transitions_per_rotation=60 + 1/np.pi, u_function=lambda **k: np.zeros(3))
stim.t = stim.t

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, subplot_kw={'projection': '3d'})
for ax in axs:
    ax.plot(X[:,0], X[:,1], X[:,2])
    ax.axis('equal')
axs[0].view_init(elev=90, azim=0, roll=0)
axs[1].view_init(elev=0, azim=0, roll=0)
plt.show()


In [None]:
class StimRegressor(StreamingTransformer):
    def __init__(self, autoreg=None, attempt_correction=True, input_streams=None, output_streams=None, log_level=None):
        input_streams = input_streams or {2:'X', 1:'stim', 0:'dt_X'}
        super().__init__(input_streams=input_streams, output_streams=output_streams, log_level=log_level)
        if autoreg is None:
            autoreg = StreamingKalmanFilter()
        self.autoreg = autoreg
        self.attempt_correction = attempt_correction
        self.stim_reg = BaseKNearestNeighborRegressor(k=2)
        self.last_seen_stims = deque(maxlen=1)


    def _partial_fit_transform(self, data, stream, return_output_stream):
        if self.input_streams[stream] == 'X':
            if self.last_seen_stims and np.array(self.last_seen_stims).any():
                prediction = self.autoreg.partial_fit_transform(np.array([[X.dt]]), 'dt_X')
                self.stim_reg.observe(self.autoreg.state, prediction - data)
            self.autoreg.partial_fit_transform(data, 'X')
        
        if self.input_streams[stream] == 'dt_X':
            data = self.autoreg.partial_fit_transform(data, 'dt_X')
            if self.last_seen_stims and np.array(self.last_seen_stims).any() and self.attempt_correction:
                # raise  Exception()
                pred = self.stim_reg.predict(self.autoreg.state)
                data = data * np.nan
            

        if self.input_streams[stream] == 'stim':
            self.last_seen_stims.append(data)

        return (data, stream) if return_output_stream else data


    def get_params(self, deep=True):
        return super().get_params(deep) | dict(autoreg=self.autoreg)



# StimRegressor().test_if_api_compatible();
        


In [None]:
n_steps = 1
qX = ArrayWithTime(X.dt * n_steps * np.ones((len(X),1)), (X.t + stim.t) / 2)

s1 = StimRegressor(autoreg=StreamingKalmanFilter(), attempt_correction=False)
o1 = s1.offline_run_on([qX, stim, X], show_tqdm=True)

s2 = StimRegressor(autoreg=StreamingKalmanFilter(), attempt_correction=True)
o2 = s2.offline_run_on([qX, stim, X], show_tqdm=True)



In [None]:
%matplotlib inline
_, axs = plt.subplots(squeeze=False)

for o in [o1, o2]:
    o = ArrayWithTime(o, o.t + o.dt*n_steps)
    error = ArrayWithTime.subtract_aligned_indices(o, X)
    s = error.time_to_sample(stim.t[stim.flatten() == 1])+0
    axs[0,0].plot(error.t[s], error[s,2], '.')
    