In [None]:
# imports
import os
import sys
sys.path.insert(0, os.path.abspath(r'D:\Code Repos\prey_capture'))

import panel as pn
import holoviews as hv
from bokeh.io import export_svgs, export_png
from holoviews import opts, dim
from holoviews.operation import histogram
hv.extension('bokeh')
from bokeh.resources import INLINE

import functions_bondjango as bd
import functions_plotting as fp
import paths

import importlib
import processing_parameters
import datetime
import pandas as pd
import numpy as np
from pprint import pprint
import random
import umap
import scipy.stats as stat
from rastermap import Rastermap


In [None]:
# set up the figure config
importlib.reload(fp)
importlib.reload(processing_parameters)
# define the target saving path
save_path = os.path.join(paths.figures_path, 'TC_aggregate')

# define the printing mode (REMEMBER TO CHECK IT'S NOT IN MANUAL IN THE ACTUAL CELL)
save_mode = True
# define the target document
target_document = 'paper'
# set up the figure theme
fp.set_theme()
# load the label dict
label_dict = processing_parameters.label_dictionary
variable_list = processing_parameters.variable_list

In [None]:
# Load the tc_consolidate file
importlib.reload(processing_parameters)

# get the search query
search_consolidate = processing_parameters.search_consolidate

# query the database for data to plot
data_path = bd.query_database('analyzed_data', search_consolidate)
# data_path = [el['analysis_path'] for el in data_path if 'test' not in el['analysis_path']][0]
data_path = [el['analysis_path'] for el in data_path if 'test' not in el['analysis_path']]
pprint(data_path)

In [None]:
# define the target feature and load it
# target_feature = 'mouse_x__mouse_y'
# target_feature = 'cricket_0_mouse_distance__cricket_0_delta_heading'
# target_feature = 'cricket_0_x__cricket_0_y'
# target_feature = 'mouse_speed'
target_feature = 'cricket_0_mouse_distance'
# target_feature = 'latent_0__latent_1'

data = []
# load the data
for path in data_path:
# data = pd.read_hdf(data_path, target_feature)
    try:
        data.append(pd.read_hdf(path, target_feature))
        print(f'Dimensions of the data (cells by feature): {data[-1].shape}')
    except KeyError:
        continue
data = pd.concat(data, axis=0)

# define a dictionary for the combined labels
combo_labels = {
    'mouse_x__mouse_y': 'Mouse Position',
    'cricket_0_mouse_distance__cricket_0_delta_heading': 'Prey angle and distance',
    'cricket_0_x__cricket_0_y': 'Prey position', 
    'mouse_speed__cricket_0_speed': 'Mouse and prey speed'
}
print(data.columns)

In [None]:
# plot fraction of selective cells over animals and time
# select the test columns and the day and animal ones
analysis_df = data[['Resp_test', 'Cons_test', 'day', 'animal']]
# generate a binary vector with cells passing both criteria
both = ((analysis_df.loc[:, 'Resp_test']>0) & (analysis_df.loc[:, 'Cons_test']>0)).to_numpy()
# insert this as a column in the dataframe
analysis_df.insert(analysis_df.shape[1], 'Pass_fraction', both) 
# group by animals and day
sums = analysis_df.groupby(['animal', 'day'], as_index=False)['Pass_fraction'].mean()
# allocate space for the dates
date_list = []
# convert all the dates to intervals
for date in sums.loc[:, 'day']:
    new_day = datetime.datetime.strptime(date, '%m_%d_%Y')
    date_list.append(new_day)
# replace the day column
sums['day'] = date_list
# allocate memory for the delta time df
df_list = []
# for all animals
for idx, (animal, df) in enumerate(sums.groupby(['animal'])):
    # reset the index of the sub dataframe
    df.reset_index(drop=True, inplace=True)
    # get the delta time
    delta_time = [(el-df['day'][0]).days for el in df['day']]
    df['day'] = delta_time
    # build the label
    label = 'M'+str(idx)
    # generate the plot
    plot = hv.Curve(df[['day', 'Pass_fraction']], label=label)
    plot.opts(ylim=(-0.01, 1))
    # format the plot
    plot = fp.format_figure(plot, width=600, height=400)

    # store
    df_list.append(plot)

# build the overlay and display
out_plot = hv.Overlay(df_list).opts(show_legend=True, legend_position='top')

