In [1]:
import os
import sys
import time
import pandas as pd
import pickle
import numpy as np
from fractions import Fraction
import scipy
from scipy.signal import resample_poly, iirnotch, filtfilt
import mne
from mne_bids import BIDSPath, write_raw_bids
import pyedflib
from tqdm import tqdm
from pqdm.processes import pqdm
import getpass
import argparse
import matplotlib.pyplot as plt
import gzip
import shutil
#sys.path.append("util")
#from utils import get_iEEG_data
import datetime
import ast
import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# definitions

def convert_hup_to_rid(rh_table,hup_id):
    # get number from hup_id
    hup_num = int(hup_id[3:6])
    # get rid from rid_hup_table
    rid = rh_table.loc[rh_table["hupsubjno"] == hup_num]["record_id"].values[0]
    return rid

def get_coherence(signals, labels, bands, interval_length = 1, fs = 200):
    # signals: numpy array of shape (n_channels, n_samples)
    # interval_length: length of interval in seconds
    # fs: sampling rate (Hz)
    # returns: median coherence of pkl

    # if signals is all nan, warn
    if np.isnan(signals).all():
        print("WARNING: signals contains only NaN values!")
    # replace NaN values with interpolated values
    signals = pd.DataFrame(signals).interpolate(axis=1).fillna(method="ffill",axis=1).fillna(method="bfill",axis=1).to_numpy()
    
    # get column names from pickle_data
    channel_names = labels

    indices_to_delete = []

    # subtract mean value from each channel
    signals = signals - np.mean(signals, axis=1, keepdims=True)

    #print("Calculating coherence...")
    
    # initialize list of len(channel_names) by len(channel_names) numpy array of coherences
    coherences_by_band = [np.array([[0.0 for k in range(len(channel_names))] for j in range(len(channel_names))]) for _ in range(len(bands))]

    # for each unique pair of channel names, calculate coherence
    for k in range(len(channel_names)):
        for j in range(k+1,len(channel_names)):
            # get coherence between channels k and j
            f, Cxy = scipy.signal.coherence(signals[k,int(2.5*fs):-int(2.5*fs)],signals[j,int(2.5*fs):-int(2.5*fs)],nperseg = 2*fs)
            for band in bands:
                band_start_hz = band[0]
                band_end_hz = band[1]
                # take the median coherence over the frequency band of interest
                # find the indices of the start and end of the band
                ind_start = np.argmax(f*fs >= band_start_hz)
                ind_end = np.argmax(f*fs >= band_end_hz)
                coherences_by_band[bands.index(band)][k,j] = np.median(Cxy[ind_start:ind_end])

    for k in range(len(coherences_by_band)):
        coherences = coherences_by_band[k]
        # symmetrize coherences
        coherences = coherences + coherences.T - np.diag(np.diag(coherences))
        # self-coherence is 1
        coherences = coherences + np.identity(len(channel_names))
        coherences_by_band[k] = pd.DataFrame(coherences, columns = channel_names, index = channel_names)

    return coherences_by_band

In [4]:
sleep_stage_dict = {
    "R": 1,
    "W": 2,
    "N1": 3,
    "N2": 4,
    "N3": 5
}

# skip N1
sleep_stages_to_run = ["W","N2","N3","R"]
# band cutoffs in Hz
bands = {
    "delta": [1, 4],
    "theta": [4, 8],
    "alpha": [8, 12],
    "beta": [12, 30],
    "gamma": [30, 80],
    "broad": [1, 80]
}

band_names = list(bands.keys())

parent_directory = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))),'data')
atlas_directory = os.path.join(parent_directory,"atlas")
data_final_directory = os.path.join(parent_directory,"data_final")

assert os.path.isdir(parent_directory), f"{parent_directory} does not exist."

# read metadata csv
metadata_csv = pd.read_csv("../../data/data_final/combined_atlas_metadata.csv")

# read signals npz
signals_npz = np.load("../../data/data_final/combined_atlas_signals.npz")

