In [None]:
%load_ext autoreload
%autoreload 2

# Imports
import itertools
import numpy as np
import pandas as pd
import scipy.stats as st
import panel as pn
import seaborn as sns
import holoviews as hv
import hvplot.pandas
import datashader as dshade
import ipywidgets as widgets
from ipywidgets import interact, interactive
from statsmodels.stats.multitest import multipletests

import os
import sys
import random
import importlib
import warnings
import math
import h5py
warnings.filterwarnings('ignore')

from bokeh.resources import INLINE
from bokeh.io import export_svgs, export_png
from bokeh.plotting import show
from holoviews import opts, dim
from holoviews.operation import histogram
from holoviews.operation.datashader import datashade, shade
hv.extension('bokeh')
# hv.extension('matplotlib')
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
%aimport paths
import processing_parameters
import functions_bondjango as bd
import functions_loaders as fl
import functions_plotting as fp
import functions_data_handling as fdh
import functions_kinematic as fk
import functions_tuning as tuning
import functions_misc as misc
from wirefree_experiment import WirefreeExperiment, DataContainer

importlib.reload(fp)
# set up the figure theme
fp.set_theme()
plt.rcParams["font.family"] = "Arial"
cm = 1./2.54
figure_save_path = r"Z:\Prey_capture\WF_Figures"

In [None]:
from umap.umap_ import UMAP
from rastermap import Rastermap
import sklearn.preprocessing as preproc

In [None]:
all_paths, all_queries = fl.query_search_list()
mice = ['_'.join(os.path.basename(path).split('_')[7:10]) for path in all_paths[0]]
print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, metadata  = fl.load_preprocessing(path, queries, latents_flag=False)
    data_list.append(data)

frame_rate = processing_parameters.wf_frame_rate
kinem_vars = [processing_parameters.variable_list_free, processing_parameters.variable_list_fixed]


In [None]:
data_dict = {}
for ds_list, kinem_var_set, label in zip(data_list, kinem_vars, ['free', 'fixed']):
    half1 = []
    half2 = []

    for ds in ds_list:
        if 'wheel_speed' in ds.columns:
            ds['wheel_speed_abs'] = ds['wheel_speed'].abs().copy()

        data = ds[kinem_var_set]
        # Split the halves
        halves = np.array_split(data, 2, axis=0)
        half1.append(halves[0])
        half2.append(halves[1])

    data_dict[label + '1'] = pd.concat(half1, axis=0).fillna(0)
    data_dict[label + '2'] = pd.concat(half2, axis=0).fillna(0)

In [None]:
free1_tunings = preproc.StandardScaler().fit_transform(data_dict['free1'].to_numpy())
free2_tunings = preproc.StandardScaler().fit_transform(data_dict['free2'].to_numpy())
fixed1_tunings = preproc.StandardScaler().fit_transform(data_dict['fixed1'].to_numpy())
fixed2_tunings = preproc.StandardScaler().fit_transform(data_dict['fixed2'].to_numpy())

In [None]:
reducer1 = UMAP(min_dist=0.1, n_neighbors=20)
embedded_data1 = reducer1.fit_transform(free1_tunings)
embedded_data2 = reducer1.fit_transform(free2_tunings)

In [None]:
np.save(os.path.join(figure_save_path + r'free1_umap.npy'), embedded_data1)
np.save(os.path.join(figure_save_path + r'free2_umap.npy'), embedded_data2)

In [None]:
reducer2 = UMAP(min_dist=0.1, n_neighbors=10)
embedded_data3 = reducer2.fit_transform(fixed1_tunings)
embedded_data4 = reducer2.fit_transform(fixed2_tunings)

In [None]:
np.save(os.path.join(figure_save_path + r'fixed1_umap.npy'), embedded_data3)
np.save(os.path.join(figure_save_path + r'fixed2_umap.npy'), embedded_data4)

In [None]:
perc = 99
predictor_columns = kinem_vars[0]
plot_list = []

for predictor_column in predictor_columns:
    label_idx = [idx for idx, el in enumerate(predictor_columns) if predictor_column == el]

    for i, (raw_data, embedded_data) in enumerate(zip([free1_tunings, free2_tunings], [embedded_data1, embedded_data2])):
        raw_labels = raw_data[:, label_idx]
        
        raw_labels = np.abs(raw_labels)
        raw_labels[raw_labels>np.percentile(raw_labels, perc)] = np.percentile(raw_labels, perc)
        raw_labels[raw_labels<np.percentile(raw_labels, 100-perc)] = np.percentile(raw_labels, 100-perc)

        half_label = np.ones_like(raw_labels) + i
        
        plot_data = np.concatenate([embedded_data, raw_labels.reshape((-1, 1)), half_label.reshape((-1, 1))], axis=1)

        umap_plot = hv.Scatter(plot_data, kdims=['Dim 1'], vdims=['Dim 2', 'Parameter', 'Half'],  hover_cols=['Parameter'])
        # umap_plot = hv.HexTiles(umap_data, kdims=['Dim 1', 'Dim 2'])
        umap_plot.opts(colorbar=False, color='Parameter', cmap='Spectral_r', alpha=1, xaxis=None, yaxis=None, tools=['hover'])
        umap_plot.opts(width=300, height=300, size=2, title=processing_parameters.wf_label_dictionary[predictor_column])

    # save_name = os.path.join(figure_save_path, f"UMAP_{predictor_column}_{cell_subset}.png")   
    # umap_plot = fp.save_figure(umap_plot, save_path=save_name, fig_width=6, dpi=1000, fontsize='screen', target='save', display_factor=0.1)
    plot_list.append(umap_plot)

In [None]:
layout = hv.Layout(plot_list).cols(5)
layout