# display the figure
out_plot

# assemble the file name
# save_name = os.path.join(save_path, '_'.join(('pass_fraction', target_feature)) + '.png')
# save the figure
# _ = fp.save_figure(out_plot, save_name, fig_width=6, dpi=600)

In [None]:
# plot responsivity and consistency over animals and time

analysis_df = data[['Resp_index', 'Cons_index', 'day', 'animal']]
analysis_df.loc[np.isnan(analysis_df.loc[:, 'Resp_index'].to_numpy()), 'Resp_index'] = 0
analysis_df.loc[np.isinf(analysis_df.loc[:, 'Resp_index'].to_numpy()), 'Resp_index'] = 0
analysis_df.loc[np.isnan(analysis_df.loc[:, 'Cons_index'].to_numpy()), 'Cons_index'] = 0
analysis_df.loc[np.isinf(analysis_df.loc[:, 'Cons_index'].to_numpy()), 'Cons_index'] = 0

# include only cells that are significant
analysis_df = analysis_df.iloc[both, :]

sums = analysis_df.groupby(['day', 'animal'], as_index=False)[['day', 'Resp_index', 'Cons_index']].mean()

# allocate space for the dates
date_list = []
# convert all the dates to intervals
for date in sums.loc[:, 'day']:
    new_day = datetime.datetime.strptime(date, '%m_%d_%Y')
    date_list.append(new_day)
# replace the day column
sums.loc[:, 'day'] = date_list
# allocate memory for the delta time df
resp_list = []
cons_list = []
# for all animals
for idx, (animal, df) in enumerate(sums.groupby(['animal'])):
    # reset the index of the sub dataframe
    df.reset_index(drop=True, inplace=True)
    # get the delta time
    delta_time = [(el-df['day'][0]).days for el in df['day']]
    df.loc[:, 'day'] = delta_time
    
#     df = df.iloc[both, :]
    # build the label
    label = 'M'+str(idx)
    # plot and save
    resp_plot = hv.Curve(df[['day', 'Resp_index']], label=label)
    resp_plot.opts(ylim=(1, 6))
    resp_plot = fp.format_figure(resp_plot, width=600, height=400)
    cons_plot = hv.Curve(df[['day', 'Cons_index']], label=label)
    cons_plot.opts(ylim=(0, 0.7))
    cons_plot = fp.format_figure(cons_plot, width=600, height=400)
    resp_list.append(resp_plot)
    cons_list.append(cons_plot)

resp_overlay = hv.Overlay(resp_list).opts(legend_position='top', show_legend=True)
cons_overlay = hv.Overlay(cons_list).opts(legend_position='top', show_legend=True)
# print(sums)
# combo_plot = (hv.Overlay(resp_list)+hv.Overlay(cons_list)).cols(1)
# combo_plot.opts(opts.Overlay(legend_position='top', show_legend=True))
# display figure
(resp_overlay+cons_overlay).cols(1)

# assemble the file name
# save_name = os.path.join(save_path, '_'.join(('resp_index', target_feature)) + '.png')
# save the figure
# _ = fp.save_figure(resp_overlay, save_name, fig_width=6, dpi=600)


# assemble the file name
# save_name = os.path.join(save_path, '_'.join(('cons_index', target_feature)) + '.png')
# save the figure
# _ = fp.save_figure(cons_overlay, save_name, fig_width=6, dpi=600)

In [None]:
# # example cells
# random.seed(1)
# # define how many cells to plot per animal
# number_cells = 10
# # define the criterion
# # criterion = 'random'
# criterion_list = ['top', 'random']
# # define the colormap to use
# cmap = 'Purples'

# for criterion in criterion_list:

