In [None]:
import numpy as np
import adaptive_latents as al
import warnings
import matplotlib.pyplot as plt
import scipy.signal as signal

rng = np.random.default_rng()

In [None]:
# load dataset
sub_dataset_identifier = 2


with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)
    d = al.datasets.Naumann24uDataset(sub_dataset_identifier=sub_dataset_identifier)

In [None]:
# extract a rectangular block

def get_rectangular_block(neural_data, n_neurons=150):
    # type: (al.ArrayWithTime, int) -> al.ArrayWithTime
    cutoff1 = np.nonzero(np.nancumsum(neural_data[:,n_neurons]) > 0)[0][0]
    cutoff2 = np.nonzero(np.nancumsum(neural_data[cutoff1,::-1]))[0][0]
    neural_data = neural_data.slice(cutoff1, -1)[:,:-cutoff2]
    cutoff3 = np.where(np.isnan(neural_data).any(axis=1))[0][-1] + 1
    neural_data = neural_data.slice(cutoff3, -1)
    assert not np.isnan(neural_data).any()
    return neural_data.copy()

neural_data = get_rectangular_block(d.neural_data, 100)

## visual psth

In [None]:
psths = []
for s in d.visual_stimuli.loc[:,'sample']:
    psths.append(d.neural_data[s-50:s+50, :neural_data.shape[1]])
psths = np.array(psths)

In [None]:
responses = psths - np.quantile(neural_data, 0.05, axis=0)

plt.matshow(responses[6])

In [None]:
plt.plot(np.quantile(neural_data.as_array(), 0.05, axis=0))

In [None]:
%matplotlib qt
fig, ax = plt.subplots()

custom_kernel = np.zeros(10)

# smoothed = al.Pipeline([al.KernelSmoother(tau=6), al.KernelSmoother(custom_kernel=)]).offline_run_on(neural_data)
# ax.plot(smoothed.t, smoothed[:,10].as_array());
ax.plot(neural_data.t, neural_data.as_array())

for t in d.visual_stimuli.time:
    ax.axvline(t, color='k')


In [None]:

def plot_single_group_responses(group_n, responses_by_group):
    single_group_responses = responses_by_group[group_n]
    target_neuron = target_neurons[group_n]

    fig, axs = plt.subplots(ncols=max(single_group_responses.shape[0],2), nrows=2,  figsize=(10, 5), sharey='row')
    gs = axs[1,0].get_gridspec()
    for ax in axs[1, :4]:
        ax.remove()
    axs[1,0] = fig.add_subplot(gs[1,:2])

    stimuli_in_previous_groups = sum([r.shape[0] for r in responses_by_group[:group_n]])
    for i, ax in enumerate(axs[0,:]):
        stimulus_number = stimuli_in_previous_groups + i
        ax.plot(single_group_responses[i,:,:])
        ax.set_xlabel('samples from stim')
        ax.set_title(f'stim {stimulus_number}, neuron = {d.opto_stimulations.loc[stimulus_number,"target_neuron"]}', fontsize='small')

    mean_responses = single_group_responses.mean(axis=0)
    axs[1,0].plot(mean_responses)

    axs[0,0].set_ylabel('response magnitude (a.u.)')

    axs[1,0].set_ylabel('response magnitude (a.u.)')
    axs[1,0].set_title(f'average response for group {group_n}')
    axs[1,0].set_xlabel('samples from stim')

    sizes = np.mean(mean_responses, axis=0)
    sizes = np.abs(sizes / 7.3) * 15
    sizes[sizes < 5]  = np.nan
    plot_per_neuron(axs[1,4], sizes, d)
    axs[1,4].scatter(d.neuron_df.loc[target_neuron, 'x'], d.neuron_df.loc[target_neuron, 'y'], s=10, color='blue')
    fig.tight_layout()

    print(f"Neuron with the highest average peak: {np.unravel_index(np.nanargmax(mean_responses), mean_responses.shape)[1]}")


## make latents

In [None]:
latents = neural_data

