In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import sys

import holoviews as hv
hv.extension('bokeh')

import pandas as pd
import matplotlib.pyplot as plt


sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import paths
import processing_parameters
import functions_bondjango as bd
import functions_misc as fmisc
import functions_matching as fm
import functions_data_handling as fdh
import functions_tuning as tuning
import functions_plotting as fp
import functions_kinematic as fk
from wirefree_experiment import WirefreeExperiment, DataContainer
from functions_wirefree_trigger_fix import get_trial_duration_stats

# fig_path = paths.wf_figures_path
data_path = r"D:\thesis\WF_Figures\full\repeat_normal_VTuningWF"
fig_path = r"D:\thesis\figures"

session_shorthand = ['session1', 'session2']    # ['session1', 'session2'], ['free', 'fixed']
sort = 'slug'   # 'rig', 'slug'

In [None]:
def drop_partial_or_long_trials(df, min_trial_length=4.5, max_trial_length=5.5):
    """
    This function drops trials that are shorter than min_trial_length (partial trials) and
    trials that are longer than max_trial_length (errors in trial number indexing) from the dataframe.

    Parameters:
    df (DataFrame): The dataframe containing the trials.
    min_trial_length (float): The minimum length for a trial. Defaults to 4.5.
    max_trial_length (float): The maximum length for a trial. Defaults to 5.5.

    Returns:
    DataFrame: The dataframe after dropping the partial and long trials.
    """

    trial_lengths = df[df.trial_num > 0].groupby('trial_num').apply(lambda x: x.shape[0] / processing_parameters.wf_frame_rate)

    # Drop trials that are shorter than min_trial_length (partial trials)
    short_trials = trial_lengths[trial_lengths < min_trial_length].index
    df = df.drop(df[df.trial_num.isin(short_trials)].index)

    # Drop trials that are longer than max_trial_length (errors in trial number indexing)
    long_trials = trial_lengths[trial_lengths > max_trial_length].index
    df = df.drop(df[df.trial_num.isin(long_trials)].index).reset_index(drop=True)

    return df   


def filter_viewed_trials(kinematics, activity_df):

    # Filter trials by head pitch if freely moving
    pitch_lower_cutoff = processing_parameters.head_pitch_cutoff[0]
    pitch_upper_cutoff = processing_parameters.head_pitch_cutoff[1]
    view_fraction = processing_parameters.view_fraction
    kinematics['viewed'] = np.logical_and(kinematics['head_pitch'].to_numpy() >= pitch_lower_cutoff,
                                          kinematics['head_pitch'].to_numpy() <= pitch_upper_cutoff)
    viewed_trials = kinematics.groupby('trial_num').filter(
        lambda x: (x['viewed'].sum() / len(x['viewed'])) > view_fraction).trial_num.unique()

    viewed_activity_df = activity_df.loc[activity_df.trial_num.isin(viewed_trials)].copy()
    return viewed_activity_df


def parse_trial_frames(df, pre_trial=0, post_trial=0):
    trial_idx_frames = df[df.trial_num >= 1.0].groupby(['trial_num']).apply(
        lambda x: [x.index[0] - (pre_trial * processing_parameters.wf_frame_rate), 
                   x.index[0], x.index[-1], 
                   x.index[-1] + (post_trial * processing_parameters.wf_frame_rate) + 1]
        ).to_numpy()
    trial_idx_frames = np.vstack(trial_idx_frames)

    if trial_idx_frames[0, 0] < 0:
        trial_idx_frames[0, 0] = 0

    if trial_idx_frames[-1, -1] > df.index[-1]:
        trial_idx_frames[-1, -1] = df.index[-1]

    # Get the shifts from the zero point (important for plotting)
    max_zero_idx_shift = np.max(trial_idx_frames[:, 1] - trial_idx_frames[:, 0])
    
    traces = []
    for i, frame in enumerate(trial_idx_frames):
        df_slice = df.iloc[frame[0]:frame[-1], :].copy()
        df_slice['trial_num'] = df_slice.loc[frame[1], 'trial_num']
        df_slice['direction'] = df_slice.loc[frame[1], 'direction']
        df_slice['direction_wrapped'] = df_slice.loc[frame[1], 'direction_wrapped']
        df_slice['orientation'] = df_slice.loc[frame[1], 'orientation']
        zero_idx_shift = np.abs((frame[1] - frame[0]) - max_zero_idx_shift)
        df_slice['zero_idx_shift'] = zero_idx_shift

        traces.append(df_slice)
    
    traces = pd.concat(traces, axis=0).reset_index(drop=True)
    return traces, trial_idx_frames


