In [1]:
#Import necessary packages
from scipy import signal
import numpy as np
import scipy as sp
from scipy import io
from thoi.measures.gaussian_copula import multi_order_measures
import os
import pickle
import pandas as pd

  from tqdm.autonotebook import tqdm


In [2]:
#Load pickle file path
file_path_pkl = '/Users/krisha/Desktop/BCM/Analysis/HOI_implementation/HOI_LLK/Data/PKL files/processed_subject_dict_MOM.pkl' 
# Load analysis results using 'pickle.load'
with open(file_path_pkl, 'rb') as file:
    subject_dict_MOM = pickle.load(file)

## Compute Deltas

#### Clean Data

In [3]:
#RS_024 24H is an acquisition error that leads to inf values. Remove that key from the dictionary and treat it as missing data.
del(subject_dict_MOM['RS_024']['24H'])

In [4]:
#Keep the correlation analysis to O and S info. For frequency bands, keep it to Alpha, Beta, Gamma, Theta:
keys_to_keep = {'o', 's'}
filtered_subject_dict_MOM = {
    subj: {time: {k: {key: val for key, val in values.items() if key in keys_to_keep}
                  for k, values in time_data.items() if k in {'Alpha', 'Beta', 'Gamma', 'Theta'}}
           for time, time_data in subj_data.items()}
    for subj, subj_data in subject_dict_MOM.items()
}

In [5]:
def compute_deltas(mega_dict):
    """
    This function computes within-subject deltas between baseline and subsequent timepoints ('1H', '24H', 'Day7')
    for the O and S information from the input `mega_dict`.

    The input `mega_dict` is structured with:
    - Subject IDs as keys.
    - Values as dictionaries of timepoints ('Baseline', '1H', '24H', 'Day7'), where each timepoint holds
      frequency band data (e.g., 'Alpha', 'Beta', 'Gamma', 'Theta') containing numerical values ('o', 's').

    The function performs the following:
    1. Initializes an empty `delta_dict` to store results.
    2. Iterates over subjects, skipping those without baseline data.
    3. For each subject, computes deltas between baseline and '1H', '24H', and 'Day7' for each frequency band.
    4. Calculates deltas for numerical values and preserves other columns (if any).
    5. If timepoint data is missing, fills deltas with `None`.
    6. Returns `delta_dict` containing computed deltas for each subject.
    """

    # Initialize a container to hold the deltas for each subject
    delta_dict = {}

    # Iterate through each subject in the mega dictionary
    for subject_id, timepoints in mega_dict.items():
        # Initialize a container for the current subject's deltas
        subject_deltas = {}
        
        # Retrieve baseline values for all frequency bands
        baseline_data = timepoints.get('Baseline', None)
        
        # If baseline is missing, skip this subject
        if baseline_data is None:
            continue
        
        # Define the delta timepoints to compare with baseline
        delta_timepoints = ['1H', '24H', 'Day7']
        
        # Iterate through each timepoint (1H, 24H, Day7)
        for tp in delta_timepoints:
            # Get the data for the current timepoint, if available
            tp_data = timepoints.get(tp, None)
            
            # Initialize a container for the current timepoint deltas
            timepoint_deltas = {}

            # Iterate through each frequency band (Alpha, Beta, etc.)
            for freq_band, baseline_df in baseline_data.items():
                # Prepare a dictionary to store deltas
                freq_band_deltas = {}

                # Get the dataframe for the current timepoint and frequency band, if available
                tp_freq_df = tp_data.get(freq_band, None) if tp_data else None
                
                if tp_freq_df is not None:
                    # Calculate deltas for 'o' and 's' row-wise
                    for var in ['o', 's']:
                        freq_band_deltas[f'{var}_deltas'] = tp_freq_df[var] - baseline_df[var]
                else:
                    # If data is missing for the current timepoint, populate deltas with None
                    for var in ['o', 's']:
                        freq_band_deltas[f'{var}_deltas'] = None

                # Store the deltas for this frequency band
                timepoint_deltas[f'{freq_band}_deltas'] = freq_band_deltas
            
            # Store the deltas for the current timepoint (1H, 24H, Day7)
            subject_deltas[f'{tp}'] = timepoint_deltas

        # Add the subject's deltas to the delta dictionary
        delta_dict[f'{subject_id}_deltas'] = subject_deltas
    
    return delta_dict


In [6]:
#Run the function:
subject_MOM_deltas=compute_deltas(filtered_subject_dict_MOM)

