In [None]:
import numpy as np
import matplotlib.pyplot as plt
from adaptive_latents import ArrayWithTime
from adaptive_latents.input_sources.autoregressor import AdamOptimizer
import adaptive_latents as al
import pandas as pd

from adaptive_latents.transformer import DecoupledTransformer, Concatenator
from adaptive_latents.regressions import BaseVanillaOnlineRegressor, BaseNearestNeighborRegressor


rng = np.random.default_rng()

In [None]:
class AR_K:
    def __init__(self, *, k=1, rank_limit=None, rng=None, init_method='full_rank', iter_limit=200):
        # type: (AR_K, int, int | None, Any, Literal['full_rank', 'random'], int) -> None
        self.k = k
        self.rank_limit = rank_limit
        self.init_method = init_method
        self.iter_limit = iter_limit

        self.neuron_d = None
        self.stim_d = None

        self.As = None
        self.Bs = None
        self.Cs = None
        self.v = None

        self._X = None
        self._Y = None

        self.rng = rng or np.random.default_rng(0)

    def fit(self, activity, stim):
        self.neuron_d = activity.shape[1]
        self.stim_d = stim.shape[1]

        assert self.rank_limit is None or not self.rank_limit > self.neuron_d

        Y = []
        X = []
        for i in range(self.k, activity.shape[0]):
            Y.append(activity[i])
            x = [activity[i - self.k:i].reshape(1, -1), stim[i - self.k:i].reshape(1, -1), [[1]]]
            X.append(np.hstack(x)[0])

        Y = np.array(Y)
        X = np.array(X)

        As, Bs, v = self._fit_full_rank(X, Y)

        self.As = As
        self.Bs = Bs
        self.v = v
        self._X = X
        self._Y = Y

    def _fit_full_rank(self, X, Y):
        beta, residuals, rank, s = np.linalg.lstsq(X, Y)

        As = beta[:self.neuron_d * self.k].reshape(self.k, self.neuron_d, self.neuron_d)
        Bs = beta[self.neuron_d * self.k:self.neuron_d * self.k + self.stim_d * self.k].reshape(self.k, self.stim_d,
                                                                                                self.neuron_d)
        v = beta[self.neuron_d * self.k + self.stim_d * self.k:]

        return As, Bs, v

    def predict(self, initial_observations, stim, n_steps=100):
        new = np.zeros(shape=(n_steps + self.k, self.neuron_d)) * np.nan
        new[:self.k] = initial_observations[-self.k:]

        for i in np.arange(n_steps) + self.k:
            new[i] = self.v
            new[i] = new[i] + ((new[i - self.k:i, None] - self.v) @ self.As).sum(axis=0)
            new[i] = new[i] + (stim[i - self.k:i, None] @ self.Bs).sum(axis=0)
        return new[-n_steps:]


In [None]:
class StimRegressor(DecoupledTransformer):
    def __init__(self, input_streams=None, spatial_stim_response=True, *args, **kwargs):
        input_streams = input_streams or {0:'X', 1:'stim'}
        super().__init__(input_streams=input_streams, *args, **kwargs)
        self.reg = BaseVanillaOnlineRegressor()
        self.stim_reg = BaseNearestNeighborRegressor()
        self.spatial_stim_response = spatial_stim_response
        self.last_seen = None
        self.last_seen_stim = None
        self.predictions = []
        self.auto_pred = []
        self.stim_pred = []

    def _partial_fit(self, data, stream):
        if self.input_streams[stream] == 'X':
            if self.last_seen is not None:
                auto_pred = ArrayWithTime(self.reg.predict(self.last_seen), data.t)
                
                stim_pred = np.zeros(shape=data.shape)
                if self.last_seen_stim.any() and self.spatial_stim_response:
                    stim_pred += self.stim_reg.predict(self.last_seen)
                stim_pred = ArrayWithTime(stim_pred, data.t)
                
                prediction = auto_pred + stim_pred
                self.predictions.append(prediction)
                self.auto_pred.append(auto_pred)
                self.stim_pred.append(stim_pred)
                
                if not self.last_seen_stim.any():
                    self.reg.observe(self.last_seen, data)
                else:
                    prediction = self.reg.predict(self.last_seen)
                    self.stim_reg.observe(self.last_seen, data-prediction)

                    prediction = ArrayWithTime(prediction, data.t)
                
            self.last_seen = data

        if self.input_streams[stream] == 'stim':
            self.last_seen_stim = data
        

    def transform(self, data, stream=0, return_output_stream=False):
        return data, stream
        