# read channel_harmonize.xlsx
ch_xls = pd.ExcelFile("../channel_harmonize.xlsx")
# read using first row as labels
ch_harmonize = pd.read_excel(ch_xls, "Sheet1", header=0)

amf_xls = pd.ExcelFile("../atlas_metadata_final.xlsx")
# read using first row as labels
amf = pd.read_excel(amf_xls, "Sheet1", header=0)

# load rid_hup_table.csv
rid_hup_table = pd.read_csv("/mnt/leif/littlab/users/ianzyong/sleep-atlas/util/rid_hup_table.csv")

# calculate the atlas
# first, get coherence matrices

# get unique values of metadata_csv["pt"]
all_pts = sorted(list(set(metadata_csv["pt"])))

# get set of unique values in "reg" column from metadata_csv
regions = set(metadata_csv["reg"])
# remove nans
regions = [x for x in regions if str(x) != 'nan']
# sort and convert to list
regions = sorted(list(regions))
print(f"Regions: {regions}")

atlases = [[0 for _ in range(len(band_names))] for _ in range(len(sleep_stages_to_run))]
# set every value to an empty list
for a in range(len(atlases)):
    for b in range(len(atlases[a])):
        atlases[a][b] = pd.DataFrame(np.zeros((len(regions),len(regions))),columns=regions,index=regions,dtype=object)
        for m in range(len(atlases[a][b])):
            for n in range(len(atlases[a][b])):
                atlases[a][b].iloc[m,n] = []

#atlases = np.array(atlases)
band_cutoffs = [bands[x] for x in band_names]

Regions: ['Amyg_Hipp_L', 'Amyg_Hipp_R', 'Cingulum_L', 'Cingulum_R', 'FMO_Rect_L', 'FMO_Rect_R', 'FSM_SMA_L', 'FSM_SMA_R', 'Frontal_Mid_All_L', 'Frontal_Mid_All_R', 'Frontal_Sup_All_L', 'Frontal_Sup_All_R', 'Frontal_inf_All_L', 'Frontal_inf_All_R', 'Fusiform_L', 'Fusiform_R', 'Insula_L', 'Insula_R', 'Occipital_Lat_L', 'Occipital_Lat_R', 'Occipital_Med_L', 'Occipital_Med_R', 'ParaHippocampal_L', 'ParaHippocampal_R', 'Parietal_Sup_Inf_L', 'Parietal_Sup_Inf_R', 'Postcentral_L', 'Postcentral_R', 'Precentral_L', 'Precentral_R', 'Precuneus_PCL_L', 'Precuneus_PCL_R', 'SupraMarginal_Angular_L', 'SupraMarginal_Angular_R', 'Temporal_Inf_L', 'Temporal_Inf_R', 'Temporal_Mid_L', 'Temporal_Mid_R', 'Temporal_Sup_L', 'Temporal_Sup_R', 'thalam_limbic_L', 'thalam_limbic_R']


