### Configuration

In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed

from scipy.stats import circmean, circvar
from astropy.stats import rayleightest
from statsmodels.stats.multitest import multipletests

import utils__helpers_revisions
from utils__helpers_revisions import rayleightest_zstat
import utils__config

In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'G:\\My Drive\\Residency\\Research\\Lab - Damisah\\Project - Sleep\\Revisions'

### Parameters

This permutation-based Rayleigh Test analysis for significance phase distributions has several steps. 

But first note that, in the previous script, we generated a dataset by merging spike times with the nearest sampled LFP time point in order to get the closest slow-wave-band phase angle. Thus, for each unit-channel pair, we have a list of phase angles with timestamps (the timestamps are used to merge the data with other information such as sleep stage). This was done n x m times where n is the number of unit spike trains and m is the number of macroelectrode channels. 

- Step 1: Perform Rayleigh's Test on "original" phase angles (independently for every unit-channel pair) to produce a z-statistic (and p-value, although we won't use it).
- Step 2: Jitter each phase angle in the "original" phase angles randomly by up to have a wavelength (pi), which is equal to pi/2 in either direction. A transform is applied to re-circularize the phase angles if angles near the edge (-pi or pi) are jittered out of the allowable circular range. Repeat this step to generate 1000 sets of surrogate phase angle data.
- Step 3: Perform Rayleigh's Test independently on all 1000 surrogate datasets. 
- Step 4: Assign a permutation p-value to the original phase-angle dataset. This permutation p-value is equal to the percent of surrogate z-statistics that are equal to or larger than the original z-statistic.
- Step 5: FDR correct the permutation p-values based on how many unit-channel pairs were analyzed.

Other notes on the analysis:
- The analysis is conducted independently across several strata (e.g. only taking spike phases during NREM, during W/REM, or during high-delta-power NREM).
- The mean phase angle of the original phase angle is also calculated and preserved for later analysis.
- The Rayleigh's Test outputs a p-value and a z-statistic. We can use either one for the permutation test, but the z-statistic is preferred due to being non-parametric. Note that z-statistics are interpreted in the opposite manner as a p-value. Whereas smaller is more significant for p-values, larger is better for z-statistics.

In [3]:
input_path = 'Cache/S01_Feb02_sf_coupling.csv'
output_path = 'Cache/S01_Feb02_permuted_rayleigh.csv'

# input_path = 'Cache/S05_Jul11_sf_coupling.csv'
# output_path = 'Cache/S05_Jul11_permuted_rayleigh.csv'

# input_path = 'Cache/S05_Jul12_sf_coupling.csv'
# output_path = 'Cache/S05_Jul12_permuted_rayleigh.csv'

# input_path = 'Cache/S05_Jul13_sf_coupling.csv'
# output_path = 'Cache/S05_Jul13_permuted_rayleigh.csv'

In [4]:
n_permutations = 1000
n_cores = 6

### Import and Format

In [5]:
# Import data
data = pd.read_csv(input_path)

### Define functions for Permutation Rayleigh's Tests

In [6]:
def permutation_test(data, permutations=n_permutations):

    observed_p_value = rayleightest(data['phase'])
    observed_zstat = rayleightest_zstat(data['phase'])

    permuted_zstats = []

    for _ in range(permutations):
        
        # Jitter phases by adding a random value between -pi/2 and pi/2
        jittered_phases = data['phase'] + np.random.uniform(-np.pi / 2, np.pi / 2, size=data['phase'].shape)

        # Ensure that the jittered phases wrap correctly around the circle
        jittered_phases = np.mod(jittered_phases + np.pi, 2 * np.pi) - np.pi

        # Conduct Rayleigh's test on the permuted data
        permuted_zstat = rayleightest_zstat(jittered_phases)
        permuted_zstats.append(permuted_zstat)

    # Calculate the p-value for the permutation test
    # (note that the sign must be changed from ">=" to "<=" if using p-values instead of z-statistics)
    permutation_p_value = np.sum(np.array(permuted_zstats) >= observed_zstat) / permutations

    return observed_p_value, permutation_p_value

In [7]:
def analyze_group(subset_name, channel, unit_id, group_data):

    observed_p_value, permutation_p_value = permutation_test(group_data)

    return {
        'subset': subset_name,
        'channel': channel,
        'unit_id': unit_id,
        'p_original': observed_p_value,
        'p_permuted': permutation_p_value
    }

### Run permutation tests

In [8]:
# Define the subsets
subsets = {
    'NREM': data[data['stage'] == 'NREM'],
    'WREM': data[data['stage'] == 'WREM'],
    'DREM': data[data['DREM'] == 1]
}

results = []

for subset_name, subset_data in subsets.items():
    # Prepare tasks
    tasks = [(subset_name, channel, unit_id, group) for (channel, unit_id), group in subset_data.groupby(['channel', 'unit_id'])]
    
    # Run tasks in parallel
    parallel_results = Parallel(n_jobs=-1)(delayed(analyze_group)(*task) for task in tqdm(tasks, desc=f'Analyzing {subset_name}'))
    
    results.extend(parallel_results)

# Convert results to DataFrame
results_df = pd.DataFrame(results)

