In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import importlib
sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv
import h5py
import cv2
from skimage import color
from hmmlearn import hmm
from scipy.stats import mannwhitneyu
hv.extension('bokeh')

import paths
import processing_parameters
import functions_bondjango as bd
import functions_misc as fmisc
import functions_matching as fm
import functions_data_handling as fdh
import functions_tuning as tuning
import functions_plotting as fp
import functions_kinematic as fk
from wirefree_experiment import WirefreeExperiment, DataContainer
from functions_wirefree_trigger_fix import get_trial_duration_stats

# fig_path = paths.wf_figures_path
fig_path = r"H:\thesis\figures"

# Load Data

In [None]:
def drop_partial_or_long_trials(df, min_trial_length=4.5, max_trial_length=5.5):
    """
    This function drops trials that are shorter than min_trial_length (partial trials) and
    trials that are longer than max_trial_length (errors in trial number indexing) from the dataframe.

    Parameters:
    df (DataFrame): The dataframe containing the trials.
    min_trial_length (float): The minimum length for a trial. Defaults to 4.5.
    max_trial_length (float): The maximum length for a trial. Defaults to 5.5.

    Returns:
    DataFrame: The dataframe after dropping the partial and long trials.
    """

    trial_lengths = df[df.trial_num > 0].groupby('trial_num').apply(lambda x: x.shape[0] / processing_parameters.wf_frame_rate)

    # Drop trials that are shorter than min_trial_length (partial trials)
    short_trials = trial_lengths[trial_lengths < min_trial_length].index
    df = df.drop(df[df.trial_num.isin(short_trials)].index)

    # Drop trials that are longer than max_trial_length (errors in trial number indexing)
    long_trials = trial_lengths[trial_lengths > max_trial_length].index
    df = df.drop(df[df.trial_num.isin(long_trials)].index).reset_index(drop=True)

    return df   


def filter_viewed_trials(kinematics, activity_df):

    # Filter trials by head pitch if freely moving
    pitch_lower_cutoff = processing_parameters.head_pitch_cutoff[0]
    pitch_upper_cutoff = processing_parameters.head_pitch_cutoff[1]
    view_fraction = processing_parameters.view_fraction
    kinematics['viewed'] = np.logical_and(kinematics['head_pitch'].to_numpy() >= pitch_lower_cutoff,
                                          kinematics['head_pitch'].to_numpy() <= pitch_upper_cutoff)
    viewed_trials = kinematics.groupby('trial_num').filter(
        lambda x: (x['viewed'].sum() / len(x['viewed'])) > view_fraction).trial_num.unique()

    viewed_activity_df = activity_df.loc[activity_df.trial_num.isin(viewed_trials)].copy()
    return viewed_activity_df


In [136]:
importlib.reload(processing_parameters)

# get the search string
# for thesis: 'mouse:MM_221109_a, slug:01_11_2023,'
# for control light 'mouse:MM_221109_a, slug:,'
# for control dark 'mouse:MM_221109_a, slug:01_27_2023,'

search_string = 'mouse:MM_221109_a, slug:01_27_2023' 
parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database

# get the raw experiment
exp_query = bd.query_database('vr_experiment', search_string)
exp_query.sort(key=lambda x: x['rig'])

# Get the preprocessing file
preproc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:preprocessing')
preproc_query = [q for q in preproc_query if parsed_search['mouse'].lower() in q['slug']]
preproc_query.sort(key=lambda x: x['rig'])
preproc_paths = np.sort(np.array([el['analysis_path'] for el in preproc_query if (el['analysis_type'] == 'preprocessing') and
                        (parsed_search['mouse'].lower() in el['slug'])]))

# Get the tuning curve file
tc_query = bd.query_database('analyzed_data', search_string + r', analysis_type:tc_analysis')
tc_query = [q for q in tc_query if parsed_search['mouse'].lower() in q['slug']]
tc_query.sort(key=lambda x: x['rig'])
tc_paths = np.sort(np.array([el['analysis_path'] for el in tc_query if (el['analysis_type'] == 'tc_analysis') and
                    (parsed_search['mouse'].lower() in el['slug'])]))

