In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import itertools
import numpy as np
import pandas as pd
import scipy.stats as st
import panel as pn
import seaborn as sns
import holoviews as hv
import hvplot.pandas
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact, interactive
from statsmodels.stats.multitest import multipletests

import os
import sys
import random
import importlib
import datetime
import warnings
import math
import h5py
warnings.filterwarnings('ignore')

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from bokeh.plotting import show
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer
import paths

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
cm = 1./2.54

# Go through all the mice and create aggregate files

In [None]:
importlib.reload(processing_parameters)
mice = ['MM_221109_a', 'MM_221110_a', 'MM_220915_a','MM_220928_a', 
        'MM_230518_b', 'MM_230705_b', 'MM_230706_a', 'MM_230706_b']
results = ['repeat']    # ['multi', 'fullfield', 'control']
lightings = ['normal', 'dark']
rigs = ['VWheelWF', 'VTuningWF']    # 'ALL' used for everything but repeat aggs

for mouse, result, light, rig in itertools.product(mice, results, lightings, rigs):

    # get the search stringtc_conso
    search_string = f"mouse:{mouse}, result:{result}, lighting:{light}, rig:{rig}"
    print(search_string)

    parsed_search = fdh.parse_search_string(search_string)
    
    # get the paths from the database
    file_infos = bd.query_database("analyzed_data", search_string)
    input_paths = np.array([el['analysis_path'] for el in file_infos if ('_tcconsolidate' in el['slug']) and
                            (parsed_search['mouse'].lower() in el['slug'])])
    input_paths = np.array([in_path for in_path in input_paths if os.path.isfile(in_path)])
    print(np.sort(input_paths))
    
    if len(input_paths) == 0:
        continue
    
    else:
        date_list = [os.path.basename(path)[:10] for path in input_paths]
        mouse = parsed_search['mouse']
        
        # assemble the output path
        output_path = os.path.join(paths.analysis_path, f"AGG_{'_'.join(parsed_search.values())}.hdf5")
        
        data_list = []
        for file, date in zip(input_paths, date_list):
            data_dict = {}
            with pd.HDFStore(file, 'r') as tc:
                # print(tc.keys())
                if '/no_ROIs'in tc.keys():
                    continue
                else:
                    for key in tc.keys():
                        label = "_".join(key.split('/')[1:])

                        data = tc[key]
                        data['date'] = date
                        data_dict[label] = data
                        
                    data_list.append(data_dict)

        if len(data_list) > 0:
            # Aggregate it all
            for key in data_list[0].keys():
                df = pd.concat([d[key] for d in data_list]).reset_index(names='old_index')
                df.to_hdf(output_path, key)
                    
            # assemble the entry data
            entry_data = {
                'analysis_type': 'agg_tc',
                'analysis_path': output_path,
                'date': '',
                'pic_path': '',
                'result': parsed_search['result'],
                'rig': parsed_search['rig'],
                'lighting': parsed_search['lighting'],
                'imaging': 'wirefree',
                'slug': misc.slugify(os.path.basename(output_path)[:-5]),
            }
            
            # check if the entry already exists, if so, update it, otherwise, create it
            update_url = '/'.join((paths.bondjango_url, 'analyzed_data', entry_data['slug'], ''))
            output_entry = bd.update_entry(update_url, entry_data)
            if output_entry.status_code == 404:
                # build the url for creating an entry
                create_url = '/'.join((paths.bondjango_url, 'analyzed_data', ''))
                output_entry = bd.create_entry(create_url, entry_data)
            
            print('The output status was %i, reason %s' %
                    (output_entry.status_code, output_entry.reason))
            if output_entry.status_code in [500, 400]:
                print(entry_data)

# Load aggregate files from all mice to make plots

In [None]:
# get the search string
search_string = f"analysis_type:agg_tc, result:control, lighting:dark"
parsed_search = fdh.parse_search_string(search_string)
save_suffix = "_".join((parsed_search['result'], parsed_search['rig'], parsed_search['lighting']))

if not os.path.isdir(os.path.join(paths.wf_figures_path, save_suffix)):
    os.mkdir(os.path.join(paths.wf_figures_path, save_suffix))

figure_save_path = os.path.join(paths.wf_figures_path, save_suffix)

if parsed_search['result'] == 'repeat':
    if parsed_search['rig'] == 'VWheelWF':
        session_types = ['fixed0', 'fixed1']
    else:
        session_types = ['free0', 'free1']
    session_shorthand = session_types
else:
    session_types = ['VWheelWF', 'VTuningWF']
    session_shorthand = ['fixed', 'free']

# get the paths from the database
file_infos = bd.query_database("analyzed_data", search_string)
input_paths = np.array([el['analysis_path'] for el in file_infos if ('agg' in el['slug'])])
print(np.sort(input_paths))

data_list = []
for file in input_paths:
    data_dict = {}
    mouse = '_'.join(os.path.basename(file).split('_')[10:13])
    with pd.HDFStore(file, 'r') as tc:
        for key in tc.keys():
            label = "_".join(key.split('/')[1:])
            data = tc[key]
            data['mouse'] = mouse
            data_dict[label] = data
                
        data_list.append(data_dict)

