In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import numpy as np
import pandas as pd
import scipy.stats as st
import panel as pn
import seaborn as sns
import holoviews as hv
import hvplot.pandas
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact, interactive

import os
import sys
import random
import importlib
import datetime
import warnings
import math
import h5py
warnings.filterwarnings('ignore')

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from bokeh.plotting import show
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
cm = 1./2.54
figure_save_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster"

# Go through all the mice and create aggregate files

In [None]:
importlib.reload(processing_parameters)
for mouse in ['MM_221110_a', 'MM_221109_a', 'MM_220928_a', 'MM_220915_a',
              'MM_230518_b', 'MM_230705_b', 'MM_230706_a', 'MM_230706_b']:
    
    # get the search string
    search_string = f"mouse:{mouse}, lighting:normal, analysis_type:tc_consolidate, result:repeat"
    parsed_search = fdh.parse_search_string(search_string)
    
    # get the paths from the database
    file_infos = bd.query_database("analyzed_data", search_string)
    input_paths = np.array([el['analysis_path'] for el in file_infos if ('_tcconsolidate' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])])
    print(np.sort(input_paths))
    if len(input_paths) == 0:
        continue
    else:
        date_list = [os.path.basename(path)[:10] for path in input_paths]
        mouse = parsed_search['mouse']
        
        # assemble the output path
        output_path = os.path.join(paths.analysis_path, f"AGG_{'_'.join(parsed_search.values())}.hdf5")
        
        data_list = []
        for file, date in zip(input_paths, date_list):
            data_dict = {}
            with pd.HDFStore(file, 'r') as tc:
                # print(tc.keys())
                if '/no_ROIs'in tc.keys():
                    continue
                else:
                    for key in tc.keys():
                        label = "_".join(key.split('/')[1:])
                        data = tc[key]
                        data['date'] = date
                        data_dict[label] = data
                        
                    data_list.append(data_dict)
        
        # Aggregate it all
        agg_dict = {}
        
        for key in data_list[0].keys():
            df = pd.concat([d[key] for d in data_list]).reset_index(names='old_index')
            agg_dict[key] = df
            df.to_hdf(output_path, key)
        
        # assemble the entry data
        entry_data = {
            'analysis_type': 'agg_tc',
            'analysis_path': output_path,
            'date': '',
            'pic_path': '',
            'result': parsed_search['result'],
            'rig': parsed_search['rig'],
            'lighting': parsed_search['lighting'],
            'imaging': 'wirefree',
            'slug': misc.slugify(os.path.basename(output_path)[:-5]),
        }
        
        # check if the entry already exists, if so, update it, otherwise, create it
        update_url = '/'.join((paths.bondjango_url, 'analyzed_data', entry_data['slug'], ''))
        output_entry = bd.update_entry(update_url, entry_data)
        if output_entry.status_code == 404:
            # build the url for creating an entry
            create_url = '/'.join((paths.bondjango_url, 'analyzed_data', ''))
            output_entry = bd.create_entry(create_url, entry_data)
        
        print('The output status was %i, reason %s' %
              (output_entry.status_code, output_entry.reason))
        if output_entry.status_code in [500, 400]:
            print(entry_data)

# Load aggregate files from all mice to make plots

In [None]:
# get the search string
search_string = f"analysis_type:agg_tc, result:multi"
parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database
file_infos = bd.query_database("analyzed_data", search_string)
input_paths = np.array([el['analysis_path'] for el in file_infos if ('agg' in el['slug'])])
print(np.sort(input_paths))

data_list = []
for file in input_paths:
    data_dict = {}
    mouse = '_'.join(os.path.basename(file).split('_')[10:13])
    with pd.HDFStore(file, 'r') as tc:
        for key in tc.keys():
            label = "_".join(key.split('/')[1:])
            data = tc[key]
            data['mouse'] = mouse
            data_dict[label] = data
                
        data_list.append(data_dict)

# Aggregate it all
agg_dict = {}
 
for key in data_list[0].keys():
    df = pd.concat([d[key] for d in data_list]).reset_index(drop=True)
    agg_dict[key] = df