In [None]:

b,a = signal.butter(N=10, Wn=[1/20, 1/2], fs=1/neural_data.dt, btype='band', output='ba')
latents=signal.filtfilt(b,a, latents, axis=0)


In [None]:
pro = None
jpca = None
centerer = None
p = al.Pipeline([
    # al.KernelSmoother(tau=2),
    # centerer:=al.CenteringTransformer(init_size=50),
    pro:=al.proSVD(k=3),
    # ica:=al.mmICA(init_size=200),
    # jpca:=al.sjPCA()
])


online_output = p.offline_run_on(latents)

# freeze the pipeline
for i, step in enumerate(p.steps):
    if isinstance(step, al.KernelSmoother):
        p.steps[i] = step.blank_copy()
    else:
        step.freeze()
        
offline_output = p.offline_run_on(latents)


latents = [online_output, offline_output][1]

In [None]:
plt.figure()
plt.plot(pro.Q)

In [None]:
%matplotlib qt
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.plot(latents[:,0], latents[:,1], latents[:,2])

In [None]:
fig, ax = plt.subplots()
ax.matshow(neural_data.T)

In [None]:
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
ax.plot(latents[:,0], latents[:,1], latents[:,2])

i = 0
for _, stim_t in enumerate(d.visual_stimuli.time[:-1]):
    if stim_t > latents.t.min():
        stim = np.nonzero(latents.t > stim_t)[0][0]
        s = slice(stim +15,  stim + 20)
        response = latents[s] - latents[s.start]
        response = latents[s] 
        ax.plot(response[:, 0], response[:, 1], response[:, 2], 'C2')
        # ax.text(o[s.start,0], o[s.start,1], o[s.start,2], s=d.opto_stimulations.loc[i,'stim_name'])
        i += 1

for i, stim_t in enumerate(d.opto_stimulations.time):
    stim = np.nonzero(latents.t > stim_t)[0][0]
    s = slice(stim + 13, stim + 15)
    response = latents[s] - latents[s.start]
    response = latents[s] 
    ax.plot(response[:, 0], response[:, 1], response[:, 2], 'C1')
    # ax.scatter(o[s.start, 0], o[s.start, 1], o[s.start, 2], color='k', s=2)
    # ax.text(o[s.start,0], o[s.start,1], o[s.start,2], s=d.opto_stimulations.loc[i,'stim_name'])


ax.scatter(0,0,0, color='k')


# arrow = np.vstack([np.zeros(3), np.ones(151) @ pro.Q]).T
# ax.plot(arrow[0], arrow[1], arrow[2], color='C6', label='1s vector')


In [None]:
%matplotlib qt
fig, ax = plt.subplots(figsize=(18, 5))

# N is order
# sos = butter(N=5, Wn=[1/20, 1/2], fs=1/neural_data.dt, btype='band', output='sos')
# sos = butter(N=5, Wn=1/2, fs=1/neural_data.dt, btype='low', output='sos')
# filtered = sosfiltfilt(sos, neural_data, axis=0)

# b,a = butter(N=5, Wn=1/5, fs=1/neural_data.dt, btype='low', output='ba')
# filtered=lfilter(b,a, neural_data, axis=0)

b,a = signal.butter(N=5, Wn=[1/20, 1/5], fs=1/neural_data.dt, btype='band', output='ba')
# filtered=lfilter(b,a, neural_data, axis=0)
filtered=signal.filtfilt(b,a, neural_data, axis=0)

# ax.plot(filtered)
ax.plot(neural_data.t,filtered[:,:])
# ax.plot(neural_data.t, neural_data[:,83])
# plt.plot(neural_data.t, neural_data.as_array())

for stim_s in d.visual_stimuli.loc[:,'time']:
    ax.axvline(stim_s, color='k')

for stim_s in d.opto_stimulations.loc[:,'time']:
    ax.axvline(stim_s, color='k')

ax.set_xlim([500,800])

In [None]:
fig, ax = plt.subplots()
ax.matshow(filtered)