# Get the cell matching file
cell_match_query = bd.query_database('analyzed_data', search_string + r', analysis_type:cellmatching')
cell_match_query = [q for q in cell_match_query if parsed_search['mouse'].lower() in q['slug']]
cell_match_query.sort(key=lambda x: x['rig'])
cell_matching_path = np.array([el['analysis_path'] for el in cell_match_query if (el['analysis_type'] == 'cellmatching') and
                                (parsed_search['mouse'].lower() in el['slug']) and ('daycellmatch' in el['slug'])])

# Get the calcium file names
calcium_paths = np.array([p.replace('preproc', 'calciumraw') for p in preproc_paths])

# Parse the rigs
rigs = np.array([os.path.basename(file).split('_')[6] for file in calcium_paths])
sort_array = np.argsort(rigs)

rigs = rigs[sort_array]
calcium_paths = calcium_paths[sort_array]
preproc_paths = preproc_paths[sort_array]
tc_paths = tc_paths[sort_array]

print(cell_matching_path)
print(calcium_paths)
print(preproc_paths)
print(tc_paths)
print(rigs)

# For thesis
parsed_search['lighting'] = 'dark'
parsed_search['result'] = 'control'

In [137]:
# Load the matching assignments and find the column that corresponds to each file
assignments =  fm.match_cells(cell_matching_path[0])
new_cols = [col.split('_')[-2] for col in assignments.columns]
assignments.columns = new_cols
col_sort_idx = np.argsort(assignments.columns)
assignments = assignments[assignments.columns[col_sort_idx]]

# Use number of non-NaNs in each row to filter out components that were not registered in enough sessions
assignments_filtered = assignments.dropna().astype(int).to_numpy()
unassigned = np.array(assignments[np.sum(~np.isnan(assignments), axis=1) < 2])
unassigned = [unassigned[~np.isnan(unassigned[:, 0]), 0].astype(int), unassigned[~np.isnan(unassigned[:, 1]), 1].astype(int)]
unassigned = [np.sort(np.unique(unassigned[0])), np.sort(np.unique(unassigned[1]))]

# Specify the path to the curated cell matches file
curated_cell_matches_path = os.path.join(r"C:\Users\mmccann\Desktop", 
                                f"curated_cell_matches_{parsed_search['result']}_{parsed_search['lighting']}_{parsed_search['rig']}.xlsx")

try:
    # Read all sheets into a list of dataframes
    curated_matches_dict = pd.read_excel(curated_cell_matches_path, sheet_name=None)

    # Concatenate the dataframes into a single dataframe
    curated_matches = pd.concat(curated_matches_dict.values(), ignore_index=True)

    # Get the hand-picked matches for the current experiment
    day_mouse_curated_idxs = curated_matches[(curated_matches['mouse'] == parsed_search['mouse']) & 
                                            (curated_matches['day'] == parsed_search['slug'])]['index'].values
    
    if len(day_mouse_curated_idxs) == 0:
        raise Exception("No curated matches found for the current experiment. Continuing with CaImAn matches...")
    else:
        day_mouse_curated_matches = assignments_filtered[day_mouse_curated_idxs, :]

except Exception as e:
    print(f"Could not find the file {curated_cell_matches_path}. Continuing with CaImAn matches...")
    day_mouse_curated_matches = assignments_filtered

In [138]:
# Load freely moving data
exp_free = WirefreeExperiment(exp_info=exp_query[0], preproc_info=preproc_query[0], tc_info=tc_query[0])
exp_free._load_preprocessing()
exp_free._load_tc()

exp_free.deconv_fluor = drop_partial_or_long_trials(exp_free.deconv_fluor.copy())
exp_free.norm_deconv_fluor = tuning.normalize_responses(exp_free.deconv_fluor.copy())
exp_free.deconv_fluor_viewed = filter_viewed_trials(exp_free.kinematics, exp_free.deconv_fluor.copy())
exp_free.norm_deconv_fluor_viewed = filter_viewed_trials(exp_free.kinematics, exp_free.norm_deconv_fluor.copy())

# load head fixed data
exp_fixed = WirefreeExperiment(exp_info=exp_query[1], preproc_info=preproc_query[1], tc_info=tc_query[1])
exp_fixed._load_preprocessing()
exp_fixed._load_tc()