In [4]:
# for each row in metadata_csv
coherence_results = []
for pt in all_pts:
    print(f"Calculating coherences for {pt}... (patient {all_pts.index(pt)+1} of {len(all_pts)}))")
    # get dataframe where metadata_csv["pt"] == pt and normative is True
    this_pt_metadata = metadata_csv[(metadata_csv["pt"] == pt)]
    # get index of rows where metadata_csv["pt"] == pt
    this_pt_indices = this_pt_metadata.index
    this_pt_normative_indices = this_pt_metadata[this_pt_metadata["normative"] == True].index
    # get channel names
    this_pt_channels = this_pt_metadata["name"]
    
    for stage in sleep_stages_to_run:
        stage_index = sleep_stages_to_run.index(stage)
        signals = signals_npz[stage][this_pt_indices,:]
        if stage == "W":
            print("Shape of signals: ", signals.shape)
        coherence_result = get_coherence(signals,this_pt_channels,band_cutoffs)
        coherence_results.append(coherence_result)
        
        print(f"Adding values to {stage} atlas...")

        # add coherences from normative channel pairs to the corresponding atlas
        for k in range(len(this_pt_normative_indices)):
            for j in range(k,len(this_pt_normative_indices)):
                ind_1 = this_pt_normative_indices[k]
                ind_2 = this_pt_normative_indices[j]
                # get the channel names from the metadata dictionary
                row_ch = metadata_csv["name"].iloc[ind_1]
                col_ch = metadata_csv["name"].iloc[ind_2]
                # get the region names from the metadata dictionary
                row_reg = metadata_csv["reg"].iloc[ind_1]
                col_reg = metadata_csv["reg"].iloc[ind_2]
                # add the value to the corresponding list in this_atlas
                for m in range(len(band_names)):
                    this_atlas = atlases[stage_index][m]
                    this_atlas.loc[row_reg,col_reg].append(coherence_result[m].loc[row_ch,col_ch])

# save this_atlas as a .csv file
for k in range(len(sleep_stages_to_run)):
    for m in range(len(band_names)):
        this_atlas = atlases[k][m]
        this_atlas.to_csv(os.path.join(data_final_directory,"atlas",f"{sleep_stages_to_run[k]}_{band_names[m]}_atlas.csv"))

print(f"Atlases saved.")

Calculating coherences for MNI001... (patient 1 of 211))
Shape of signals:  (14, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for MNI002... (patient 2 of 211))
Shape of signals:  (15, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for MNI003... (patient 3 of 211))
Shape of signals:  (11, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for MNI004... (patient 4 of 211))
Shape of signals:  (7, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for MNI005... (patient 5 of 211))
Shape of signals:  (17, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas

  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to W atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to N2 atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to N3 atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to R atlas...
Calculating coherences for sub-RID0060... (patient 124 of 211))
Shape of signals:  (69, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for sub-RID0063... (patient 125 of 211))
Shape of signals:  (44, 6000)


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to W atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to N2 atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to N3 atlas...


  Cxy = np.abs(Pxy)**2 / Pxx / Pyy


Adding values to R atlas...
Calculating coherences for sub-RID0064... (patient 126 of 211))
Shape of signals:  (92, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for sub-RID0065... (patient 127 of 211))
Shape of signals:  (96, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for sub-RID0068... (patient 128 of 211))
Shape of signals:  (78, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for sub-RID0070... (patient 129 of 211))
Shape of signals:  (78, 6000)
Adding values to W atlas...
Adding values to N2 atlas...
Adding values to N3 atlas...
Adding values to R atlas...
Calculating coherences for sub-RID0089... (patient 130 of 211))
Shape of signals:  (109, 6000)
Adding values to W atlas...
Adding values to 

In [29]:
# save numpy to disk
coherence_results_np = np.array(coherence_results, dtype=object)

# get every 4th entry
c_W = coherence_results_np[0::4]
c_N2 = coherence_results_np[1::4]
c_N3 = coherence_results_np[2::4]
c_R = coherence_results_np[3::4]

# save coherence results
np.save(os.path.join(data_final_directory,'coherence_results.npy'), coherence_results_np)
np.savez(os.path.join(data_final_directory,'coherence_results.npz'), W=c_W, N2=c_N2, N3=c_N3, R=c_R)

In [75]:


# load coherence_results
coherence_results_np = np.load(os.path.join(data_final_directory,'coherence_results.npy'), allow_pickle=True)
z_scores_np = np.copy(coherence_results_np)
result_counter = 0

# load atlas csvs into numpy array (stages x bands)
atlases = np.array([[pd.read_csv(os.path.join(data_final_directory,"atlas",f"{stage}_{band}_atlas.csv"), index_col=0) for band in band_names] for stage in sleep_stages_to_run])
    