#     # get a list of all the animals
#     animal_list = np.unique(data['animal'])
#     print(animal_list)
#     # get current bins
#     current_bins = processing_parameters.tc_params[target_feature]
#     plot_list = []
#     # for all the animals
#     for animal in animal_list:
#         # get the cells for this animal
#         current_cells = data.iloc[data.loc[:, 'animal'].to_numpy()==animal, :]
#         # get the cells based on a criterion
#         if criterion == 'top':
#             # get top cells from every animal based on passing both criteria and aggregated score
#             both_idx = ((current_cells.loc[:, 'Resp_test'] > 0) & \
#                 (current_cells.loc[:, 'Cons_test'] > 0)).to_numpy()
#             agg = current_cells.loc[both_idx, 'Resp_index'] * current_cells.loc[both_idx, 'Cons_index']
#             current_cells.insert(current_cells.shape[1], 'agg', agg)
#             top_idx = np.flip(np.argsort(current_cells.loc[both_idx, 'agg']))
#             current_cells = current_cells.loc[both_idx, :]
#             top_cells = current_cells.iloc[top_idx, :]
#         elif criterion == 'random':
#             # get n random cells per animal
#             sel_idx = random.sample(list(np.arange(current_cells.shape[0])), number_cells)
#             top_cells = current_cells.iloc[sel_idx, :]

#         # make sure there are no more cells indexed than detected
#         idx_limit = number_cells if top_cells.shape[0] >= number_cells else top_cells.shape[0]

#         # for the top n
#         for cell_idx in np.arange(idx_limit):
#             # get the current cell
#             current_cell = top_cells.iloc[cell_idx, :]

#             # get the current values
#             current_resp = current_cell['Resp_index']
#             current_cons = current_cell['Cons_index']
#             # get the date
#             current_day = current_cell['day']
#             # get the cell id in the original dataframe
#             cell_id = top_cells.index[cell_idx]
#             # get the variable names
#             var_names = target_feature.split('__')[::-1]
#             # get the three maps
#             # define bins
#             bins0 = np.linspace(current_bins[0], current_bins[1], 10)
# #             bins1 = np.linspace(current_bins[1][0], current_bins[1][1], 10)

#             # full map
#             full_labels = [el for el in current_cell.index if ('bin_' in el) & ('half_' not in el)]
#             full_map = current_cell.loc[full_labels].to_numpy().reshape((10, 10))

#             title = f'{animal}, {current_day}'
#             full_map = hv.Image((bins1, bins0, full_map), kdims=var_names)
#             full_map.opts(title=title, shared_axes=False, xrotation=45, cmap=cmap)
#             plot_list.append(full_map)

#             # half maps
#             half0_labels = [el for el in current_cell.index if ('half_0_bin_' in el)]
#             half0_map = current_cell.loc[half0_labels].to_numpy().reshape((10, 10))

#             title = f'Cell number: {cell_id}'
#             half0_map = hv.Image((bins1, bins0, half0_map), kdims=var_names)
#             half0_map.opts(title=title, shared_axes=False, xrotation=45, cmap=cmap)
#             plot_list.append(half0_map)

#             half1_labels = [el for el in current_cell.index if ('half_1_bin_' in el)]
#             half1_map = current_cell.loc[half1_labels].to_numpy().reshape((10, 10))

#             title = f'Cons: {current_cons:.2f}, Resp: {current_resp:.2f}'
#             half1_map = hv.Image((bins1, bins0, half1_map), kdims=var_names)
#             half1_map.opts(title=title, shared_axes=False, xrotation=45, cmap=cmap)
#             plot_list.append(half1_map)

#     example_layout = hv.Layout(plot_list).opts(shared_axes=False).cols(3)
#     example_layout

#     # assemble the path
#     save_name = os.path.join(save_path, '_'.join(('examples', target_feature, criterion)) + '.png')
#     hv.save(example_layout, save_name, dpi=600)


In [None]:
# # plot defined cells as examples

# # define the number of cells to save
# save_cells = 2 
# # define the target animal
# animal = 'DG_200701_a'
# # reset the rng
# random.seed(1)

# criterion_list = ['top', 'random']
# current_bins = processing_parameters.tc_params[target_feature]
# labels = processing_parameters.label_dictionary
# cmap = 'Purples'
# number_cells = 10

# cell_list = []
# # get the cells for this animal
# current_cells = data.iloc[data.loc[:, 'animal'].to_numpy()==animal, :]
# number_cells = np.min((number_cells, current_cells.shape[0]))

# def format_colorbar_hook(plot, element):
#     # get the plot dict
#     b = plot.state
#     b.right[0].major_label_text_font_size = '40pt'
#     b.right[0].major_label_text_font_size = '40pt'
#     b.right[0].label_standoff = 25