def trial_average_response(ds, cells_to_match, stim_type):
    
    ds.dropna(inplace=True)
    idxs_shifts = ds.groupby(['trial_num']).apply(lambda x: np.unique(x.zero_idx_shift)[0]).reset_index()
    idxs_shifts = idxs_shifts.rename({0: 'zero_idx_shift'}, axis=1)	

    if stim_type in ['orientation', 'direction', 'direction_wrapped']:
        trials_per_stim = ds.groupby([stim_type, 'trial_num'])[cells_to_match].agg(list).reset_index()
        trials_per_stim = trials_per_stim.join(idxs_shifts.set_index('trial_num'), on='trial_num')
        trials_per_stim = trials_per_stim.groupby([stim_type]).agg(list)

        idxs_shifts = trials_per_stim['zero_idx_shift'].copy()
        trial_array = trials_per_stim.copy()

        for i, row in trials_per_stim.iterrows():
            for cell in cells_to_match:
                shifts = list(idxs_shifts.loc[row.name])
                trial_array.loc[i, cell] = fmisc.list_lists_to_array(row[cell], alignment='left')

        trial_averages = trial_array.applymap(np.nanmean, axis=0)
        trial_averages = trial_averages.drop('zero_idx_shift', axis=1)
        trial_array = trial_array.drop('zero_idx_shift', axis=1)  
        
    elif stim_type == 'vis':
        trials_per_stim = ds.groupby('trial_num')[cells_to_match].agg(list)
        trials_per_stim = trials_per_stim.join(idxs_shifts.set_index('trial_num'), on='trial_num')

        idxs_shifts = trials_per_stim['zero_idx_shift'].copy()
        trial_averages = trials_per_stim.iloc[0, :].copy()
        trial_array = trials_per_stim.iloc[0, :].copy()

        for cell in cells_to_match:
            trials_agg = fmisc.list_lists_to_array(trials_per_stim[cell].to_list(), alignment='left')
            trial_array[cell] = trials_agg
            trial_averages[cell] = np.nanmean(trials_agg, axis=0) 

        trial_averages = trial_averages.drop('zero_idx_shift')
        trial_array = trial_array.drop('zero_idx_shift')  
        
    else:
        raise Exception('Invalid stim_type')
    

    return trial_averages, trial_array


def hv_plot_vis_trial_averages(trial_averages, trials, cells, stim_type):

    plt_list = []

    for i, cell in enumerate(cells):

        cell_resps = trials[cell]
        cell_mean = trial_averages[cell]
        
        if stim_type in ['orientation', 'direction']:
            
            for k in range(cell_resps.shape[-1]):
                resps = cell_resps.iloc[k]
                mean = cell_mean.iloc[k]
                
                trials_list = [hv.Curve(resps[r, :]).opts(color='k', alpha=0.25, line_width=0.75) for r in np.arange(resps.shape[0])]
                mean_list = [hv.Curve(mean).opts(color='r', xlabel='', ylabel='')]
                
                if i == 0:
                    [mean_plot.opts(title=f"{cell_resps.index[k]:.1f}") for mean_plot in mean_list]

                plt_list.append(hv.Overlay(trials_list + mean_list))
                
        elif stim_type == 'vis':
            resps = cell_resps
            mean = cell_mean

            trials_list = [hv.Curve(resps[r, :]).opts(color='k', alpha=0.25, line_width=0.75) for r in np.arange(resps.shape[0])]
            mean_list = [hv.Curve(mean).opts(color='r', xlabel='', ylabel='')]
            plt_list.append(hv.Overlay(trials_list + mean_list))

        else:
            raise Exception('Invalid stim_type')
            
    return plt_list

In [None]:
with pd.HDFStore(os.path.join(data_path, 'stats.hdf5'), 'r') as f:
    ref = f['ref_cells_all_matches_both_vis_resp'][:]
    comp = f['comp_cells_all_matches_both_vis_resp'][:]

