In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
from utils.multisession_utils import align_sessions
from utils.config_utils import get_config
from utils.plot.pcs import plot_pcs
import numpy as np
import wandb
from utils.t5_utils import load_toolkit_datasets, get_trialized_data

In [None]:
config = get_config()

## Start Time

In [None]:
# wandb.init(project='plots', name='Alignment Move Onset Testing')

In [None]:
config.defrost()
config.data.ol_align_field = 'start_time'
config.data.ol_align_range = [0, 2500]
config.freeze()

In [None]:
alignment_matrices, alignment_biases = align_sessions(config)

In [None]:
datasets = load_toolkit_datasets(config)

In [None]:
config.defrost()
config.data.ol_align_field = 'start_time'
config.data.ol_align_range = [-0, 2500]
config.data.cl_align_field = 'start_time'
config.data.cl_align_range = [0, 2000]
config.freeze()

In [None]:
trialized_data = get_trialized_data(config, datasets)

In [None]:
ol_cond_avg = ([], []) 
ol_single_trial = ([], [])
cl_single_trial = ([], [])

trial_len = (config.data.ol_align_range[1] - config.data.ol_align_range[0]) / config.data.bin_size

for idx, session in enumerate(config.data.sessions):    
    for cond_id, trials in trialized_data[session]['ol_trial_data'].groupby('condition'):
        if cond_id != 0:
            low_d_trials = []
            for trial_id, trial in trials.groupby('trial_id'):
                heldin_spikes = trial.spikes_smth.to_numpy()[:, datasets[session].heldin_channels]
                if heldin_spikes.shape[0] == trial_len:
                    low_d_trial = np.dot(heldin_spikes, alignment_matrices[idx].T)
                    low_d_trial = low_d_trial + np.array(alignment_biases[idx])
                    low_d_trials.append(low_d_trial)

            ol_single_trial[0].append(np.concatenate(low_d_trials, 0))
            ol_single_trial[1].append(cond_id)

            ol_cond_avg[0].append(np.array(low_d_trials).mean(0))
            ol_cond_avg[1].append(cond_id)

    for cond_id, trials in trialized_data[session]['cl_trial_data'].groupby('condition'):
        if cond_id != 0:
            low_d_trials = []
            for trial_id, trial in trials.groupby('trial_id'):
                if trial.shape[0] > 45 and trial.shape[0] < 400:
                    heldin_spikes = trial.spikes_smth.to_numpy()[45:1795, datasets[session].heldin_channels]
                    
                    low_d_trial = np.dot(heldin_spikes, alignment_matrices[idx].T)
                    low_d_trial = low_d_trial + np.array(alignment_biases[idx])
                    low_d_trials.append(low_d_trial)

            cl_single_trial[0].append(np.concatenate(low_d_trials, 0))
            cl_single_trial[1].append(cond_id)

In [None]:
fig = plot_pcs(*ol_cond_avg, 'OL Condition Averaged', return_fig=True)
fig.show()

In [None]:
fig = plot_pcs(*ol_single_trial, 'OL Single Trial', return_fig=True)
fig.show()

In [None]:
# fig = plot_pca(*cl_single_trial, 'CL Single Trial', return_fig=True)
# fig.show()

## Speed Onset

In [None]:
config.defrost()
config.data.ol_align_field = 'speed_onset'
config.data.ol_align_range = [-700, 1250]
config.freeze()

In [None]:
so_alignment_matrices, so_alignment_biases = align_sessions(config)

In [None]:
config.defrost()
config.data.ol_align_field = 'speed_onset'
config.data.ol_align_range = [-700, 1250]
config.data.cl_align_field = 'start_time'
config.data.cl_align_range = [500, 2000]
config.freeze()

In [None]:
trialized_data = get_trialized_data(config, datasets)

In [None]:
ol_cond_avg = ([], []) 
ol_single_trial = ([], [])
cl_single_trial = ([], [])

trial_len = (config.data.ol_align_range[1] - config.data.ol_align_range[0]) / config.data.bin_size
cl_trial_len = (config.data.cl_align_range[1] - config.data.cl_align_range[0]) / config.data.bin_size

for idx, session in enumerate(config.data.sessions):    
    for cond_id, trials in trialized_data[session]['ol_trial_data'].groupby('condition'):
        if cond_id != 0:
        # if cond_id == 1:
            low_d_trials = []
            for trial_id, trial in trials.groupby('trial_id'):
                heldin_spikes = trial.spikes_smth.to_numpy()[:, datasets[session].heldin_channels]
                if heldin_spikes.shape[0] == trial_len:
                    low_d_trial = np.dot(heldin_spikes[75:], so_alignment_matrices[idx].T)
                    low_d_trial = low_d_trial + np.array(so_alignment_biases[idx])
                    low_d_trials.append(low_d_trial)

            ol_single_trial[0].append(np.concatenate(low_d_trials, 0))
            ol_single_trial[1].append(cond_id)

            ol_cond_avg[0].append(np.array(low_d_trials).mean(0))
            ol_cond_avg[1].append(cond_id)

    for cond_id, trials in trialized_data[session]['cl_trial_data'].groupby('condition'):
        if cond_id != 0:
            low_d_trials = []
            for trial_id, trial in trials.groupby('trial_id'):
                # if trial.shape[0] > 45 and trial.shape[0] < 400:
                heldin_spikes = trial.spikes_smth.to_numpy()[:, datasets[session].heldin_channels]
                # if heldin_spikes.shape[0] == trial_len:
                
                low_d_trial = np.dot(heldin_spikes, so_alignment_matrices[idx].T)
                low_d_trial = low_d_trial + np.array(so_alignment_biases[idx])
                low_d_trials.append(low_d_trial)

            cl_single_trial[0].append(np.concatenate(low_d_trials, 0))
            cl_single_trial[1].append(cond_id)

In [None]:
fig = plot_pcs(*ol_cond_avg, 'OL Condition Averaged', return_fig=True)
fig.show()

In [None]:
fig = plot_pcs(*ol_single_trial, 'OL Single Trial', return_fig=True)
fig.show()

In [None]:
# fig = plot_pca(*cl_single_trial, 'CL Single Trial', return_fig=True)
# fig.show()