# # for the criteria on the list
# for criterion in criterion_list:
#     working_cells = current_cells.copy()
#     # get the cells based on a criterion
#     if criterion == 'top':
#         # get top cells from every animal based on passing both criteria and aggregated score
#         both_idx = ((working_cells.loc[:, 'Resp_test'] > 0) & \
#             (current_cells.loc[:, 'Cons_test'] > 0)).to_numpy()
#         agg = current_cells.loc[both_idx, 'Resp_index'] * working_cells.loc[both_idx, 'Cons_index']
#         working_cells.insert(working_cells.shape[1], 'agg', agg)
#         top_idx = np.flip(np.argsort(working_cells.loc[both_idx, 'agg']))
#         working_cells = working_cells.loc[both_idx, :]
#         top_cells = working_cells.iloc[top_idx, :]
#     elif criterion == 'random':
#         # get n random cells per animal
#         sel_idx = random.sample(list(np.arange(working_cells.shape[0])), number_cells)
#         top_cells = current_cells.iloc[sel_idx, :]
#     # plot and save the top n cells from above
#     for cell in np.arange(save_cells):
#         # get the current cell
#         current_cell = top_cells.iloc[cell, :]

#     #         # get the current values
#     #         current_resp = current_cell['Resp_index']
#     #         current_cons = current_cell['Cons_index']
#     #         # get the date
#     #         current_day = current_cell['day']
#     #         # get the cell id in the original dataframe
#     #         cell_id = top_cells.index[cell_idx]
#         # get the variable names
#         var_names = target_feature.split('__')[::-1]
#         # get the three maps
#         # define bins
#         bins0 = np.linspace(current_bins[0][0], current_bins[0][1], 10)
#         bins1 = np.linspace(current_bins[1][0], current_bins[1][1], 10)
        
# #         # convert to polar coordinates
# #         bins0_pol = bins0*np.cos(bins1)
# #         bins1_pol = bins0*np.sin(bins1)
        
# #         bins0 = bins0_pol
# #         bins1 = bins1_pol

#         # full map
#         full_labels = [el for el in current_cell.index if ('bin_' in el) & ('half_' not in el)]
#         full_map = current_cell.loc[full_labels].to_numpy().reshape((10, 10))

#     #         title = f'{animal}, {current_day}'
#         full_map = hv.Image((bins1, bins0, full_map), kdims=[labels[el] for el in var_names])
# #         full_map = hv.HexTiles((bins0, bins1, full_map))
#         full_map.opts(cmap=cmap, colorbar=True, hooks=[format_colorbar_hook])
#     #     full_map.opts(shared_axes=False, xrotation=45, cmap=cmap)
#         full_map = fp.format_figure(full_map, frame_width=400, frame_height=400)
#         cell_list.append(full_map)

#         save_name = os.path.join(save_path, 
#                                  '_'.join(('Example', str(cell), target_feature, criterion)) + '.png')
#         _ = fp.save_figure(full_map, save_name, fig_width=4, dpi=600)        
# # hv.Layout(cell_list).cols(2)

In [None]:
print(data_path)

In [None]:
# define the target variables

# target_features = ['mouse_x', 'cricket_0_mouse_distance', 'cricket_0_delta_heading', 'cricket_0_x', 'mouse_speed', 'cricket_0_speed', 'mouse_heading']

# target_features = ['mouse_speed', 'mouse_x', 'mouse_angular_speed', 'cricket_0_mouse_distance',
#                    'cricket_0_delta_heading', 'cricket_0_x',
#                    'cricket_0_visual_angle', 'hunt_trace', 'cricket_0_direction', 'cricket_0_loom',
#                    'cricket_0_delta_visual']
# target_features.remove('hunt_trace')

# allocate memory for the fractions
fraction_list = []
resp_list = []
cons_list = []
# for all the targets
for target_feature in variable_list:
    # load the data
    data = []
    for file in data_path:
        try:
            data.append(pd.read_hdf(file, target_feature))
        except KeyError:
            continue
    print(target_feature)
    data = pd.concat(data, axis=0)
    # plot average histograms
    analysis_df = data[['Resp_test', 'Cons_test', 'Qual_test', 'Qual_index', 'day', 'animal']]
    
    analysis_df.loc[np.isnan(analysis_df.loc[:, 'Resp_test']), 'Resp_test'] = 0
    analysis_df.loc[np.isinf(analysis_df.loc[:, 'Resp_test']), 'Resp_test'] = 0
    
    analysis_df.loc[np.isnan(analysis_df.loc[:, 'Cons_test']), 'Cons_test'] = 0
    analysis_df.loc[np.isinf(analysis_df.loc[:, 'Cons_test']), 'Cons_test'] = 0
    # generate a binary vector with cells passing both criteria
