In [None]:
import adaptive_latents as al
import numpy as np
import warnings
import scipy.signal as signal
import matplotlib.pyplot as plt
import naumann_utility_functions as nuf

rng = np.random.default_rng()

In [None]:
from dataclasses import dataclass
from typing import Literal

@dataclass
class Options:
    sub_dataset: Literal[1,2]
    n_neurons: int
    shuffle: set
    highpass: Literal[None, 'centering', 'filter']
    lowpass: Literal[None, 'smoothing', 'filter']
    zscore: Literal[None, 'streaming' , 'batch']
    space_train_on: Literal['both', 'vis', 'opto']
    zscore_per: Literal['whole', 'half'] = 'whole'
    causal:bool = False
    offline:bool = True
    high_f:float = 1/2
    low_f:float = 1/50


options = Options(
    sub_dataset=2,
    n_neurons=150,
    shuffle=set(),
    highpass=None,
    lowpass='filter',
    zscore=None,
    zscore_per = 'whole',
    space_train_on='both',
)

In [None]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    d = al.datasets.Naumann24uDataset(sub_dataset_identifier=options.sub_dataset)

In [None]:
def get_rectangular_block(neural_data, n_neurons=150):
    # type: (al.ArrayWithTime, int) -> al.ArrayWithTime
    cutoff1 = np.nonzero(np.nancumsum(neural_data[:,n_neurons]) > 0)[0][0]
    cutoff2 = np.nonzero(np.nancumsum(neural_data[cutoff1,::-1]))[0][0]
    neural_data = neural_data.slice(cutoff1, -1)[:,:-cutoff2]
    additional_cutoff_info = np.where(np.isnan(neural_data).any(axis=1))[0]
    if additional_cutoff_info.size > 0:
        cutoff3 = additional_cutoff_info[-1] + 1
        neural_data = neural_data.slice(cutoff3, -1)
    assert not np.isnan(neural_data).any()
    return neural_data.copy()

neural_data = get_rectangular_block(d.neural_data, options.n_neurons)


In [None]:
if isinstance(options.shuffle, str):
    options.shuffle = {options.shuffle}
    
if 'row' in options.shuffle:
    rng.shuffle(neural_data, axis=1)

if 'column' in options.shuffle:
    rng.shuffle(neural_data, axis=0)


In [None]:

filt = signal.filtfilt if options.causal else signal.lfilter
old_t = neural_data.t


In [None]:
if options.highpass == 'centering':
    centerer = al.CenteringTransformer(init_size=50)
    if options.offline:
        neural_data = centerer.offline_fit_then_transform(neural_data)
    else:
        neural_data = centerer.offline_run_on(neural_data)
        
elif options.highpass == 'filter' and not options.lowpass == 'filter':
    
    b, a = signal.butter(N=10, Wn=options.low_f, fs=1/neural_data.dt, btype='high', output='ba')
    neural_data = filt(b, a, neural_data, axis=0)


In [None]:
if options.lowpass == 'smoothing':
    smoother = al.KernelSmoother(tau=2)
    neural_data = smoother.offline_run_on(neural_data)
    
elif options.lowpass == 'filter' and not options.highpass == 'filter':
    b, a = signal.butter(N=10, Wn=options.high_f, fs=1/neural_data.dt, btype='low', output='ba')
    neural_data = filt(b, a, neural_data, axis=0)


In [None]:
if options.highpass == options.lowpass == 'filter':
    b, a = signal.butter(N=10, Wn=[options.low_f, options.high_f], fs=1/neural_data.dt, btype='band', output='ba')
    neural_data = filt(b, a, neural_data, axis=0)

In [None]:
neural_data = al.ArrayWithTime(neural_data, old_t)

In [None]:
to_zscore = []
if options.zscore_per == 'whole':
    to_zscore = [neural_data]
elif options.zscore_per == 'half':
    to_zscore = [
        neural_data.slice_by_time(None, d.end_of_visual_period_time),
        neural_data.slice_by_time(d.end_of_visual_period_time, None)
    ]
else:
    raise ValueError()

zscored = []
for to_score in to_zscore:
    if options.zscore == 'streaming':
        zscorer = al.ZScoringTransformer()
        if options.offline:
            scored = zscorer.offline_fit_then_transform(to_score)
        else:
            scored = zscorer.offline_run_on(to_score)
    elif options.zscore == 'batch':
        to_score = to_score - to_score.mean(axis=0)
        scored = to_score / to_score.std(axis=0, ddof=1)
    elif options.zscore is None:
        scored = to_score
    else:
        raise ValueError()
    zscored.append(scored)

neural_data = al.ArrayWithTime(np.vstack(zscored), np.hstack([np.array(x.t[-x.shape[0]:]).flatten() for x in zscored]))


In [None]:
assert options.offline is True or options.space_train_on == 'both'

pro = al.proSVD(k=3)

if options.offline:
    if options.space_train_on == 'both':
        train_data = neural_data
    elif options.space_train_on == 'vis':
        train_data = neural_data.slice_by_time(None, d.end_of_visual_period_time)
    elif options.space_train_on == 'opto':
        train_data = neural_data.slice_by_time(d.end_of_visual_period_time, None)
    else:
        raise ValueError()

    pro.offline_run_on(train_data)
    pro.freeze()
    latents = pro.offline_run_on(neural_data)
else:
    latents = pro.offline_run_on(neural_data)
    


In [None]:
%matplotlib qt

def plot_for_times(data):
    fig, axs = plt.subplots(figsize=(15, 10), nrows=3)
    for ax in axs:
        ax.plot(data.t, data.as_array())
        for t in d.visual_stimuli.time:
            ax.axvline(t, color='k')

        for t in d.opto_stimulations.time:
            ax.axvline(t, color='k', linestyle='--')

    xrange = np.array([0,100])
    axs[0].set_xlim(d.end_of_visual_period_time /2 + xrange)
    axs[1].set_xlim(d.end_of_visual_period_time + xrange - 50)
    axs[2].set_xlim((data.t[-1] + d.end_of_visual_period_time)/2 + xrange)
    return fig, axs

plot_for_times(neural_data);


In [None]:
plot_for_times(latents);


In [None]:
%matplotlib qt
fig, axs = plt.subplots(subplot_kw={'projection': '3d'}, ncols=2, figsize=(15, 10))

vis_period = latents.slice_by_time(None, d.end_of_visual_period_time)
opto_period = latents.slice_by_time(d.end_of_visual_period_time, None)

axs[0].plot(vis_period[:,0], vis_period[:,1], vis_period[:,2], color='C0')
axs[0].plot(opto_period[:,0], opto_period[:,1], opto_period[:,2], color='C0')
axs[0].axis('equal')



period = np.array([3, 9])
for t in d.opto_stimulations.time:
    sub = latents.slice_by_time(*(t + period))
    axs[1].plot(sub[:,0], sub[:,1], sub[:,2], color='C1')
    
period = np.array([7, 12])
for t in d.visual_stimuli.time:
    sub = latents.slice_by_time(*(t + period))
    axs[1].plot(sub[:,0], sub[:,1], sub[:,2], color='C2')


axs[0].shareview(axs[1])