# for each row in metadata_csv
for pt in all_pts:
    print(f"Calculating z-scores for {pt}... (patient {all_pts.index(pt)+1} of {len(all_pts)}))")
    # get dataframe where metadata_csv["pt"] == pt and normative is True
    this_pt_metadata = metadata_csv[(metadata_csv["pt"] == pt)]
    # get index of rows where metadata_csv["pt"] == pt
    this_pt_indices = this_pt_metadata.index
    this_pt_normative_indices = this_pt_metadata[this_pt_metadata["normative"] == True].index
    # get channel names
    this_pt_channels = this_pt_metadata["name"]
    # get channel regions
    this_pt_regions = this_pt_metadata["reg"]
    
    for stage in sleep_stages_to_run:
        stage_index = sleep_stages_to_run.index(stage)
        z_scores_by_band = z_scores_np[result_counter]
        result_counter += 1
        
        for band in band_names:
            band_index = band_names.index(band)
            # get the coherence results for this band
            z_scores = z_scores_by_band[band_index]
            this_atlas = pd.DataFrame(atlases[stage_index][band_index],columns=regions,index=regions,dtype=object)

            # for each row and column of z_scores
            for i in range(len(this_pt_regions)):
                for j in range(i,len(this_pt_regions)):
                    feature_val = z_scores.iloc[i,j]
                    row_reg = this_pt_regions.iloc[i]
                    col_reg = this_pt_regions.iloc[j]
                    atlas_conns = this_atlas.loc[row_reg,col_reg]
                    #print(f"atlas_conns: {atlas_conns}")
                    if atlas_conns == "[]":
                        #print(f"{stage}, {band_name}: no atlas distribution for {row_reg} to {col_reg}.")
                        z_scores.iloc[i,j] = np.nan
                        continue
                    # convert string of list to list
                    atlas_conns = [float(s.strip()) for s in atlas_conns[1:-1].split(',')]
                    # if the feature_val is in atlas_conns, remove it (leave this normative connection for this patient out of the atlas distribution)
                    if np.count_nonzero(atlas_conns == feature_val) > 0:
                        #print("Connection from this patient removed from atlas distribution before scoring.")
                        atlas_conns = np.delete(atlas_conns, np.where(atlas_conns == feature_val)[0])
                    # suppress RunTimeWarning for np.nanmean and np.nanstd
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore", category=RuntimeWarning)
                        # calculate the absolute value z-score for this value
                        z_scores.iloc[i,j] = abs((feature_val - np.nanmean(atlas_conns)) / np.nanstd(atlas_conns))
            
            # symmetrize z_scores by overwriting bottom triangle with top triangle
            z_scores = np.triu(z_scores) + np.triu(z_scores).T
            # set diagonal to np.nan
            np.fill_diagonal(z_scores, np.nan)
            z_scores_np[result_counter-1][band_index] = z_scores
            

# get every 4th entry of z_scores_np
z_scores_W = z_scores_np[0::4]
z_scores_N2 = z_scores_np[1::4]
z_scores_N3 = z_scores_np[2::4]
z_scores_R = z_scores_np[3::4]

# save this_atlas as a .csv file
np.savez(os.path.join(data_final_directory,'coherence_atlas_z_scores.npz'), W=z_scores_W, N2=z_scores_N2, N3=z_scores_N3, R=z_scores_R)

print(f"Z-scores saved.")

Calculating z-scores for MNI001... (patient 1 of 211))
Calculating z-scores for MNI002... (patient 2 of 211))
Calculating z-scores for MNI003... (patient 3 of 211))
Calculating z-scores for MNI004... (patient 4 of 211))
Calculating z-scores for MNI005... (patient 5 of 211))
Calculating z-scores for MNI006... (patient 6 of 211))
Calculating z-scores for MNI007... (patient 7 of 211))
Calculating z-scores for MNI008... (patient 8 of 211))
Calculating z-scores for MNI009... (patient 9 of 211))
Calculating z-scores for MNI010... (patient 10 of 211))
Calculating z-scores for MNI011... (patient 11 of 211))
Calculating z-scores for MNI012... (patient 12 of 211))
Calculating z-scores for MNI013... (patient 13 of 211))
Calculating z-scores for MNI014... (patient 14 of 211))
Calculating z-scores for MNI015... (patient 15 of 211))
Calculating z-scores for MNI016... (patient 16 of 211))
Calculating z-scores for MNI017... (patient 17 of 211))
Calculating z-scores for MNI018... (patient 18 of 211))
C