#     both = ((analysis_df.loc[:, 'Resp_test']>0) & (analysis_df.loc[:, 'Cons_test']>0)).to_numpy()
    both = (analysis_df.loc[:, 'Qual_test']>0).to_numpy()
    # insert this as a column in the dataframe
    analysis_df.insert(analysis_df.shape[1], 'Pass_fraction', both) 
    analysis_df['Qual_index'] = np.abs(analysis_df['Qual_index'].mask(analysis_df['Qual_test'].to_numpy()==0))
#     analysis_df.iloc[analysis_df['Qual_test'].to_numpy()==0, :].loc[:, 'Qual_index'] = np.nan
#     print(analysis_df.iloc[analysis_df['Qual_test'].to_numpy()==0, :].loc[:, 'Qual_index'])
    
#     sums = analysis_df.groupby(['day'], as_index=False)['Pass_fraction'].sum()
#     counts = analysis_df.groupby(['day'], as_index=False)['Pass_fraction', 'Qual_index'].mean()
    counts = analysis_df.groupby(['day'], as_index=False)['Pass_fraction', 'Qual_index'].agg(np.nanmean)
    
    counts.loc[:, 'Feature'] = label_dict[target_feature]
    
    fraction_list.append(counts[['Feature', 'Pass_fraction', 'Qual_index']])
    
    resp_df = data[['Resp_index', 'day']]
    resp_df.loc[:, 'Pass_fraction'] = both
    resp_df = resp_df.loc[resp_df['Pass_fraction'] == 1, :]
    print(resp_df.shape)
    resp_df.loc[:, 'Feature'] = target_feature
    resp_list.append(resp_df[['Resp_index', 'Feature']])
    
    
    cons_df = data[['Cons_index', 'day']]
    cons_df.loc[:, 'Pass_fraction'] = both
    cons_df = cons_df.loc[cons_df['Pass_fraction'] == 1, :]
    
    cons_df.loc[:, 'Feature'] = target_feature
    cons_list.append(cons_df[['Cons_index', 'Feature']])


# TC Box Plot 

In [None]:
# plot the distributions
# print(fraction_list[1].shape)
importlib.reload(fp)
plot_array = pd.concat(fraction_list)
# print(plot_array.columns)
# print(plot_array)

# print(plot_array)
# ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]

whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Pass_fraction'])
whisker.opts(width=800, height=800, xrotation=45, ylabel='Significant fraction', xlabel='', box_line_width=1, whisker_line_width=1, outlier_line_width=1)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'TC_fraction')) + '.png')
# save the figure
fig = fp.save_figure(whisker, save_name, fig_width=5, dpi=1200, fontsize='small', target='save')


# TC Mean Quality Boxplot 

In [None]:
# plot the distributions
# print(fraction_list[1].shape)
importlib.reload(fp)
plot_array = pd.concat(fraction_list)

# print(plot_array.columns)
# print(plot_array)

# print(plot_array)
# ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]

whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Qual_index'])
whisker.opts(width=800, height=800, xrotation=45, ylabel='Tuning index', xlabel='', box_line_width=1, whisker_line_width=1, outlier_line_width=1)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'TC_index')) + '.png')
# save the figure
fig = fp.save_figure(whisker, save_name, fig_width=5, dpi=1200, fontsize='small', target='save')


In [None]:
plot_array = pd.concat(resp_list)
# print(plot_array)

# whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Resp_index'])
whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Resp_index'])
whisker.opts(width=600, height=800, xrotation=45)

In [None]:
plot_array = pd.concat(cons_list)
# print(plot_array)

# whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Resp_index'])
whisker = hv.BoxWhisker(plot_array, ['Feature'], ['Cons_index'])
whisker.opts(width=600, height=800, xrotation=45)

In [None]:
print(data)

In [None]:
# Use the TCs for each cell and feature to embed with UMAP

tc_whole = []
# for all the targets
for idx, target_feature in enumerate(variable_list):
    # load the data
    data = []
    for file in data_path:
        # get the keys
        