In [7]:
#Check structure of the new dictionary
#At this point, RS_024 24H is already buggy
print(subject_MOM_deltas['RS_004_deltas']['24H'].keys())

dict_keys(['Alpha_deltas', 'Beta_deltas', 'Gamma_deltas', 'Theta_deltas'])


In [8]:
#Check delta row count
print(subject_MOM_deltas['RS_004_deltas']['24H']['Alpha_deltas']['o_deltas'].shape[0])

32647


### Separate Data Into Ketamine and Midazolam Dictionaries

In [9]:
#Create a list of subject_ids with the midazolam subjects
midazolam_subjects=['RS_004_deltas', 'RS_007_deltas', 'RS_010_deltas', 'RS_016_deltas', 'RS_017_deltas', 'RS_019_deltas', 'RS_024_deltas', 'RS_025_deltas', 'RS_028_deltas', 'RS_029_deltas', 'RS_036_deltas', 'RS_041_deltas', 'RS_042_deltas'] 

In [10]:
#Extract only the midazolam subjects from the mega dictionary into a separate dictionary
midazolam_deltas_dict = {key: subject_MOM_deltas[key] for key in midazolam_subjects if key in subject_MOM_deltas}

In [11]:
#As there are only two conditions:
#Extract the ketamine subjects from the mega dictionary into a separate dictionary by using "not in"
ketamine_deltas_dict= {key: value for key, value in subject_MOM_deltas.items() if key not in midazolam_subjects}

### Compute Correlations

In [40]:
#Load Clinical Data for CADSS percentage changes
CADSS_deltas= pd.read_csv('CADSS_1H_deltas.csv')

In [41]:
print(CADSS_deltas)

        SubjectID  group  CADSS_Day0  CADSS_Day1_40mins  CADSS_1H_rawchange
0   RS_006_deltas      1           0                  0                   0
1   RS_009_deltas      1           0                  4                   4
2   RS_011_deltas      1           5                 35                  30
3   RS_012_deltas      1           5                 57                  52
4   RS_018_deltas      1           1                  2                   1
5   RS_020_deltas      1           0                  3                   3
6   RS_021_deltas      1           0                 25                  25
7   RS_030_deltas      1           0                  0                   0
8   RS_032_deltas      1           0                  0                   0
9   RS_033_deltas      1           0                 11                  11
10  RS_034_deltas      1           0                  4                   4
11  RS_035_deltas      1           0                  7                   7
12  RS_037_d

In [42]:
#Set index for the CADSS data to be the subject ID in order to grab relevant data in the future
CADSS_deltas= CADSS_deltas.set_index("SubjectID")

In [43]:
print(CADSS_deltas.columns)

Index(['group', 'CADSS_Day0', 'CADSS_Day1_40mins', 'CADSS_1H_rawchange'], dtype='object')


In [44]:
def extract_subject_cadss(subject_data, cadss_data):
    '''
    Extracts relevant subjects from the raw CADSS data based on group assignment (ketamine or midazolam).
    Filters the data based on the group value in the subject dictionary.
    '''
    # Initialize an empty list to hold filtered data.
    filtered_data_frame=[]
    # Check the group of the first subject in `subject_data` (0 for ketamine, 1 for midazolam)
    if cadss_data.loc[list(subject_data.keys())[0]]['group'] == 0:
        # Filter data for midazolam group subjects
        filtered_data_frame=cadss_data[cadss_data['group']==0]
    else:
        # Filter data for ketamine group subjects
        filtered_data_frame=cadss_data[cadss_data['group']==1]
    return filtered_data_frame


#Extracts relevant CADSS columns based on timepoint 
def get_target_cadss(cadss_data, timepoint):
    '''
    Extracts specific CADSS columns for the raw change.
    '''
    return cadss_data['CADSS_1H_rawchange']

def get_measure_row_data(subject_data, timepoint, band, measure, row):
    '''
    Returns a row of data across all subjects for a specific timepoint, band, and measure.
    `row` specifies the row number for the given band and measure.
    '''
    row_series = []
    for subject_id in subject_data.keys():
        # Check if data is available for the given timepoint, band, measure, and subject ID
        if subject_data[subject_id][timepoint][band][measure] is None:
            continue
        else:
            # Append the value at the specified row for the given band and measure
            row_series.append(subject_data[subject_id][timepoint][band][measure][row])
    return pd.Series(row_series)

