In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import warnings
# warnings.simplefilter('ignore', UserWarning)

import numpy as np
import pandas as pd
import pycircstat as circ
from scipy import stats

sys.path.insert(0, os.path.abspath(rf'{os.path.expanduser("~")}/repos/bonhoeffer/prey_capture/'))
import functions_kinematic as fk

In [None]:
def calculate(ds_label, shifts_ds):

    residuals = np.linalg.norm(shifts_ds[['pref_1', 'pref_2']].to_numpy() - shifts_ds[['pref_1', 'pref_1']].to_numpy(), axis=1)
    signed_residuals = shifts_ds['pref_2'].values - shifts_ds['pref_1'].values


    # This is a circular measure, so wrap the residuals
    if 'pd' in ds_label:
        bound = 180

        # Normal residuals
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]

        # signed residuals
        wrapped_signed_residuals = fk.wrap_negative(signed_residuals, bound=bound)
        wrapping_idxs = np.argwhere(signed_residuals > bound).flatten()
        # if len(wrapping_idxs) > 0:
        #     signed_residuals[wrapping_idxs] = bound - wrapped_signed_residuals[wrapping_idxs]

            
    elif 'po' in ds_label:
        bound = 90
        
        # Normal residuals
        wrapped_residuals = fk.wrap(residuals, bound=bound)
        wrapping_idxs = np.argwhere(residuals > bound).flatten()
        if len(wrapping_idxs) > 0:
            residuals[wrapping_idxs] = bound - wrapped_residuals[wrapping_idxs]

        # signed residuals
        wrapped_signed_residuals = fk.wrap_negative(signed_residuals, bound=bound)
        wrapping_idxs = np.argwhere(signed_residuals > bound).flatten()
        # if len(wrapping_idxs) > 0:
        #     signed_residuals[wrapping_idxs] = bound - wrapped_signed_residuals[wrapping_idxs]
    else:
        pass

    # Run Kruskal Wallis test
    k_stat, pval = stats.kruskal(shifts_ds['pref_2'], shifts_ds['pref_1'])
    print(f"Kruskal Wallis: {k_stat}, p: {pval}")

    # Also do a t-test
    t_stat, pval = stats.ttest_ind(shifts_ds['pref_2'], shifts_ds['pref_1'])
    print(f"t-stat: {t_stat}, p: {pval}")

    # Calculate the mean and std of the residuals
    mean_residual = np.mean(residuals)
    std_residual = np.std(residuals)
    print(f"Residual mean: {mean_residual}, Residual std: {std_residual}")

    # Calculate mean difference
    mean_signed_residuals = np.mean(signed_residuals)
    print(f"Mean Signed Residual: {mean_signed_residuals}")

    # Calculate the RMSE of the residuals
    rmse_residual = np.sqrt(np.mean(residuals ** 2))
    print(f"RMSE: {rmse_residual}")

    # Calculate the Pearson R and pvalue of the shifts
    pearson_r, pval = stats.pearsonr(shifts_ds['pref_2'], shifts_ds['pref_1'])
    print(f"Pearson's R: {pearson_r}, p: {pval}")


    return residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval


def circ_corr_pval(alpha1, alpha2, rho=None):
    # Compute mean directions
    n = len(alpha1)
    alpha1_bar = circ.mean(alpha1)
    alpha2_bar = circ.mean(alpha2)

    # Compute correlation coefficient
    if rho is None:
        num = np.sum(np.sin(alpha1 - alpha1_bar) * np.sin(alpha2 - alpha2_bar))
        den = np.sqrt(np.sum(np.sin(alpha1 - alpha1_bar)**2) * np.sum(np.sin(alpha2 - alpha2_bar)**2))
        rho = num / den

    # Compute p-value
    l20 = np.mean(np.sin(alpha1 - alpha1_bar)**2)
    l02 = np.mean(np.sin(alpha2 - alpha2_bar)**2)
    l22 = np.mean(np.sin(alpha1 - alpha1_bar)**2 * np.sin(alpha2 - alpha2_bar)**2)

    ts = np.sqrt((n * l20 * l02) / l22) * rho
    pval = 2 * (1 - stats.norm.cdf(abs(ts)))
    return rho, pval