#         with pd.HDFStore(file, mode='r') as h:
#             current_keys = [el[1:] for el in h.keys()]
#         if target_features.sort() != current_keys.sort():
#         if ~all(el in target_features for el in current_keys):
#             print('yay')
#             continue
        try:
            data.append(pd.read_hdf(file, target_feature))
#             print(file, target_feature, data[-1].shape)
        except KeyError:
            continue
    data = pd.concat(data, axis=0)
#     print(target_feature, data.shape)
    
    # load the relevant columns
    if idx == 0:
        target_columns = ['day', 'animal', 'Resp_index', 'Cons_index', 'Qual_index'] + [el for el in data.columns if ('bin' in el) & ('half' not in el)]
    else:
        target_columns = ['Resp_index', 'Cons_index', 'Qual_index'] + [el for el in data.columns if ('bin' in el) & ('half' not in el)]
    
    data = data.loc[:, target_columns]
    
    # change the column names
    new_names = {el: target_feature+'_'+el if ('bin' in el) | ('index' in el) else el for el in target_columns}
    data = data.rename(columns=new_names)
    # save in the list
    tc_whole.append(data)

# concatenate    
tc_whole = pd.concat(tc_whole, axis=1)
print(tc_whole.shape)
# exclude all rows with nans
# drop_columns = ['day', 'animal'] + [el for el in tc_whole.columns if 'Cons' in el]
cleanup_columns = [el for el in tc_whole.columns if 'Qual' in el]
cleanup_data = tc_whole.loc[:, cleanup_columns].to_numpy()
cleanup_data[np.isnan(cleanup_data)] = 0
cleanup_data[np.isinf(cleanup_data)] = 0
tc_whole.loc[:, cleanup_columns] = cleanup_data
# nonan_vector = ~np.any(np.isnan(tc_whole.drop(drop_columns, axis=1).to_numpy()), axis=1)
# nan_vector = np.any(np.isnan(tc_whole.drop(drop_columns, axis=1).to_numpy()), axis=1)
# tc_whole[np.isnan(tc_whole.to_numpy())] = 0
# tc_whole = tc_whole.iloc[nonan_vector, :]

print(tc_whole.shape)


In [None]:
# count cells per animal

animal_counts = tc_whole.groupby(['animal', 'day'])[['day']].count()
for el in np.arange(animal_counts.shape[0]):
    print(animal_counts.iloc[el, :])
# print(animal_counts)

# Correlation plot 

In [None]:
# get the correlation matrix between the responsivity indexes

# get only the resp columns
resp_columns = [el for el in tc_whole.columns if 'Qual_index' in el]

resp_indexes = tc_whole.loc[:, resp_columns]

# get rid of the nans and infs
nonan_vector = ~np.any(np.isnan(resp_indexes.to_numpy()), axis=1)
resp_indexes = resp_indexes.iloc[nonan_vector, :]
# print(resp_indexes)

# calculate the correlation matrix
correlation_matrix, pvalue_matrix = stat.spearmanr(resp_indexes.to_numpy())
correlation_matrix[pvalue_matrix>0.05] = 0

In [None]:
# plot the correlation matrix
# ticks = [(idx+0.5, el[:-11]) for idx, el in enumerate(resp_columns)]

# raster = hv.Raster(correlation_matrix)
# raster.opts(width=800, height=600, yticks=ticks, xticks=ticks, xrotation=45, colorbar=True, cmap='RdBu', clim=(-1, 1), tools=['hover'])
# raster
importlib.reload(fp)

ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]
# ticks = [(idx+0.5, idx) for idx, el in enumerate(variable_list)]
plot_matrix = correlation_matrix.copy()
plot_matrix = np.tril(plot_matrix, k=0)
plot_matrix[plot_matrix==0] = np.nan
# hv.Raster(correlation_matrix)
raster = hv.Raster(plot_matrix)
# format the plot
raster.opts(width=1050, height=800, yticks=ticks, xticks=ticks, colorbar=True, cmap='RdBu_r', clim=(-1, 1), xrotation=45, xlabel='', ylabel='')

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'TC_correlation')) + '.png')
# save the figure
fig = fp.save_figure(raster, save_name, fig_width=7, dpi=1200, fontsize='small', target='save')

# Tunings plot 

In [None]:
# plot the tunings across the population