Analyzing NREM: 100%|██████████| 2516/2516 [01:46<00:00, 23.55it/s]
Analyzing WREM: 100%|██████████| 2516/2516 [01:45<00:00, 23.95it/s]
Analyzing DREM: 100%|██████████| 2516/2516 [01:16<00:00, 32.89it/s]


### Add mean phase angle and mean resultant length

In [9]:
# Function to calculate the mean resultant length
def mrl(samples, high=np.pi, low=-np.pi, axis=None):
    var = circvar(samples, high=high, low=low, axis=axis)
    R = 1 - var
    return R

# Function to apply to each group for calculating circular mean and MRL
def calculate_stats(group, high_val = np.pi, low_val = -np.pi):

    mean_phase = circmean(group, high=high_val, low=low_val)
    resultant_length = mrl(group, high=high_val, low=low_val)

    return pd.Series([mean_phase, resultant_length], index=['mean_phase', 'mean_rl'])

# Initialize an empty DataFrame to store the results
mean_angles = pd.DataFrame(columns=['subset', 'channel', 'unit_id', 'mean_phase', 'mean_rl'])

# Use subsets from prior cell
for subset_name, subset_data in subsets.items():

    # Group the subset data by 'channel' and 'unit_id'
    grouped = subset_data.groupby(['channel', 'unit_id'])

    # Apply calculate_stats function to compute mean_phase and mrl for each group
    group_stats = grouped['phase'].apply(calculate_stats).reset_index()

    # Add the 'subset' column to the group stats DataFrame and cast back to wide
    group_stats['subset'] = subset_name

    group_stats = group_stats.pivot_table(index=['channel', 'unit_id', 'subset'],
                                          columns='level_2', values='phase').reset_index()
    
    # Count the number of spikes for each subset-channel-unit_id pairing
    group_stats['n_spikes'] = grouped.size().reset_index(name='n_spikes')['n_spikes']

    # Append the results to the main results DataFrame
    mean_angles = pd.concat([mean_angles, group_stats], ignore_index=True)

# Rearrange the DataFrame columns if needed
mean_angles = mean_angles[['subset', 'channel', 'unit_id', 'mean_phase', 'mean_rl', 'n_spikes']]
mean_angles

Unnamed: 0,subset,channel,unit_id,mean_phase,mean_rl,n_spikes
0,NREM,LA1,S01_Ch195_neg_Unit3,1.499263,0.024515,12582.0
1,NREM,LA1,S01_Ch195_pos_Unit2,1.695347,0.031067,13288.0
2,NREM,LA1,S01_Ch196_neg_Unit1,1.602211,0.044770,8231.0
3,NREM,LA1,S01_Ch196_neg_Unit3,1.249254,0.037708,6270.0
4,NREM,LA1,S01_Ch196_neg_Unit4,0.769845,0.023571,7347.0
...,...,...,...,...,...,...
7543,DREM,RPI6,S01_Ch245_neg_Unit2,0.097985,0.063904,229.0
7544,DREM,RPI6,S01_Ch245_neg_Unit4,0.215068,0.031564,981.0
7545,DREM,RPI6,S01_Ch245_pos_Unit2,1.533068,0.143973,284.0
7546,DREM,RPI6,S01_Ch245_pos_Unit4,0.965964,0.061968,1448.0


In [10]:
results_df = pd.merge(results_df, mean_angles, on=['subset', 'channel', 'unit_id'], how='inner')

### FDR Correction

In [11]:
results_df['p_corrected'] = multipletests(results_df['p_permuted'], method='fdr_bh', alpha=0.05)[1]

In [12]:
results_df.to_csv(output_path, index = False)

results_df[results_df.p_corrected < 0.05]

Unnamed: 0,subset,channel,unit_id,p_original,p_permuted,mean_phase,mean_rl,n_spikes,p_corrected
1,NREM,LA1,S01_Ch195_pos_Unit2,2.693084e-06,0.011,1.695347,0.031067,13288.0,0.045124
2,NREM,LA1,S01_Ch196_neg_Unit1,6.840269e-08,0.004,1.602211,0.044770,8231.0,0.021443
6,NREM,LA1,S01_Ch197_pos_Unit1,1.898493e-08,0.005,0.682886,0.050604,6943.0,0.025535
8,NREM,LA1,S01_Ch200_neg_Unit2,1.263946e-14,0.000,1.259090,0.069251,6673.0,0.000000
13,NREM,LA1,S01_Ch203_neg_Unit4,4.451290e-27,0.000,2.217300,0.128337,3684.0,0.000000
...,...,...,...,...,...,...,...,...,...
7516,DREM,RPI6,S01_Ch197_neg_Unit1,2.069736e-08,0.005,-2.053519,0.160365,688.0,0.025535
7524,DREM,RPI6,S01_Ch203_neg_Unit4,2.324741e-19,0.000,0.617878,0.254007,665.0,0.000000
7527,DREM,RPI6,S01_Ch231_neg_Unit3,3.725749e-08,0.006,1.036860,0.203760,412.0,0.029313
7538,DREM,RPI6,S01_Ch244_neg_Unit2,5.418309e-08,0.005,0.941173,0.085513,2288.0,0.025535