ref_sub = ref[['mouse', 'day', 'cell', 'is_vis_resp', 'vis_resp_pval', 'fit_osi', 'pref_ori', 'fit_dsi', 'pref_dir']]
comp_sub = comp[['mouse', 'day', 'cell', 'is_vis_resp', 'vis_resp_pval', 'fit_osi', 'pref_ori', 'fit_dsi', 'pref_dir']]

df = ref_sub.join(comp_sub, lsuffix=f'_{session_shorthand[0]}', rsuffix=f'_{session_shorthand[1]}')
df.drop(columns=[f'mouse_{session_shorthand[0]}', f'day_{session_shorthand[0]}'], inplace=True)
df.rename(columns={f'mouse_{session_shorthand[1]}': 'mouse', f'day_{session_shorthand[1]}': 'day'}, inplace=True)
df = df.sort_values(by=['mouse', 'day']).reset_index(drop=True)

slugs = df.groupby(['mouse', 'day'])[[f"cell_{session_shorthand[0]}", f"cell_{session_shorthand[1]}"]].agg(list).reset_index()

In [None]:
df

In [None]:
# Load and process data
preproc_df_list_fixed = []
tc_df_list_fixed = []

preproc_df_list_free = []
tc_df_list_free = []

preproc_cols = ['trial_num', 'time_vector', 'direction', 'direction_wrapped', 'orientation', 'grating_phase']

for i, row in slugs.iterrows():
    search_string = f'mouse:{row.mouse}, slug:{row.day}' 
    parsed_search = fdh.parse_search_string(search_string)

    # get the paths from the database

    # get the raw experiment
    exp_query = bd.query_database('vr_experiment', search_string)
    exp_query.sort(key=lambda x: x[sort])

    # Get the preprocessing file
    preproc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:preprocessing')
    preproc_query = [q for q in preproc_query if parsed_search['mouse'].lower() in q['slug']]
    preproc_query.sort(key=lambda x: x[sort])
    preproc_paths = np.array([el['analysis_path'] for el in preproc_query if (el['analysis_type'] == 'preprocessing') and
                            (parsed_search['mouse'].lower() in el['slug'])])

    # Get the tuning curve file
    tc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:tc_analysis')
    tc_query = [q for q in tc_query if parsed_search['mouse'].lower() in q['slug']]
    tc_query.sort(key=lambda x: x[sort])
    tc_paths = np.array([el['analysis_path'] for el in tc_query if (el['analysis_type'] == 'tc_analysis') and
                        (parsed_search['mouse'].lower() in el['slug'])])

    # Parse the rigs
    rigs = np.array([os.path.basename(file).split('_')[6] for file in preproc_paths])

    # Load and save freely moving data
    exp_free = WirefreeExperiment(exp_info=exp_query[0], preproc_info=preproc_query[0], tc_info=tc_query[0])
    exp_free._load_preprocessing()
    exp_free._load_tc()

    exp_free.norm_deconv_fluor = tuning.normalize_responses(exp_free.deconv_fluor.copy(), columnwise=True)
    exp_free.norm_deconv_fluor = drop_partial_or_long_trials(exp_free.norm_deconv_fluor.copy())
    exp_free.norm_deconv_fluor_viewed = filter_viewed_trials(exp_free.kinematics, exp_free.norm_deconv_fluor.copy())
    exp_free.norm_inferred_spikes = tuning.normalize_responses(exp_free.inferred_spikes.copy(), columnwise=True)
    exp_free.norm_inferred_spikes = drop_partial_or_long_trials(exp_free.norm_inferred_spikes)

    free_preproc_df = exp_free.norm_deconv_fluor_viewed.loc[:, preproc_cols + row[f'cell_{session_shorthand[0]}']].copy()
    free_preproc_df['mouse'] = row.mouse
    free_preproc_df['day'] = row.day
    preproc_df_list_free.append(free_preproc_df)

    free_tc_df = exp_free.visual_tcs.deconvolved_fluor_viewed_props.loc[row[f'cell_{session_shorthand[0]}'], :]
    free_tc_df['mouse'] = row.mouse
    free_tc_df['day'] = row.day
    tc_df_list_free.append(free_tc_df)

    # load head fixed data
    exp_fixed = WirefreeExperiment(exp_info=exp_query[1], preproc_info=preproc_query[1], tc_info=tc_query[1])
    exp_fixed._load_preprocessing()
    exp_fixed._load_tc()

    exp_fixed.norm_deconv_fluor = tuning.normalize_responses(exp_fixed.deconv_fluor.copy(), columnwise=True)
    exp_fixed.norm_deconv_fluor = drop_partial_or_long_trials(exp_fixed.norm_deconv_fluor.copy())
    exp_fixed.norm_deconv_fluor_viewed = exp_fixed.norm_deconv_fluor.copy()
    exp_fixed.norm_inferred_spikes = tuning.normalize_responses(exp_fixed.inferred_spikes.copy(), columnwise=True)
    exp_fixed.norm_inferred_spikes = drop_partial_or_long_trials(exp_fixed.norm_inferred_spikes)

    fixed_preproc_df = exp_fixed.norm_deconv_fluor_viewed.loc[:, preproc_cols + row[f'cell_{session_shorthand[1]}']].copy()
    fixed_preproc_df['mouse'] = row.mouse
    fixed_preproc_df['day'] = row.day
    preproc_df_list_fixed.append(fixed_preproc_df)

    fixed_tc_df = exp_fixed.visual_tcs.deconvolved_fluor_viewed_props.loc[row[f'cell_{session_shorthand[1]}'], :]
    fixed_tc_df['mouse'] = row.mouse
    fixed_tc_df['day'] = row.day
    tc_df_list_fixed.append(fixed_tc_df)

