In [None]:
import numpy as np
import matplotlib.pyplot as plt
import warnings
import adaptive_latents as al
import scipy.signal as signal


rng = np.random.default_rng()

In [None]:
from dataclasses import dataclass
from typing import Literal

@dataclass
class Options:
    sub_dataset: Literal[1,2]
    n_neurons: int

options = Options(
    sub_dataset=1,
    n_neurons=100,
)

In [None]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    d = al.datasets.Naumann24uDataset(sub_dataset_identifier=options.sub_dataset)

In [None]:
def get_rectangular_block(neural_data, n_neurons=150):
    # type: (al.ArrayWithTime, int) -> al.ArrayWithTime
    cutoff1 = np.nonzero(np.nancumsum(neural_data[:,n_neurons]) > 0)[0][0]
    cutoff2 = np.nonzero(np.nancumsum(neural_data[cutoff1,::-1]))[0][0]
    neural_data = neural_data.slice(cutoff1, -1)[:,:-cutoff2]
    additional_cutoff_info = np.where(np.isnan(neural_data).any(axis=1))[0]
    if additional_cutoff_info.size > 0:
        cutoff3 = additional_cutoff_info[-1] + 1
        neural_data = neural_data.slice(cutoff3, -1)
    assert not np.isnan(neural_data).any()
    return neural_data.copy()


In [None]:
neural_data = get_rectangular_block(d.neural_data, options.n_neurons)


b, a = signal.butter(N=10, Wn=1/2, fs=1/neural_data.dt, btype='low', output='ba')
neural_data = signal.filtfilt(b, a, neural_data, axis=0)


pro = al.proSVD(k=10)
neural_data = pro.offline_fit_then_transform(neural_data)


In [None]:
def predict_closest(query, training_data, offset=1):
    training_data = np.vstack(training_data)
    idx = np.argmin(np.linalg.norm(training_data - query, axis=1))
    return training_data[min(idx+offset, len(training_data)-1)]

def predict_closest_with_offset(query, training_data, offset=1):
    training_data = np.vstack(training_data)
    idx = np.argmin(np.linalg.norm(training_data - query, axis=1))
    delta = training_data[min(idx+offset, len(training_data)-1)] - training_data[idx]
    return query + delta

def predict_query(query, training_data, offset=1):
    return query

def predict_random(query, training_data, offset=1):
    training_data = np.vstack(training_data)
    return rng.choice(training_data)




In [None]:
def evaluate(p, data, folds=5, offset=5):
    edges = np.floor(np.linspace(0, neural_data.shape[0], folds+1)).astype(int)

    mses = []
    for start, end in zip(edges[:-1], edges[1:]):
        to_predict = data[start:end]
        training_data = [data[:start], data[end:]]
        
        predictions = []
        for i in np.arange(to_predict.shape[0]):
            prediction = p(query=to_predict[i], training_data=training_data, offset=offset)
            predictions.append(prediction)
        mses.append(float(((np.array(predictions[offset:]) - to_predict[:-offset])**2).mean()))
    return mses


fig, ax = plt.subplots()
for k, v in {'identity': predict_query, 'close':predict_closest, 'close plus offset':predict_closest_with_offset, 'rand':predict_random}.items():
    ax.plot(evaluate(v, neural_data, folds=10, offset=39), label=k)
ax.legend()

