In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings

import numpy as np
import pandas as pd
import pycircstat as circ
from scipy import stats

sys.path.insert(0, os.path.abspath(r'C:/Users/mmccann/repos/bonhoeffer/prey_capture/'))
import functions_kinematic as fk

In [16]:
def calculate(label, shifts_ds):

    residuals = np.linalg.norm(shifts_ds[['pref_1', 'pref_2']].to_numpy() - shifts_ds[['pref_1', 'pref_1']].to_numpy(), axis=1)

    # This is a circular measure, so wrap the residuals
    if 'pd' in ds_label:
        bound = 180
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]
    elif 'po' in ds_label:
        bound = 90
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]
    else:
        pass

    # Calculate the RMSE of the residuals
    rmse_residual = np.sqrt(np.mean(residuals ** 2))
    print(f"RMSE: {rmse_residual}")

    # Calculate the mean and std of the residuals
    mean_residual = np.mean(residuals)
    std_residual = np.std(residuals)
    print(f"Residual mean: {mean_residual}, Residual std: {std_residual}")

    # Calculate the Pearson R and pvalue of the shifts
    pearson_r, pval = stats.pearsonr(shifts_ds['pref_1'], shifts_ds['pref_2'])
    print(f"Pearson's R: {pearson_r}, p: {pval}")

    return residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval

In [20]:
type = 'multi'  # 'repeat', 'multi'
lighting = 'normal'  # 'normal', 'dark'
rig = 'ALL'  # 'VWheelWF', 'VTuningWF', 'ALL'
match_type = 'curated'  # 'curated', 'all'

data_path = os.path.join(r"H:\thesis\WF_Figures\full", f'{type}_{lighting}_{rig}')
print(data_path)

if 'ALL' in data_path:
    settings_dict = {'session_shorthand': ['free', 'fixed'],
                     'sort': 'rig',
                     'vis_resp_type': 'both'}
elif 'VWheelWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'vis_resp_type': 'both'}
elif 'VTuningWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'vis_resp_type': 'both'}
else:
    raise ValueError('data_path not recognized')

ds_list = [f'osi_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'dsi_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'po_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'pd_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp']

for ds_label in ds_list:
    print(ds_label)

    with pd.HDFStore(os.path.join(data_path, 'stats.hdf5'), 'r') as f:
        shifts = f[ds_label][:]

    residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval = calculate(ds_label, shifts)

    # Calculate the circular correlation of the residuals (only valid for PO and PD)
    if 'po' in ds_label or 'pd' in ds_label:
        circ_corr = circ.corrcc(np.deg2rad(shifts['pref_2']), np.deg2rad(shifts['pref_1']))
        print(f"Circular Correlation: {circ_corr} \n")

        # Do it on a per-mouse basis
        all_mice = shifts.groupby('mouse').count().index
        mice_with_enough_cells = all_mice[shifts.groupby('mouse').day.count() > 3]

        for mouse in mice_with_enough_cells:
            print(mouse)
            mouse_shifts = shifts[shifts['mouse'] == mouse]
            print(f"Number of matches: {len(mouse_shifts)}")

            mouse_residuals, mouse_rmse_residual, mouse_mean_residual, mouse_std_residual, mouse_pearson_r, mouse_pval = calculate(ds_label, mouse_shifts)

            # Calculate the circular correlation of the residuals (only valid for PO and PD)
            if 'po' in ds_label or 'pd' in ds_label:
                circ_corr_mouse = circ.corrcc(np.deg2rad(mouse_shifts['pref_1']), np.deg2rad(mouse_shifts['pref_2']))
                print(f"Circular Correlation: {circ_corr_mouse}\n")

    print('\n')

H:\thesis\WF_Figures\full\multi_normal_ALL
osi_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 0.44949213499943674
Residual mean: 0.3731451808062739, Residual std: 0.2506109603896148
Pearson's R: 0.10474046175987313, p: 0.3989337516923427


dsi_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 0.30304174823537855
Residual mean: 0.25089098631857876, Residual std: 0.16996474386661875
Pearson's R: 0.18275466653361142, p: 0.13880442271982488


po_shifts_vis_resp_curated_matches_both_vis_resp
RMSE: 40.94404794412297
Residual mean: 32.3895552298627, Residual std: 25.046592064836165
Pearson's R: 0.2963422576913399, p: 0.014892443532134543
Circular Correlation: 0.317045409205032 

MM_221109_a
Number of matches: 7
RMSE: 41.75474520179431
Residual mean: 36.07214428857715, Residual std: 21.029958470972684
Pearson's R: 0.8967864379367536, p: 0.006213691940324025
Circular Correlation: 0.7962573652027235

MM_230518_b
Number of matches: 18
RMSE: 50.5451424615673
Residual mean: 43.34669338677354

In [21]:
shifts.groupby('mouse').count()

Unnamed: 0_level_0,pref_1,pref_2,delta_pref,day
mouse,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
MM_220915_a,1,1,1,1
MM_221109_a,7,7,7,7
MM_230518_b,18,18,18,18
MM_230705_b,2,2,2,2
MM_230706_a,10,10,10,10
MM_230706_b,29,29,29,29