exp_fixed.deconv_fluor = drop_partial_or_long_trials(exp_fixed.deconv_fluor.copy())
exp_fixed.norm_deconv_fluor = tuning.normalize_responses(exp_fixed.deconv_fluor.copy())
exp_fixed.deconv_fluor_viewed = filter_viewed_trials(exp_fixed.kinematics, exp_fixed.deconv_fluor.copy())
exp_fixed.norm_deconv_fluor_viewed = exp_fixed.norm_deconv_fluor.copy()

# Test the thresholds for MannWhitneyU test based responsivity classification

In [139]:
use_exp = 'fixed'
this_std = 8

# load the used dataset from processing params
used_activity_ds = processing_parameters.activity_datasets[0]
print(used_activity_ds)

if use_exp == 'free':
    this_exp = exp_free
else:
    this_exp = exp_fixed

this_tc = getattr(this_exp.visual_tcs, f'{used_activity_ds}_props')
activity_df = this_exp.norm_deconv_fluor_viewed
cells = [col for col in activity_df.columns if 'cell' in col]
activity_df.reset_index(drop=True, inplace=True)

In [140]:
# --- 1. Calculate Responsivity --- #

# -- 1.0 Get std of each cell across whole experiment
std_activity = activity_df.loc[:, cells].apply(np.std)

# -- 1.1 Get the std or cell response, and mean, max and AUC of response during trials
trial_max_activity = (activity_df.loc[activity_df.trial_num > 0, :]
                        .groupby(['trial_num', 'direction_wrapped', 'orientation'])[cells]
                        .agg(np.max).copy().reset_index())
trial_mean_activity = (activity_df.loc[activity_df.trial_num > 0, :]
                        .groupby(['trial_num', 'direction_wrapped', 'orientation'])[cells]
                        .agg(np.mean).copy().reset_index())
trial_auc_activity = (activity_df.loc[activity_df.trial_num > 0, :]
                        .groupby(['trial_num', 'direction_wrapped', 'orientation'])[cells]
                        .agg(np.trapz).copy().reset_index())

# -- 1.2 Get the std or cell response, and mean, max and AUC of response during ITI
#    Parse the dataframe into trial frames, which are the trial + 1.5 sec of the preceding inter-trial interval
#    This will be used for evaluating per-trial responsivity
trial_frames_short_iti, _ = tuning.parse_trial_frames(activity_df, pre_trial=1.5)

iti_std_activity = (trial_frames_short_iti.groupby('frame_num')
                    .apply(lambda x: x.loc[x.trial_num == 0, cells].std())
                    .reset_index(names='trial_num'))
iti_std_activity.insert(1, 'direction_wrapped', trial_max_activity['direction_wrapped'])
iti_std_activity.insert(2, 'orientation', trial_max_activity['orientation'])

iti_mean_activity = (trial_frames_short_iti.groupby('frame_num')
                        .apply(lambda x: x.loc[x.trial_num == 0, cells].mean())
                        .reset_index(names='trial_num'))
iti_mean_activity.insert(1, 'direction_wrapped', trial_max_activity['direction_wrapped'])
iti_mean_activity.insert(2, 'orientation', trial_max_activity['orientation'])

iti_max_activity = (trial_frames_short_iti.groupby('frame_num')
                    .apply(lambda x: x.loc[x.trial_num == 0, cells].max())
                    .reset_index(names='trial_num'))
iti_max_activity.insert(1, 'direction_wrapped', trial_max_activity['direction_wrapped'])
iti_max_activity.insert(2, 'orientation', trial_max_activity['orientation'])

iti_auc_activity = (trial_frames_short_iti.groupby('frame_num')
                    .apply(lambda x: x.loc[x.trial_num == 0, cells].apply(np.trapz))
                    .reset_index(names='trial_num'))
iti_auc_activity.insert(1, 'direction_wrapped', trial_max_activity['direction_wrapped'])
iti_auc_activity.insert(2, 'orientation', trial_max_activity['orientation'])

In [152]:
vis_drive_test

In [151]:
# -- 1.3 Responsivity Evaluations

# 1.3.1 Determine if the cell is visually responsive.
#    This is done by comparing the AUC activity during each 5 sec ITI to the AUC activity during the trial. If a
#    cell passes a Mann-Whitney U test where the test checks if activity during the trials is greater than that
#    during the ITI (i.e. alternative='greater'), then the cell is considered visually driven.
trial_frames_long_iti, _ = tuning.parse_trial_frames(activity_df, pre_trial=5.0)
long_iti_auc_activity = (trial_frames_long_iti.groupby('frame_num')
                            .apply(lambda x: x.loc[x.trial_num == 0, cells]
                            .apply(np.trapz))
                            .reset_index(names='trial_num'))