def correlate_x_y(measure_row_data, cadss_data):
    '''
    Calculates the Pearson correlation between a row of measure data and CADSS data.
    Resets the index of `cadss_data` to align it with `measure_row_data`.
    '''
    cadss_data.reset_index(drop=True, inplace=True)
    return cadss_data.corr(measure_row_data, method="pearson")


def create_band_measure_corr(subject_data, cadss_data, timepoint, band, measure, measure_size):
    '''
    Creates a series of correlations for each row of data within a specific timepoint, band, and measure.
    Returns a series where each element is the correlation for a row of data.
    '''
    corr_series = []
    for row in range(measure_size):
        # Get data for the specified row across all subjects
        row_data_temp=get_measure_row_data(subject_data, timepoint, band, measure, row)
        # Calculate the correlation for this row and append to the results list
        corr_series.append(correlate_x_y(row_data_temp, cadss_data)) 
    return pd.Series(corr_series) # Return the list of correlation results as a Pandas Series

In [45]:
def compute_subject_correlations_df(subject_data, cadss_data_raw, timepoint):
    '''
    Function to compute subject correlations for a specific subject timepoint and return them as a DataFrame
    '''
    # Extract relevant clinical data for the subject dictionary in the parameter (ketamine or midazolam)
    # from the raw CADSS data
    cadss_data_km = extract_subject_cadss(subject_data, cadss_data_raw)
    # Extract target CADSS columns based on the specified subject timepoint
    cadss_data = get_target_cadss(cadss_data_km, timepoint)

    # Initialize an empty DataFrame with MEASURE_SIZE rows (32,647 rows)
    MEASURE_SIZE = 32647 
    correlations_df = pd.DataFrame(index=range(MEASURE_SIZE))
    
    # Create a list of the bands and measures to loop through
    bands = ['Alpha_deltas', 'Beta_deltas', 'Gamma_deltas', 'Theta_deltas']
    measures = ['o_deltas', 's_deltas']
    timepoint_names = ["1HCADSS"]

    # Loop through each band for correlation calculations
    for band in bands:

        # Loop through each measure for correlation calculations
        for measure in measures:

            # Clean up the band and measure names for use in column names
            band_name = band.replace("_deltas", "")
            measure_name = measure.replace("_deltas", "")

            result = create_band_measure_corr(subject_data, cadss_data, timepoint, band, measure, MEASURE_SIZE)
            col_name = f"{band_name}_{measure_name}_{timepoint_names[0]}_corr"
            correlations_df[col_name] = result
    
    return correlations_df # Return the complete DataFrame of correlations

In [46]:
#Run the correlations function for all the timepoints for both dictionaries
ket_correlations_CADSS1_1H= compute_subject_correlations_df(ketamine_deltas_dict, CADSS_deltas, '1H')
ket_correlations_CADSS1_24H= compute_subject_correlations_df(ketamine_deltas_dict, CADSS_deltas, '24H')
ket_correlations_CADSS1_Day7= compute_subject_correlations_df(ketamine_deltas_dict, CADSS_deltas, 'Day7')

#Run the correlations function for all the timepoints for both dictionaries
mid_correlations_CADSS1_1H= compute_subject_correlations_df(midazolam_deltas_dict, CADSS_deltas, '1H')
mid_correlations_CADSS1_24H= compute_subject_correlations_df(midazolam_deltas_dict, CADSS_deltas, '24H')
mid_correlations_CADSS1_Day7= compute_subject_correlations_df(midazolam_deltas_dict, CADSS_deltas, 'Day7')


In [47]:
#Obtain R squared for all the dataframes

ket_rsquared_CADSS1_1H= ket_correlations_CADSS1_1H.pow(2)
ket_rsquared_CADSS1_24H= ket_correlations_CADSS1_24H.pow(2)
ket_rsquared_CADSS1_Day7= ket_correlations_CADSS1_Day7.pow(2)
mid_rsquared_CADSS1_1H= mid_correlations_CADSS1_1H.pow(2)
mid_rsquared_CADSS1_24H= mid_correlations_CADSS1_24H.pow(2)
mid_rsquared_CADSS1_Day7= mid_correlations_CADSS1_Day7.pow(2)

In [48]:
# Define a function to get the max index and value for each column
def get_max_info(df):
    max_indices = df.idxmax()
    max_values = df.max()
    result= pd.DataFrame({'Max_Index': max_indices, 'Max_Value': max_values})
    return result
# Apply the function to each DataFrame and store in a dictionary