# Aggregate it all
agg_dict = {}
 
for key in data_list[0].keys():
    df = pd.concat([d[key] for d in data_list]).reset_index(drop=True)
    agg_dict[key] = df

# Cell matches

In [None]:
match_nums = [d['cell_matches'].groupby(['date']).apply(lambda x: len(x)).values for d in data_list]

match_nums = np.concatenate([el if (len(el) !=0) else -5*np.ones(1) for el in match_nums])
frequencies, edges = np.histogram(match_nums, 40)

print('Values: %s, Edges: %s' % (frequencies.shape[0], edges.shape[0]))
cell_match_hist = hv.Histogram((edges, frequencies)).opts(xlabel='# Matched Cells', ylabel='Freq.')

save_path = os.path.join(figure_save_path, f"num_cells_matched_{save_suffix}.png")
cell_match_hist = fp.save_figure(cell_match_hist, save_path=save_path, fig_width=8, dpi=800, fontsize='screen', target='both', display_factor=0.2)

In [None]:
match_frac0 

In [None]:
match_nums = [d['cell_matches'].groupby(['date']).apply(lambda x: len(x)).values for d in data_list]
num_cells = np.zeros((len(data_list), 2))


num_cells0 = []
num_cells1 = []
match_frac0 = []
match_frac1 = []

for i, d in enumerate(data_list):
    num0 = d[f'{session_types[0]}_all_cells_norm_spikes_viewed_props'].groupby(['date']).apply(lambda x: len(x)).values
    num1 = d[f'{session_types[1]}_all_cells_norm_spikes_viewed_props'].groupby(['date']).apply(lambda x: len(x)).values

    num_cells0.append(num0)
    num_cells1.append(num1)

    match_frac0.append(match_nums[i]/num0)
    match_frac1.append(match_nums[i]/num1)

match_nums = [match for match in match_nums if match.shape[0] > 0]
match_frac0 = [match for match in match_frac0 if match.shape[0] > 0]
match_frac1 = [match for match in match_frac1 if match.shape[0] > 0]

match_nums = np.concatenate(match_nums).ravel()
match_frac0 = np.concatenate(match_frac0).ravel()
match_frac1 = np.concatenate(match_frac1).ravel()

scatter = hv.Points((match_frac0, match_frac1))
scatter.opts(xlim=(0, 1), xlabel=f'{session_shorthand[0].title()} Frac. Match',
             ylim=(0, 1), ylabel=f'{session_shorthand[1].title()} Frac. Match',
             size=10, width=400, height=400, tools=['hover'])
line = hv.Curve(np.arange(0, 1.1)).opts(color='black', alpha=0.5)
frac_cell_match_scatter = line * scatter
frac_cell_match_scatter

# frequencies0, edges0 = np.histogram(match_frac0, 40)
# frequencies1, edges1 = np.histogram(match_frac1, 40)

# # print('Values: %s, Edges: %s' % (frequencies.shape[0], edges.shape[0]))
# frac_cell_match_hist0 = hv.Histogram((edges0, frequencies0), label=session_types[0]).opts(xlabel='Frac. Matched Cells', ylabel='Freq.')
# frac_cell_match_hist1 = hv.Histogram((edges1, frequencies1), label=session_types[1]).opts(xlabel='Frac. Matched Cells', ylabel='Freq.')
# frac_cell_match_hist = frac_cell_match_hist0 * frac_cell_match_hist1

save_path = os.path.join(figure_save_path, f"frac_cells_matched_{save_suffix}.png")
frac_cell_match_scatter = fp.save_figure(frac_cell_match_scatter, save_path=save_path, fig_width=10, dpi=800, fontsize='screen', target='both', display_factor=0.2)

# Fraction postural or visual responsive

In [None]:
# Within-animal fraction kinematic or dir/ori responsive
ds_name = f'{session_types[0]}_summary_stats'
ds = agg_dict[ds_name]
kinem_cols = list(ds.columns[4:-8])
vis_cols = list(ds.columns[-8:-2])
kinem_vis_cols = kinem_cols + vis_cols[1:3] + vis_cols[4:]

if ('fixed' in ds_name) or ('VWheelWF' in ds_name):
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_fixed_kinem_cols + processing_parameters.wf_vis_cols))
else:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_free_kinem_cols + processing_parameters.wf_vis_cols))

frac_vis_resp = ds[ds.old_index == 'all_cells'].rename(columns=rename_dict)

violinplot_free = frac_vis_resp[list(rename_dict.values())].hvplot.violin(legend=False)
violinplot_free.opts(xlabel='', ylabel='Significant Frac.', ylim=(-0.05, 1.05), xrotation=45, width=1500, height=1000)

save_path = os.path.join(figure_save_path, f"sig_frac_kinem_vis_{save_suffix}_{session_types[0]}.png")
violinplot_free = fp.save_figure(violinplot_free, save_path=save_path, fig_width=15, dpi=800, fontsize='screen', target='both', display_factor=0.1)

In [None]:
# Within-animal fraction kinematic or dir/ori responsive
ds_name = f'{session_types[1]}_summary_stats'
ds = agg_dict[ds_name]
kinem_cols = list(ds.columns[4:-8])
vis_cols = list(ds.columns[-8:-2])
kinem_vis_cols = kinem_cols + vis_cols[1:3] + vis_cols[4:]