long_iti_auc_activity.insert(1, 'direction_wrapped', trial_max_activity['direction_wrapped'])
long_iti_auc_activity.insert(2, 'orientation', trial_max_activity['orientation'])

stats, pvals = mannwhitneyu(trial_auc_activity[cells], long_iti_auc_activity[cells],
                            alternative='greater', axis=0)
vis_drive_test = pd.DataFrame(index=cells, data={'statistic': stats, 'pvalue': pvals})

p_cutoff = 0.25
vis_drive_test['vis_resp'] = vis_drive_test.pvalue < p_cutoff
vis_drive_test['not_vis_resp'] = vis_drive_test.pvalue > 1-p_cutoff
vis_drive_test['mod_vis_resp'] = np.logical_and(vis_drive_test.pvalue >= p_cutoff, vis_drive_test.pvalue <= 1-p_cutoff)

dir_sel = this_tc.fit_osi >= processing_parameters.selectivity_idx_cutoff
ori_sel = this_tc.fit_dsi >= processing_parameters.selectivity_idx_cutoff

vis_drive_test['vis_resp_dir_sel'] = np.logical_and(vis_drive_test.vis_resp, dir_sel)
vis_drive_test['not_vis_resp_dir_sel'] = np.logical_and(vis_drive_test.not_vis_resp, dir_sel)
vis_drive_test['mod_vis_resp_dir_sel'] = np.logical_and(vis_drive_test.mod_vis_resp, dir_sel)

vis_drive_test['vis_resp_ori_sel'] = np.logical_and(vis_drive_test.vis_resp, ori_sel)
vis_drive_test['not_vis_resp_ori_sel'] = np.logical_and(vis_drive_test.not_vis_resp, ori_sel)
vis_drive_test['mod_vis_resp_ori_sel'] = np.logical_and(vis_drive_test.mod_vis_resp, ori_sel)

vis_resp_cells = vis_drive_test.loc[vis_drive_test.vis_resp].index.to_list()
not_vis_resp_cells = vis_drive_test.loc[vis_drive_test.not_vis_resp].index.to_list()
mod_resp_cells = vis_drive_test.loc[vis_drive_test.mod_vis_resp].index.to_list()

vis_resp_dir_sel_cells = vis_drive_test.loc[vis_drive_test.vis_resp_dir_sel].index.to_list()
not_vis_resp_dir_sel_cells = vis_drive_test.loc[vis_drive_test.not_vis_resp_dir_sel].index.to_list()
mod_resp_dir_sel_cells = vis_drive_test.loc[vis_drive_test.mod_vis_resp_dir_sel].index.to_list()

vis_resp_ori_sel_cells = vis_drive_test.loc[vis_drive_test.vis_resp_ori_sel].index.to_list()
not_vis_resp_ori_sel_cells = vis_drive_test.loc[vis_drive_test.not_vis_resp_ori_sel].index.to_list()
mod_resp_ori_sel_cells = vis_drive_test.loc[vis_drive_test.mod_vis_resp_ori_sel].index.to_list()

print(vis_drive_test.iloc[:, 2:].sum())

In [142]:
fp.plot_dff_spikes_trials(this_exp, os.path.join(fig_path, '2_head_fixed_cells_traces_ethogram'), save=False,
                          plot_spikes=False, plot_trials=True, plot_running=False, 
                          fig_width=15, fontsize='paper', cells_to_plot=np.random.choice(vis_resp_dir_sel_cells, min(len(vis_resp_dir_sel_cells), 10), replace=False))

In [146]:
for i, cell in enumerate(vis_resp_dir_sel_cells):
    fig = plt.figure(layout='constrained', figsize=(3/fp.constant_in2cm, 3/fp.constant_in2cm))

    this_fig_axes = fp.plot_tuning_with_stats(this_tc, cell, subfig=fig, tuning_kind='direction', 
                                              plot_selectivity=False, font_size='paper', plot_trials=False)