plot_matrix = resp_indexes.to_numpy().copy()

model = Rastermap(n_components=2, n_X=30, nPC=200, init='pca')
model.fit(plot_matrix)
plot_matrix = plot_matrix[model.isort, :]

# idx_matrix = np.lexsort(plot_matrix)
# print(idx_matrix)
# plot_matrix = plot_matrix[:, idx_matrix]

plot = hv.Raster(plot_matrix)
plot.opts(width=1000, height=600, tools=['hover'], cmap='RdBu_r', xticks=ticks, ylabel='Cells', clim=(-0.5, 0.5), xlabel='', xrotation=45, colorbar=True)

# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'TC_tunings')) + '.png')
# save the figure
fig = fp.save_figure(plot, save_name, fig_width=7, dpi=1200, fontsize='small', target='save')

In [None]:
%%time
# run the UMAP embedding

# get the data
# umap_data = tc_whole.drop(drop_columns, axis=1).to_numpy()
umap_data = tc_whole.loc[:, cleanup_columns].to_numpy()

# run the decomposition
reducer = umap.UMAP(min_dist=0.5, n_neighbors=30)
embedded_data = reducer.fit_transform(umap_data)

# UMAP TC plot 

In [None]:
# plot the UMAP results
importlib.reload(fp)
# define the interval between points
interv = 1
# define percentile to discard
perc = 95

label_list = resp_columns + ['animal', 'day']
umap_list = []
# target_key = 'cricket_0_mouse_distance'
# for all the variables
for target_key in label_list:
#     print(target_key)
    if ('Cons' in target_key) | ('Resp' in target_key):
        continue
#     if 'index' not in target_key:
#         counts, raw_labels = np.unique(tc_whole.loc[:, target_key].to_numpy(), return_inverse=True)
#         raw_labels = (raw_labels - raw_labels.min())/(raw_labels.max() - raw_labels.min())
#     else:
#         raw_labels = tc_whole.loc[:, target_key].to_numpy().copy()
# #         print(target_key)
# #         print(raw_labels)
# #         raw_labels = np.log10(raw_labels)
#         raw_labels[raw_labels>10] = np.nan
# #         raw_labels[np.isnan(raw_labels)] = 0
# #         raw_labels[np.isinf(raw_labels)] = 0
    if 'index' not in target_key:
        counts, raw_labels = np.unique(tc_whole.loc[:, target_key].to_numpy(), return_inverse=True)
        raw_labels = (raw_labels - raw_labels.min())/(raw_labels.max() - raw_labels.min())
        title = target_key
    else:
#     #     counts, raw_labels = np.unique(output_df.loc[:, target_key].to_numpy(), return_inverse=True)
        raw_labels = tc_whole.loc[:, target_key].to_numpy().astype(np.float64)
        raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
#         raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)
#         raw_labels[raw_labels<0] = 0
        title = label_dict[target_key[:-11]]
#     print(raw_labels)
    compiled_labels = np.expand_dims(raw_labels, axis=1)
    
#     compiled_labels[compiled_labels==0] = np.nan

    umap_data = np.concatenate((embedded_data,compiled_labels),axis=1)

    compiled_labels = compiled_labels[::interv]
    umap_data = umap_data[::interv, :]

    umap_plot = hv.Scatter(umap_data, vdims=['Dim 2', target_key], kdims=['Dim 1'])
    umap_plot.opts(color=target_key, colorbar=False, cmap='Spectral_r', size=1, tools=['hover'], clim=(-np.nanmax(compiled_labels), np.nanmax(compiled_labels)))
    umap_plot.opts(height=600, width=800, colorbar_opts={'title': 'QI', 'title_standoff': 15}, xaxis=None, yaxis=None, title=title)
    
    save_name = os.path.join(save_path, '_'.join((target_document, 'TC_UMAP', target_key)) + '.png')
    # save the figure
    fig = fp.save_figure(umap_plot, save_name, fig_width=7.7, dpi=1200, fontsize=target_document, target='save', display_factor=0.1)

    umap_list.append(umap_plot)

# hv.Layout(umap_list).cols(2)