tc_df_fixed = pd.concat(tc_df_list_fixed)
tc_df_free = pd.concat(tc_df_list_free)


# Polar Plots

In [None]:
# Head Fixed

fig_list = []
for i, cell in enumerate(tc_df_fixed.index.to_list()):
    fig = plt.figure(layout='constrained', figsize=(3/fp.constant_in2cm, 3/fp.constant_in2cm))

    tc_df_row = tc_df_fixed.iloc[i].to_frame().T
    this_fig_axes = fp.plot_tuning_with_stats(tc_df_row, cell, subfig=fig, tuning_kind='direction', 
                                              plot_selectivity=False, font_size='paper', plot_trials=False)
    save_path = os.path.join(fig_path, '19_cherry_picked_freely_moving_repeat_FOV_TCs', f'{session_shorthand[1]}_polar_plots', f'tc_{i}_{cell}.png')
    fig.savefig(save_path, dpi=800, format='png')


In [None]:
# Freely Moving

fig_list = []
for i, cell in enumerate(tc_df_free.index.to_list()):
    fig = plt.figure(layout='constrained', figsize=(3/fp.constant_in2cm, 3/fp.constant_in2cm))

    tc_df_row = tc_df_free.iloc[i].to_frame().T
    this_fig_axes = fp.plot_tuning_with_stats(tc_df_row, cell, subfig=fig, tuning_kind='direction', 
                                              plot_selectivity=False, font_size='paper', plot_trials=False)
    save_path = os.path.join(fig_path, '19_cherry_picked_freely_moving_repeat_FOV_TCs', f'{session_shorthand[0]}_polar_plots', f'tc_{i}_{cell}.png')
    fig.savefig(save_path, dpi=800, format='png')


# Trial Average Direction Responses

In [None]:
# Head Fixed
pre_trial_period = 2
post_trial_period = 2