In [143]:
fp.plot_dff_spikes_trials(this_exp, os.path.join(fig_path, '2_head_fixed_cells_traces_ethogram'), save=False,
                          plot_spikes=False, plot_trials=True, plot_running=False, 
                          fig_width=15, fontsize='paper', cells_to_plot=np.random.choice(not_vis_resp_dir_sel_cells, min(len(not_vis_resp_dir_sel_cells), 10), replace=False))

In [145]:
for i, cell in enumerate(not_vis_resp_dir_sel_cells):
    fig = plt.figure(layout='constrained', figsize=(3/fp.constant_in2cm, 3/fp.constant_in2cm))

    this_fig_axes = fp.plot_tuning_with_stats(this_tc, cell, subfig=fig, tuning_kind='direction', 
                                              plot_selectivity=False, font_size='paper', plot_trials=False)


In [144]:
fp.plot_dff_spikes_trials(this_exp, os.path.join(fig_path, '2_head_fixed_cells_traces_ethogram'), save=False,
                          plot_spikes=False, plot_trials=True, plot_running=False, 
                          fig_width=15, fontsize='paper', cells_to_plot=np.random.choice(mod_resp_dir_sel_cells, min(len(mod_resp_dir_sel_cells), 8), replace=False))

In [147]:
for i, cell in enumerate(mod_resp_dir_sel_cells):
    fig = plt.figure(layout='constrained', figsize=(3/fp.constant_in2cm, 3/fp.constant_in2cm))

    this_fig_axes = fp.plot_tuning_with_stats(this_tc, cell, subfig=fig, tuning_kind='direction', 
                                              plot_selectivity=False, font_size='paper', plot_trials=False)


In [None]:
fp.plot_dff_spikes_trials(this_exp, os.path.join(fig_path, '2_head_fixed_cells_traces_ethogram'), save=False,
                          plot_spikes=False, plot_trials=True, plot_running=False, 
                          fig_width=15, fontsize='paper', cells_to_plot=np.random.choice(vis_resp_cells, min(len(vis_resp_cells), 8), replace=False))

In [None]:
fp.plot_dff_spikes_trials(this_exp, os.path.join(fig_path, '2_head_fixed_cells_traces_ethogram'), save=False,
                          plot_spikes=False, plot_trials=True, plot_running=False, 
                          fig_width=15, fontsize='paper', cells_to_plot=np.random.choice(dir_resp_cells, min(len(dir_resp_cells), 8), replace=False))

# Select cells with particular visual response properties

In [None]:
def get_vis_tuned_cells(ds, vis_stim='dir', sel_thresh=0.3, drop_na=True):
    data = ds.copy()

    if (vis_stim == 'dir'):
        # Cells cannot be both responsive to all visual stimuli and to directions
        cells = data[(data['is_dir_responsive'] == 1) & (data['fit_dsi'] >= sel_thresh)]
                    #   & (data['fit_osi'] < sel_thresh)] & (data['is_vis_responsive'] == 0)
        return cells
    
    elif (vis_stim == 'ori'):
        # Cells cannot be both responsive to all visual stimuli and to orientations
        cells = data[(data['is_ori_responsive'] == 1) & (data['fit_osi'] >= sel_thresh)]
            #   & (data['fit_dsi'] < sel_thresh)] & (data['is_vis_responsive'] == 0)
        return cells

    elif (vis_stim == 'vis'):
        cells = data[(data['is_vis_responsive'] == 1) & (data['is_gen_responsive'] == 0)]
        return cells
    
    elif (vis_stim == 'gen') :
        cells = data[data['is_gen_responsive'] == 1]
        return cells

    else:
        return Exception('Invalid vis_stim')


