In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

import os
import sys
import numpy as np
import pandas as pd
import scipy.signal as sig
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))

import paths
import processing_parameters
import functions_loaders as fl
import functions_data_handling as fdh
import functions_bondjango as bd
from functions_misc import list_lists_to_array, find_nearest

import panel as pn
import holoviews as hv
from holoviews import opts, dim
from holoviews.operation import histogram
hv.extension('bokeh')
from bokeh.resources import INLINE

In [None]:
all_paths, all_queries = fl.query_search_list()
mice = ['_'.join(os.path.basename(path).split('_')[7:10]) for path in all_paths[0]]
print(all_paths)

data_list = []
# load the data
for path, queries in zip(all_paths, all_queries):
    
    data, _, metadata  = fl.load_preprocessing(path, queries, latents_flag=False)
    data_list.append(data)

data_list = [ds for el in data_list for ds in el]
frame_rate = processing_parameters.wf_frame_rate
kinem_vars =  processing_parameters.variable_list_free #processing_parameters.variable_list_free + processing_parameters.variable_list_fixed


# Test autocorrelation of kinematic variables

In [None]:
window_size = 10    # seconds

autocorr_dict = {}
mouse_speed_list = []
mouse_angular_speed_list = []
pupil_diam_list = []
for ds in data_list:
    ds.dropna(inplace=True)
    
    if 'mouse_angular_speed' in ds.columns:
        mouse_angular_speed_list.append(ds['mouse_angular_speed'].to_numpy())
    if 'mouse_speed' in ds.columns:
        mouse_speed_list.append(ds['mouse_speed'].to_numpy())

    for kvar in kinem_vars:
        if kvar in ds.columns:

            # Handle exception for x and y position in head fixed data
            if (kvar == 'mouse_x_m') and ('wheel_speed' in ds.columns):
                continue
            elif (kvar == 'mouse_y_m') and ('wheel_speed' in ds.columns):
                continue
            else:
                autocorr_list = []
                x = ds[kvar].to_numpy()

                if window_size == 'all':	
                    xp = x - x.mean()
                    result = sig.correlate(xp, xp, mode='full')
                    result = result[result.size//2:] / np.var(x) / len(xp)

                else:
                    # Parse the signal into chunks
                    chunk_size = x.size//(window_size*frame_rate)
                    if chunk_size < 1:
                        pass
                    else:
                        x = np.array_split(x, chunk_size)
                        for y in x:
                            yp = y - y.mean()
                            result = sig.correlate(yp, yp, mode='full')
                            result = result[result.size//2:] / np.var(y) / len(yp)
                            autocorr_list.append(result)
                        result = list_lists_to_array(autocorr_list)

                autocorr_dict[kvar] = result

In [None]:
mouse_speed = np.concatenate(mouse_speed_list)
# mouse_speed *= 360/(2*np.pi)
mouse_speed = mouse_speed[~np.isnan(mouse_speed)]
lower_thresh = np.percentile(mouse_speed, 2)
upper_thresh = np.percentile(mouse_speed, 99.5)
# mouse_speed = mouse_speed[mouse_speed <= upper_thresh]
# mouse_speed = mouse_speed[mouse_speed >= lower_thresh]
fig=plt.figure()
plt.hist(mouse_speed, bins=200)
plt.yscale('log')
plt.show()

In [None]:
mouse_angular_speed = np.concatenate(mouse_angular_speed_list)
# mouse_angular_speed *= 360/(2*np.pi)
mouse_angular_speed = mouse_angular_speed[~np.isnan(mouse_angular_speed)]
lower_thresh = np.percentile(mouse_angular_speed, 2)
upper_thresh = np.percentile(mouse_angular_speed, 99.5)
# mouse_angular_speed = mouse_angular_speed[mouse_angular_speed <= upper_thresh]
# mouse_angular_speed = mouse_angular_speed[mouse_angular_speed >= lower_thresh]
fig=plt.figure()
plt.hist(mouse_angular_speed, bins=200)
# plt.yscale('log')
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=6, figsize=(10, 5))
mean_autocorr_dict = {}
for ax, key in zip(axes.flatten(), autocorr_dict.keys()):
    autocorr = autocorr_dict[key]
    mean_autocorr = np.mean(autocorr, axis=0)
    mean_autocorr_dict[key] = mean_autocorr
    if window_size != 'all':
        ax.plot(np.arange(0, autocorr.shape[-1])/frame_rate, autocorr.T, alpha=0.2)
        ax.plot(np.arange(0, autocorr.shape[-1])/frame_rate, mean_autocorr, color='r')
        ax.set_xlim(-1, window_size+1)
        ax.hlines(0, -1, window_size+1, color='k', linestyle='--')

    else:
        ax.plot(np.arange(0, autocorr.shape[-1])/frame_rate, autocorr.T)

    ax.set_title(key)
    ax.set_xlabel('Time (s)')

plt.tight_layout()

In [None]:
for key in mean_autocorr_dict.keys():
    autocorr = mean_autocorr_dict[key]
    zero = np.argwhere(np.diff(np.sign(autocorr)))[0][0]
    print(f'Autocorrelation zero crossing for {key} is {zero/frame_rate} seconds')
    print(f'Autocorrelation minimum is {np.nanmin(autocorr)} at {np.nanargmin(autocorr)/frame_rate} seconds\n')

# Look at how consistent traces are across session halves

In [None]:
trace1_dict = {}
trace2_dict = {}
for kvar in kinem_vars:
    kvar1_list = []
    kvar2_list = []

    for ds in data_list:

        ds.dropna(inplace=True)
        
        if 'wheel_speed' in ds.columns:
            ds['wheel_speed_abs'] = ds['wheel_speed'].abs().copy()

        if kvar in ds.columns:
            # Handle exception for x and y position in head fixed data
            if (kvar == 'mouse_x_m') and ('wheel_speed' in ds.columns):
                continue
            elif (kvar == 'mouse_y_m') and ('wheel_speed' in ds.columns):
                continue
            else:
                x = ds[kvar].to_numpy()
                x = np.array_split(x, 2)
                kvar1_list.append(x[0])
                kvar2_list.append(x[1])
            
            # trace1_dict[kvar] = np.concatenate(kvar1_list)
            # trace2_dict[kvar] = np.concatenate(kvar2_list)
    trace1_dict[kvar] = list_lists_to_array(kvar1_list)
    trace2_dict[kvar] = list_lists_to_array(kvar2_list)

In [None]:
plot_list = []
for feature in trace1_dict.keys():
    traces1 = trace1_dict[feature]
    traces2 = trace2_dict[feature]
    plt1 = hv.Path((np.arange(traces1.shape[0]), traces1))
    plt2 = hv.Path((np.arange(traces2.shape[0]), traces2))
    overlay = plt1 * plt2
    overlay.opts(title=feature, xrotation=45)
    plot_list.append(overlay)
hv.Layout(plot_list).cols(5).opts(shared_axes=False)