In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import numpy as np
import pandas as pd
import scipy.stats as st
import seaborn as sns
import holoviews as hv
import hvplot.pandas
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact, interactive
from statsmodels.stats.multitest import multipletests

import os
import sys
import importlib
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from bokeh.plotting import show
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
cm = 1./2.54
figure_save_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\thesis\figures\figure_media"

# Load aggregate files from all mice to make plots

In [None]:
# If saving file for the first time, set save to True
save = False
agg_dict = {}
search_string = f"result:repeat, lighting:normal"


if save:
    # get the search string
    search_string += f", analysis_type:agg_tc"
    parsed_search = fdh.parse_search_string(search_string)
    output_path = os.path.join(paths.analysis_path, f"AGG_{'_'.join(parsed_search.values())}.hdf5")

    # get the paths from the database
    file_infos = bd.query_database("analyzed_data", search_string)
    input_paths = np.array([el['analysis_path'] for el in file_infos if ('agg' in el['slug'])])

    data_list = []
    for file in tqdm(input_paths, desc="Loading files"):
        print(file)
        data_dict = {}
        mouse = '_'.join(os.path.basename(file).split('_')[10:13])

        with pd.HDFStore(file, 'r') as tc:

            if 'no_ROIs'in tc.keys():
                continue

            else:
                
                for key in tc.keys():
                    label = "_".join(key.split('/')[1:])
                    data = tc[key]
                    if 'animal' in data.columns:
                        data = data.drop(columns='animal')

                    if '08_31_2023_VWheelWF_fixed2' in data.columns:
                        data.drop(columns='08_31_2023_VWheelWF_fixed2', inplace=True)

                    data['mouse'] = mouse
                    data_dict[label] = data
                        
                data_list.append(data_dict)

    # Aggregate it all
    agg_dict = {}
    
    for key in data_list[0].keys():
        df = pd.concat([d[key] for d in data_list]).reset_index(drop=True)
        df.to_hdf(output_path, key)
        agg_dict[key] = df

    # assemble the entry data
    entry_data = {
        'analysis_type': 'agg_all',
        'analysis_path': output_path,
        'date': '',
        'pic_path': '',
        'result': parsed_search['result'],
        'rig': parsed_search['rig'],
        'lighting': parsed_search['lighting'],
        'imaging': 'wirefree',
        'slug': misc.slugify(os.path.basename(output_path)[:-5]),
        }

    # check if the entry already exists, if so, update it, otherwise, create it
    update_url = '/'.join((paths.bondjango_url, 'analyzed_data', entry_data['slug'], ''))
    output_entry = bd.update_entry(update_url, entry_data)
    if output_entry.status_code == 404:
        # build the url for creating an entry
        create_url = '/'.join((paths.bondjango_url, 'analyzed_data', ''))
        output_entry = bd.create_entry(create_url, entry_data)

    print('The output status was %i, reason %s' % (output_entry.status_code, output_entry.reason))
    if output_entry.status_code in [500, 400]:
        print(entry_data)

else:
    search_string += f", analysis_type:agg_all"
    parsed_search = fdh.parse_search_string(search_string)

    file_infos = bd.query_database("analyzed_data", search_string)
    input_paths = np.array([el['analysis_path'] for el in file_infos])
    print(np.sort(input_paths))

    with pd.HDFStore(input_paths[0], 'r') as tc:
        for key in tc.keys():
            label = "_".join(key.split('/')[1:])
            data = tc[key]
            agg_dict[label] = data

if parsed_search['result'] == 'repeat':
    if parsed_search['rig'] == 'VWheelWF':
        session_types = ['fixed0', 'fixed1']
    else:
        session_types = ['free0', 'free1']
    session_shorthand = session_types
else:
    session_types = ['VWheelWF', 'VTuningWF']
    session_shorthand = ['fixed', 'free']

save_suffix = f"{parsed_search['result']}_{parsed_search['lighting']}_{parsed_search['rig']}"

# Specify the path to the curated cell matches file
excel_file_path = os.path.join(r"C:\Users\mmccann\Desktop", 
                               f"curated_cell_matches_{parsed_search['result']}_{parsed_search['lighting']}_{parsed_search['rig']}.xlsx")

try:
    # Read all sheets into a list of dataframes
    curated_matches_dict = pd.read_excel(excel_file_path, sheet_name=None)

    # Concatenate the dataframes into a single dataframe
    curated_matches = pd.concat(curated_matches_dict.values(), ignore_index=True)
except:
    print(f"Could not find the file {excel_file_path}")

In [None]:
cell_kind = 'all_cells'
activity_dataset = 'norm_spikes_viewed_props'

# Cell matches

In [None]:
match_nums = agg_dict['cell_matches'].groupby(['mouse', 'day']).apply(lambda x: x.loc[:, session_types[0]].count()).values

num0 = agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'].groupby(['mouse', 'day']).apply(lambda x: len(x)).values
num1 = agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'].groupby(['mouse', 'day']).apply(lambda x: len(x)).values

match_frac0 = match_nums/num0
match_frac1 = match_nums/num1

frequencies0, edges0 = np.histogram(match_frac0, 20)
frequencies1, edges1 = np.histogram(match_frac1, 20)

# print('Values: %s, Edges: %s' % (frequencies.shape[0], edges.shape[0]))
frac_cell_match_hist0 = hv.Histogram((edges0, frequencies0), label=session_shorthand[0]).opts(xlabel='Frac. Matched Cells', ylabel='Freq.', alpha=0.5)
frac_cell_match_hist1 = hv.Histogram((edges1, frequencies1), label=session_shorthand[1]).opts(xlabel='Frac. Matched Cells', ylabel='Freq.', alpha=0.5)

save_path = os.path.join(figure_save_path, f"frac_cells_matched_{parsed_search['result']}_{parsed_search['rig']}.png")

frac_cell_match_overlay = frac_cell_match_hist0 * frac_cell_match_hist1
frac_cell_match_overlay = fp.save_figure(frac_cell_match_overlay, save_path=save_path, fig_width=10, dpi=800, fontsize='screen', target='screen', display_factor=0.2)

In [None]:
scatter = hv.Points((match_frac0, match_frac1))
scatter.opts(xlim=(0, 1), xlabel=f'Frac. Match {session_shorthand[0].title()}',
             ylim=(0, 1), ylabel=f'Frac. Match {session_shorthand[1].title()}',
             size=10, color='blue', width=500, height=500)
line = hv.Curve((np.linspace(0, 1, 101), np.linspace(0, 1, 101))).opts(color='gray')
frac_cell_match_scatter = line * scatter
frac_cell_match_scatter

save_path = os.path.join(figure_save_path, f"frac_cells_matched_{save_suffix}.png")
frac_cell_match_scatter = fp.save_figure(frac_cell_match_scatter, save_path=save_path, fig_width=5, dpi=500, fontsize='paper', target='both', display_factor=0.4)

# Fraction visually responsive

In [None]:
def violin_swarm(ds, save_path, backend='hvplot', save=False, cmap='blue', 
                 xlabel='', ylabel='',
                 width=1500, height=1000, font_size='screen', dpi=800):

    rename_dict = dict(zip(list(ds.columns), [processing_parameters.wf_label_dictionary_wo_units[col] for col in list(ds.columns)]))

    ds = ds.rename(columns=rename_dict)

    if backend=='hvplot':
        violinplot = ds[list(rename_dict.values())].hvplot.violin(legend=False, inner='quartiles', color=cmap)
        violinplot.opts(xlabel=xlabel, ylabel=ylabel, ylim=(-0.05, 1.05), xrotation=45, width=width, height=height)
        if save:
            violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=width, dpi=dpi, fontsize='screen', target='both', display_factor=0.1)
        else:
            violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=width, dpi=dpi, fontsize='screen', target='screen', display_factor=0.1)
        return violinplot
    
    elif backend=='seaborn':
        swarm_palette = {k:'k' for k in rename_dict.values()}
        fig, ax = plt.subplots(figsize=(width, height))
        violinplot = sns.violinplot(data=ds[list(rename_dict.values())], color=cmap, native_scale=True, width=1)
        violinplot = sns.stripplot(data=ds[list(rename_dict.values())], size=2, palette=swarm_palette, marker="x", linewidth=1)
        ax.set_ylim((-0.05, 1.05))
        violinplot.spines[['right', 'top']].set_visible(False)
        font_size = int(fp.font_sizes_raw[font_size]['xlabel'][:-2])
        violinplot.set_xlabel(xlabel, fontsize=font_size)
        violinplot.set_ylabel(ylabel, fontsize=font_size)
        plt.xticks(rotation=45)
        plt.tight_layout()

        if save:
            plt.savefig(save_path, dpi=dpi, format='png')

        return violinplot
    else:
        return Exception('Invalid backend')
    
def hv_hist(ds, key, label, drop_na=True, xlabel=''):
    data = ds[key].copy()

    if drop_na:
        data.replace([np.inf, -np.inf], np.nan, inplace=True)
        data.dropna(inplace=True)
        data = data[data >= 0]
        
    frequencies, edges = np.histogram(data, 20)
    hist = hv.Histogram((edges, frequencies), label=label).opts(xlabel=xlabel, ylabel='Freq.')
    return hist

In [None]:
def vis_frac_responsive(ds):
    is_ori_resp = ds['osi'] > 0.5
    is_dir_resp = ds['dsi_abs'] > 0.5
    frac_ori_resp = is_ori_resp.sum() / is_ori_resp.count()
    frac_dir_resp = is_dir_resp.sum() / is_dir_resp.count()

    is_vis_resp = is_ori_resp + is_dir_resp
    is_vis_resp = is_vis_resp > 0
    frac_vis_resp = is_vis_resp.sum() / is_vis_resp.count()

    return is_vis_resp, frac_vis_resp, frac_ori_resp, frac_dir_resp

