In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings

import numpy as np
import pandas as pd
import pycircstat as circ
from scipy import stats

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import functions_kinematic as fk

In [94]:
def calculate(label, shifts_ds):

    residuals = np.linalg.norm(shifts_ds[['pref_1', 'pref_2']].to_numpy() - shifts_ds[['pref_1', 'pref_1']].to_numpy(), axis=1)

    # This is a circular measure, so wrap the residuals
    if 'pd' in ds_label:
        bound = 180
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]
    elif 'po' in ds_label:
        bound = 90
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]
    else:
        pass

    # Calculate the RMSE of the residuals
    rmse_residual = np.sqrt(np.mean(residuals ** 2))
    print(f"RMSE: {rmse_residual}")

    # Calculate the mean and std of the residuals
    mean_residual = np.mean(residuals)
    std_residual = np.std(residuals)
    print(f"Mean: {mean_residual}, Std: {std_residual}")

    # Calculate the Pearson R and pvalue of the shifts
    pearson_r, pval = stats.pearsonr(shifts_ds['pref_2'], shifts_ds['pref_1'])
    print(f"Pearson's R: {pearson_r}, p: {pval}")

    return residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval

In [101]:
type = 'repeat'  # 'repeat', 'multi'
lighting = 'normal'  # 'normal', 'dark'
rig = 'VWheelWF'  # 'VWheelWF', 'VTuningWF', 'ALL'

data_path = os.path.join(r"H:\thesis\WF_Figures\full", f'{type}_{lighting}_{rig}')
print(data_path)

if 'ALL' in data_path:
    settings_dict = {'session_shorthand': ['free', 'fixed'],
                     'sort': 'rig',
                     'match_type': 'curated',
                     'vis_resp_type': 'both'}
elif 'VWheelWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'match_type': 'curated',
                     'vis_resp_type': 'both'}
elif 'VTuningWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'match_type': 'all',
                     'vis_resp_type': 'both'}
else:
    raise ValueError('data_path not recognized')

ds_list = [f'osi_shifts_vis_resp_{settings_dict["match_type"]}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'dsi_shifts_vis_resp_{settings_dict["match_type"]}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'po_shifts_vis_resp_{settings_dict["match_type"]}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'pd_shifts_vis_resp_{settings_dict["match_type"]}_matches_{settings_dict["vis_resp_type"]}_vis_resp']

for ds_label in ds_list:
    print(ds_label)

    with pd.HDFStore(os.path.join(data_path, 'stats.hdf5'), 'r') as f:
        shifts = f[ds_label][:]

    residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval = calculate(ds_label, shifts)

    # Calculate the circular correlation of the residuals (only valid for PO and PD)
    if 'po' in ds_label or 'pd' in ds_label:
        circ_corr = circ.corrcc(np.deg2rad(shifts['pref_2']), np.deg2rad(shifts['pref_1']))
        print(f"Circular Correlation: {circ_corr} \n")

        # Do it on a per-mouse basis
        all_mice = shifts.groupby('mouse').count().index
        mice_with_enough_cells = all_mice[shifts.groupby('mouse').day.count() > 3]

        for mouse in mice_with_enough_cells:
            print(mouse)
            mouse_shifts = shifts[shifts['mouse'] == mouse]
            print(f"Number of matches: {len(mouse_shifts)}")

            mouse_residuals, mouse_rmse_residual, mouse_mean_residual, mouse_std_residual, mouse_pearson_r, mouse_pval = calculate(ds_label, mouse_shifts)

            # Calculate the circular correlation of the residuals (only valid for PO and PD)
            if 'po' in ds_label or 'pd' in ds_label:
                circ_corr_mouse = circ.corrcc(np.deg2rad(mouse_shifts['pref_1']), np.deg2rad(mouse_shifts['pref_2']))
                print(f"Circular Correlation: {circ_corr_mouse}\n")

    print('\n')

H:\thesis\WF_Figures\full\repeat_normal_VWheelWF
osi_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 0.2383843975211514
Mean: 0.1637060299433885, Std: 0.17328432341587258
Pearson's R: 0.5101345351808639, p: 0.0017439094975474235


dsi_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 0.31424304400615705
Mean: 0.27261241244708767, Std: 0.15630471293609316
Pearson's R: 0.4938520325815782, p: 0.0025696980676579374


po_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 22.7335472878654
Mean: 14.77927283137704, Std: 17.27388974334769
Pearson's R: 0.678372614764204, p: 7.522350425499216e-06
Circular Correlation: 0.683249661204028 

MM_230518_b
Number of matches: 24
RMSE: 22.862099897964995
Mean: 15.465931863727455, Std: 16.836881045225415
Pearson's R: 0.5493533485045177, p: 0.005428128843325314
Circular Correlation: 0.6387317192974519

MM_230706_b
Number of matches: 10
RMSE: 12.84492539742555
Mean: 8.368737474949901, Std: 9.744554455735628
Pearson's R: 0.9785727047838929, p: 8.987450

In [102]:
shifts.groupby('mouse').count()

Unnamed: 0_level_0,pref_1,pref_2,delta_pref,day
mouse,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
MM_220915_a,1,1,1,1
MM_230518_b,24,24,24,24
MM_230706_b,10,10,10,10
