In [None]:
from adaptive_latents import ArrayWithTime
from collections import deque
from adaptive_latents.input_sources import LDS, KalmanFilter
from adaptive_latents.transformer import StreamingTransformer
import matplotlib.pyplot as plt
from adaptive_latents import VJF, Bubblewrap, StreamingKalmanFilter
from adaptive_latents.predictor import Predictor
from adaptive_latents.regressions import BaseKNearestNeighborRegressor
import numpy as np
import itertools



In [None]:
rng = np.random.default_rng(1)

_, X, stim = LDS.nest_dynamical_system(500, transitions_per_rotation=60 + 1/np.pi, u_function='curvy', rng=rng, early_shift=1e-8)

stim.t = stim.t 

In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, subplot_kw={'projection': '3d'})
for ax in axs:
    ax.plot(X[:,0], X[:,1], X[:,2])
    ax.axis('equal')
axs[0].view_init(elev=90, azim=0, roll=0)
axs[1].view_init(elev=0, azim=0, roll=0)
plt.show()


In [None]:
class StimRegressor(StreamingTransformer):
    def __init__(self, autoreg=None, stim_reg=None, attempt_correction=True, input_streams=None, output_streams=None, log_level=None,):
        input_streams = input_streams or {1:'X', 0:'stim', 2:'dt_X'}
        super().__init__(input_streams=input_streams, output_streams=output_streams, log_level=log_level)
        if autoreg is None:
            autoreg = StreamingKalmanFilter()
        self.autoreg: Predictor = autoreg
        self.attempt_correction = attempt_correction
        if stim_reg is None:
            stim_reg = BaseKNearestNeighborRegressor(k=2)
        self.stim_reg = stim_reg
        self.last_seen_stims = deque(maxlen=1)


    def _partial_fit_transform(self, data, stream, return_output_stream):
        if self.input_streams[stream] == 'X':
            data_depth = 1
            assert data.shape[0] == data_depth, data.shape
            
            if np.isfinite(data).all():
                if self.last_seen_stims and self.last_seen_stims[-1]:
                    pred = self.autoreg.predict(n_steps=1)
                    residual = data - pred
                    self.stim_reg.observe(pred, residual)
                    
                    self.autoreg.toggle_parameter_fitting(False)
                    self.autoreg.observe(data, stream=self.input_streams[stream])
                    self.autoreg.toggle_parameter_fitting(True)
                else:
                    self.autoreg.observe(data, stream=self.input_streams[stream])
                
            data = ArrayWithTime.from_transformed_data(self.autoreg.get_state().reshape(data_depth,-1), data)

        elif self.input_streams[stream] == 'dt_X':
            steps = self.autoreg.data_to_n_steps(data)
            pred = self.autoreg.predict(n_steps=steps)
            
            if np.isfinite(pred).all():
                if self.last_seen_stims and  self.last_seen_stims[-1] and self.attempt_correction:
                    pred = pred + self.stim_reg.predict(pred)
            
            data = ArrayWithTime.from_transformed_data(pred, data)

        elif self.input_streams[stream] == 'stim':
            self.last_seen_stims.append(data)

        return (data, stream) if return_output_stream else data
    
    def get_params(self, deep=True):
        return super().get_params(deep) | dict(autoreg=self.autoreg, stim_reg=self.stim_reg, attempt_correction=self.attempt_correction)

    # this is mostly for testing
    def expected_data_streams(self, rng, DIM):
        # TODO: do this better
        return [
            (rng.normal(size=(1, DIM)), 'X'),
            (np.ones((1,1)), 'dt_X'),
            (np.zeros((1,1)) * (rng.random() > .9), 'toggle_parameter_fitting'),
        ]

StimRegressor().test_if_api_compatible();
        


In [None]:
n_steps = 1
qX = ArrayWithTime(n_steps * np.ones((stim.shape[0] - 1, 1)), (stim.t[1:] + X.t[1:])/2)

s1 = StimRegressor(autoreg=StreamingKalmanFilter(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)
o1 = s1.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)

s2 = StimRegressor(autoreg=Bubblewrap(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)
o2 = s2.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)

s3 = StimRegressor(autoreg=VJF(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=True)
o3 = s3.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)

s4 = StimRegressor(autoreg=StreamingKalmanFilter(), stim_reg=BaseKNearestNeighborRegressor(k=10, maxlen=200), attempt_correction=False)
o4 = s4.offline_run_on([stim, X, qX], show_tqdm=True, convinient_return=2)



In [None]:
%matplotlib inline
_, axs = plt.subplots(squeeze=False)

order = list(reversed(['kalman filter', 'bubblewrap', 'vjf', 'kalman filter (no correction)']))

test_dim = 0
on_stim = 0
s = (stim.t[stim.flatten() == on_stim], test_dim)
for o, name in zip([o4, o3, o2, o1], order):
    error = ArrayWithTime.subtract_aligned_indices(o, X)
    to_plot = error.slice_by_time(*s, all_axes=True)
    true = X.slice_by_time(*s, all_axes=True)
    
    print(f"{name: <30} {np.nanstd(to_plot)/ np.nanstd(true):.2f}")
    axs[0,0].plot(to_plot.t, to_plot, '.')



axs[0,0].legend(order)
axs[0,0].set_title(f'comparison of d{test_dim} on {"stim" if on_stim else "non-stim"} samples')
axs[0,0].set_ylabel('prediction error')
axs[0,0].set_xlabel('time (a.u., technically rotations)');



In [None]:
%matplotlib qt
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

XX, YY = np.meshgrid(np.linspace(-10,10,14), np.linspace(-10,10,12))
Z = 0 * XX

def S(state, s):
    return s * state[0] / np.linalg.norm(state[:2])

for i_x, i_y in itertools.product(range(XX.shape[0]), range(XX.shape[1])):
    Z[i_x, i_y] = S([XX[i_x,i_y], YY[i_x,i_y], None], 1)

ax.plot_surface(XX, YY, Z, zorder=10)
ax.plot(s1.stim_reg.history[:,0], s1.stim_reg.history[:,1], s1.stim_reg.history[:,5], '.', zorder=10)


In [None]:
%matplotlib qt
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

n_points = 100

XX, YY = np.meshgrid(np.linspace(-10,10,n_points), np.linspace(-10,10,n_points))
Z = 0 * XX

def S(state, s):
    return s * state[0] / np.linalg.norm(state[:2])

for i_x, i_y in itertools.product(range(XX.shape[0]), range(XX.shape[1])):
    Z[i_x, i_y] = s1.stim_reg.predict(np.array([XX[i_x, i_y], YY[i_x, i_y], 0]))[-1] #  - S([XX[i_x,i_y], YY[i_x,i_y], None], 1)

ax.plot_surface(XX, YY, Z, zorder=10)
print(Z.mean())