In [None]:
%matplotlib qt
dt = 0.05

A = np.array([
    [np.cos(dt),  -np.sin(dt), 0],
    [np.sin(dt),   np.cos(dt), 0],
    [         0,            0, .99]
])

def C(x,y):
    return y * x[0] / np.linalg.norm(x[:2])

ts = np.arange(0, 500*2.1*np.pi, dt)

stim = (ts * 0).reshape(-1, 1)
stim[rng.choice(stim.shape[0], size=(ts.max()*.1).astype(int), replace=False)] = 1

X_true = np.zeros((ts.size, 3))
X_true[0,1] = 3

for i, t in enumerate(ts):
    if i == 0:
        continue
    X_true[i] = A @ X_true[i-1]
    X_true[i,2] += C(X_true[i], stim[i])
    X_true[i] += rng.normal(0, 0.01, X_true[i].shape)

X = X_true + rng.normal(0, 0.01, X_true.shape)

X = ArrayWithTime(X, ts)
stim = ArrayWithTime(stim, ts-1e-8)

fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.plot(X[:,0], X[:,1], X[:,2])
ax.axis('equal')
plt.show()


In [None]:

s = StimRegressor(input_streams={0:'X', 1:'stim'}, spatial_stim_response=True)
s.offline_run_on([X, stim], show_tqdm=False);

In [None]:

X = np.array([
    [0,0], 
    [1,1], 
    [2,2], 
    [3,3], 
    [4,4], 
    [5,7], # stim
    [6,8], 
    [7,11], # stim
    [8,12], 
    [9,15], # stim
    [10,16],
])[:,1:2]
stim = np.zeros(shape=(X.shape[0], 1))
stim[5] = 1
stim[7] = 1
stim[9] = 1


X = al.ArrayWithTime(X, np.arange(X.shape[0]))
stim = al.ArrayWithTime(stim, np.arange(stim.shape[0]) - 0.001)

s = StimRegressor(input_streams={1:'X', 0:'stim'}, spatial_stim_response=True)
s.offline_run_on([stim, X]);

In [None]:
def operation(x):
    return al.ArrayWithTime.from_list(x, drop_early_nans=True, squeeze_type='to_2d')[:, slice(-1,None)]

pred = operation(s.predictions)
p1 = operation(s.auto_pred)
p2 = operation(s.stim_pred[1:])
true = operation(X.slice(-pred.shape[0],None))
st = stim.slice(2,None)

In [None]:

df = pd.DataFrame(np.hstack([st, p1, p2, pred, true]))
d = true.shape[-1]
df.columns = ['stim'] + ['p1']*d + ['p2']*d + ['pred']*d + ['true']*d
df

In [None]:
fig, axs = plt.subplots(squeeze=False)
axs[0,0].plot(X.t, X[:,-1], '.-')
# axs[0,0].plot(stim.t, stim,'.')
axs[0,0].plot(pred.t, pred, '.-')



In [None]:
fig, axs = plt.subplots(squeeze=False)
error = pred - X[-pred.shape[0]]
axs[0,0].plot(error.t, error[:,-1], '.-')



In [None]:
pred.t, X.t

In [None]:
fig, axs = plt.subplots(ncols=3)

axs[0].matshow(ar.As[0])
axs[1].matshow(s.reg.get_beta())
axs[2].matshow(A.T)