if ('fixed' in ds_name) or ('VWheelWF' in ds_name):
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_fixed_kinem_cols + processing_parameters.wf_vis_cols))
else:
    rename_dict = dict(zip(kinem_vis_cols, processing_parameters.wf_free_kinem_cols + processing_parameters.wf_vis_cols))

frac_vis_resp = ds[ds.old_index == 'all_cells'].rename(columns=rename_dict)

violinplot_fixed = frac_vis_resp[list(rename_dict.values())].hvplot.violin(legend=False, violin_fill_color='red')
violinplot_fixed.opts(xlabel='', ylabel='Significant Frac.', ylim=(-0.05, 1.05), xrotation=45, width=1500, height=1000)

save_path = os.path.join(figure_save_path, f"sig_frac_kinem_vis_{save_suffix}_{session_types[1]}.png")
violinplot_fixed = fp.save_figure(violinplot_fixed, save_path=save_path, fig_width=15, dpi=800, fontsize='screen', target='both', display_factor=0.1)

In [None]:
def get_frac_tuned(df):
    kinem_cols = list(df.columns[1:-6])
    vis_cols = list(df.columns[-6:-2])
    
    df['sum_kinem'] = df[kinem_cols].sum(axis=1)
    df['sum_vis'] = df[vis_cols].sum(axis=1)
    df['sum_mix'] = df[['sum_kinem', 'sum_vis']].sum(axis=1)
   
    frac_vis_tuned = df['sum_vis'].loc[df['sum_vis'] > 0].count() / df.shape[0]
    frac_kinem_tuned = df['sum_kinem'].loc[df['sum_kinem'] > 0].count() / df.shape[0]

    frac_only_kinem = df[(df[vis_cols].sum(axis=1) == 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_only_vis = df[(df[kinem_cols].sum(axis=1) == 0) & (df[vis_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_mix_tuned = df[(df[vis_cols].sum(axis=1) > 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]

    return frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned

In [None]:
plt_list = []
df_list = []
for ds, exp_type in zip([agg_dict[f'{session_types[0]}_multimodal_tuned'], agg_dict[f'{session_types[1]}_multimodal_tuned']], session_types):
    only_kinem = []
    only_vis = []
    vis_tuned = []
    kinem_tuned = []
    mix_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned = get_frac_tuned(group)
        only_kinem.append(frac_only_kinem)
        only_vis.append(frac_only_vis)
        vis_tuned.append(frac_vis_tuned)
        kinem_tuned.append(frac_kinem_tuned)
        mix_tuned.append(frac_mix_tuned)

    df['group'] = [exp_type] * len(vis_tuned)
    df[' Visual Only'] = only_vis
    df[' Postural Only'] = only_kinem
    df[' Visual'] = vis_tuned
    df[' Postural'] = kinem_tuned
    df['Multimodal'] = mix_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

comp_tuning = pd.concat(df_list, axis=0)
comp_tuning = pd.melt(comp_tuning, id_vars=['group'], value_vars=[' Visual Only', ' Postural Only', ' Visual', ' Postural', 'Multimodal'])

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = comp_tuning.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x[session_types[0]], x[session_types[1]], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
comp_tuning_rank_sum

In [None]:
violinplot = hv.Violin(comp_tuning, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=500,
                                                                  # ylim=(-0.05, 1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, 'red'],
                                                                  show_legend=False, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, f"frac_multimodal_{save_suffix}.png")
# violinplot
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=10, dpi=800, fontsize='screen', target='both', display_factor=0.3)

In [None]:
ax = sns.violinplot(data=comp_tuning, x="variable", y="value", hue="group")
ax.spines[['right', 'top']].set_visible(False)
ax.set_ylabel('Significant Frac.')
ax.set_xlabel('')
ax.legend(title='', loc='lower right')
plt.savefig(os.path.join(figure_save_path, f'frac_multimodal_alt_{save_suffix}.png'), dpi=800, format='png')

# Bootstrapped tuning shifts

In [None]:
def plot_compare_matched_vis_tuning(ds_list, cell, tuning_kind='direction', exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(18*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(ds_list), ncols=1, hspace=0.07)
    
        for i, subfig in enumerate(subfigs):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # direction tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    for sub_ax, ds in zip(axes, ds_list):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind, error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

def plot_compare_all_vis_tuning(ds_list, cell, tuning_kind=['direction', 'orientation'], exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(36*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=2, ncols=2, hspace=0.07)
    
        for i, subfig in enumerate(subfigs.flatten()):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i//2]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    tuning_kind = ['direction', 'orientation'] * 2
    for i, (ds, sub_ax) in enumerate(zip(ds_list, axes)):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind[i], error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

In [None]:
agg_dict[f'{session_types[0]}_matched_norm_spikes_viewed_props'].columns

In [None]:
fixed_matched = [f'{session_types[0]}_matched_norm_spikes_viewed_props',
                 f'{session_types[0]}_matched_norm_spikes_viewed_still_props']
free_matched = [f'{session_types[1]}_matched_norm_spikes_viewed_props',
                f'{session_types[1]}_matched_norm_spikes_viewed_still_props']
datasets = [f'{session_types[0]}_matched_norm_spikes_viewed_props', f'{session_types[1]}_matched_norm_spikes_viewed_props']

In [None]:
ref_ds = agg_dict[datasets[0]]
strict_index = ref_ds.loc[(ref_ds.osi > 0.75)].index
super_strict_index = ref_ds.loc[(ref_ds.osi >= 0.75) & (ref_ds.bootstrap_p_osi >= 0.95)].index

In [None]:
tuning_shifts = []

for (idxRow, cell_fixed), (_, cell_free) in zip(agg_dict[datasets[0]].iloc[strict_index, :].iterrows(), agg_dict[datasets[1]].iloc[strict_index, :].iterrows()):

    resultant_fixed = fk.wrap(cell_fixed.bootstrap_resultant_ori[:, -1], bound=180)
    resultant_fixed = resultant_fixed[~np.isnan(resultant_fixed)]
    # ci_fixed = st.t.interval(alpha=0.95, df=len(resultant_fixed)-1, loc=np.mean(resultant_fixed), scale=st.sem(resultant_fixed)) 
    ci_fixed = st.norm.interval(confidence=0.95, loc=np.mean(resultant_fixed), scale=st.sem(resultant_fixed)) 

    del_ci_fixed = ci_fixed[-1] - ci_fixed[0]
    
    resultant_free = fk.wrap(cell_free.bootstrap_resultant_ori[:, -1], bound=180)
    resultant_free = resultant_free[~np.isnan(resultant_free)]
    # ci_free = st.t.interval(alpha=0.95, df=len(resultant_free)-1, loc=np.mean(resultant_free), scale=st.sem(resultant_free)) 
    ci_free = st.norm.interval(confidence=0.95, loc=np.mean(resultant_free), scale=st.sem(resultant_free)) 
    del_ci_free = ci_free[-1] - ci_free[0]
    
    # Check if orientation tuned or not
    if del_ci_fixed > 20:
        # The cell is not orientation tuned
        del_po = np.nan
        sig_shift = 0
    else:
        # Get delta preferred orientation
        po_fixed = fk.wrap(cell_fixed.resultant_ori[-1], bound=180)
        po_free = fk.wrap(cell_free.resultant_ori[-1], bound=180)
        if (np.isnan(po_fixed)) or (np.isnan(po_free)):
            po_fixed = cell_fixed.pref_ori
            po_free = cell_free.pref_ori
        del_po = po_free - po_fixed
        del_po_wrapped = fk.wrap(del_po, bound=180)
    
        # determine significance of shift
        if ~(ci_fixed[0] <= po_free <= ci_fixed[1]) and ~(ci_free[0] <= po_fixed <= ci_free[1]):
            # Shift is significant
            sig_shift = 1
        else:
            # Shift is not significant    
            sig_shift = 0
    
    tuning_shifts.append([idxRow, po_fixed, po_free, del_po, del_po_wrapped, sig_shift, cell_fixed.mouse, cell_fixed.date])

tuning_shifts = pd.DataFrame(data=tuning_shifts, columns=['', f'po_{session_shorthand[0]}', f'po_{session_shorthand[1]}', 'delta_po', 'delta_po_wrapped', 'is_sig_po', 'mouse', 'date'])
tuning_shifts = tuning_shifts.set_index(tuning_shifts.columns[0])

In [None]:
scatter = hv.Points(tuning_shifts[[f'po_{session_shorthand[0]}', f'po_{session_shorthand[1]}']].to_numpy())
scatter.opts(xlim=(-1, 181), xlabel=f'{session_shorthand[0].title()} Pref. Ori.', xticks=[0, 45, 90, 135, 180],
             ylim=(-1, 181), ylabel=f'{session_shorthand[1].title()} Pref. Ori.', yticks=[0, 45, 90, 135, 180],
             size=10, width=400, height=400)
line = hv.Curve(np.arange(0, 180)).opts(color='black', alpha=0.5)
shift_po_plot = line * scatter
# line*scatter.hist(dimension=['x','y'], num_bins=18)
save_path = os.path.join(figure_save_path, f"shift_po_{save_suffix}.png")
shift_po_plot = fp.save_figure(shift_po_plot, save_path=save_path, fig_width=10, dpi=800, fontsize='screen', target='both', display_factor=0.2)

In [None]:
bins = np.arange(0, 1.01, 0.05)
plt.hist([agg_dict[datasets[0]].responsivity_ori, agg_dict[datasets[1]].responsivity_ori], bins=bins, edgecolor="black")
plt.legend(session_shorthand)
plt.axvline(0.25, c='r')
plt.title('responsivity')

In [None]:
tuning_shifts.groupby('mouse').apply(lambda x: x.is_sig_po.count())

In [None]:
overall_sig_po_shift_frac = tuning_shifts.is_sig_po.sum() / tuning_shifts.is_sig_po.count() 
print(overall_sig_po_shift_frac)
tuning_shifts.groupby('mouse').apply(lambda x: x.is_sig_po.sum() / x.is_sig_po.count())

In [None]:
frequencies, edges = np.histogram(tuning_shifts[tuning_shifts.is_sig_po == 1].delta_po.to_numpy(), 36)
hv.Histogram((edges, frequencies)).opts(width=600)

In [None]:
x = tuning_shifts[tuning_shifts.is_sig_po == 1].groupby('mouse').apply(lambda x: np.histogram(x.delta_po.to_numpy(), 36)).to_list()
layout = hv.Layout([hv.Histogram((edges, frequencies)).opts(width=400) for frequencies, edges in x]).opts(shared_axes=False)
layout

In [None]:
%matplotlib inline
datasets = [f'{session_types[0]}_matched_norm_spikes_viewed_still_props', f'{session_types[1]}_matched_norm_spikes_viewed_still_props']

matched_vis_fig = interactive(plot_compare_all_vis_tuning,
                              ds_list=widgets.fixed([agg_dict[dataset] for dataset in datasets]),
                              cell=tuning_shifts[tuning_shifts.is_sig_po == 1].index, 
                              tuning_kind = widgets.fixed(['direction', 'orientation']),
                              exp_name = widgets.fixed(session_types),
                              error = ['std', 'sem'],
                              polar=[True, False],
                              norm=widgets.fixed(True),
                              axes=widgets.fixed(None))
matched_vis_fig

In [None]:
tuning_shifts.loc[67]

In [None]:
import pycircstat as circ

# DSI/OSI Distributions

In [None]:
free_matched = [f'{session_types[1]}_matched_norm_spikes_viewed_props',
                f'{session_types[1]}_matched_norm_spikes_viewed_still_props']
fixed_matched = [f'{session_types[0]}_matched_norm_spikes_viewed_props',
                 f'{session_types[0]}_matched_norm_spikes_viewed_still_props']
matched_free_dir = free_matched[0:1]
matched_fixed_dir = fixed_matched[0:1]

In [None]:
ds_list_matched_free = [agg_dict[d] for d in matched_free_dir]
ds_list_matched_fixed = [agg_dict[d] for d in matched_fixed_dir]
cols = ['osi', 'dsi_abs']

## Matched Cells

In [None]:
plt_list = []
df_list = []
ds_list = [agg_dict[f'{s_type}_matched_norm_spikes_viewed_props'] for s_type in session_types]
for ds, exp_type in zip(ds_list, session_shorthand):
    frac_ori_tuned = []
    frac_dir_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        ori_tuned = np.abs(group.osi) > 0.5
        dir_tuned = np.abs(group.dsi_abs) > 0.5

        # ori_tuned = (np.abs(group.osi) > 0.25) & (np.abs(group.dsi_nasal_temporal) <= 0.25)
        # dir_tuned = (np.abs(group.osi) <= 0.25) & (np.abs(group.dsi_nasal_temporal) > 0.25)
        frac_ori_tuned.append(ori_tuned.sum() / ori_tuned.count())
        frac_dir_tuned.append(dir_tuned.sum() / dir_tuned.count())


    df['group'] = [exp_type] * len(frac_ori_tuned)
    df['Ori. Tuned'] = frac_ori_tuned
    df['Dir. Tuned'] = frac_dir_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

matched_dsi_osi = pd.concat(df_list, axis=0)
matched_dsi_osi= pd.melt(matched_dsi_osi, id_vars=['group'], value_vars=['Ori. Tuned', 'Dir. Tuned'])

violinplot_matched_dsi_osi = hv.Violin(matched_dsi_osi, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=600,
                                                                  ylim=(0, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "matched_frac_ori_dir_tuned.png")
# violinplot
violinplot_matched_dsi_osi = fp.save_figure(violinplot_matched_dsi_osi, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='screen', display_factor=0.1)

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = matched_dsi_osi.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x[session_shorthand[0]], x[session_shorthand[1]], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
print(comp_tuning_stats)
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
plt_list = []
df_list = []
ds_list = [agg_dict[f'{s_type}_matched_norm_spikes_viewed_props'] for s_type in session_types]
for ds, exp_type in zip(ds_list, session_shorthand):
    df = pd.DataFrame()
    df['OSI'] = ds.osi.abs().to_numpy()
    df['DSI'] = ds.dsi_abs.abs().to_numpy()
    df['group'] = [exp_type] * len(ds.osi)
    df_list.append(df)
    
    # plt_list.append(df.hvplot.box(legend=False))

matched_selectivity = pd.concat(df_list, axis=0)
matched_selectivity = pd.melt(matched_selectivity, id_vars=['group'], value_vars=['OSI', 'DSI'])

violinplot_selectivity_matched = hv.Violin(matched_selectivity, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=600,
                                                                  ylim=(-0.1, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "matched_selectivity.png")
violinplot_selectivity_matched = fp.save_figure(violinplot_selectivity_matched, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='screen', display_factor=0.1)

In [None]:
def replace_inf(x):
    x[x > np.percentile(x, 99)] = np.percentile(x, 99)
    return x

a = matched_selectivity[matched_selectivity.variable == 'OSI'].fillna(0).drop(columns='variable')
a = a.groupby('group').agg(list).T.reset_index(drop=True)


frequencies_free_osi, edges_free_osi = np.histogram(np.clip(a['Freely Moving'][0], 0, 1)	, 20)
frequencies_fixed_osi, edges_fixed_osi = np.histogram(np.clip(a['Head Fixed'][0], 0, 1), 20)

cell_match_hist_osi = hv.Overlay([hv.Histogram((frequencies_free_osi, edges_free_osi), label='Freely Moving').opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_osi, edges_fixed_osi), label='Head Fixed').opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

b = matched_selectivity[matched_selectivity.variable == 'DSI'].fillna(0).drop(columns='variable')
b = b.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_dsi, edges_free_dsi = np.histogram(np.clip(b['Freely Moving'][0], 0, 1), 20)
frequencies_fixed_dsi, edges_fixed_dsi = np.histogram(np.clip(b['Head Fixed'][0], 0, 1), 20)

cell_match_hist_dsi = hv.Overlay([hv.Histogram((frequencies_free_dsi, edges_free_dsi)).opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_dsi, edges_fixed_dsi)).opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

layout_matched_still = cell_match_hist_dsi.opts(height=300, width=400, xlabel='DSI') + cell_match_hist_osi.opts(height=300, width=500, xlabel='OSI', ylabel='', legend_position='right', fontsize={'legend': 10})
layout_matched_still
# save_path = os.path.join(figure_save_path, "Fig4", "unmatched_dsi_osi_hist.png")
# layout_matched_still = fp.save_figure(layout_matched_still, save_path=save_path, fig_width=18, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

## Unmatched Cells

In [None]:
plt_list = []
df_list = []
ds_list = [agg_dict[f'{s_type}_unmatched_norm_spikes_viewed_props'] for s_type in session_types]
for ds, exp_type in zip(ds_list, session_shorthand):
    frac_ori_tuned = []
    frac_dir_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        ori_tuned = np.abs(group.osi) > 0.75
        dir_tuned = np.abs(group.dsi_abs) > 0.75
        frac_ori_tuned.append(ori_tuned.sum() / ori_tuned.count())
        frac_dir_tuned.append(dir_tuned.sum() / dir_tuned.count())


    df['group'] = [exp_type] * len(frac_ori_tuned)
    df['Ori. Tuned'] = frac_ori_tuned
    df['Dir. Tuned'] = frac_dir_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

unmatched_dsi_osi = pd.concat(df_list, axis=0)
unmatched_dsi_osi= pd.melt(unmatched_dsi_osi, id_vars=['group'], value_vars=['Ori. Tuned', 'Dir. Tuned'])

violinplot_dsi_osi_unmatched = hv.Violin(unmatched_dsi_osi, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=400,
                                                                  ylim=(0, 1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "unmatched_frac_ori_dir_tuned.png")
# violinplot
violinplot_dsi_osi_unmatched = fp.save_figure(violinplot_dsi_osi_unmatched, save_path=save_path, fig_width=16, dpi=1000, fontsize='screen', target='screen', display_factor=0.1)

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = unmatched_dsi_osi.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x['Freely Moving'], x['Head Fixed'], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
print(comp_tuning_stats)
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
plt_list = []
df_list = []
ds_list = [agg_dict[f'{s_type}_unmatched_norm_spikes_viewed_props'] for s_type in session_types]
for ds, exp_type in zip(ds_list, session_shorthand):
    df = pd.DataFrame()
    df['OSI'] = ds.osi.abs().to_numpy()
    df['DSI'] = ds.dsi_abs.abs().to_numpy()
    df['group'] = [exp_type] * len(ds.osi)
    df_list.append(df)
    
    # plt_list.append(df.hvplot.box(legend=False))

unmatched_selectivity = pd.concat(df_list, axis=0)
unmatched_selectivity = pd.melt(unmatched_selectivity, id_vars=['group'], value_vars=['OSI', 'DSI'])

violinplot_selectivity_unmatched = hv.Violin(unmatched_selectivity, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=400,
                                                                  ylim=(-0.1, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "unmatched_selectivity.png")
violinplot_selectivity_unmatched = fp.save_figure(violinplot_selectivity_unmatched, save_path=save_path, fig_width=16, dpi=1000, fontsize='poster', target='screen', display_factor=0.1)

In [None]:
a = unmatched_selectivity[unmatched_selectivity.variable == 'OSI'].fillna(0).drop(columns='variable')
a = a.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_osi, edges_free_osi = np.histogram(a['Freely Moving'][0], 20)
frequencies_fixed_osi, edges_fixed_osi = np.histogram(a['Head Fixed'][0], 20)

unmatched_hist_osi = hv.Overlay([hv.Histogram((frequencies_free_osi, edges_free_osi), label='Freely Moving').opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_osi, edges_fixed_osi), label='Head Fixed').opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

b = unmatched_selectivity[unmatched_selectivity.variable == 'DSI'].fillna(0).drop(columns='variable')
b = b.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_dsi, edges_free_dsi = np.histogram(b['Freely Moving'][0], 20)
frequencies_fixed_dsi, edges_fixed_dsi = np.histogram(b['Head Fixed'][0], 20)

unmatched_hist_dsi = hv.Overlay([hv.Histogram((frequencies_free_dsi, edges_free_dsi)).opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_dsi, edges_fixed_dsi)).opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

layout_unmatched = unmatched_hist_dsi.opts(height=300, width=400, xlabel='DSI') + unmatched_hist_osi.opts(height=300, width=500, xlabel='OSI', ylabel='', legend_position='right', fontsize={'legend': 10})
layout_unmatched

## Compare matched/unmatched strongly tuned

In [None]:
matched_selectivity['super_group'] = 'matched'
unmatched_selectivity['super_group'] = 'unmatched'
all_selectivity = pd.concat([matched_selectivity, unmatched_selectivity]).fillna(0)
violinplot_selectivity = hv.Violin(all_selectivity, ['variable', 'group', 'super_group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=200,
                                                                  ylim=(-0.05, 1.05),
                                                                  violin_color=hv.dim('super_group'),
                                                                  cmap='Category10',
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )

# Do a Wilcoxon Rank-Sum  with Bonferroni correction
rank_sum = all_selectivity.groupby(['variable', 'group', 'super_group'], group_keys=True).value.agg(list).unstack(level=-1)
stats = rank_sum.apply(lambda x: st.ranksums(x['matched'], x['unmatched'], alternative='two-sided'), axis=1)
p_vals = [stats.to_numpy()[i][1] for i in np.arange(stats.shape[0])]
print(p_vals)
print(multipletests(p_vals, alpha=0.05, method='bonferroni'))

save_path = os.path.join(figure_save_path, "selectivity_dist_comp.png")
violinplot_selectivity = fp.save_figure(violinplot_selectivity, save_path=save_path, fig_width=30, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)

In [None]:
all_selectivity

In [None]:
save_paths = []
selectivity_array = all_selectivity.groupby(['variable', 'group', 'super_group']).agg(list).values
num_comps = a.shape[0]/2

selectivity_plot_list = []
for i, (a, b) in enumerate(zip(selectivity_array[::2], selectivity_array[1::2])):
    save_path = os.path.join(figure_save_path, f"selectivity_dist_comp_histogram_{i}.png")
    
    frequencies_matched, edges_matched = np.histogram(np.clip(a[0], 0, 1), 20)
    frequencies_matched = frequencies_matched.astype(float) / np.max(frequencies_matched)
    frequencies_unmatched, edges_unmatched = np.histogram(np.clip(b[0], 0, 1), 20)
    frequencies_unmatched = frequencies_unmatched.astype(float) / np.max(frequencies_unmatched.astype(float))

    matched_hist = hv.Histogram((frequencies_matched, edges_matched), label='matched').opts(alpha=0.5, xlabel='', width=800, height=600, fill_color=fp.hv_mpi_green_rgb, ylabel="Probability")
    unmatched_hist = hv.Histogram((frequencies_unmatched, edges_unmatched), label='unmatched').opts(alpha=0.5, xlabel='', width=800, height=600, fill_color=fp.hv_mpi_yellow_rgb)
    
    if i > 0:
        matched_hist.opts(yaxis=None, width=600)
        unmatched_hist.opts(yaxis=None, width=600)
    
    hist_overlay = hv.Overlay([matched_hist, unmatched_hist])
    hist_overlay.opts(show_legend=False)

    if i == 0:
        hist_overlay = fp.save_figure(hist_overlay, save_path=save_path, fig_width=8, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)
    else:
        hist_overlay = fp.save_figure(hist_overlay, save_path=save_path, fig_width=6, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)

    selectivity_plot_list.append(hist_overlay)

selectivity_layout = hv.Layout(selectivity_plot_list)
selectivity_layout

In [None]:
matched_dsi_osi['super_group'] = 'matched'
unmatched_dsi_osi['super_group'] = 'unmatched'
all_dsi_osi = pd.concat([matched_dsi_osi, unmatched_dsi_osi]).fillna(0)
violinplot = hv.Violin(all_dsi_osi, ['variable', 'group', 'super_group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=200,
                                                                  ylim=(-0.05, 1),
                                                                  violin_color=hv.dim('super_group'),
                                                                  cmap=[fp.hv_mpi_green_hex, fp.hv_mpi_yellow_hex],
                                                                  fontsize={'legend': 10}
                                                                 )

# Do a Wilcoxon Rank-Sum  with Bonferroni correction
rank_sum = all_dsi_osi.groupby(['variable', 'group', 'super_group'], group_keys=True).value.agg(list).unstack(level=-1)
stats = rank_sum.apply(lambda x: st.ranksums(x['matched'], x['unmatched'], alternative='two-sided'), axis=1)
p_vals = [stats.to_numpy()[i][1] for i in np.arange(stats.shape[0])]
print(p_vals)
print(multipletests(p_vals, alpha=0.05, method='bonferroni'))

save_path = os.path.join(figure_save_path, "sig_frac_selectivity_dsi_osi.png")
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=30, dpi=1200, fontsize='poster', target='both', display_factor=0.1)

In [None]:
sig_dsi_osi_array = all_dsi_osi.groupby(['variable', 'group', 'super_group']).agg(list).values
num_comps = a.shape[0]/2

dsi_osi_plot_list = []
for a, b in zip(sig_dsi_osi_array[::2], sig_dsi_osi_array[1::2]):
    frequencies_matched, edges_matched = np.histogram(a[0], 20)
    frequencies_unmatched, edges_unmatched = np.histogram(b[0], 20)

    hist_overlay =  hv.Overlay([hv.Histogram((frequencies_unmatched, edges_unmatched), label='unmatched').opts(alpha=0.5, fill_color=fp.hv_orange_rgb), 
                                hv.Histogram((frequencies_matched, edges_matched), label='matched').opts(alpha=0.5, fill_color=fp.hv_blue_rgb)])

    dsi_osi_plot_list.append(hist_overlay)

dsi_osi_layout = hv.Layout(dsi_osi_plot_list)
dsi_osi_layout

# UMAP

In [None]:
from umap.umap_ import UMAP
from rastermap import Rastermap
import sklearn.preprocessing as preproc

In [None]:
cell_subset = '_matched'
kinem_label_list = processing_parameters.variable_list_free + processing_parameters.variable_list_fixed
labels = kinem_label_list + ['norm_spikes_viewed_props']
labels = ['_'.join((cell_subset, label)) for label in labels]

data_dict = {}
for label in labels:
    agg_keys = [key for key in agg_dict.keys() if label in key]

    for key in agg_keys:
        ds = agg_dict[key]
        base_label = '_'.join(label.split('_')[len(cell_subset.split('_')):])
        if base_label in kinem_label_list: 
            tuning = ds['Qual_index']
            data_dict[base_label] = tuning
        else:
            tuning_dsi = ds['dsi_abs']
            tuning_osi = ds['osi']
            data_dict['dsi'] = np.clip(tuning_dsi, 0, 1)
            data_dict['osi'] = np.clip(tuning_osi, 0, 1)


raw_tunings = pd.DataFrame.from_dict(data_dict)
raw_tunings = raw_tunings.fillna(0)

In [None]:
tunings = preproc.StandardScaler().fit_transform(raw_tunings.to_numpy())

In [None]:
tunings.shape

In [None]:
tuning_subset = tunings[np.random.choice(tunings.shape[0], tunings.shape[0], replace=False)]

In [None]:
# perform umap on the fit cell tuning
reducer1 = UMAP(min_dist=0.1, n_neighbors=20)
embedded_data1 = reducer1.fit_transform(tuning_subset)

In [None]:
perc = 99
predictor_columns = data_dict.keys()    #['dsi', 'osi', 'responsivity']
plot_list = []

for predictor_column in predictor_columns:
    label_idx = [idx for idx, el in enumerate(predictor_columns) if predictor_column == el]
    raw_labels = tuning_subset[:, label_idx]
    
    raw_labels = np.abs(raw_labels)
    
    raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
    raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)
    
    plot_data = np.concatenate([embedded_data1, raw_labels.reshape((-1, 1))], axis=1)

    umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
    # umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
    umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None, yaxis=None, tools=['hover'])
    umap_plot.opts(width=300, height=300, size=2, title=processing_parameters.wf_label_dictionary[predictor_column])

    save_name = os.path.join(figure_save_path, f"UMAP_{predictor_column}_{cell_subset}.png")   
    # umap_plot = fp.save_figure(umap_plot, save_path=save_name, fig_width=6, dpi=1000, fontsize='screen', target='save', display_factor=0.1)
    plot_list.append(umap_plot)

In [None]:
layout = hv.Layout(plot_list).cols(5)
layout

In [None]:
predictor_columns = data_dict.keys()
target_field = 'dsi'


label_idx = [idx for idx, el in enumerate(predictor_columns) if target_field == el]
raw_labels = tuning_subset[:, label_idx]

raw_labels = np.abs(raw_labels)

raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)

plot_data = np.concatenate([embedded_data1, raw_labels.reshape((-1, 1))], axis=1)

In [None]:
# umap plot of the fit cell tunings
umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
# umap_plot = hv.HexTiles(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
umap_plot.opts(colorbar=True, color='Parameter', cmap='Spectral_r', tools=['hover'], alpha=1)
umap_plot.opts(width=500, height=500, size=5)

# assemble the file name
# save_name = os.path.join(save_path, '_'.join(('poster', 'UMAP')) + '.png')
# # save the figure
# fig = fp.save_figure(umap_plot, save_name, fig_width=15, dpi=1200, fontsize='poster', target='screen')

In [None]:
# plot the tunings from MINE
ticks = [(idx+0.5, processing_parameters.wf_label_dictionary[el]) for idx, el in enumerate(predictor_columns)]
plot_matrix = raw_tunings.dropna().to_numpy().copy().T
plot_matrix[plot_matrix<0.05] = 0
model = Rastermap(n_clusters=2, n_PCs=200)
model.fit(plot_matrix)
plot_matrix = plot_matrix[model.isort, :]
plot = hv.Raster(plot_matrix)
plot.opts(width=1000, height=600, cmap='RdBu_r', tools=['hover'], clim=(-1, 1), xticks=ticks, xrotation=45, xlabel='', ylabel='Cells', colorbar=True)

# assemble the file name
save_name = os.path.join(save_path, '_'.join(('Fig5', 'MINE_tunings')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=7, dpi=1200, fontsize='poster', target='screen')