cell_count = 0
fixed_plot_list = []
for trials_df in preproc_df_list_fixed:
        cells = [col for col in trials_df.columns if 'cell' in col] 
        trials_df.reset_index(drop=True, inplace=True)
        dff_trials, tc_frames = parse_trial_frames(trials_df, pre_trial=pre_trial_period, post_trial=post_trial_period)
        dir_averages, dir_trials = trial_average_response(dff_trials.copy(), cells, 'direction_wrapped')

        time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
        trial_vector = np.zeros_like(time_vector)
        trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
        trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.15)

        norm_dff_dirs = hv_plot_vis_trial_averages(dir_averages, dir_trials, cells, 'direction')

        save_paths = []
        for cell in cells:
                for dir in np.arange(dir_averages.shape[0]):
                        save_paths.append(os.path.join(fig_path, '19_cherry_picked_freely_moving_repeat_FOV_TCs', f'{session_shorthand[0]}_tcs_by_dir', f'tc_{cell_count}_{cell}_dir_{dir}.png'))
                cell_count += 1

        for i, norm_dff_i in enumerate(norm_dff_dirs):
                norm_dff_i = norm_dff_i * trial_plot
                norm_dff_i = norm_dff_i.opts(hv.opts.Curve(width=100, height=100, xlabel=None, ylabel=None, xaxis=None, yaxis=None, title=''))
                norm_dff_i = fp.save_figure(norm_dff_i, save_path=save_paths[i], fig_width=1, dpi=800, fontsize='paper', target='save', display_factor=0.2)
                fixed_plot_list.append(norm_dff_i)

# hv.Layout(new_norm_dff_dirs).cols(12)

In [None]:
# Freely Moving
pre_trial_period = 2
post_trial_period = 2

cell_count = 0
free_plot_list = []
for trials_df in preproc_df_list_free:
        cells = [col for col in trials_df.columns if 'cell' in col] 
        trials_df.reset_index(drop=True, inplace=True)
        trials, tc_frames = parse_trial_frames(trials_df, pre_trial=pre_trial_period, post_trial=post_trial_period)
        dir_averages, dir_trials = trial_average_response(trials.copy(), cells, 'direction_wrapped')

        time_vector = np.linspace(-pre_trial_period, 5 + post_trial_period, (5 + pre_trial_period + post_trial_period) * processing_parameters.wf_frame_rate) 
        trial_vector = np.zeros_like(time_vector)
        trial_vector[(time_vector >= 0) & (time_vector <= 5)] = 1
        trial_plot = hv.Area(trial_vector).opts(color='gray', alpha=0.15)

        norm_dff_dirs = hv_plot_vis_trial_averages(dir_averages, dir_trials, cells, 'direction')

        save_paths = []
        for cell in cells:
                for dir in np.arange(dir_averages.shape[0]):
                        save_paths.append(os.path.join(fig_path, '19_cherry_picked_freely_moving_repeat_FOV_TCs', f'{session_shorthand[1]}_tcs_by_dir', f'tc_{cell_count}_{cell}_dir_{dir}.png'))
                cell_count += 1

        for i, norm_dff_i in enumerate(norm_dff_dirs):
                norm_dff_i = norm_dff_i * trial_plot
                norm_dff_i = norm_dff_i.opts(hv.opts.Curve(width=100, height=100, xlabel=None, ylabel=None, xaxis=None, yaxis=None, title=''))
                norm_dff_i = fp.save_figure(norm_dff_i, save_path=save_paths[i], fig_width=1, dpi=800, fontsize='paper', target='save', display_factor=0.2)
                free_plot_list.append(norm_dff_i)

# hv.Layout(new_norm_dff_dirs).cols(12)

In [None]:
# idx_list = [3, 4, 55, 59, 48, 31, 2, 6]   # Full exp
# idx_list = [30, 33, 5, 10, 7]  # repeat HF
idx_list = [7, 4, 5, 1, 3]  # repeat FM

In [None]:
df.iloc[:, :][[f'pref_ori_{session_shorthand[0]}', f'pref_ori_{session_shorthand[1]}', f'pref_dir_{session_shorthand[0]}', f'pref_dir_{session_shorthand[1]}']]

In [None]:
df.iloc[:, :][['mouse', 'day', f'cell_{session_shorthand[0]}', f'cell_{session_shorthand[1]}', f'fit_osi_{session_shorthand[0]}', f'fit_osi_{session_shorthand[1]}', f'fit_dsi_{session_shorthand[0]}', f'fit_dsi_{session_shorthand[1]}',]]

In [None]:
df[['mouse', 'day', f'cell_{session_shorthand[0]}', f'cell_{session_shorthand[1]}']]

In [None]:
df[(df[f'fit_osi_{session_shorthand[0]}'] > 0.7) & (df[f'fit_osi_{session_shorthand[1]}'] > 0.7)]