In [None]:
import numpy as np
import matplotlib.pyplot as plt
from adaptive_latents import ArrayWithTime
import adaptive_latents as al
import pandas as pd
import itertools

from adaptive_latents.transformer import DecoupledTransformer, Concatenator
from adaptive_latents.regressions import BaseVanillaOnlineRegressor, BaseKNearestNeighborRegressor, OnlineRegressor


rng = np.random.default_rng()

## StimRegressor

In [None]:
class StimRegressor(DecoupledTransformer):
    def __init__(self, input_streams=None, spatial_stim_response=True, *args, **kwargs):
        input_streams = input_streams or {}
        super().__init__(input_streams=input_streams, *args, **kwargs)
        self.reg = BaseVanillaOnlineRegressor()
        self.stim_reg = BaseKNearestNeighborRegressor(k=2)
        self.spatial_stim_response = spatial_stim_response
        self.last_seen = None
        self.last_seen_stim = None
        self.predictions = []
        self.auto_pred = []
        self.stim_pred = []

    def _partial_fit(self, data, stream):
        if self.input_streams[stream] == 'X':
            if self.last_seen is not None:
                auto_pred = ArrayWithTime(self.reg.predict(self.last_seen), data.t)
                
                stim_pred = np.zeros(shape=data.shape)
                if self.last_seen_stim.any() and self.spatial_stim_response:
                    stim_pred += self.stim_reg.predict(self.last_seen)
                stim_pred = ArrayWithTime(stim_pred, data.t)
                
                prediction = auto_pred + stim_pred
                self.predictions.append(prediction)
                self.auto_pred.append(auto_pred)
                self.stim_pred.append(stim_pred)
                
                if not self.last_seen_stim.any():
                    self.reg.observe(self.last_seen, data)
                else:
                    prediction = self.reg.predict(self.last_seen)
                    self.stim_reg.observe(self.last_seen, data-prediction)

                    prediction = ArrayWithTime(prediction, data.t)
                
            self.last_seen = data

        if self.input_streams[stream] == 'stim':
            self.last_seen_stim = data
        

    def transform(self, data, stream=0, return_output_stream=False):
        return data, stream
        

## Trivial manual example

In [None]:

X = np.array([
    [0,0], 
    [1,1], 
    [2,2], 
    [3,3], 
    [4,4], 
    [5,7], # stim
    [6,8], 
    [7,11], # stim
    [8,12], 
    [9,15], # stim
    [10,16],
])[:,1:2]
stim = np.zeros(shape=(X.shape[0], 1))
stim[5] = 1
stim[7] = 1
stim[9] = 1


X = al.ArrayWithTime(X, np.arange(X.shape[0]))
stim = al.ArrayWithTime(stim, np.arange(stim.shape[0]) - 0.001)

s = StimRegressor(input_streams={1:'X', 0:'stim'}, spatial_stim_response=True)
s.offline_run_on([stim, X]);

In [None]:
def operation(x):
    return al.ArrayWithTime.from_list(x, drop_early_nans=True, squeeze_type='to_2d')[:, slice(-1,None)]

pred = operation(s.predictions)
p1 = operation(s.auto_pred)
p2 = operation(s.stim_pred[1:])
true = operation(X.slice(-pred.shape[0],None))
st = stim.slice(2,None)

In [None]:

df = pd.DataFrame(np.hstack([st, p1, p2, pred, true]))
d = true.shape[-1]
df.columns = ['stim'] + ['p1']*d + ['p2']*d + ['pred']*d + ['true']*d
df

In [None]:
fig, axs = plt.subplots(squeeze=False)
axs[0,0].plot(X.t, X[:,-1], '.-')
# axs[0,0].plot(stim.t, stim,'.')
axs[0,0].plot(pred.t, pred, '.-')



## Nest example

In [None]:
%matplotlib qt
dt = 0.05

A = np.array([
    [np.cos(dt),  -np.sin(dt), 0],
    [np.sin(dt),   np.cos(dt), 0],
    [         0,            0, .99]
])

def C(x,y):
    return y * x[0] / np.linalg.norm(x[:2])

ts = np.arange(0, 500*2.1*np.pi, dt)

stim = (ts * 0).reshape(-1, 1)

stim[np.random.default_rng(0).choice(stim.shape[0], size=(ts.max()*.1).astype(int), replace=False)] = 1

X_true = np.zeros((ts.size, 3))
X_true[0,1] = 3

for i, t in enumerate(ts):
    if i == 0:
        continue
    X_true[i] = A @ X_true[i-1]
    X_true[i,2] += C(X_true[i-1], stim[i])
    X_true[i] += rng.normal(0, 0.01, X_true[i].shape)

X = X_true + rng.normal(0, 0.01, X_true.shape)

X = ArrayWithTime(X, ts)
stim = ArrayWithTime(stim, ts-1e-8)

fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.plot(X[:,0], X[:,1], X[:,2])
ax.axis('equal')
plt.show()


In [None]:
s1 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=True)
s1.offline_run_on([X, stim], show_tqdm=True)

s2 = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=False)
s2.offline_run_on([X, stim], show_tqdm=True);


In [None]:
%matplotlib inline
_, axs = plt.subplots(nrows=2,squeeze=False, sharex=True)

for i,s in enumerate([s2, s1]):
    pred = operation(s.predictions)
    axs[i,0].plot(X.t, X[:,-1])
    axs[i,0].plot(pred.t, pred[:,-1])
    axs[i,0].legend(['true system', 'predicted'])
    axs[i,0].set_xlim([2511.5,2513])
    axs[i,0].set_ylim([-1,2])
    
axs[0,0].set_title('no stim regression')
axs[1,0].set_title('with stim regression')



In [None]:
%matplotlib inline
_, axs = plt.subplots(squeeze=False)

for s in [s2,s1]:
    pred = operation(s.predictions)
    sl = stim[-pred.shape[0]:,0]==1
    error = (pred - X[-pred.shape[0]:])
    axs[0,0].plot(error.t[sl], error[sl,-1], '.')
axs[0,0].legend(['no stim regression', 'with stim regression'])
axs[0,0].set_title('Error comparison for timepoints with stimulation')



In [None]:
%matplotlib qt
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

X, Y = np.meshgrid(np.linspace(-6,6,10), np.linspace(-6,6,10))
Z = 0 * X

for i_x, i_y in itertools.product(range(10), range(10)):
    Z[i_x, i_y] = C([X[i_x,i_y], Y[i_x,i_y], None], 1)

ax.plot_surface(X, Y, Z, zorder=10)
plt.plot(s1.stim_reg.history[:,0], s1.stim_reg.history[:,1], s1.stim_reg.history[:,5], '.', zorder=10)

In [None]:
np.unique(stim)