def get_sig_tuned_vis_cells(agg_dict, exp_kind, which_cells):
    ds_name = '_'.join([exp_kind, which_cells, 'norm_spikes_viewed_props'])
    data = agg_dict[ds_name]
    df = pd.DataFrame(columns=['is_vis_resp', 'frac_vis_resp', 'frac_ori_resp', 'frac_dir_resp'])

    # Assign direction vs orientation selectivity
    data['is_dir_seleective'] = data['dsi_abs'] > 0.3

    return df

## Responsivity to orientation or direction stimuli

In [None]:
resp_dir_fixed = hv_hist(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], 'responsivity_dir', 'fixed', xlabel='responsivity')
resp_dir_free = hv_hist(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], 'responsivity_dir', 'free', xlabel='responsivity')
overlay_resp_dir = resp_dir_free.opts(alpha=0.5) * resp_dir_fixed.opts(alpha=0.5) 

resp_ori_fixed = hv_hist(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], 'responsivity_ori', 'fixed', xlabel='responsivity')
resp_ori_free = hv_hist(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], 'responsivity_ori', 'free', xlabel='responsivity')
overlay_resp_ori = resp_ori_free.opts(alpha=0.5) * resp_ori_fixed.opts(alpha=0.5) 

responsivity_hists = overlay_resp_dir.opts(title='Resp. Dir.', width=500) + overlay_resp_ori.opts(title='Resp. Ori.', width=500)
responsivity_hists

## Distributions of direction and orientation selectivity

In [None]:
dsi_fixed = hv_hist(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], 'dsi_abs', 'fixed', xlabel='selectivity')
dsi_free = hv_hist(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], 'dsi_abs', 'free', xlabel='selectivity')
overlay_dsi = dsi_free.opts(alpha=0.5) * dsi_fixed.opts(alpha=0.5)

osi_fixed = hv_hist(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], 'osi', 'fixed', xlabel='selectivity')
osi_free = hv_hist(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], 'osi', 'free', xlabel='selectivity')
overlay_osi = osi_free.opts(alpha=0.5) * osi_fixed.opts(alpha=0.5)

selctivity_hists = overlay_dsi.opts(title='DSI', width=500) + overlay_osi.opts(title='OSI', width=500)
selctivity_hists

## Define cells as direction or orientation tuned
This means that a cell must be responsive to direction or orientation (resp > 0.3) and have a DSI or OSI > 0.3

In [None]:
def get_vis_tuned_cells(ds, vis_stim='dir', resp_thresh=0.3, sel_tresh=0.5, drop_na=True):
    data = ds.copy()

    if vis_stim == 'dir':
        resp_type = f'responsivity_{vis_stim}'
        sel_type = 'dsi_abs'
    elif vis_stim == 'ori':
        resp_type = f'responsivity_{vis_stim}'
        sel_type = 'osi'
    elif (vis_stim == 'vis') or (vis_stim == 'untuned') :
        resp_type = ['responsivity_dir', 'responsivity_ori']
    else:
        return Exception('Invalid vis_stim')
    
    if vis_stim == 'vis':
        cells = data[(data[resp_type[0]].abs() >= resp_thresh) & (data[resp_type[1]].abs() >= resp_thresh)]
    elif vis_stim == 'untuned':
        cells = data[(data[resp_type[0]].abs() < resp_thresh) & (data[resp_type[1]].abs() < resp_thresh)]
    else:
        cells = data[(data[resp_type].abs() >= resp_thresh) & (data[sel_type].abs() >= sel_tresh)]
    return cells

In [None]:
# Create dataframes to store binary tuning information
fixed_cell_tunings = agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'][['old_index', 'mouse', 'day']].copy()
free_cell_tunings = agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'][['old_index', 'mouse', 'day']].copy()

In [None]:
# Cells that meet direction selectivity criteria
free_dir_tuned = get_vis_tuned_cells(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], vis_stim='dir', resp_thresh=0.5)
free_ori_tuned = get_vis_tuned_cells(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], vis_stim='ori', resp_thresh=0.5)

# Cells that meet orientation selectivity criteria
fixed_dir_tuned = get_vis_tuned_cells(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], vis_stim='dir', resp_thresh=0.5)
fixed_ori_tuned = get_vis_tuned_cells(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], vis_stim='ori', resp_thresh=0.5)