def filter_vis_selectivity(fixed_exp_tcs, free_exp_tcs, matches, vis_stim, sel_thresh=0.3):

    # Get the right columns
    if vis_stim == 'dir':
        sel_var = 'fit_dsi'
        resp_test = 'is_dir_responsive'
    elif vis_stim == 'ori':
        sel_var = 'fit_osi'
        resp_test = 'is_ori_responsive'
    else:
        raise ValueError('Invalid vis_stim')

    # Find matched_cells
    free_matched = free_exp_tcs.iloc[matches[:, 0], :]
    fixed_matched = fixed_exp_tcs.iloc[matches[:, 1], :]

    # Get the selectivity values from the matched cells
    free = free_matched[sel_var].abs()
    fixed = fixed_matched[sel_var].abs()
    diff = free.values - fixed.values
    sel_matched = pd.DataFrame({'fixed': fixed.values, 'free': free.values, 'diff': diff})

    # Find matches where both are responsive
    free_responsive = free_matched[resp_test].values
    fixed_responsive = fixed_matched[resp_test].values
    diff_resp = free_responsive - fixed_responsive
    both_resp_idxs = np.argwhere(diff_resp == 0).flatten()
    gained_resp_idx = np.argwhere(diff_resp == 1).flatten()
    lost_resp_idx = np.argwhere(diff_resp == -1).flatten()

    # find cells that maintained, gained, or lost selectivity
    kept_sel = sel_matched[(sel_matched['fixed'] >= sel_thresh) & (sel_matched['free'] >= sel_thresh)]
    strengthened_sel = sel_matched[(sel_matched['fixed'] >= sel_thresh) & (sel_matched['free'] >= sel_thresh) & (sel_matched['diff'] > 0.15)]
    weakened_sel = sel_matched[(sel_matched['fixed'] >= sel_thresh) & (sel_matched['free'] >= sel_thresh) & (sel_matched['diff'] < -0.15)]
    gained_sel = sel_matched[(sel_matched['fixed'] < sel_thresh) & (sel_matched['free'] >= sel_thresh)]
    lost_sel = sel_matched[(sel_matched['fixed'] >= sel_thresh) & (sel_matched['free'] < sel_thresh)]

    sel_matched['kept'] = sel_matched.index.isin(kept_sel.index)
    sel_matched['lost'] = sel_matched.index.isin(lost_sel.index)
    sel_matched['gained'] = sel_matched.index.isin(gained_sel.index)
    sel_matched['strengthened'] = sel_matched.index.isin(strengthened_sel.index)
    sel_matched['weakened'] = sel_matched.index.isin(weakened_sel.index)
    
    return sel_matched

In [None]:
free_tcs = getattr(exp_free.visual_tcs, f'{used_activity_ds}_props')
fixed_tcs = getattr(exp_fixed.visual_tcs, f'{used_activity_ds}_props')

# Cells that are generally responsive (but not specific for visual stimuli)
free_gen_resp = get_vis_tuned_cells(free_tcs, vis_stim='gen', sel_thresh=processing_parameters.selectivity_idx_cutoff)
fixed_gen_resp = get_vis_tuned_cells(fixed_tcs, vis_stim='gen', sel_thresh=processing_parameters.selectivity_idx_cutoff)

# Cells that meet visual responsivity criteria
free_vis_resp = get_vis_tuned_cells(free_tcs, vis_stim='vis', sel_thresh=processing_parameters.selectivity_idx_cutoff)
fixed_vis_resp = get_vis_tuned_cells(fixed_tcs, vis_stim='vis', sel_thresh=processing_parameters.selectivity_idx_cutoff)

# Cells that meet direction selectivity criteria
free_dir_tuned = get_vis_tuned_cells(free_tcs, vis_stim='dir', sel_thresh=processing_parameters.selectivity_idx_cutoff)
fixed_dir_tuned = get_vis_tuned_cells(fixed_tcs, vis_stim='dir', sel_thresh=processing_parameters.selectivity_idx_cutoff)

# Cells that meet orientation selectivity criteria
free_ori_tuned = get_vis_tuned_cells(free_tcs, vis_stim='ori', sel_thresh=processing_parameters.selectivity_idx_cutoff)
fixed_ori_tuned = get_vis_tuned_cells(fixed_tcs, vis_stim='ori', sel_thresh=processing_parameters.selectivity_idx_cutoff)


# Find cells that are both direction and orientation tuned, and figure out what to do with them.
intersect, comm1, comm2 = np.intersect1d(free_dir_tuned.index, free_ori_tuned.index, return_indices=True)
free_both_tuned = free_dir_tuned.iloc[comm1].copy()

# Remove cells tuned to both from each category
free_dir_tuned = free_dir_tuned.drop(free_dir_tuned.index[comm1])
free_ori_tuned = free_ori_tuned.drop(free_ori_tuned.index[comm2])

intersect, comm1, comm2 = np.intersect1d(fixed_dir_tuned.index, fixed_ori_tuned.index, return_indices=True)
fixed_both_tuned = fixed_dir_tuned.iloc[comm1].copy()
fixed_dir_tuned = fixed_dir_tuned.drop(fixed_dir_tuned.index[comm1])
fixed_ori_tuned = fixed_ori_tuned.drop(fixed_ori_tuned.index[comm2])

