In [None]:
import adaptive_latents as al
from adaptive_latents import ArrayWithTime
from collections import deque
from adaptive_latents.input_sources import LDS, KalmanFilter
from adaptive_latents.transformer import StreamingTransformer
import matplotlib.pyplot as plt
from adaptive_latents import VJF, Bubblewrap, StreamingKalmanFilter
from adaptive_latents.predictor import Predictor
from adaptive_latents.regressions import BaseKNearestNeighborRegressor
import numpy as np
import pytest_timeout



In [None]:
rng = np.random.default_rng(0)

_, X, stim = LDS.nest_dynamical_system(100, transitions_per_rotation=60 + 1/np.pi, stims_per_rotation=.05, u_function='constant', rng=rng)
# _, X, stim = LDS.nest_dynamical_system(100, transitions_per_rotation=60 + 1/np.pi, u_function=lambda **k: np.zeros(3))

stim.t = stim.t 

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, subplot_kw={'projection': '3d'})
for ax in axs:
    ax.plot(X[:,0], X[:,1], X[:,2])
    ax.axis('equal')
axs[0].view_init(elev=90, azim=0, roll=0)
axs[1].view_init(elev=0, azim=0, roll=0)
plt.show()


In [None]:
class StimRegressor(StreamingTransformer):
    def __init__(self, autoreg=None, attempt_correction=True, input_streams=None, output_streams=None, log_level=None,):
        input_streams = input_streams or {1:'X', 0:'stim', 2:'dt_X'}
        super().__init__(input_streams=input_streams, output_streams=output_streams, log_level=log_level)
        if autoreg is None:
            autoreg = StreamingKalmanFilter()
        self.autoreg: Predictor = autoreg
        self.attempt_correction = attempt_correction
        self.stim_reg = BaseKNearestNeighborRegressor(k=2)
        self.last_seen_stims = deque(maxlen=1)


    def _partial_fit_transform(self, data, stream, return_output_stream):
        if self.input_streams[stream] == 'X':
            data_depth = 1
            assert data.shape[0] == data_depth, data.shape

            if np.isfinite(data).all():
                self.autoreg.observe(data, stream=self.input_streams[stream])
            data = ArrayWithTime.from_transformed_data(self.autoreg.get_state().reshape(data_depth,-1), data)

        elif self.input_streams[stream] == 'dt_X':
            steps = self.autoreg.data_to_n_steps(data)
            pred = self.autoreg.predict(n_steps=steps)
            
            if np.isfinite(pred).all():
                if np.any(self.last_seen_stims):
                    pred[2] = pred[2]  + 80
            
            data = ArrayWithTime.from_transformed_data(pred, data)

        elif self.input_streams[stream] == 'stim':
            self.last_seen_stims.append(data)

        return (data, stream) if return_output_stream else data
    
    def get_params(self, deep=True):
        return super().get_params(deep) | dict(autoreg=self.autoreg, attempt_correction=self.attempt_correction)

    # this is mostly for testing
    def expected_data_streams(self, rng, DIM):
        # TODO: do this better
        return [
            (rng.normal(size=(1, DIM)), 'X'),
            (np.ones((1,1)), 'dt_X'),
            (np.zeros((1,1)) * (rng.random() > .9), 'toggle_parameter_fitting'),
        ]

    # 
    # def _partial_fit_transform(self, data, stream, return_output_stream):
    #     if self.input_streams[stream] == 'X':
    #         if self.last_seen_stims and np.array(self.last_seen_stims).any():
    #             prediction = self.autoreg.partial_fit_transform(np.array([[X.dt]]), 'dt_X')
    #             self.stim_reg.observe(self.autoreg.state, prediction - data)
    #         self.autoreg.partial_fit_transform(data, 'X')
    #     
    #     if self.input_streams[stream] == 'dt_X':
    #         data = self.autoreg.partial_fit_transform(data, 'dt_X')
    #         if self.last_seen_stims and np.array(self.last_seen_stims).any() and self.attempt_correction:
    #             # raise  Exception()
    #             pred = self.stim_reg.predict(self.autoreg.state)
    #             data = data * np.nan
    #         
    # 
    #     if self.input_streams[stream] == 'stim':
    #         self.last_seen_stims.append(data)
    # 
    #     return (data, stream) if return_output_stream else data
    # 
    # 



StimRegressor().test_if_api_compatible();
        


In [None]:
n_steps = 1
qX = ArrayWithTime(n_steps * np.ones((stim.shape[0] - 1, 1)), (stim.t[1:] + X.t[1:])/2)

s1 = StimRegressor(autoreg=StreamingKalmanFilter(), attempt_correction=False, log_level=1)
o1 = s1.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)

# s2 = StimRegressor(autoreg=StreamingKalmanFilter(), attempt_correction=True)
# o2 = s2.offline_run_on([qX, stim, X], show_tqdm=True)



In [None]:
%matplotlib qt
_, axs = plt.subplots(squeeze=False)

for o in [o1]:
    axs[0,0].plot(X.t, X[:,2], '.')
    axs[0,0].plot(o.t, o[:,2], '.')

    # s = error.time_to_sample(stim.t[stim.flatten() == 1])+0
    # axs[0,0].plot(error.t[s], error[s,2], '.')
l,r = 33.46, 33.69

for t in stim.t[stim.flatten() == 1]:
    axs[0,0].axvline(t, color='k', linestyle='--')

# for t in X.t[(l <= X.t) & (X.t < r)]:
#     axs[0,0].axvline(t, alpha=.3, color='b')
# 
# for t in qX.t[(l <= qX.t) & (qX.t < r)]:
#     axs[0,0].axvline(t, alpha=.3, color='r')
    
axs[0,0].set_xlim(l,r)
axs[0,0].set_ylabel('prediction error')
axs[0,0].set_xlabel('time (a.u., technically radians)')