In [None]:
match_nums = [len(d['cell_matches']) for d in data_list]
plt.hist(match_nums, bins=100, edgecolor="black")
plt.title('Num. cells matched')
plt.xlabel('Cells')
plt.ylabel('Count')

In [None]:
# Within-animal fraction kinematic or dir/ori responsive
ds_name = 'VTuningWF_summary_stats'
ds = agg_dict[ds_name]
kinem_cols = list(ds.columns[4:-6])
vis_cols = list(ds.columns[-6:-2])
kinem_vis_cols = kinem_cols + vis_cols

if 'VWheelWF' in ds_name:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_fixed_kinem_cols + processing_parameters.wf_vis_cols))
else:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_free_kinem_cols + processing_parameters.wf_vis_cols))

frac_vis_resp = ds[ds.old_index == 'all_cells'].rename(columns=rename_dict)

violinplot_free = frac_vis_resp[list(rename_dict.values())].hvplot.violin(legend=False, cmap=['blue']*len(kinem_cols) +  ['red']*len(vis_cols))
violinplot_free.opts(xlabel='', ylabel='Significant Frac.', ylim=(-0.05, 1.05), xrotation=45, width=800, height=600)

save_path = os.path.join(figure_save_path, "Fig2", "sig_frac_kinem_vis_free.png")
violinplot_free = fp.save_figure(violinplot_free, save_path=save_path, fig_width=11, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

In [None]:
# Within-animal fraction kinematic or dir/ori responsive
ds_name = 'VWheelWF_summary_stats'
ds = agg_dict[ds_name]
kinem_cols = list(ds.columns[4:-6])
vis_cols = list(ds.columns[-6:-2])
kinem_vis_cols = kinem_cols + vis_cols

if 'VWheelWF' in ds_name:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_fixed_kinem_cols + processing_parameters.wf_vis_cols))
else:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_free_kinem_cols + processing_parameters.wf_vis_cols))

frac_vis_resp = ds[ds.old_index == 'all_cells'].rename(columns=rename_dict)

violinplot = frac_vis_resp[list(rename_dict.values())].hvplot.violin(legend=False, cmap=['blue']*len(kinem_cols) +  ['red']*len(vis_cols))
violinplot.opts(xlabel='', ylabel='Significant Frac.', ylim=(-0.05, 1.05), xrotation=45, width=600, height=600)

save_path = os.path.join(figure_save_path, "Fig2", "sig_frac_kinem_vis_fixed.png")
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=8, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

