In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import numpy as np
import pandas as pd
import panel as pn
import holoviews as hv
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact, interactive

import os
import sys
import random
import importlib
import datetime
import warnings
import math
warnings.filterwarnings('ignore')

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt


sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
cm = 1./2.54

In [None]:
figure_save_path = r"C:\Users\mmccann\Dropbox\bonhoeffer lab\TAC\TAC5\presentation\media"

In [None]:
def plot_compare_matched_vis_tuning(ds_list, cell, tuning_kind='direction', exp_name='', error='std', norm=True, polar=True, axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(32*cm, 2*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(ds_list), ncols=1, hspace=0.07)
    
        for i, subfig in enumerate(subfigs):
            subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # direction tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(222) #  resp
            ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2, ax3])
            axes.append(ax)

    for sub_ax, ds in zip(axes, ds_list):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind, error=error, norm=norm, polar=polar, axes=sub_ax)

    return fig, axes

def plot_compare_all_vis_tuning(ds_list, cell, tuning_kind=['direction', 'orientation'], exp_name=[''], error='std', norm=True, polar=True, axes=None, **kwargs):
    
    if axes is None:
        axes = []
        figsize = kwargs.get('figsize', (22*cm, 5*cm*(len(ds_list)//len(tuning_kind))))
        fig = plt.figure(layout='constrained', figsize=figsize)
        # fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(exp_name), ncols=2, hspace=0.07)
    
        for i, subfig in enumerate(subfigs.flatten()):
            # subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[(i-1)//len(exp_name)]].title())
            
            if polar:
                ax1 = subfig.add_subplot(121, projection="polar") # tuning
            else:
                ax1 = subfig.add_subplot(121) # tuning
                
            ax2 = subfig.add_subplot(122) #  resp
            # ax3 = subfig.add_subplot(224) #  error
            ax = np.array([ax1, ax2])
            axes.append(ax)

    tuning_kind = ['direction', 'orientation'] * 2
    for i, (ds, sub_ax) in enumerate(zip(ds_list, axes)):
        sub_ax = fp.plot_tuning_with_stats(ds, cell, tuning_kind=tuning_kind[i], error=error, norm=norm, polar=polar, axes=sub_ax, figsize=(figsize[0]/2, figsize[1]/2))

    return fig, axes

In [None]:
def plot_comapre_kine_tuning(data_list, cell, exp_name='', axes=None):
    if axes is None:
        axes = []
        fig = plt.figure(layout='constrained', figsize=(35*cm, 10*len(data_list)*cm))
        fig.suptitle(f"Cell {cell}", fontsize='x-large')
        subfigs = fig.subfigures(nrows=len(data_list), ncols=1, hspace=0.07)

        try:
            for i, (ds, subfig) in enumerate(zip(data_list, subfigs)):
                # subfig.suptitle(processing_parameters.wf_label_dictionary[exp_name[i]].title())
                kinem_keys = [key for key in ds.keys() if ('props' not in key) and 
                      ('counts' not in key) and ('edges' not in key) and (key != "cell_matches") and (key != 'rig')]
                if len(kinem_keys) > 5:
                    rows = 2
                    cols = math.ceil(len(kinem_keys) / rows)
                    ax = subfig.subplots(ncols=cols, nrows=rows, sharey=True)
                else:
                    rows = 1
                    cols = len(kinem_keys)
                    ax = subfig.subplots(ncols=cols, nrows=rows, sharey=True)
                axes.append(ax.flatten())
                
        except TypeError:
                # subfigs.suptitle(processing_parameters.wf_label_dictionary[exp_name].title())
                kinem_keys = [key for key in data_list[0].keys() if ('props' not in key) and 
                      ('counts' not in key) and ('edges' not in key) and (key != "cell_matches") and (key != 'rig')]
                if len(kinem_keys) > 5:
                    rows = 2
                    cols = math.ceil(len(kinem_keys) / rows)
                    ax = subfigs.subplots(ncols=cols, nrows=rows, sharey=True)
                else:
                    rows = 1
                    cols = len(kinem_keys)
                    ax = subfigs.subplots(ncols=cols, nrows=rows, sharey=True)
                axes.append(ax.flatten())

    for i, (sub_ax, data_dict) in enumerate(zip(axes, data_list)):
        kinem_keys = [key for key in data_dict.keys() if ('props' not in key) and 
                      ('counts' not in key) and ('edges' not in key) and (key != "cell_matches") and (key != 'rig')]
        
        for j, (ax, k_key) in enumerate(zip(sub_ax, kinem_keys)):
            data = data_dict[k_key]
            bin_cols = [col for col in data.columns if ('half' not in col) and ('bin'in col)]
            current_bins = processing_parameters.tc_params[k_key]
            bins0 = np.linspace(current_bins[0], current_bins[1], processing_parameters.bin_number)
            ax.plot(bins0, data.loc[cell, bin_cols].fillna(0).to_numpy())
            ax.set_xlabel(processing_parameters.wf_label_dictionary[k_key])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if j % cols != 0:
                ax.spines['left'].set_visible(False)
                ax.yaxis.set_ticks_position('none') 
            else:
                ax.set_ylabel('Activity [a.u.]')

    return fig, axes

In [None]:
def plot_comapre_kine_tuning_hv(data_list, cell, exp_name='', axes=None):

    for i, data_dict in enumerate(data_list):
        kinem_keys = [key for key in data_dict.keys() if ('props' not in key) and 
                      ('counts' not in key) and ('edges' not in key) and (key != "cell_matches") and (key != 'rig')]

        plot_list = []
        for j, k_key in enumerate(kinem_keys):
            data = data_dict[k_key]
            bin_cols = [col for col in data.columns if ('half' not in col) and ('bin'in col)]
            half_0_bin_cols = [col for col in data.columns if ('half_0' in col) and ('bin'in col)]
            half_0_bin_cols = [col for col in data.columns if ('half_1' in col) and ('bin'in col)]


            current_bins = processing_parameters.tc_params[k_key]
            bins0 = np.linspace(current_bins[0], current_bins[1], processing_parameters.bin_number)
            var = processing_parameters.wf_label_dictionary[k_key]
            curve = hv.Curve((bins0, data.loc[cell, bin_cols].fillna(0).to_numpy()), str(var), 'activity').opts(xlabel=var, ylabel='Activity [a.u.]', line_width=2)


            
            spread
            plot_list.append(curve)
            
    return plot_list

In [None]:
importlib.reload(processing_parameters)

# get the search string
search_string = processing_parameters.search_string  # + r", analysis_type:tc_analysis"
parsed_search = fdh.parse_search_string(search_string)

# get the paths from the database
file_infos = bd.query_database("analyzed_data", search_string)
tc_paths = np.array([el['analysis_path'] for el in file_infos if ('_tcday' in el['slug']) and
                        (parsed_search['mouse'].lower() in el['slug'])])
rigs = np.array([el['rig'] for el in file_infos if ('_tcday' in el['slug']) and
                 (parsed_search['mouse'].lower() in el['slug'])])
rev_sort = np.argsort(rigs)[::-1]
rigs = rigs[rev_sort]
tc_paths = tc_paths[rev_sort]
preproc_paths = [file.replace('tcday', 'preproc') for file in tc_paths]

# assemble the output path
print(tc_paths)
print(rigs)

tc_data_list = []
for file in tc_paths:
    data_dict = {}
    with pd.HDFStore(file, 'r') as tc:
        for key in tc.keys():
            data_dict[key.split('/')[-1]] = tc[key]
    tc_data_list.append(data_dict)

preproc_exp_list = []
for file in preproc_paths:
    exp = WirefreeExperiment(file, use_xarray=False)
    exp.dff = tuning.calculate_dff(exp.raw_fluor)
    exp.norm_spikes = tuning.normalize_responses(exp.raw_spikes)
    exp.norm_fluor = tuning.normalize_responses(exp.raw_fluor)
    exp.norm_dff = tuning.normalize_responses(exp.dff)
    preproc_exp_list.append(exp)

# Match cells
matches = tc_data_list[0]['cell_matches'].dropna()
matched_tc_data_list = []
matched_preproc_exp = []

kinem_keys = []

for tc_data, preproc_exp, rig in zip(tc_data_list, preproc_exp_list, rigs):
    matched_dict = {}
    match_col = np.where([rig in el for el in matches.columns])[0][0]
    match_idxs = matches.iloc[:, match_col].to_numpy(dtype=int)
    matched_dict['rig'] = rig

    # handle tuning curves
    kinem = [key for key in tc_data.keys() if ('props' not in key) and ('counts' not in key) 
             and ('edges' not in key) and (key != "cell_matches") and (key != 'rig')]
    kinem_keys.append(kinem)
    
    for feature in tc_data.keys():
        # Save matched TCs
        if ('counts' not in feature) and ('edges' not in feature) and (feature != "cell_matches"):
            matched_dict[feature] = tc_data[feature].iloc[match_idxs, :].reset_index(names='original_cell_id')

    matched_tc_data_list.append(matched_dict)

    # handle preproc data
    noncell_cols = [col for col in preproc_exp.dff.columns if 'cell' not in col]
    matched_cell_cols = [f'cell_{idx:04}' for idx in match_idxs]
    rename_cell_cols = [f'cell_{idx:04}' for idx in np.arange(len(match_idxs))]
    rename_dict = dict(zip(matched_cell_cols, rename_cell_cols))
    preproc_exp.matched_dff = preproc_exp.dff.loc[:, noncell_cols + matched_cell_cols].rename(columns=rename_dict)
    preproc_exp.matched_norm_spikes = preproc_exp.norm_spikes.loc[:, noncell_cols + matched_cell_cols].rename(columns=rename_dict)
    preproc_exp.matched_norm_fluor = preproc_exp.norm_fluor.loc[:, noncell_cols + matched_cell_cols].rename(columns=rename_dict)
    preproc_exp.matched_norm_dff = preproc_exp.norm_dff.loc[:, noncell_cols + matched_cell_cols].rename(columns=rename_dict)
    matched_preproc_exp.append(preproc_exp)

In [None]:
np.std(ds['head_direction'].half_1_bin_0)

In [None]:
np.std(ds['head_direction'].half_0_bin_0)

In [None]:
ds['head_direction'].bin_0

# View Tuning Curves

In [None]:
datasets[0].shuffle_dsi_abs[1]

In [None]:
ds_names = [key for key in matched_tc_data_list[0].keys() if ('props' in key) and ('still' in key)]
datasets = [dataset[ds_name] for dataset in matched_tc_data_list for ds_name in ds_names]
print(ds_names)
print(datasets[0].columns)

In [None]:
hist_resp = datasets[0].shuffle_dsi_abs[1]
hist_resp[~np.isfinite(hist_resp)] = 0
real_resp = datasets[0].dsi_abs[1]
edges = np.arange(0, 1, 0.05)
plt.hist(np.abs(hist_resp), bins=edges, edgecolor="black")
plt.axvline(x=abs(real_resp), color='r', linestyle='dashed', linewidth=2)

In [None]:
len(datasets)

In [None]:
ds = datasets[1]
cell = 9
counts, edges = np.histogram(ds.bootstrap_osi[cell])
a = ds.bootstrap_osi[cell]
a[a == -np.inf] = 0
plt.hist(a, bins=20)
plt.axvline(ds.osi[cell], color='r')
# plt.xlim((-1,1))

In [None]:
datasets[1].trial_resp_norm[1]

In [None]:
ds_names = [key for key in matched_tc_data_list[0].keys() if ('props' in key) and ('still' in key)]
datasets = [dataset[ds_name] for dataset in matched_tc_data_list for ds_name in ds_names]

matched_vis_fig = interactive(plot_compare_all_vis_tuning,
                              ds_list=widgets.fixed(datasets),
                              cell=np.arange(len(matches)), 
                              tuning_kind = widgets.fixed(['direction', 'orientation']),
                              exp_name = widgets.fixed(rigs),
                              error = ['std', 'sem'],
                              polar=[True, False],
                              norm=widgets.fixed(True),
                              axes=widgets.fixed(None))
matched_vis_fig

In [None]:
matched_vis_fig.result[0].savefig(os.path.join(figure_save_path, 'Fig3', 'matched_cell_tc.png'), dpi=1000, format='png')

In [None]:
matched_kinem_fig = interactive(plot_comapre_kine_tuning,
                                 data_list=widgets.fixed(matched_tc_data_list),
                                 cell=np.arange(len(matches)), 
                                 exp_name=widgets.fixed(rigs),
                                 axes=widgets.fixed(None))
matched_kinem_fig

In [None]:
matched_kinem_fig.result[0].savefig(os.path.join(figure_save_path, 'Fig3', 'matched_cell8_tc.svg'), dpi=1000, format='svg')

In [None]:
frac_cons = matched_tc_data_list[1]['head_direction']['Cons_test'].sum() / matched_tc_data_list[1]['head_direction']['Cons_test'].count()
frac_resp = matched_tc_data_list[1]['head_direction']['Resp_test'].sum() / matched_tc_data_list[1]['head_direction']['Resp_test'].count()
frac_qual = matched_tc_data_list[1]['head_direction']['Qual_test'].sum() / matched_tc_data_list[1]['head_direction']['Qual_test'].count()

## Plot cells with the highest responsivity index

In [None]:
ds_names = [key for key in matched_tc_data_list[0].keys() if ('props' in key) and ('still' in key)]

idx = 1
rig = rigs[idx]
ds = tc_data_list[idx]
cells = np.arange(ds['norm_spikes_viewed_still_direction_props'].shape[0])
sorted_responsivity = ds['norm_spikes_viewed_still_direction_props'].reset_index(names='original_cell_id').sort_values('responsivity', ascending=False)
responsive_cells = sorted_responsivity.index.to_list()

vis_tc_fig = interactive(plot_compare_all_vis_tuning,
                              ds_list=widgets.fixed([ds[name] for name in ds_names]),
                              cell=responsive_cells, 
                              tuning_kind = widgets.fixed(['direction', 'orientation']),
                              exp_name = widgets.fixed([rig]),
                              error = ['std', 'sem'],
                              polar=[True, False],
                              norm=widgets.fixed(True),
                              axes=widgets.fixed(None))
vis_tc_fig

In [None]:
vis_tc_fig.result[0].savefig(os.path.join(figure_save_path, 'Fig2', 'cell20_free_vis_tc.png'), dpi=1000, format='png')

In [None]:
kinem_tc_fig = interactive(plot_comapre_kine_tuning,
             data_list=widgets.fixed([ds]),
             cell=responsive_cells, 
             exp_name=widgets.fixed(rigs[idx]),
             axes=widgets.fixed(None))
kinem_tc_fig

In [None]:
kinem_tc_fig.result[0].savefig(os.path.join(figure_save_path, 'Fig2', 'cell20_free_kinem_tc.png'), dpi=1000, format='png')

In [None]:
kinem_tc_fig = interactive(plot_comapre_kine_tuning_hv,
             data_list=widgets.fixed([ds]),
             cell=responsive_cells, 
             exp_name=widgets.fixed(rigs[idx]),
             axes=widgets.fixed(None))
kinem_tc_fig

In [None]:
fig_list = []
for i, subfig in enumerate(kinem_tc_fig.result):
    subfig.opts(width=500, height=300, xrotation=45, ylabel='', yaxis=None)
    subfig = fp.save_figure(subfig, save_path=os.path.join(figure_save_path, 'Fig2', f'cell61_free_kinem_tc_{i}.png'), fig_width=7, dpi=1000, fontsize='poster', target='save', display_factor=0.1)
    fig_list.append(subfig)

layout = hv.Layout(fig_list).cols(5)
layout

# View heatmaps

In [None]:
%matplotlib widget
cell = 'cell_0089'
angle_type = 'direction_wrapped'
a = preproc_exp_list[0].norm_spikes.groupby([angle_type, 'trial_num'])[cell].agg(list)
a = a.droplevel('trial_num').groupby(angle_type).agg(list)
a = a.apply(misc.list_lists_to_array).reset_index()
a = a.drop(a[a[angle_type] == -1000].index)
a['mean_trace'] = a[cell].map(lambda x: np.nanmean(x, axis=0))
fig = plt.figure(figsize=(10,5))
for idx, row in a.iterrows():
    plt.plot(row['mean_trace'][20:120])
plt.legend([f'{val:.1f}' for val in a[angle_type].values], ncol=3)

In [None]:
def raster_spikes(data, fig=None):
    if fig is None:
        fig = plt.figure(figsize=(15,5))
        ax = plt.axes()
        
    im = plt.imshow(data, cmap='Reds', aspect='auto')

    cax = fig.add_axes([ax.get_position().x1 + 0.01, ax.get_position().y0, 0.02, ax.get_position().height])
    cbar = plt.colorbar(im, cax=cax)
    cbar.set_label('Activity (a.u.)', rotation=270, labelpad=20)
    # t_label = np.arange(0, 121, 20)
    # plt.xticks(t_label, t_label*0.05)
    # plt.yticks(np.arange(0, mean_resp.shape[0]), [f'{d:.2f}' for d in directions])
    # plt.title(cell.replace("_", " ").capitalize())
    # plt.xlabel('Time [sec]')
    # plt.ylabel('Directions [$^{\circ}$]') 

    return fig

In [None]:
cell = 'cell_0109'
directions = a[angle_type].to_numpy()
mean_resp = a['mean_trace'].to_list()
mean_resp = misc.list_lists_to_array(mean_resp)
fig = plt.figure(figsize=(15,5))
ax = plt.axes()
im = plt.imshow(mean_resp[:, :120], cmap='Reds', aspect='auto', vmax=1.0)
t_label = np.arange(0, 121, 20)
plt.xticks(t_label, t_label*0.05)
plt.yticks(np.arange(0, mean_resp.shape[0]), [f'{d:.2f}' for d in directions])
plt.title(cell.replace("_", " ").capitalize())
plt.xlabel('Time [sec]')
plt.ylabel('Directions [$^{\circ}$]') 


cax = fig.add_axes([ax.get_position().x1 + 0.01, ax.get_position().y0, 0.02, ax.get_position().height])
cbar = plt.colorbar(im, cax=cax)
# cbar.clim(0, 1.0)
cbar.set_label('Nprm. Activity (a.u.)', rotation=270, labelpad=20)
# plt.compact_layout()

save_path = os.path.join(r'C:\Users\mmccann\Dropbox\bonhoeffer lab\SFN 2023\poster', f'spikes_{cell}.png')
# fp.save_figure(fig, fig_width=8, save_path=save_path, fontsize='poster', target='both')
plt.savefig(save_path, dpi=600)

In [None]:
cells = [col for col in preproc_exp_list[0].raw_spikes.columns if 'cell' in col]
a = preproc_exp_list[0].norm_spikes[cells].to_numpy()

In [None]:
fig = raster_spikes(a.T)