# Double check cells that are visually reposnsive, make sure that all are contained in the vis_resp
free_resp_cells = np.unique(np.concatenate([free_dir_tuned.index, free_ori_tuned.index, free_both_tuned.index]))
not_in_free_resp_cells = np.setdiff1d(free_vis_resp.index, free_resp_cells, assume_unique=True)
free_vis_resp = pd.concat([free_vis_resp, free_tcs.loc[not_in_free_resp_cells, :]])
free_vis_resp = free_vis_resp.reset_index().drop_duplicates(subset=['index'])

fixed_resp_cells = np.unique(np.concatenate([fixed_dir_tuned.index, fixed_ori_tuned.index, fixed_both_tuned.index]))
not_in_fixed_resp_cells = np.setdiff1d(fixed_vis_resp.index, fixed_resp_cells, assume_unique=True)
fixed_vis_resp = pd.concat([fixed_vis_resp, fixed_tcs.loc[not_in_fixed_resp_cells, :]])
fixed_vis_resp = fixed_vis_resp.reset_index().drop_duplicates(subset=['index'])


In [None]:
print(free_dir_tuned.index.to_list(), free_ori_tuned.index.to_list())
print(fixed_dir_tuned.index.to_list(), fixed_ori_tuned.index.to_list())

In [None]:
# For thesis
# final_idxs = [4, 11, 12, 21, 27, 36, 39, 19]
# free ['cell_0013', 'cell_0026', 'cell_0027', 'cell_0049', 'cell_0075', 'cell_0119', 'cell_0133', 'cell_0045']
# fixed ['cell_0014', 'cell_0036', 'cell_0039', 'cell_0054', 'cell_0092', 'cell_0128', 'cell_0141', 'cell_0061']

num_matches = min(day_mouse_curated_matches.shape[0], 8)

ori_matched = filter_vis_selectivity(fixed_tcs, free_tcs, day_mouse_curated_matches, 'ori', sel_thresh=processing_parameters.selectivity_idx_cutoff)
dir_matched = filter_vis_selectivity(fixed_tcs, free_tcs, day_mouse_curated_matches, 'dir', sel_thresh=processing_parameters.selectivity_idx_cutoff)

ori_kept = ori_matched[ori_matched['kept']].index.values
ori_gained = ori_matched[ori_matched['gained']].index.values
ori_lost = ori_matched[ori_matched['lost']].index.values
ori_strong = ori_matched[ori_matched['strengthened']].index.values
ori_weak = ori_matched[ori_matched['weakened']].index.values

dir_kept = dir_matched[dir_matched['kept']].index.values
dir_gained = dir_matched[dir_matched['gained']].index.values
dir_lost = dir_matched[dir_matched['lost']].index.values
dir_strong = dir_matched[dir_matched['strengthened']].index.values
dir_weak = dir_matched[dir_matched['weakened']].index.values

# Choose cells somewhat at random for the thesis: want 8 in total, one from each category + 2 random
chosen_idxs = []
for arr in [ori_kept, ori_lost, ori_gained, ori_strong, ori_weak, dir_kept, dir_lost, dir_gained, dir_strong, dir_weak]:
    if arr.size > 0:
        chosen_idxs.append(np.random.choice(arr, 1)[0])

# Enforce no repeats
chosen_idxs = np.unique(chosen_idxs)

# randomly choose 2 more cells from those remaining
remaining_idxs = np.setdiff1d(np.arange(len(day_mouse_curated_matches)), chosen_idxs)
random_idxs = np.random.choice(remaining_idxs, num_matches - len(chosen_idxs), replace=False)
final_idxs = np.concatenate([chosen_idxs, random_idxs]).astype(int)

# enforce no repeats
# final_idxs = np.unique(final_idxs)

# final_idxs = [4, 11, 12, 21, 27, 36, 39, 19]
final_cells = day_mouse_curated_matches[final_idxs, :]

exp_free.cells_to_match = [f'cell_{id:04d}' for id in final_cells[:, 0]]
exp_fixed.cells_to_match = [f'cell_{id:04d}' for id in final_cells[:, 1]]
print(final_idxs)
print(exp_free.cells_to_match)
print(exp_fixed.cells_to_match)