In [None]:
def get_frac_tuned(df):
    kinem_cols = list(df.columns[1:-6])
    vis_cols = list(df.columns[-6:-2])
    
    df['sum_kinem'] = df[kinem_cols].sum(axis=1)
    df['sum_vis'] = df[vis_cols].sum(axis=1)
    df['sum_mix'] = df[['sum_kinem', 'sum_vis']].sum(axis=1)
   
    frac_vis_tuned = df['sum_vis'].loc[df['sum_vis'] > 0].count() / df.shape[0]
    frac_kinem_tuned = df['sum_kinem'].loc[df['sum_kinem'] > 0].count() / df.shape[0]
    # frac_mix_tuned = df['sum_mix'].loc[df['sum_mix'] > 0].count() / df.shape[0]

    frac_only_kinem = df[(df[vis_cols].sum(axis=1) == 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_only_vis = df[(df[kinem_cols].sum(axis=1) == 0) & (df[vis_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_mix_tuned = df[(df[vis_cols].sum(axis=1) > 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]

    return frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned

In [None]:
plt_list = []
df_list = []
for ds, exp_type in zip([agg_dict['VTuningWF_multimodal_tuned'], agg_dict['VWheelWF_multimodal_tuned']], ['Freely Moving', 'Head Fixed']):
    only_kinem = []
    only_vis = []
    vis_tuned = []
    kinem_tuned = []
    mix_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned = get_frac_tuned(group)
        only_kinem.append(frac_only_kinem)
        only_vis.append(frac_only_vis)
        vis_tuned.append(frac_vis_tuned)
        kinem_tuned.append(frac_kinem_tuned)
        mix_tuned.append(frac_mix_tuned)

    df['group'] = [exp_type] * len(vis_tuned)
    df[' Visual Only'] = only_vis
    df[' Postural Only'] = only_kinem
    df[' Visual'] = vis_tuned
    df[' Postural'] = kinem_tuned
    df['Multimodal'] = mix_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

both = pd.concat(df_list, axis=0)
both = pd.melt(both, id_vars=['group'], value_vars=[' Visual Only', ' Postural Only', ' Visual', ' Postural', 'Multimodal'])

In [None]:
violinplot = hv.Violin(both, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=400,
                                                                  violin_color=hv.dim('group'),
                                                                  show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "Fig2", "frac_multimodal.png")
# violinplot
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=17, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

In [None]:
ax = sns.violinplot(data=both, x="variable", y="value", hue="group")
ax.spines[['right', 'top']].set_visible(False)
ax.set_ylabel('Significant Frac.')
ax.set_xlabel('')
ax.legend(title='', loc='lower right')
plt.savefig(os.path.join(figure_save_path, 'Fig2', 'frac_multimodal_alt.png'), dpi=1000, format='png')

# Bootstrapped tuning shifts

In [None]:
def plot_compare_matched_vis_tuning(ds_list, cell, tuning_kind='direction', exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(18*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(ds_list), ncols=1, hspace=0.07)
    
        for i, subfig in enumerate(subfigs):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # direction tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    for sub_ax, ds in zip(axes, ds_list):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind, error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

def plot_compare_all_vis_tuning(ds_list, cell, tuning_kind=['direction', 'orientation'], exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(36*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=2, ncols=2, hspace=0.07)
    
        for i, subfig in enumerate(subfigs.flatten()):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i//2]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    tuning_kind = ['direction', 'orientation'] * 2
    for i, (ds, sub_ax) in enumerate(zip(ds_list, axes)):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind[i], error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

In [None]:
agg_dict['VTuningWF_matched_norm_spikes_viewed_direction_props'].columns

In [None]:
free_matched = ['VTuningWF_matched_norm_spikes_viewed_direction_props',
                'VTuningWF_matched_norm_spikes_viewed_orientation_props',
                'VTuningWF_matched_norm_spikes_viewed_still_direction_props',
                'VTuningWF_matched_norm_spikes_viewed_still_orientation_props',]
fixed_matched = ['VWheelWF_matched_norm_spikes_viewed_direction_props',
                 'VWheelWF_matched_norm_spikes_viewed_orientation_props',
                 'VWheelWF_matched_norm_spikes_viewed_still_direction_props',
                 'VWheelWF_matched_norm_spikes_viewed_still_orientation_props',]

In [None]:
tuning_kind = 'still_orientation'
datasets = [f'VWheelWF_matched_norm_spikes_viewed_{tuning_kind}_props', f'VTuningWF_matched_norm_spikes_viewed_{tuning_kind}_props']

ref_ds = agg_dict[datasets[0]]
strict_index = ref_ds.loc[(ref_ds.responsivity < 0.25)].index
super_strict_index = ref_ds.loc[(ref_ds.responsivity >= 0.25) & (ref_ds.p_responsivity >= 0.95)].index

In [None]:
tuning_shifts = []

if 'direction' in tuning_kind:
    multiplier = 1.
else:
    multiplier = 2.

for (idxRow, cell_fixed), (_, cell_free) in zip(agg_dict[datasets[0]].iloc[strict_index, :].iterrows(), agg_dict[datasets[1]].iloc[strict_index, :].iterrows()):

    resultant_fixed = fk.wrap(cell_fixed.bootstrap_resultant[:, -1], bound=180)
    resultant_fixed = resultant_fixed[~np.isnan(resultant_fixed)]
    ci_fixed = st.t.interval(alpha=0.95, df=len(resultant_fixed)-1, loc=np.mean(resultant_fixed), scale=st.sem(resultant_fixed)) 
    del_ci_fixed = ci_fixed[-1] - ci_fixed[0]
    
    resultant_free = fk.wrap(cell_free.bootstrap_resultant[:, -1], bound=180)
    resultant_free = resultant_free[~np.isnan(resultant_free)]
    ci_free = st.t.interval(alpha=0.95, df=len(resultant_free)-1, loc=np.mean(resultant_free), scale=st.sem(resultant_free)) 
    del_ci_free = ci_free[-1] - ci_free[0]
    
    # Check if orientation tuned or not
    if del_ci_fixed > 40:
        # The cell is not orientation tuned
        del_po = np.nan
        sig_shift = 0
    else:
        # Get delta preferred orientation
        po_fixed = fk.wrap(cell_fixed.resultant[-1], bound=180)
        po_free = fk.wrap(cell_free.resultant[-1], bound=180)
        if (np.isnan(po_fixed)) or (np.isnan(po_free)):
            po_fixed = cell_fixed.pref
            po_free = cell_free.pref
        del_po = po_free - po_fixed
        del_po_wrapped = fk.wrap(del_po, bound=180)
    
        # determine significance of shift
        if ~(ci_fixed[0] <= po_free <= ci_fixed[1]) and ~(ci_free[0] <= po_fixed <= ci_free[1]):
            # Shift is significant
            sig_shift = 1
        else:
            # Shift is not significant    
            sig_shift = 0
    
    tuning_shifts.append([idxRow, del_po, del_po_wrapped, sig_shift, cell_fixed.mouse, cell_fixed.date])

tuning_shifts = pd.DataFrame(data=tuning_shifts, columns=['', 'delta_po', 'delta_po_wrapped', 'is_sig_po', 'mouse', 'date'])
tuning_shifts = tuning_shifts.set_index(tuning_shifts.columns[0])

In [None]:
bins = np.arange(0, 1.01, 0.05)
plt.hist([agg_dict[datasets[0]].responsivity, agg_dict[datasets[1]].responsivity], bins=bins, edgecolor="black")
plt.legend(['fixed', 'free'])
plt.axvline(0.25, c='r')
plt.title(f'{tuning_kind} responsivity')

In [None]:
overall_sig_po_shift_frac = tuning_shifts.is_sig_po.sum() / tuning_shifts.is_sig_po.count() 
print(overall_sig_po_shift_frac)
tuning_shifts.groupby('mouse').apply(lambda x: x.is_sig_po.sum() / x.is_sig_po.count())

In [None]:
frequencies, edges = np.histogram(tuning_shifts[tuning_shifts.is_sig_po == 1].delta_po.to_numpy(), 36)
hv.Histogram((edges, frequencies)).opts(width=600)

In [None]:
x = tuning_shifts[tuning_shifts.is_sig_po == 1].groupby('mouse').apply(lambda x: np.histogram(x.delta_po.to_numpy(), 36)).to_list()
layout = hv.Layout([hv.Histogram((edges, frequencies)).opts(width=400) for frequencies, edges in x]).opts(shared_axes=False)
layout

In [None]:
%matplotlib inline
datasets = ['VWheelWF_matched_norm_spikes_viewed_still_direction_props', 'VWheelWF_matched_norm_spikes_viewed_still_orientation_props',
            'VTuningWF_matched_norm_spikes_viewed_still_direction_props', 'VTuningWF_matched_norm_spikes_viewed_still_orientation_props']

matched_vis_fig = interactive(plot_compare_all_vis_tuning,
                              ds_list=widgets.fixed([agg_dict[dataset] for dataset in datasets]),
                              cell=tuning_shifts[tuning_shifts.is_sig_po == 1].index, 
                              tuning_kind = widgets.fixed(['direction', 'orientation']),
                              exp_name = widgets.fixed(['VWheelWF', 'VTuningWF']),
                              error = ['std', 'sem'],
                              polar=[True, False],
                              norm=widgets.fixed(True),
                              axes=widgets.fixed(None))
matched_vis_fig

In [None]:
tuning_shifts.loc[67]