def calculate_circular(ds_label, shifts_ds):

    if 'pd' in ds_label:
        bound = 180
    elif 'po' in ds_label:
        bound = 90
    else:
        pass

    pref1 = np.deg2rad(shifts_ds['pref_1'].values)
    pref2 = np.deg2rad(shifts_ds['pref_2'].values)

    circ_diff = circ.cdiff(pref2, pref1)
    circ_mean_diff = circ.mean(circ_diff)
    circ_mean_diff = np.rad2deg(circ_mean_diff)
    circ_mean_diff = fk.wrap_negative(circ_mean_diff, bound=bound)
    circ_std_diff = circ.std(circ_diff)
    circ_std_diff = np.rad2deg(circ_std_diff)
    circ_std_diff = fk.wrap_negative(circ_std_diff, bound=bound)
    print(f"Circular Mean Difference: {circ_mean_diff}, Circular Std Difference: {circ_std_diff}")

    circ_corr = circ.corrcc(pref1, pref2)
    rho, p = circ_corr_pval(pref1, pref2, rho=circ_corr)
    print(f"Circular Correlation: {circ_corr}, p: {p}")

    cm_test = circ.cmtest(pref1, pref2)
    print(f"CM Test p: {cm_test[0]}, cm test statistic: {cm_test[1]}")

    watson_test = circ.watson_williams(pref1, pref2)
    print(f"Watson Williams Test p: {watson_test[0]}")

In [None]:
type = 'multi'  # 'repeat', 'multi'
lighting = 'normal'  # 'normal', 'dark'
rig = 'ALL'  # 'VWheelWF', 'VTuningWF', 'ALL'
match_type = 'curated'  # 'curated', 'all'

data_path = os.path.join(r"H:\thesis\WF_Figures\full", f'{type}_{lighting}_{rig}')
print(data_path)

if 'ALL' in data_path:
    settings_dict = {'session_shorthand': ['free', 'fixed'],
                     'sort': 'rig',
                     'vis_resp_type': 'both'}
elif 'VWheelWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'vis_resp_type': 'both'}
elif 'VTuningWF' in data_path:
    settings_dict = {'session_shorthand': ['session1', 'session2'],
                     'sort': 'slug',
                     'vis_resp_type': 'both'}
else:
    raise ValueError('data_path not recognized')

ds_list = [
           f'osi_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'dsi_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'po_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp',
           f'pd_shifts_vis_resp_{match_type}_matches_{settings_dict["vis_resp_type"]}_vis_resp'
           ]

for ds_label in ds_list:
    print(ds_label)

    with pd.HDFStore(os.path.join(data_path, 'stats.hdf5'), 'r') as f:
        shifts = f[ds_label][:]

    residuals, rmse_residual, mean_residual, std_residual, pearson_r, pval = calculate(ds_label, shifts)

    # Calculate the circular correlation of the residuals (only valid for PO and PD)
    if 'po_' in ds_label or 'pd_' in ds_label:
        calculate_circular(ds_label, shifts)

        # # Do it on a per-mouse basis
        # all_mice = shifts.groupby('mouse').count().index
        # mice_with_enough_cells = all_mice[shifts.groupby('mouse').day.count() > 3]

        # for mouse in mice_with_enough_cells:
        #     print(mouse)
        #     mouse_shifts = shifts[shifts['mouse'] == mouse]
        #     print(f"Number of matches: {len(mouse_shifts)}")

        #     mouse_residuals, mouse_rmse_residual, mouse_mean_residual, mouse_std_residual, mouse_pearson_r, mouse_pval = calculate(ds_label, mouse_shifts)
        #     calculate_circular(mouse_shifts)
        #     print('\n')

    print('\n')