In [None]:
def gini(array):
    """Calculate the Gini coefficient of a numpy array. From https://neuroplausible.com/gini"""
    # All values are treated equally, arrays must be 1d:
    array = array.flatten()
    if np.amin(array) < 0:
        # Values cannot be negative:
        array -= np.amin(array)
    # Values cannot be 0:
    array += 0.0000001
    # Values must be sorted:
    array = np.sort(array)
    # Index per array element:
    index = np.arange(1,array.shape[0]+1)
    # Number of array elements:
    n = array.shape[0]
    # Gini coefficient:
    return ((np.sum((2 * index - n  - 1) * array)) / (n * np.sum(array)))

In [None]:
def gini2(array, bins=30):
    """Calculate the Gini coefficient according to de Oliveira and Kim et al."""
    # bin the data
    counts, bin_edges, _ = stat.binned_statistic(np.abs(array), array, bins=bins, statistic='count')
    
    # get the fractions
    fractions = counts/counts.sum()
    # multiply by the counts
    values = (bin_edges[1:] + bin_edges[:-1])/2
    s = np.cumsum(fractions * values)
    s0 = np.concatenate(([0], s[:-1]), axis=0)

    # calculate the coefficient
    gini_coefficient = 1 - np.sum(fractions*(s0 + s))/s[-1]
    
    return gini_coefficient
    

In [None]:
# Calculate and plot Gini coefficient

# allocate memory for the calculation
gini_array = []
# for all the variables
for animal_date, current_day in tc_whole.groupby(['animal', 'day'], as_index=False):
    # allocate memory for the day
    day_list = []
    # for all the features
    for feature in variable_list:
        # get the feature
        current_feat = current_day[feature+'_Qual_index'].to_numpy()
        # calculate the gini coefficient and store
        current_gini = gini2(current_feat, bins=20)
        day_list.append(pd.DataFrame([[label_dict[feature], current_gini]], columns=['Feature', 'Gini']))
        
    # store
    gini_array.append(pd.concat(day_list, axis=0))
    
gini_array = pd.concat(gini_array, axis=0)
print(gini_array)

In [None]:
%%time
# calculate the Gini coefficient based on resampled weights

# define the number of shuffles
number_shuffles = 100
# allocate a list for the output
shuffle_gini = []
# for all the shuffles
for shuff in np.arange(number_shuffles):

    # for all the variables
    for animal_date, current_day in tc_whole.groupby(['animal', 'day'], as_index=False):
        # allocate memory for the day
        day_list = []
        # for all the features
        for feature in variable_list:
            # get the feature
            current_feat = current_day[feature+'_Qual_index'].to_numpy().astype(np.float64)
            # draw randomly from the feature
#             current_feat = np.random.choice(current_feat, current_feat.shape[0], replace=True)
#             current_feat = np.mean(current_feat)*np.ones_like(current_feat)
#             current_feat = np.random.rand(current_feat.shape[0])
            a = np.min(current_feat)
            b = np.max(current_feat)
            current_feat = (b - a) * np.random.random_sample(current_feat.shape[0]) + a
#             print(current_feat)

            # calculate the gini coefficient and store
            current_gini = gini2(current_feat, bins=20)

            day_list.append(pd.DataFrame([[label_dict[feature], current_gini]], columns=['Feature', 'Gini']))

        # store
        shuffle_gini.append(pd.concat(day_list, axis=0))
# concatenate the dataframes
shuffle_gini = pd.concat(shuffle_gini, axis=0)

# TC Gini plot 

In [None]:
# Plot the gini coefficients

importlib.reload(fp)
# print(plot_array.columns)
# print(plot_array)

# print(plot_array)
# ticks = [(idx+0.5, label_dict[el]) for idx, el in enumerate(variable_list)]

whisker0 = hv.BoxWhisker(gini_array, ['Feature'], ['Gini'])
whisker0.opts(width=1500, height=600, xrotation=45, ylabel='Sparsity', xlabel='')

whisker1 = hv.BoxWhisker(shuffle_gini, ['Feature'], ['Gini'])

whisker = (whisker0*whisker1).opts(opts.BoxWhisker(box_line_width=1, whisker_line_width=1, outlier_line_width=1))
# assemble the file name
save_name = os.path.join(save_path, '_'.join((target_document, 'TC_gini')) + '.png')
# save the figure
fig = fp.save_figure(whisker, save_name, fig_width=10, dpi=1200, fontsize='small', target='save')

In [None]:
# quantify non-linearity of the mixed selectivity