ketamine_1H_CADSS1_maxind= get_max_info(ket_rsquared_CADSS1_1H)
ketamine_24H_CADSS1_maxind= get_max_info(ket_rsquared_CADSS1_24H)
ketamine_Day7_CADSS1_maxind= get_max_info(ket_rsquared_CADSS1_Day7)
midazolam_1H_CADSS1_maxind= get_max_info(mid_rsquared_CADSS1_1H)
midazolam_24H_CADSS1_maxind= get_max_info(mid_rsquared_CADSS1_24H)
midazolam_Day7_CADSS1_maxind= get_max_info(mid_rsquared_CADSS1_Day7)

In [52]:
type(ketamine_1H_CADSS1_maxind)

pandas.core.frame.DataFrame

In [54]:
#Export to CSV:

#R squared

ket_rsquared_CADSS1_1H.to_csv("Ketamine 1 Hour CADSS 1H R Squared.csv")
ket_rsquared_CADSS1_24H.to_csv("Ketamine 24 Hour CADSS 1H R Squared.csv")
ket_rsquared_CADSS1_Day7.to_csv("Ketamine Day 7 CADSS 1H R Squared.csv")
mid_rsquared_CADSS1_1H.to_csv("Midazolam 1 Hour CADSS 1H R Squared.csv")
mid_rsquared_CADSS1_24H.to_csv("Midazolam 24 Hour CADSS 1H R Squared.csv")
mid_rsquared_CADSS1_Day7.to_csv("Midazolam Day 7 CADSS 1H R Squared.csv")

#Maximum Indices

ketamine_1H_CADSS1_maxind.to_csv("Ketamine 1 Hour CADSS 1H Highest Correlations.csv")
ketamine_24H_CADSS1_maxind.to_csv("Ketamine 24 Hour CADSS 1H Highest Correlations.csv")
ketamine_Day7_CADSS1_maxind.to_csv("Ketamine Day 7 CADSS 1H Highest Correlations.csv")
midazolam_1H_CADSS1_maxind.to_csv("Midazolam 1 Hour CADSS 1H Highest Correlations.csv")
midazolam_24H_CADSS1_maxind.to_csv("Midazolam 24 Hour CADSS 1H Highest Correlations.csv")
midazolam_Day7_CADSS1_maxind.to_csv("Midazolam Day 7 CADSS 1H Highest Correlations.csv")

In [None]:
'''
Dictionary version (Unrefined- look at it again)

def compute_subject_correlations(subject_data, cadss_data_raw, timepoint):
    #Extract relevant clinical data for the subject dictionary in the parameter (ketamine or midazolam) from the raw CADSS data
    cadss_data_km = extract_subject_cadss(subject_data, cadss_data_raw)
    #Extract target CADSS columns, depending on the subject timepoint 
    cadss_data = get_target_cadss_tp(cadss_data_km, timepoint)

    #Initialize a correlations dictionary
    correlations={}
    #Create a list of the bands and measures to loop through in your dictionary
    bands = ['Alpha_deltas', 'Beta_deltas', 'Gamma_deltas', 'Theta_deltas']
    measures = ['o_deltas', 's_deltas']
    timepoint_names = ["24HCADSS", "7DAYCADSS"]
    #Size of the HOI measure rows (change if the size changes 
    MEASURE_SIZE = 32647 
    #
    for band in bands:
        for measure in measures:
            if timepoint =='1H':
                for cadss_tp_index in range(len(cadss_data)):                        
                   result = create_band_measure_corr(subject_data, cadss_data[cadss_tp_index], timepoint, band, measure, MEASURE_SIZE) 
                   correlations.setdefault(band, {}).setdefault(measure, {})[f"{band.replace("_deltas", "")}_{measure.replace("_deltas", "")}_{timepoint_names[cadss_tp_index]}_corr"] = result
            elif timepoint == '24H':
                result = create_band_measure_corr(subject_data, cadss_data, timepoint, band, measure, MEASURE_SIZE)
                correlations.setdefault(band, {}).setdefault(measure, {})[f"{band.replace("_deltas", "")}_{measure.replace("_deltas", "")}_{timepoint_names[0]}_corr"] = result
            else:
                result = create_band_measure_corr(subject_data, cadss_data, timepoint, band, measure, MEASURE_SIZE)
                correlations.setdefault(band, {}).setdefault(measure, {})[f"{band.replace("_deltas", "")}_{measure.replace("_deltas", "")}_{timepoint_names[1]}_corr"] = result
    
    return correlations
'''