In [47]:
# load z-scores
z_scores_npz = np.load(os.path.join(data_final_directory,'coherence_atlas_z_scores.npz'), allow_pickle=True)

final_z_scores = []

for stage in sleep_stages_to_run:
    print(f"Calculating final z-scores for {stage}...")
    this_band_scores = z_scores_npz[stage]
    # take 75th percentile of z-scored data
    # arrange as (channels x bands)
    final_z_scores.append(np.vstack([np.transpose(np.vstack([np.nanquantile(x, 0.75, axis=0) for x in this_band_scores[y]])) for y in range(len(this_band_scores))]))
    
np.savez(os.path.join(data_final_directory,'combined_atlas_coherence_normative.npz'), W=final_z_scores[0], N2=final_z_scores[1], N3=final_z_scores[2], R=final_z_scores[3])
print("Saved final z-scores to file.")


Calculating final z-scores for W...


  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
  subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t >= 0.5)
  diff_b_a = subtract(b, a)
  lerp_interpolation = asanyarray(add(a, diff_b_a * t, out=out))


Calculating final z-scores for N2...
Calculating final z-scores for N3...
Calculating final z-scores for R...
Saved final z-scores to file.


In [15]:
bp_npz = np.load(os.path.join(data_final_directory,'combined_atlas_bandpower_normative.npz'), allow_pickle=True)

In [17]:
bp_npz['W'].shape

(8979, 5)

In [46]:
a = z_scores_npz['W'][0][0]
#print(a)
# take 75th percentile
#a = np.quantile(a, 0.75)
print(np.nanquantile(a, 0.75, axis=0))
b = np.vstack([np.transpose(np.vstack([np.nanquantile(x, 0.75, axis=0) for x in z_scores_npz['W'][y]])) for y in range(len(z_scores_npz['W']))])
print(pd.DataFrame(b))
print(b.shape)

[0.59303222 0.61858896 0.64911486 0.78476947 0.88125483 1.00648916
 0.84109694 0.84109694 0.700735   0.64601339 0.73746132 0.78735107
 0.79332803 0.82700746]


  result = np.apply_along_axis(_nanquantile_1d, axis, a, q,
  subtract(b, diff_b_a * (1 - t), out=lerp_interpolation, where=t >= 0.5)
  diff_b_a = subtract(b, a)
  lerp_interpolation = asanyarray(add(a, diff_b_a * t, out=out))


             0         1         2         3         4         5
0     0.593032  0.835182  0.635489  0.523970  0.705390  0.763221
1     0.618589  0.780027  0.551182  0.475351  0.619678  0.580298
2     0.649115  0.633615  0.581227  0.407827  0.645108  0.476063
3     0.784769  0.826231  0.558500  0.493589  0.498515  0.473783
4     0.881255  0.896117  0.946699  0.787405  0.826642  0.944673
...        ...       ...       ...       ...       ...       ...
8974  7.379353  1.180578  1.815784  1.115036  1.211963  1.209485
8975  7.562499  1.223285  1.804566  0.791875  1.433807  1.350296
8976  1.473163  0.876093  1.103916  1.011843  0.835531  1.016463
8977  1.727455  0.833897  1.267967  1.074397  0.843194  1.193091
8978  1.561171  0.814169  1.126735  1.311727  0.988390  1.036492

[8979 rows x 6 columns]
(8979, 6)


In [20]:
band_names

['delta', 'theta', 'alpha', 'beta', 'gamma', 'broad']