# Cells that meet visual responsivity criteria
free_vis_resp = get_vis_tuned_cells(agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'], vis_stim='vis', resp_thresh=0.5)
fixed_vis_resp = get_vis_tuned_cells(agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'], vis_stim='vis', resp_thresh=0.5)

In [None]:
# Find cells that are both direction and orientation tuned, and figure out what to do with them.
intersect, comm1, comm2 = np.intersect1d(free_dir_tuned.index, free_ori_tuned.index, return_indices=True)
free_both_tuned = free_dir_tuned.iloc[comm1].copy()

# Remove cells tuned to both from each category
free_dir_tuned = free_dir_tuned.drop(free_dir_tuned.index[comm1])
free_ori_tuned = free_ori_tuned.drop(free_ori_tuned.index[comm2])

intersect, comm1, comm2 = np.intersect1d(fixed_dir_tuned.index, fixed_ori_tuned.index, return_indices=True)
fixed_both_tuned = fixed_dir_tuned.iloc[comm1].copy()
fixed_dir_tuned = fixed_dir_tuned.drop(fixed_dir_tuned.index[comm1])
fixed_ori_tuned = fixed_ori_tuned.drop(fixed_ori_tuned.index[comm2])

# Double check cells that are visually reposnsive, make sure that all are contained in the vis_resp
free_resp_cells = np.unique(np.concatenate([free_dir_tuned.index, free_ori_tuned.index, free_both_tuned.index]))
not_in_free_resp_cells = np.setdiff1d(free_vis_resp.index, free_resp_cells, assume_unique=True)
free_vis_resp = pd.concat([free_vis_resp, agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'].iloc[not_in_free_resp_cells, :]])
free_vis_resp = free_vis_resp.reset_index().drop_duplicates(subset=['index'])

fixed_resp_cells = np.unique(np.concatenate([fixed_dir_tuned.index, fixed_ori_tuned.index, fixed_both_tuned.index]))
not_in_fixed_resp_cells = np.setdiff1d(fixed_vis_resp.index, fixed_resp_cells, assume_unique=True)
fixed_vis_resp = pd.concat([fixed_vis_resp, agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'].iloc[not_in_fixed_resp_cells, :]])
fixed_vis_resp = fixed_vis_resp.reset_index().drop_duplicates(subset=['index'])


In [None]:
# Assign the tunings to the dataframes
free_cell_tunings['is_vis_resp'] = free_cell_tunings.index.isin(free_vis_resp.index)
free_cell_tunings['is_dir_tuned'] = free_cell_tunings.index.isin(free_dir_tuned.index)
free_cell_tunings['is_ori_tuned'] = free_cell_tunings.index.isin(free_ori_tuned.index)
fixed_cell_tunings['is_vis_resp'] = fixed_cell_tunings.index.isin(fixed_vis_resp.index)
fixed_cell_tunings['is_dir_tuned'] = fixed_cell_tunings.index.isin(fixed_dir_tuned.index)
fixed_cell_tunings['is_ori_tuned'] = fixed_cell_tunings.index.isin(fixed_ori_tuned.index)

In [None]:
# Get fraction of cells that are orientation. direction, and visually tuned
free_cell_per_day = free_cell_tunings.groupby(['mouse', 'day']).apply(lambda x: len(x))
frac_free_dir_tuned = free_dir_tuned.groupby(['mouse', 'day']).apply(lambda x: len(x)) / free_cell_per_day
frac_free_dir_tuned = frac_free_dir_tuned.reset_index().rename(columns={0: 'direction'})
frac_free_ori_tuned = free_ori_tuned.groupby(['mouse', 'day']).apply(lambda x: len(x)) / free_cell_per_day
frac_free_ori_tuned = frac_free_ori_tuned.reset_index().rename(columns={0: 'orientation'})
frac_free_vis_resp = free_vis_resp.groupby(['mouse', 'day']).apply(lambda x: len(x)) / free_cell_per_day
frac_free_vis_resp = frac_free_vis_resp.reset_index().rename(columns={0: 'visual'})
frac_vis_resp_free = pd.concat([frac_free_dir_tuned, frac_free_ori_tuned, frac_free_vis_resp], axis=1).drop(['mouse', 'day'], axis=1)

fixed_cell_per_day = fixed_cell_tunings.groupby(['mouse', 'day']).apply(lambda x: len(x))
frac_fixed_dir_tuned = fixed_dir_tuned.groupby(['mouse', 'day']).apply(lambda x: len(x)) / fixed_cell_per_day
frac_fixed_dir_tuned = frac_fixed_dir_tuned.reset_index().rename(columns={0: 'direction'})
frac_fixed_ori_tuned = fixed_ori_tuned.groupby(['mouse', 'day']).apply(lambda x: len(x)) / fixed_cell_per_day
frac_fixed_ori_tuned = frac_fixed_ori_tuned.reset_index().rename(columns={0: 'orientation'})
frac_fixed_vis_resp = fixed_vis_resp.groupby(['mouse', 'day']).apply(lambda x: len(x)) / fixed_cell_per_day
frac_fixed_vis_resp = frac_fixed_vis_resp.reset_index().rename(columns={0: 'visual'})
frac_vis_resp_fixed = pd.concat([frac_fixed_dir_tuned, frac_fixed_ori_tuned, frac_fixed_vis_resp], axis=1).drop(['mouse', 'day'], axis=1)

In [None]:
save_path = os.path.join(figure_save_path, f"frac_vis_tuned_free.png")
violinplot_free_vis = violin_swarm(frac_vis_resp_free, save_path, cmap=fp.hv_blue_hex, font_size='poster', backend='seaborn', width=4.5, height=5, save=True)

In [None]:
save_path = os.path.join(figure_save_path, f"frac_vis_tuned_fixed.png")
violinplot_fixed_vis = violin_swarm(frac_vis_resp_fixed, save_path, cmap='red', font_size='poster', backend='seaborn', width=4.5, height=5, save=True)

### Histograms of only dir or ori tuned cells

In [None]:
# Visualize the non-overlapping cells
dsi_free = hv_hist(free_dir_tuned, 'dsi_abs', 'free', xlabel='selectivity')
dsi_fixed = hv_hist(fixed_dir_tuned, 'dsi_abs', 'fixed', xlabel='selectivity')
overlay_dsi = dsi_fixed.opts(alpha=0.5) * dsi_free.opts(alpha=0.5)

osi_free = hv_hist(free_ori_tuned, 'osi', 'free', xlabel='selectivity')
osi_fixed = hv_hist(fixed_ori_tuned, 'osi', 'fixed', xlabel='selectivity')
overlay_osi = osi_fixed.opts(alpha=0.5) * osi_free.opts(alpha=0.5)

selctivity_hists = overlay_dsi.opts(title='DSI - dir selective', width=500) + overlay_osi.opts(title='OSI - ori selective', width=500)


resp_dir_free = hv_hist(free_dir_tuned, 'responsivity_dir', 'free', xlabel='responsivity')
resp_dir_fixed = hv_hist(fixed_dir_tuned, 'responsivity_dir', 'fixed', xlabel='responsivity')
overlay_resp_dir = resp_dir_fixed.opts(alpha=0.5) * resp_dir_free.opts(alpha=0.5)

resp_ori_free = hv_hist(free_ori_tuned, 'responsivity_ori', 'free', xlabel='responsivity')
resp_ori_fixed = hv_hist(fixed_ori_tuned, 'responsivity_ori', 'fixed', xlabel='responsivity')
overlay_resp_ori = resp_ori_fixed.opts(alpha=0.5) * resp_ori_free.opts(alpha=0.5)

responsivity_hists = overlay_resp_dir.opts(title='Resp. Dir.', width=500) + overlay_resp_ori.opts(title='Resp. Ori.', width=500)

both_hists = responsivity_hists + selctivity_hists
both_hists.cols(2)

### Histograms of cells that are both ori and dir tuned

In [None]:
# Visualize the overlapping cells
dsi_free = hv_hist(free_both_tuned, 'dsi_abs', 'free', xlabel='selectivity')
dsi_fixed = hv_hist(fixed_both_tuned, 'dsi_abs', 'fixed', xlabel='selectivity')
overlay_dsi = dsi_free.opts(alpha=0.5) * dsi_fixed.opts(alpha=0.5)

osi_free = hv_hist(free_both_tuned, 'osi', 'free', xlabel='selectivity')
osi_fixed = hv_hist(fixed_both_tuned, 'osi', 'fixed', xlabel='selectivity')
overlay_osi = osi_free.opts(alpha=0.5) * osi_fixed.opts(alpha=0.5)

selctivity_hists = overlay_dsi.opts(title='DSI - both selective', width=500) + overlay_osi.opts(title='OSI - both selective', width=500)

resp_dir_free = hv_hist(free_both_tuned, 'responsivity_dir', 'free', xlabel='responsivity')
resp_dir_fixed = hv_hist(fixed_both_tuned, 'responsivity_dir', 'fixed', xlabel='responsivity')
overlay_resp_dir = resp_dir_free.opts(alpha=0.5) * resp_dir_fixed.opts(alpha=0.5)

resp_ori_free = hv_hist(free_both_tuned, 'responsivity_ori', 'free', xlabel='responsivity')
resp_ori_fixed = hv_hist(fixed_both_tuned, 'responsivity_ori', 'fixed', xlabel='responsivity')
overlay_resp_ori = resp_ori_free.opts(alpha=0.5) * resp_ori_fixed.opts(alpha=0.5)

responsivity_hists = overlay_resp_dir.opts(title='Resp. Dir.- both selective', width=500) + overlay_resp_ori.opts(title='Resp. Ori.- both selective', width=500)

both_hists = responsivity_hists + selctivity_hists
both_hists.cols(2)

# Fraction self-motion responsive

In [None]:
def kine_fraction_tuned(ds, use_test=True, include_responsivity=True, include_consistency=False):
    if use_test:
        resp = ds['Resp_test']
        qual = ds['Qual_test']
        cons = ds['Cons_test']

    else:
        resp = ds['Resp_index'] >= processing_parameters.tc_resp_qual_cutoff/100
        qual = ds['Qual_index'] >= processing_parameters.tc_resp_qual_cutoff/100
        cons = ds['Cons_index'] >= processing_parameters.tc_consistency_cutoff/100

    # is tuned if quality is true
    is_tuned = qual

    if include_responsivity:
        is_tuned = is_tuned + resp
        is_tuned = is_tuned > 1

    if include_consistency:
        is_tuned = is_tuned + cons
        is_tuned = is_tuned > 1

    frac_is_tuned = is_tuned.sum() / is_tuned.count()
    return frac_is_tuned

def get_sig_tuned_kinem_cells(agg_dict, exp_kind, which_cells, vars, use_test=True, include_responsivity=True, include_consistency=False):
    keys = ['_'.join([exp_kind, which_cells, var]) for var in vars]
    df = pd.DataFrame(columns=vars)

    for key, var in zip(keys, vars):
        if key in agg_dict.keys():
            df[var] = agg_dict[key].groupby(['mouse', 'day']).apply(kine_fraction_tuned, use_test, include_responsivity, include_consistency)

    return df


In [None]:
# get the head fixed tunings


In [None]:
agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}']

In [None]:
save_path = os.path.join(figure_save_path, "sig_frac_kinem_free.png")

# Within-animal fraction kinematic
frac_kine_resp_free = get_sig_tuned_kinem_cells(agg_dict, 'VTuningWF', cell_kind, processing_parameters.variable_list_free, use_test=True, include_responsivity=True, include_consistency=False)
violinplot_free_kinem = fp.violin_swarm(frac_kine_resp_free, save_path, cmap=fp.hv_blue_hex, backend='seaborn', width=15, height=5, save=True)

In [None]:
# Within-animal fraction kinematic
save_path = os.path.join(figure_save_path, "sig_frac_kinem_fixed.png")

frac_kine_resp_fixed = get_sig_tuned_kinem_cells(agg_dict, 'VWheelWF', cell_kind, processing_parameters.variable_list_fixed, use_test=True, include_responsivity=True, include_consistency=False)
violinplot_free_kinem = fp.violin_swarm(frac_kine_resp_fixed, save_path, cmap='red', backend='seaborn', width=3, height=5, save=True)

# Fraction Visual, Self-Motion, or Multimodal Tuned

In [None]:
def get_frac_tuned(df):
    kinem_cols = list(df.columns[1:-5])
    vis_cols = list(df.columns[-5:-3])
    
    df['sum_kinem'] = df[kinem_cols].sum(axis=1)
    df['sum_vis'] = df[vis_cols].sum(axis=1)
    df['sum_mix'] = df[['sum_kinem', 'sum_vis']].sum(axis=1)
   
    frac_vis_tuned = df['sum_vis'].loc[df['sum_vis'] > 0].count() / df.shape[0]
    frac_kinem_tuned = df['sum_kinem'].loc[df['sum_kinem'] > 0].count() / df.shape[0]

    frac_only_kinem = df[(df[vis_cols].sum(axis=1) == 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_only_vis = df[(df[kinem_cols].sum(axis=1) == 0) & (df[vis_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]
    frac_mix_tuned = df[(df[vis_cols].sum(axis=1) > 0) & (df[kinem_cols].sum(axis=1) > 0)].shape[0] / df.shape[0]

    return frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned

In [None]:
plt_list = []
df_list = []
for ds, exp_type in zip([agg_dict['VTuningWF_multimodal_tuned'], agg_dict['VWheelWF_multimodal_tuned']], ['Freely Moving', 'Head Fixed']):
    only_kinem = []
    only_vis = []
    vis_tuned = []
    kinem_tuned = []
    mix_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        frac_only_kinem, frac_only_vis, frac_vis_tuned, frac_kinem_tuned, frac_mix_tuned = get_frac_tuned(group)
        only_kinem.append(frac_only_kinem)
        only_vis.append(frac_only_vis)
        vis_tuned.append(frac_vis_tuned)
        kinem_tuned.append(frac_kinem_tuned)
        mix_tuned.append(frac_mix_tuned)

    df['group'] = [exp_type] * len(vis_tuned)
    df[' Visual Only'] = only_vis
    df[' Postural Only'] = only_kinem
    df[' Visual'] = vis_tuned
    df[' Postural'] = kinem_tuned
    df['Multimodal'] = mix_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

comp_tuning = pd.concat(df_list, axis=0)
comp_tuning = pd.melt(comp_tuning, id_vars=['group'], value_vars=[' Visual Only', ' Postural Only', ' Visual', ' Postural', 'Multimodal'])

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = comp_tuning.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x['Freely Moving'], x['Head Fixed'], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
violinplot = hv.Violin(comp_tuning, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=400,
                                                                  # ylim=(-0.05, 1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=['red', fp.hv_blue_hex],
                                                                  show_legend=False, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "frac_multimodal.png")
# violinplot
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=25, dpi=800, fontsize='screen', target='both', display_factor=0.1)

In [None]:
ax = sns_scatter_swarm(comp_tuning, "variable", "value", "group")
ax.set_ylabel('Significant Frac.')
plt.savefig(os.path.join(figure_save_path, 'frac_multimodal_alt.png'), dpi=800, format='png')

# Bootstrapped tuning shifts

In [None]:
def plot_compare_matched_vis_tuning(ds_list, cell, tuning_kind='direction', exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(18*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(ds_list), ncols=1, hspace=0.07)
    
        for i, subfig in enumerate(subfigs):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # direction tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    for sub_ax, ds in zip(axes, ds_list):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind, error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

def plot_compare_all_vis_tuning(ds_list, cell, tuning_kind=['direction', 'orientation'], exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(36*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=2, ncols=2, hspace=0.07)
    
        for i, subfig in enumerate(subfigs.flatten()):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i//2]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    tuning_kind = ['direction', 'orientation'] * 2
    for i, (ds, sub_ax) in enumerate(zip(ds_list, axes)):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind[i], error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

In [None]:
def find_tuned_cell_indices(ref_data, comp_data, stim_kind='orientation', cutoff=0.5, tuning_criteria='both',):

    # Find the cells in the reference dataset that are tuned
    if stim_kind == 'orientation':
        index_ref = ref_data.loc[ref_data.osi.abs() >= cutoff].index.to_numpy()
        index_comp = comp_data.loc[comp_data.osi.abs() >= cutoff].index.to_numpy()
    elif stim_kind == 'direction':
        index_ref = ref_data.loc[ref_data.dsi_abs.abs() >= cutoff].index.to_numpy()
        index_comp = comp_data.loc[comp_data.dsi_abs.abs() >= cutoff].index.to_numpy()
    else:
        return Exception('Invalid stim_kind')

    if tuning_criteria == 'both':
        indices = np.intersect1d(index_ref, index_comp)
    elif tuning_criteria == 'ref':
        indices = index_ref
    elif tuning_criteria == 'comp':
        indices = index_comp
    else:
        return Exception('Invalid tuning_criteria')

    return indices

def tuning_shifts(ds1, ds2, ci_width_cutoff=20, stim_kind='orientation', method='fit'):
    shifts = []

    stim_kind = stim_kind[:3]
    if stim_kind == 'ori':
        multiplier = 2
    else:
        multiplier = 1

    if method == 'fit':
        dist_key = f'bootstrap_pref_{stim_kind}'
        pref_key = f'pref_{stim_kind}'
    elif method == 'resultant':
        dist_key = f'bootstrap_resultant_{stim_kind}'
        pref_key = f'resultant_{stim_kind}'
    else:
        return Exception('Invalid method')

    for (idxRow, cell_1), (_, cell_2) in zip(ds1.iterrows(), ds2.iterrows()):

        # Get preferred angle
        pref_1 = cell_1[pref_key]
        pref_2 = cell_2[pref_key]

        # Get bootstrap distributions
        pref_dist_1 = cell_1[dist_key]
        pref_dist_2 = cell_2[dist_key]

        if method == 'resultant':
            pref_1 = pref_1[-1]
            pref_2 = pref_2[-1]
            pref_dist_1 = pref_dist_1[:, -1]
            pref_dist_2 = pref_dist_2[:, -1]

            if np.isnan(pref_1) or np.isnan(pref_2):
                pass

        pref_dist_1 = pref_dist_1[~np.isnan(pref_dist_1)]
        pref_dist_2 = pref_dist_2[~np.isnan(pref_dist_2)]

        if pref_dist_1.size == 0 or pref_dist_2.size == 0:
            pass

        # Wrap angles
        pref_dist_1 = fk.wrap(pref_dist_1, 360/multiplier + 0.1)
        pref_dist_2 = fk.wrap(pref_dist_2, 360/multiplier + 0.1)
        delta_pref = np.abs(pref_2 - pref_1)

        # Calculate confidence intervals
        ci_1 = st.norm.interval(confidence=0.95, loc=np.nanmean(pref_dist_1), scale=st.sem(pref_dist_1, nan_policy='omit')) 
        ci_2 = st.norm.interval(confidence=0.95, loc=np.nanmean(pref_dist_2), scale=st.sem(pref_dist_2, nan_policy='omit')) 

        # Get CI widths
        ci_width_1 = np.abs(ci_1[-1] - ci_1[0])      
        ci_width_2 = np.abs(ci_2[-1] - ci_2[0])

        # Check if tuned or not
        if (ci_width_1 <= ci_width_cutoff) and (ci_width_2 <= ci_width_cutoff):
            
            # determine significance of shift
            if ~(ci_1[0] <= pref_2 <= ci_1[-1]) and ~(ci_2[0] <= pref_1 <= ci_2[-1]):
                # Shift is significant
                sig_shift = 1
            else:
                # Shift is not significant    
                sig_shift = 0
        else:
            # The cell is not tuned
            sig_shift = 0

        # wrap to negative domain for plotting
        pref_1 = fk.wrap_negative(pref_1, bound=360/(2*multiplier) + 0.1)
        pref_2 = fk.wrap_negative(pref_2, bound=360/(2*multiplier) + 0.1)

        shifts.append([idxRow, pref_1, ci_width_1, pref_2, ci_width_2, delta_pref, sig_shift, cell_1.mouse, cell_1.day])
    
    shifts = pd.DataFrame(data=shifts, columns=['', 'pref_1', 'ci_width_1', 'pref_2', 'ci_width_2', 'delta_pref', 'is_sig', 'mouse', 'date'])
    shifts = shifts.set_index(shifts.columns[0])

    return shifts

In [None]:
curated_matches

In [None]:
curated_matches[['day', 'mouse']].drop_duplicates().to_numpy()
fixed_matches = agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}']
free_matches = agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}']

curated_ref_df_list = []
curated_comp_df_list = []
for day, mouse in curated_matches[['day', 'mouse']].drop_duplicates().to_numpy():
    curated_idxs = curated_matches.loc[(curated_matches.mouse == mouse) & (curated_matches.day == day)]['index'].to_numpy()

    ref_df = fixed_matches.loc[(fixed_matches.mouse == mouse) & (fixed_matches.day == day)]
    comp_df = free_matches.loc[(free_matches.mouse == mouse) & (free_matches.day == day)]

    curated_ref_df = ref_df.loc[ref_df.old_index.isin(curated_idxs)]
    curated_comp_df = comp_df.loc[comp_df.old_index.isin(curated_idxs)]

    curated_ref_df_list.append(curated_ref_df.copy())
    curated_comp_df_list.append(curated_comp_df.copy())

curated_ref_df = pd.concat(curated_ref_df_list).reset_index(drop=True)
curated_comp_df = pd.concat(curated_comp_df_list).reset_index(drop=True)

In [None]:
cell_kind = 'matched'
activity_dataset = 'norm_spikes_viewed_props'

ref_ds = curated_ref_df#agg_dict[f'{session_types[0]}_{cell_kind}_{activity_dataset}'].copy()
ref_ds_ori = ref_ds.drop(ref_ds[ref_ds.osi.isna()].index).copy()
ref_ds_dir = ref_ds.drop(ref_ds[ref_ds.dsi_abs.isna()].index).copy()

comp_ds = curated_comp_df#agg_dict[f'{session_types[1]}_{cell_kind}_{activity_dataset}'].copy()
comp_ds_ori = comp_ds.drop(comp_ds[comp_ds.osi.isna()].index).copy()
comp_ds_dir = comp_ds.drop(comp_ds[comp_ds.dsi_abs.isna()].index).copy()

In [None]:
ori_cutoff_high = 0.5
dir_cutoff_high = 0.5
lower_cutoff = 0.3

ori_indices = find_tuned_cell_indices(ref_ds.copy(), comp_ds.copy(), cutoff=ori_cutoff_high, stim_kind='orientation', tuning_criteria='both')
dir_indices = find_tuned_cell_indices(ref_ds.copy(), comp_ds.copy(), cutoff=dir_cutoff_high, stim_kind='direction', tuning_criteria='both')

po_shifts = tuning_shifts(ref_ds.iloc[ori_indices, :].copy(), comp_ds.iloc[ori_indices, :].copy(),
                          ci_width_cutoff=45, stim_kind='orientation', method='resultant')

pd_shifts = tuning_shifts(ref_ds.iloc[dir_indices, :].copy(), comp_ds.iloc[dir_indices, :].copy(),
                          ci_width_cutoff=90, stim_kind='direction', method='resultant')


In [None]:
residual_ori = po_shifts['pref_2'].to_numpy() - po_shifts['pref_1'].to_numpy()
rmse_residual_ori = np.mean(np.sqrt(residual_ori**2))
print(po_shifts.shape)

In [None]:
unity_line = hv.Curve((np.arange(-90, 90, 1), np.arange(-90, 90, 1))).opts(color='black')
scatter_ori = hv.Scatter(po_shifts[['pref_1', 'pref_2']], kdims=['pref_1'], vdims=['pref_2'], label='sig').opts(color='purple', size=8)
scatter_PO = unity_line * scatter_ori * hv.Text(-60, 80, f'RMSE: {rmse_residual_ori:.1f}°').opts(color='black', fontsize=15)
scatter_PO.opts(show_legend=False, width=500, height=500,
                xlabel=f'{processing_parameters.wf_label_dictionary_wo_units[session_types[0]]} Pref. Ori. [°]', 
                ylabel=f'{processing_parameters.wf_label_dictionary_wo_units[session_types[1]]} Pref. Ori. [°]')
scatter_PO.opts(hv.opts.Scatter(xlim=(-90, 90), ylim=(-90, 90), xticks=[-45, 0, 45, 90], yticks=[-90, -45, 0, 45, 90]))

save_path = os.path.join(figure_save_path, f"delta_PO_{parsed_search['result']}_{parsed_search['rig']}.png")
scatter_PO = fp.save_figure(scatter_PO, save_path=save_path, fig_width=13, dpi=800, fontsize='poster', target='screen', display_factor=0.1)

In [None]:
scatter_sig_ori = hv.Scatter(po_shifts[po_shifts.is_sig == 1][['pref_1', 'pref_2']], kdims=['pref_1'], vdims=['pref_2'], label='sig').opts(color='red', size=5)

scater_insig_ori = hv.Scatter(po_shifts[po_shifts.is_sig == 0][['pref_1', 'pref_2']], kdims=['pref_1'], vdims=['pref_2'], label='inisg').opts(alpha=0.5, color='gray', size=5)

unity_line = hv.Curve((np.arange(-90, 90, 1), np.arange(-90, 90, 1))).opts(color='black')
scatter_PO = unity_line * scater_insig_ori * scatter_sig_ori
scatter_PO.opts(show_legend=False, xlabel=f'{session_types[0].title()} Pref. Ori. [°]', ylabel=f'{session_types[1].title()} Pref. Ori. [°]', width=500, height=500)
scatter_PO.opts(hv.opts.Scatter(xlim=(-90, 90), ylim=(-90, 90), xticks=[-45, 0, 45, 90], yticks=[-90, -45, 0, 45, 90]))

save_path = os.path.join(figure_save_path, f"delta_PO_{parsed_search['result']}_{parsed_search['rig']}.png")
scatter_PO = fp.save_figure(scatter_PO, save_path=save_path, fig_width=14, dpi=800, fontsize='screen', target='screen', display_factor=0.1)

# scatter_sig_dir = hv.Scatter(pd_shifts[pd_shifts.is_sig == 1][['pref_1', 'pref_2']], kdims=['pref_1'], vdims=['pref_2'], label='sig').opts(color='red', size=5)

# scater_insig_dir = hv.Scatter(pd_shifts[pd_shifts.is_sig == 0][['pref_1', 'pref_2']], kdims=['pref_1'], vdims=['pref_2'], label='inisg').opts(alpha=0.5, color='gray', size=5)

# a = np.arange(-180, 180, 1)
# unity_line = hv.Curve((a, a)).opts(color='black')
# scatter_PD = unity_line * scater_insig_dir * scatter_sig_dir
# scatter_PD.opts(show_legend=False, xlabel=f'{session_types[0].title()} Pref. Dir. [°]', ylabel=f'{session_types[1].title()} Pref. Dir. [°]', width=500, height=500)
# scatter_PD.opts(hv.opts.Scatter(xlim=(-180, 180), ylim=(-180, 180), xticks=[-135, -90, -45, 0, 45, 90, 135, 180], yticks=[-180, -135, -90, -45, 0, 45, 90, 135, 180]))

# save_path = os.path.join(figure_save_path, f"delta_PD_{parsed_search['result']}_{parsed_search['rig']}.png")
# scatter_PD = fp.save_figure(scatter_PD, save_path=save_path, fig_width=14, dpi=800, fontsize='screen', target='save', display_factor=0.1)

# hv.Layout([scatter_PO, scatter_PD]).opts(shared_axes=False).cols(2)

In [None]:
tuning_shifts.groupby('mouse').apply(lambda x: x.is_sig_po.count())

In [None]:
bins = np.arange(0, 1.01, 0.05)
plt.hist([agg_dict[datasets[0]].responsivity_dir, agg_dict[datasets[1]].responsivity_dir], bins=bins, color=[fp.hv_blue_hex, fp.hv_red_hex], edgecolor="black")
plt.legend(['fixed', 'free'])
plt.axvline(0.25, c='r')
plt.title('responsivity - direction')

In [None]:
bins = np.arange(0, 1.01, 0.05)
plt.hist([agg_dict[datasets[0]].responsivity_ori, agg_dict[datasets[1]].responsivity_ori], bins=bins, color=[fp.hv_blue_hex, fp.hv_red_hex], edgecolor="black")
plt.legend(['fixed', 'free'])
plt.axvline(0.25, c='r')
plt.title('responsivity - orientation')

In [None]:
overall_sig_po_shift_frac = tuning_shifts.is_sig_po.sum() / tuning_shifts.is_sig_po.count() 
print(overall_sig_po_shift_frac)
tuning_shifts.groupby('mouse').apply(lambda x: x.is_sig_po.sum() / x.is_sig_po.count())

In [None]:
frequencies, edges = np.histogram(tuning_shifts[tuning_shifts.is_sig_po == 1].delta_po.to_numpy(), 36)
hv.Histogram((edges, frequencies)).opts(width=600)

In [None]:
x = tuning_shifts[tuning_shifts.is_sig_po == 1].groupby('mouse').apply(lambda x: np.histogram(x.delta_po.to_numpy(), 36)).to_list()
layout = hv.Layout([hv.Histogram((edges, frequencies)).opts(width=400) for frequencies, edges in x]).opts(shared_axes=False)
layout

In [None]:
%matplotlib inline
datasets = ['VWheelWF_matched_norm_spikes_viewed_still_direction_props', 'VWheelWF_matched_norm_spikes_viewed_still_orientation_props',
            'VTuningWF_matched_norm_spikes_viewed_still_direction_props', 'VTuningWF_matched_norm_spikes_viewed_still_orientation_props']

matched_vis_fig = interactive(plot_compare_all_vis_tuning,
                              ds_list=widgets.fixed([agg_dict[dataset] for dataset in datasets]),
                              cell=tuning_shifts[tuning_shifts.is_sig_po == 1].index, 
                              tuning_kind = widgets.fixed(['direction', 'orientation']),
                              exp_name = widgets.fixed(['VWheelWF', 'VTuningWF']),
                              error = ['std', 'sem'],
                              polar=[True, False],
                              norm=widgets.fixed(True),
                              axes=widgets.fixed(None))
matched_vis_fig

In [None]:
tuning_shifts.loc[67]

In [None]:
import pycircstat as circ

# DSI/OSI Distributions

In [None]:
free_matched = ['VTuningWF_matched_norm_spikes_viewed_props',
                'VTuningWF_matched_norm_spikes_viewed_still_props']
fixed_matched = ['VWheelWF_matched_norm_spikes_viewed_props',
                 'VWheelWF_matched_norm_spikes_viewed_still_props']
matched_free_dir = free_matched[0:1]
matched_fixed_dir = fixed_matched[0:1]

In [None]:
ds_list_matched_free = [agg_dict[d] for d in matched_free_dir]
ds_list_matched_fixed = [agg_dict[d] for d in matched_fixed_dir]
cols = ['osi', 'dsi_abs']

## Matched Cells

In [None]:
plt_list = []
df_list = []
for ds, exp_type in zip([agg_dict['VTuningWF_matched_norm_spikes_viewed_props'], agg_dict['VWheelWF_matched_norm_spikes_viewed_props']], ['Freely Moving', 'Head Fixed']):
    frac_ori_tuned = []
    frac_dir_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        ori_tuned = np.abs(group.osi) > 0.5
        dir_tuned = np.abs(group.dsi_abs) > 0.5

        # ori_tuned = (np.abs(group.osi) > 0.25) & (np.abs(group.dsi_nasal_temporal) <= 0.25)
        # dir_tuned = (np.abs(group.osi) <= 0.25) & (np.abs(group.dsi_nasal_temporal) > 0.25)
        frac_ori_tuned.append(ori_tuned.sum() / ori_tuned.count())
        frac_dir_tuned.append(dir_tuned.sum() / dir_tuned.count())


    df['group'] = [exp_type] * len(frac_ori_tuned)
    df['Ori. Tuned'] = frac_ori_tuned
    df['Dir. Tuned'] = frac_dir_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

matched_dsi_osi = pd.concat(df_list, axis=0)
matched_dsi_osi= pd.melt(matched_dsi_osi, id_vars=['group'], value_vars=['Ori. Tuned', 'Dir. Tuned'])

violinplot_matched_dsi_osi = hv.Violin(matched_dsi_osi, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=600,
                                                                  ylim=(0, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "matched_frac_ori_dir_tuned.png")
# violinplot
violinplot_matched_dsi_osi = fp.save_figure(violinplot_matched_dsi_osi, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='both', display_factor=0.1)

In [None]:
ax = sns_scatter_swarm(matched_dsi_osi, "variable", "value", "group")
ax.set_ylabel('Significant Frac.')
save_path = os.path.join(figure_save_path, "matched_frac_ori_dir_tuned.png")
plt.savefig(save_path, dpi=800, format='png')

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = matched_dsi_osi.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x['Freely Moving'], x['Head Fixed'], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
print(comp_tuning_stats)
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
plt_list = []
df_list = []

for ds, exp_type in zip([agg_dict['VTuningWF_matched_norm_spikes_viewed_props'], agg_dict['VWheelWF_matched_norm_spikes_viewed_props']], ['Freely Moving', 'Head Fixed']):
    df = pd.DataFrame()
    df['OSI'] = ds.osi.abs().to_numpy()
    df['DSI'] = ds.dsi_abs.abs().to_numpy()
    df['group'] = [exp_type] * len(ds.osi)
    df_list.append(df)
    
    # plt_list.append(df.hvplot.box(legend=False))

matched_selectivity = pd.concat(df_list, axis=0)
matched_selectivity = pd.melt(matched_selectivity, id_vars=['group'], value_vars=['OSI', 'DSI'])

violinplot_selectivity_matched = hv.Violin(matched_selectivity, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=600,
                                                                  ylim=(-0.1, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "matched_selectivity.png")
violinplot_selectivity_matched = fp.save_figure(violinplot_selectivity_matched, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='both', display_factor=0.1)

In [None]:
ax = sns_scatter_swarm(matched_selectivity.dropna(), "variable", "value", "group")
ax.set_ylabel('Significant Frac.')
save_path = os.path.join(figure_save_path, "matched_selectivity.png")
plt.savefig(save_path, dpi=800, format='png')

In [None]:
def replace_inf(x):
    x[x > np.percentile(x, 99)] = np.percentile(x, 99)
    return x

a = matched_selectivity[matched_selectivity.variable == 'OSI'].fillna(0).drop(columns='variable')
a = a.groupby('group').agg(list).T.reset_index(drop=True)


frequencies_free_osi, edges_free_osi = np.histogram(np.clip(a['Freely Moving'][0], 0, 1)	, 20)
frequencies_fixed_osi, edges_fixed_osi = np.histogram(np.clip(a['Head Fixed'][0], 0, 1), 20)

cell_match_hist_osi = hv.Overlay([hv.Histogram((frequencies_free_osi, edges_free_osi), label='Freely Moving').opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_osi, edges_fixed_osi), label='Head Fixed').opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

b = matched_selectivity[matched_selectivity.variable == 'DSI'].fillna(0).drop(columns='variable')
b = b.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_dsi, edges_free_dsi = np.histogram(np.clip(b['Freely Moving'][0], 0, 1), 20)
frequencies_fixed_dsi, edges_fixed_dsi = np.histogram(np.clip(b['Head Fixed'][0], 0, 1), 20)

cell_match_hist_dsi = hv.Overlay([hv.Histogram((frequencies_free_dsi, edges_free_dsi)).opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_dsi, edges_fixed_dsi)).opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

layout_matched_still = cell_match_hist_dsi.opts(height=300, width=400, xlabel='DSI') + cell_match_hist_osi.opts(height=300, width=500, xlabel='OSI', ylabel='', legend_position='right', fontsize={'legend': 10})
layout_matched_still
# save_path = os.path.join(figure_save_path, "Fig4", "unmatched_dsi_osi_hist.png")
# layout_matched_still = fp.save_figure(layout_matched_still, save_path=save_path, fig_width=18, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

In [None]:
unity_line = hv.Curve(zip(np.linspace(0, 1), np.linspace(0,1))).opts(color='black')

scatter_OSI = hv.Scatter(zip(np.clip(a['Head Fixed'][0], 0, 1), np.clip(a['Freely Moving'][0], 0, 1)), kdims=['OSI_fixed'], vdims=['OSI_free']).opts(xlabel='Fixed OSI', ylabel='Free OSI', size=5, width=500, height=500)
scatter_OSI = scatter_OSI * unity_line

scatter_DSI = hv.Scatter(zip(np.clip(b['Head Fixed'][0], 0, 1), np.clip(b['Freely Moving'][0], 0, 1)), kdims=['DSI_fixed'], vdims=['DSI_free']).opts(xlabel='Fixed DSI', ylabel='Free DSI', size=5, width=500, height=500)
scatter_DSI = scatter_DSI * unity_line

save_path = os.path.join(figure_save_path, f"delta_OSI_{parsed_search['result']}_{parsed_search['rig']}.png")
scatter_OSI = fp.save_figure(scatter_OSI, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='save', display_factor=0.1)

save_path = os.path.join(figure_save_path, f"delta_DSI_{parsed_search['result']}_{parsed_search['rig']}.png")
scatter_DSI = fp.save_figure(scatter_DSI, save_path=save_path, fig_width=16, dpi=800, fontsize='screen', target='save', display_factor=0.1)

scatter_OSI + scatter_DSI


## Unmatched Cells

In [None]:
plt_list = []
df_list = []
for ds, exp_type in zip([agg_dict['VTuningWF_unmatched_norm_spikes_viewed_props'], agg_dict['VWheelWF_unmatched_norm_spikes_viewed_props']], 
                        ['Freely Moving', 'Head Fixed']):
    frac_ori_tuned = []
    frac_dir_tuned = []
    df = pd.DataFrame()
    
    for name, group in ds.groupby('mouse'):
        ori_tuned = np.abs(group.osi) > 0.75
        dir_tuned = np.abs(group.dsi_abs) > 0.75
        frac_ori_tuned.append(ori_tuned.sum() / ori_tuned.count())
        frac_dir_tuned.append(dir_tuned.sum() / dir_tuned.count())


    df['group'] = [exp_type] * len(frac_ori_tuned)
    df['Ori. Tuned'] = frac_ori_tuned
    df['Dir. Tuned'] = frac_dir_tuned
    df_list.append(df)
    plt_list.append(df.hvplot.box(legend=False))

unmatched_dsi_osi = pd.concat(df_list, axis=0)
unmatched_dsi_osi= pd.melt(unmatched_dsi_osi, id_vars=['group'], value_vars=['Ori. Tuned', 'Dir. Tuned'])

violinplot_dsi_osi_unmatched = hv.Violin(unmatched_dsi_osi, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=400,
                                                                  ylim=(0, 1),
                                                                  violin_color=hv.dim('group'),
                                                                  cmap=[fp.hv_blue_hex, fp.hv_orange_hex],
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "unmatched_frac_ori_dir_tuned.png")
# violinplot
violinplot_dsi_osi_unmatched = fp.save_figure(violinplot_dsi_osi_unmatched, save_path=save_path, fig_width=16, dpi=1000, fontsize='screen', target='both', display_factor=0.1)

In [None]:
# Do a Wilcoxon Rank-Sum  with Bonferroni correction
comp_tuning_rank_sum = unmatched_dsi_osi.groupby(['variable', 'group'], group_keys=True).value.agg(list).unstack(level=1)
comp_tuning_stats = comp_tuning_rank_sum.apply(lambda x: st.mannwhitneyu(x['Freely Moving'], x['Head Fixed'], alternative='two-sided'), axis=1)
p_vals = [comp_tuning_stats.to_numpy()[i][1] for i in np.arange(comp_tuning_stats.shape[0])]
print(comp_tuning_stats)
multipletests(p_vals, alpha=0.05, method='bonferroni')

In [None]:
plt_list = []
df_list = []

for ds, exp_type in zip([agg_dict['VTuningWF_unmatched_norm_spikes_viewed_props'], agg_dict['VWheelWF_unmatched_norm_spikes_viewed_props']], 
                        ['Freely Moving', 'Head Fixed']):
    df = pd.DataFrame()
    df['OSI'] = ds.osi.abs().to_numpy()
    df['DSI'] = ds.dsi_abs.abs().to_numpy()
    df['group'] = [exp_type] * len(ds.osi)
    df_list.append(df)
    
    # plt_list.append(df.hvplot.box(legend=False))

unmatched_selectivity = pd.concat(df_list, axis=0)
unmatched_selectivity = pd.melt(unmatched_selectivity, id_vars=['group'], value_vars=['OSI', 'DSI'])

violinplot_selectivity_unmatched = hv.Violin(unmatched_selectivity, ['variable', 'group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=400,
                                                                  ylim=(-0.1, 1.1),
                                                                  violin_color=hv.dim('group'),
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )
save_path = os.path.join(figure_save_path, "unmatched_selectivity.png")
violinplot_selectivity_unmatched = fp.save_figure(violinplot_selectivity_unmatched, save_path=save_path, fig_width=16, dpi=1000, fontsize='poster', target='both', display_factor=0.1)

In [None]:
a = unmatched_selectivity[unmatched_selectivity.variable == 'OSI'].fillna(0).drop(columns='variable')
a = a.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_osi, edges_free_osi = np.histogram(a['Freely Moving'][0], 20)
frequencies_fixed_osi, edges_fixed_osi = np.histogram(a['Head Fixed'][0], 20)

unmatched_hist_osi = hv.Overlay([hv.Histogram((frequencies_free_osi, edges_free_osi), label='Freely Moving').opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_osi, edges_fixed_osi), label='Head Fixed').opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

b = unmatched_selectivity[unmatched_selectivity.variable == 'DSI'].fillna(0).drop(columns='variable')
b = b.groupby('group').agg(list).T.reset_index(drop=True)

frequencies_free_dsi, edges_free_dsi = np.histogram(b['Freely Moving'][0], 20)
frequencies_fixed_dsi, edges_fixed_dsi = np.histogram(b['Head Fixed'][0], 20)

unmatched_hist_dsi = hv.Overlay([hv.Histogram((frequencies_free_dsi, edges_free_dsi)).opts(alpha=0.5, fill_color=fp.hv_blue_rgb), 
                                  hv.Histogram((frequencies_fixed_dsi, edges_fixed_dsi)).opts(alpha=0.5, fill_color=fp.hv_orange_rgb)])

layout_unmatched = unmatched_hist_dsi.opts(height=300, width=400, xlabel='DSI') + unmatched_hist_osi.opts(height=300, width=500, xlabel='OSI', ylabel='', legend_position='right', fontsize={'legend': 10})
layout_unmatched

## Compare matched/unmatched strongly tuned

In [None]:
matched_selectivity['super_group'] = 'matched'
unmatched_selectivity['super_group'] = 'unmatched'
all_selectivity = pd.concat([matched_selectivity, unmatched_selectivity]).fillna(0)
violinplot_selectivity = hv.Violin(all_selectivity, ['variable', 'group', 'super_group'], 'value').opts(xlabel='', ylabel='Selectivity', 
                                                                  width=1000, height=200,
                                                                  ylim=(-0.05, 1.05),
                                                                  violin_color=hv.dim('super_group'),
                                                                  cmap='Category10',
                                                                  # show_legend=True, legend_position='right',
                                                                  fontsize={'legend': 10}
                                                                 )

# Do a Wilcoxon Rank-Sum  with Bonferroni correction
rank_sum = all_selectivity.groupby(['variable', 'group', 'super_group'], group_keys=True).value.agg(list).unstack(level=-1)
stats = rank_sum.apply(lambda x: st.ranksums(x['matched'], x['unmatched'], alternative='two-sided'), axis=1)
p_vals = [stats.to_numpy()[i][1] for i in np.arange(stats.shape[0])]
print(p_vals)
print(multipletests(p_vals, alpha=0.05, method='bonferroni'))

save_path = os.path.join(figure_save_path, "selectivity_dist_comp.png")
violinplot_selectivity = fp.save_figure(violinplot_selectivity, save_path=save_path, fig_width=30, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)

In [None]:
all_selectivity

In [None]:
save_paths = []
selectivity_array = all_selectivity.groupby(['variable', 'group', 'super_group']).agg(list).values
num_comps = a.shape[0]/2

selectivity_plot_list = []
for i, (a, b) in enumerate(zip(selectivity_array[::2], selectivity_array[1::2])):
    save_path = os.path.join(figure_save_path, f"selectivity_dist_comp_histogram_{i}.png")
    
    frequencies_matched, edges_matched = np.histogram(np.clip(a[0], 0, 1), 20)
    frequencies_matched = frequencies_matched.astype(float) / np.max(frequencies_matched)
    frequencies_unmatched, edges_unmatched = np.histogram(np.clip(b[0], 0, 1), 20)
    frequencies_unmatched = frequencies_unmatched.astype(float) / np.max(frequencies_unmatched.astype(float))

    matched_hist = hv.Histogram((frequencies_matched, edges_matched), label='matched').opts(alpha=0.5, xlabel='', width=800, height=600, fill_color=fp.hv_mpi_green_rgb, ylabel="Probability")
    unmatched_hist = hv.Histogram((frequencies_unmatched, edges_unmatched), label='unmatched').opts(alpha=0.5, xlabel='', width=800, height=600, fill_color=fp.hv_mpi_yellow_rgb)
    
    if i > 0:
        matched_hist.opts(yaxis=None, width=600)
        unmatched_hist.opts(yaxis=None, width=600)
    
    hist_overlay = hv.Overlay([matched_hist, unmatched_hist])
    hist_overlay.opts(show_legend=False)

    if i == 0:
        hist_overlay = fp.save_figure(hist_overlay, save_path=save_path, fig_width=8, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)
    else:
        hist_overlay = fp.save_figure(hist_overlay, save_path=save_path, fig_width=6, dpi=1200, fontsize='screen', target='screen', display_factor=0.1)

    selectivity_plot_list.append(hist_overlay)

selectivity_layout = hv.Layout(selectivity_plot_list)
selectivity_layout

In [None]:
matched_dsi_osi['super_group'] = 'matched'
unmatched_dsi_osi['super_group'] = 'unmatched'
all_dsi_osi = pd.concat([matched_dsi_osi, unmatched_dsi_osi]).fillna(0)
violinplot = hv.Violin(all_dsi_osi, ['variable', 'group', 'super_group'], 'value').opts(xlabel='', ylabel='Significant Frac.', 
                                                                  width=1000, height=200,
                                                                  ylim=(-0.05, 1),
                                                                  violin_color=hv.dim('super_group'),
                                                                  cmap=[fp.hv_mpi_green_hex, fp.hv_mpi_yellow_hex],
                                                                  fontsize={'legend': 10}
                                                                 )

# Do a Wilcoxon Rank-Sum  with Bonferroni correction
rank_sum = all_dsi_osi.groupby(['variable', 'group', 'super_group'], group_keys=True).value.agg(list).unstack(level=-1)
stats = rank_sum.apply(lambda x: st.ranksums(x['matched'], x['unmatched'], alternative='two-sided'), axis=1)
p_vals = [stats.to_numpy()[i][1] for i in np.arange(stats.shape[0])]
print(p_vals)
print(multipletests(p_vals, alpha=0.05, method='bonferroni'))

save_path = os.path.join(figure_save_path, "sig_frac_selectivity_dsi_osi.png")
violinplot = fp.save_figure(violinplot, save_path=save_path, fig_width=30, dpi=1200, fontsize='poster', target='both', display_factor=0.1)

In [None]:
sig_dsi_osi_array = all_dsi_osi.groupby(['variable', 'group', 'super_group']).agg(list).values
num_comps = a.shape[0]/2

dsi_osi_plot_list = []
for a, b in zip(sig_dsi_osi_array[::2], sig_dsi_osi_array[1::2]):
    frequencies_matched, edges_matched = np.histogram(a[0], 20)
    frequencies_unmatched, edges_unmatched = np.histogram(b[0], 20)

    hist_overlay =  hv.Overlay([hv.Histogram((frequencies_unmatched, edges_unmatched), label='unmatched').opts(alpha=0.5, fill_color=fp.hv_orange_rgb), 
                                hv.Histogram((frequencies_matched, edges_matched), label='matched').opts(alpha=0.5, fill_color=fp.hv_blue_rgb)])

    dsi_osi_plot_list.append(hist_overlay)

dsi_osi_layout = hv.Layout(dsi_osi_plot_list)
dsi_osi_layout

# UMAP

In [None]:
from umap.umap_ import UMAP
from rastermap import Rastermap
import sklearn.preprocessing as preproc

In [None]:
cell_kind = 'all_cells'   # options: 'all_cells', 'matched', 'unmatched'
kinem_label_list = processing_parameters.variable_list_free + processing_parameters.variable_list_fixed
labels = kinem_label_list + ['norm_spikes_viewed_props']
labels = ['_'.join((cell_kind, label)) for label in labels]

umap_dict = {}
for label in labels:
    agg_keys = [key for key in agg_dict.keys() if label in key]

    for key in agg_keys:
        ds = agg_dict[key]
        base_label = '_'.join(label.split('_')[len(cell_kind.split('_')):])
        if base_label in kinem_label_list: 
            tuning = ds['Qual_index']
            umap_dict[base_label] = tuning
        else:
            tuning_dsi = ds['dsi_abs']
            tuning_osi = ds['osi']
            umap_dict['dsi'] = np.clip(tuning_dsi, 0, 1)
            umap_dict['osi'] = np.clip(tuning_osi, 0, 1)


raw_tunings = pd.DataFrame.from_dict(umap_dict)
raw_tunings = raw_tunings.fillna(0)

In [None]:
# TEST - Run UMAP on matched cells
cell_kind = 'matched'   # options: 'all_cells', 'matched', 'unmatched'

if (parsed_search['result'] == 'repeat') and (parsed_search['rig'] == 'VWheelWF'):
    kinem_label_list = processing_parameters.variable_list_fixed
elif (parsed_search['result'] == 'repeat') and (parsed_search['rig'] == 'VTuningWF'):
    kinem_label_list = processing_parameters.variable_list_free
else:
    kinem_label_list = processing_parameters.variable_list_free + processing_parameters.variable_list_fixed

label_list = kinem_label_list + [activity_dataset]
labels = [f"_{cell_kind}_{label}" for label in label_list]

umap_dict = {}
for label in labels:
    data_keys = [key for key in agg_dict.keys() if label in key]

    for key in data_keys:
        ds = agg_dict[key]
        base_label = '_'.join(label.split('_')[1 + len(cell_kind.split('_')):])

        if base_label in kinem_label_list:
            tuning = ds['Qual_index']
            umap_dict[base_label] = tuning
        else:
            this_rig = key.split('_')[0]
            tuning_dsi = ds['dsi_abs'].abs().to_numpy()
            tuning_osi = ds['osi'].to_numpy()
            umap_dict[f'dsi_{this_rig}'] = np.clip(tuning_dsi, 0, 1)
            umap_dict[f'osi_{this_rig}'] = np.clip(tuning_osi, 0, 1)

max_col_length = max([len(col) for key, col in umap_dict.items()])
for key, value in umap_dict.items():
    if len(key) < max_col_length:
        new_val = np.empty(max_col_length)
        new_val.fill(np.nan)
        new_val[:len(value)] = value
        umap_dict[key] = new_val

raw_tunings = pd.DataFrame.from_dict(umap_dict)
raw_tunings = raw_tunings.fillna(0)

In [None]:
tunings = preproc.StandardScaler().fit_transform(raw_tunings.to_numpy())

In [None]:
tuning_subset = tunings[np.random.choice(tunings.shape[0], tunings.shape[0], replace=False)]

In [None]:
# perform umap on the fit cell tuning
reducer1 = UMAP(min_dist=0.1, n_neighbors=50)
embedded_data1 = reducer1.fit_transform(tuning_subset)

In [None]:
perc = 99
predictor_columns = umap_dict.keys()    #['dsi', 'osi', 'responsivity']
plot_list = []

for i, predictor_column in enumerate(predictor_columns):
    label_idx = [idx for idx, el in enumerate(predictor_columns) if predictor_column == el]
    raw_labels = tuning_subset[:, label_idx]
    
    raw_labels = np.abs(raw_labels)
    
    raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
    raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)
    
    plot_data = np.concatenate([embedded_data1, raw_labels.reshape((-1, 1))], axis=1)

    umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
    # umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
    if i == len(predictor_columns)-1:
        umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None, yaxis=None, tools=['hover'])
    else:
        umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None, yaxis=None, tools=['hover'])
    
    if any([predictor_column.startswith('dsi'), predictor_column.startswith('osi')]):
        umap_plot.opts(title=f"{predictor_column[:3].upper()} {processing_parameters.wf_label_dictionary[predictor_column.split('_')[-1]]}")
    else:
        umap_plot.opts(title=processing_parameters.wf_label_dictionary[predictor_column])

        umap_plot.opts(width=300, height=300, size=2)

    save_name = os.path.join(figure_save_path, f"UMAP_{predictor_column}_{cell_kind}.png")   
    umap_plot = fp.save_figure(umap_plot, save_path=save_name, fig_width=6, dpi=1000, fontsize='screen', target='screen', display_factor=0.1)
    plot_list.append(umap_plot)

In [None]:
layout = hv.Layout(plot_list).cols(5)
layout

In [None]:
# umap plot of the fit cell tunings
umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
# umap_plot = hv.HexTiles(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
umap_plot.opts(colorbar=True, color='Parameter', cmap='Spectral_r', tools=['hover'], alpha=1)
umap_plot.opts(width=500, height=500, size=5)

# assemble the file name
# save_name = os.path.join(save_path, '_'.join(('poster', 'UMAP')) + '.png')
# # save the figure
# fig = fp.save_figure(umap_plot, save_name, fig_width=15, dpi=1200, fontsize='poster', target='screen')

In [None]:
# plot the tunings from MINE
ticks = [(idx+0.5, processing_parameters.wf_label_dictionary[el]) for idx, el in enumerate(predictor_columns)]
plot_matrix = raw_tunings.dropna().to_numpy().copy().T
plot_matrix[plot_matrix<0.05] = 0
model = Rastermap(n_clusters=2, n_PCs=200)
model.fit(plot_matrix)
plot_matrix = plot_matrix[model.isort, :]
plot = hv.Raster(plot_matrix)
plot.opts(width=1000, height=600, cmap='RdBu_r', tools=['hover'], clim=(-1, 1), xticks=ticks, xrotation=45, xlabel='', ylabel='Cells', colorbar=True)

# assemble the file name
save_name = os.path.join(save_path, '_'.join(('Fig5', 'MINE_tunings')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=7, dpi=1200, fontsize='poster', target='screen')

In [None]:
cell_kinds = ['all_cells', 'matched', 'unmatched']  # options: 'all_cells', 'matched', 'unmatched'
reducer = UMAP(min_dist=0.1, n_neighbors=50)
for cell_kind in cell_kinds:
    labels = [f"_{cell_kind}_{label}" for label in label_list]

    # If the cells are matched, we always have the same number of cells
    if cell_kind == "matched":
        umap_dict = {}
        for label in labels:
            data_keys = [key for key in agg_dict.keys() if label in key]

            for key in data_keys:
                ds = agg_dict[key]
                base_label = '_'.join(label.split('_')[1 + len(cell_kind.split('_')):])

                if base_label in kinem_label_list:
                    tuning = ds['Qual_index']
                    umap_dict[base_label] = tuning
                else:
                    this_rig = key.split('_')[0]
                    tuning_dsi = ds['dsi_abs'].abs().to_numpy()
                    tuning_osi = ds['osi'].to_numpy()
                    umap_dict[f'dsi_{this_rig}'] = np.clip(tuning_dsi, 0, 1)
                    umap_dict[f'osi_{this_rig}'] = np.clip(tuning_osi, 0, 1)

        raw_tunings = pd.DataFrame.from_dict(umap_dict)
        raw_tunings = raw_tunings.fillna(0)

        tunings = preproc.StandardScaler().fit_transform(raw_tunings.to_numpy())

        # perform umap on the fit cell tuning
        embedded_data = reducer.fit_transform(tunings)

        perc = 99
        predictor_columns = umap_dict.keys()
        plot_list = []

        for i, predictor_column in enumerate(predictor_columns):
            label_idx = [idx for idx, el in enumerate(predictor_columns) if predictor_column == el]
            raw_labels = tunings[:, label_idx]

            raw_labels = np.abs(raw_labels)

            raw_labels[raw_labels > np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
            raw_labels[raw_labels < np.percentile(raw_labels, 100 - perc)] = np.percentile(raw_labels, 100 - perc)

            plot_data = np.concatenate([embedded_data, raw_labels.reshape((-1, 1))], axis=1)

            umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
            # umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
            if i == len(predictor_columns) - 1:
                umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None,
                                yaxis=None,
                                tools=['hover'])
            else:
                umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None,
                                yaxis=None,
                                tools=['hover'])

            if any([predictor_column.startswith('dsi'), predictor_column.startswith('osi')]):
                umap_plot.opts(title=f"{predictor_column[:3].upper()} "
                                        f"{processing_parameters.wf_label_dictionary[predictor_column.split('_')[-1]]}")
            else:
                umap_plot.opts(title=processing_parameters.wf_label_dictionary[predictor_column])

            umap_plot.opts(width=300, height=300, size=2)

            save_name = os.path.join(figure_save_path, f"{cell_kind}_UMAP_{predictor_column}.png")
            umap_plot = fp.save_figure(umap_plot, save_path=save_name, fig_width=6, dpi=800, fontsize='paper',
                                        target='screen', display_factor=0.1)

    else:
        # Here there may be uneven numbers of cells between sessions
        umap_dict_1 = {}
        umap_dict_2 = {}

        for label in labels:
            data_keys = [key for key in agg_dict.keys() if label in key]

            for key in data_keys:
                ds = agg_dict[key]
                base_label = '_'.join(label.split('_')[1 + len(cell_kind.split('_')):])

                if base_label in kinem_label_list:
                    tuning = ds['Qual_index']

                    if base_label in processing_parameters.variable_list_free:
                        umap_dict_1[base_label] = tuning
                    else:
                        umap_dict_2[base_label] = tuning
                else:
                    this_rig = key.split('_')[0]
                    tuning_dsi = ds['dsi_abs'].abs().to_numpy()
                    tuning_osi = ds['osi'].to_numpy()

                    if this_rig == 'VTuningWF':
                        umap_dict_1[f'dsi_{this_rig}'] = np.clip(tuning_dsi, 0, 1)
                        umap_dict_1[f'osi_{this_rig}'] = np.clip(tuning_osi, 0, 1)
                    else:
                        umap_dict_2[f'dsi_{this_rig}'] = np.clip(tuning_dsi, 0, 1)
                        umap_dict_2[f'osi_{this_rig}'] = np.clip(tuning_osi, 0, 1)

        raw_tunings_1 = pd.DataFrame.from_dict(umap_dict_1)
        raw_tunings_1 = raw_tunings_1.fillna(0)
        raw_tunings_2 = pd.DataFrame.from_dict(umap_dict_2)
        raw_tunings_2 = raw_tunings_2.fillna(0)

        tunings_1 = preproc.StandardScaler().fit_transform(raw_tunings_1.to_numpy())
        tunings_2 = preproc.StandardScaler().fit_transform(raw_tunings_2.to_numpy())

        # perform umap on the fit cell tuning
        embedded_data_1 = reducer.fit_transform(tunings_1)
        embedded_data_2 = reducer.fit_transform(tunings_2)

        perc = 99
        plot_list = []
        
        for umap_dict, embedded_data, tunings in zip([umap_dict_1, umap_dict_2], 
                                            [embedded_data_1, embedded_data_2], 
                                            [tunings_1, tunings_2]):
            
            predictor_columns = umap_dict.keys()

            for i, predictor_column in enumerate(predictor_columns):
                label_idx = [idx for idx, el in enumerate(predictor_columns) if predictor_column == el]
                raw_labels = tunings[:, label_idx]

                raw_labels = np.abs(raw_labels)

                raw_labels[raw_labels > np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
                raw_labels[raw_labels < np.percentile(raw_labels, 100 - perc)] = np.percentile(raw_labels, 100 - perc)

                plot_data = np.concatenate([embedded_data, raw_labels.reshape((-1, 1))], axis=1)

                umap_plot = hv.Scatter(plot_data, vdims=['Dim 2', 'Parameter'], kdims=['Dim 1'])
                # umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
                if i == len(predictor_columns) - 1:
                    umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None,
                                    yaxis=None,
                                    tools=['hover'])
                else:
                    umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None,
                                    yaxis=None,
                                    tools=['hover'])

                if any([predictor_column.startswith('dsi'), predictor_column.startswith('osi')]):
                    umap_plot.opts(title=f"{predictor_column[:3].upper()} "
                                            f"{processing_parameters.wf_label_dictionary[predictor_column.split('_')[-1]]}")
                else:
                    umap_plot.opts(title=processing_parameters.wf_label_dictionary[predictor_column])

                umap_plot.opts(width=300, height=300, size=2)

                save_name = os.path.join(figure_save_path, f"{cell_kind}_UMAP_{predictor_column}.png")
                umap_plot = fp.save_figure(umap_plot, save_path=save_name, fig_width=6, dpi=800, fontsize='paper',
                                            target='screen', display_factor=0.1)

In [None]:
layout = hv.Layout(plot_list).cols(5)
layout