In [1]:
import json
import os

import numpy as np
import pandas as pd
import tabulate

In [2]:
relevant_channels = ['F3-A2', 'F4-A1', 'C3-A2', 'C4-A1']

In [3]:
cohort_mapping = json.load(open("../../output/annotations/nikitin/cohort_mapping.json", "r"))
cohorts = ['BP', 'HC']

# base functions

In [4]:
def overlap(spindle_df, sleepstages):
    sleepstages = np.repeat(sleepstages, 2)  # transform sleepstages to half epoches
    sleepstages = sleepstages == 'N2'
    spindle_centers = (spindle_df[['start', 'end']].sum(axis=1) // 2).astype(int).values
    if len(sleepstages) < max(spindle_centers) // 15:
        print(f'Warning: sleepstages shorter than spindles, '
              f'ignoring {sum(spindle_centers // 15 >= len(sleepstages))} spindles')
    idx = [sleepstages[center // 15]
           if center // 15 < len(sleepstages) else False
           for center in spindle_centers]
    return spindle_df[idx]


def merge_and_filter_spindles(spindle_df, merge_dur=0.3, merge_dist=0.1, min_dur=0.3, max_dur=2.5):
    spindle_df.sort_values('start', inplace=True)
    durations = (spindle_df['end'] - spindle_df['start']).values
    distances = spindle_df['start'].iloc[1:].values - spindle_df['end'].iloc[:-1].values
    to_merge = (durations[:-1] < merge_dur) & (durations[1:] < merge_dur) & (distances < merge_dist)
    spindle_df.loc[np.r_[to_merge, False], 'end'] = spindle_df.loc[np.r_[False, to_merge], 'end'].values
    spindle_df = spindle_df[~np.r_[False, to_merge]]

    durations = spindle_df['end'] - spindle_df['start']
    to_filter = (durations < min_dur) | (durations > max_dur)
    spindle_df = spindle_df[~to_filter]

    print(f'Merged {np.sum(to_merge)} spindles, filtered {np.sum(to_filter)} spindles')

    return spindle_df


def calc_metrics_cohort_channel(spindles_dfs, characteristics, channel, cohort):
    spm_fast, freq_fast, dur_fast, amp_ptp_fast, amp_hil_fast = [], [], [], [], []
    spm_slow, freq_slow, dur_slow, amp_ptp_slow, amp_hil_slow = [], [], [], [], []

    for recording in [k for k, v in cohort_mapping.items() if v == cohort]:
        filtered_spindles_fast = spindles_dfs[channel].query(
            f'`file` == @recording and `frequency` > 13')
        spm_fast.append(len(filtered_spindles_fast) / characteristics[recording])
        freq_fast.append(filtered_spindles_fast['frequency'].mean())
        dur_fast.append((filtered_spindles_fast.end - filtered_spindles_fast.start).mean())
        amp_ptp_fast.append(filtered_spindles_fast['amplitude ptp'].mean() * 1e6)
        amp_hil_fast.append(filtered_spindles_fast['amplitude hilbert'].mean() * 1e6)

        filtered_spindles_slow = spindles_dfs[channel].query(
            f'`file` == @recording and `frequency` <= 13')
        spm_slow.append(len(filtered_spindles_slow) / characteristics[recording])
        freq_slow.append(filtered_spindles_slow['frequency'].mean())
        dur_slow.append((filtered_spindles_slow.end - filtered_spindles_slow.start).mean())
        amp_ptp_slow.append(filtered_spindles_slow['amplitude ptp'].mean() * 1e6)
        amp_hil_slow.append(filtered_spindles_slow['amplitude hilbert'].mean() * 1e6)

    return {
        'spm_fast': spm_fast,
        'frequency_fast': freq_fast,
        'duration_fast': dur_fast,
        'amplitude_ptp_fast': amp_ptp_fast,
        'amplitude_hilbert_fast': amp_hil_fast,
        'spm_slow': spm_slow,
        'frequency_slow': freq_slow,
        'duration_slow': dur_slow,
        'amplitude_ptp_slow': amp_ptp_slow,
        'amplitude_hilbert_slow': amp_hil_slow
    }


def calc_mean_std_p(x, y):
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ranksums.html#scipy.stats.ranksums
    from scipy.stats import ttest_ind
    from scipy.stats import ranksums
    t, p_tt = ttest_ind(x, y, equal_var=False)
    z, p_wc = ranksums(x, y)
    return np.mean(x), np.std(x), np.mean(y), np.std(y), p_tt, p_wc


def calc_metrics(bp_spindles_dfs, hc_spindles_dfs, characteristics):
    ch_keys = list(sorted(bp_spindles_dfs.keys()))
    assert all(c in ch_keys for c in relevant_channels), f'Not all relevant channels are present in {ch_keys}'
    row_fast = {}
    row_slow = {}
    for m in ['Density', 'Frequency', 'Duration', 'Amplitude PTP', 'Amplitude Hilbert']:
        row_fast[m] = {ch: [] for ch in relevant_channels}
        row_slow[m] = {ch: [] for ch in relevant_channels}

    for i, key in enumerate(relevant_channels):
        bp_metrics = calc_metrics_cohort_channel(bp_spindles_dfs, characteristics, key, 'bp')
        hc_metrics = calc_metrics_cohort_channel(hc_spindles_dfs, characteristics, key, 'hc')

        for m in ['Density', 'Frequency', 'Duration', 'Amplitude PTP', 'Amplitude Hilbert']:
            if m == 'Density':
                m_l = 'spm'
            else:
                m_l = m.lower().replace(' ', '_')
            row_fast[m][key] = calc_mean_std_p(hc_metrics[f'{m_l}_fast'], bp_metrics[f'{m_l}_fast'])
            row_slow[m][key] = calc_mean_std_p(hc_metrics[f'{m_l}_slow'], bp_metrics[f'{m_l}_slow'])

    # format metrics: {metric: {channel: { cohort: (mean, std)}}}
    headers = ['metric', 'channel', 'HC', 'BP', 'p-value t-test', 'p-value ranksum']
    table_fast = []
    for metric in row_fast:
        for i, channel in enumerate(row_fast[metric]):
            ch_met = row_fast[metric][channel]
            table_fast.append(
                [metric if i == 0 else '', channel, f"{ch_met[0]:.2f} ({ch_met[1]:.2f})",
                 f"{ch_met[2]:.2f} ({ch_met[3]:.2f})", f"{ch_met[4]:.3f}", f"{ch_met[5]:.3f}"])

    print("Fast spindles")
    print(tabulate.tabulate(table_fast, headers=headers, tablefmt='github'))
    print()

    table_slow = []
    for metric in row_slow:
        for i, channel in enumerate(row_slow[metric]):
            ch_met = row_slow[metric][channel]
            table_slow.append(
                [metric if i == 0 else '', channel, f"{ch_met[0]:.2f} ({ch_met[1]:.2f})",
                 f"{ch_met[2]:.2f} ({ch_met[3]:.2f})", f"{ch_met[4]:.3f}", f"{ch_met[5]:.3f}"])

    print("Slow spindles")
    print(tabulate.tabulate(table_slow, headers=headers, tablefmt='github'))

    print()
    print(tabulate.tabulate(table_fast, headers=headers, tablefmt='latex'))
    print()
    print(tabulate.tabulate(table_slow, headers=headers, tablefmt='latex'))

# spindles (python sumo, aggregated over 4 EEG channels, no additional shift, filtered by RSN sleep stages)

In [5]:
# load sleepstages
pred_path = "../../output/annotations/nikitin/pred_sleep_stages/"

pred_sleep_stages = {}
for file in os.listdir(pred_path):
    if file.endswith('_sleepstages.txt'):
        s_stages = np.loadtxt(pred_path + file, dtype=str, delimiter=" ")
        s_stages = ['N2' if s == '2' else s for s in s_stages]
        pred_sleep_stages[file.split("_")[0]] = s_stages

pred_char = {s_id: sum([1 for s in stages if s == 'N2']) / 2 for s_id, stages in pred_sleep_stages.items()}

In [6]:
# load spindles
pred_path = "../../output/annotations/nikitin/pred_spindles_w_pred_stages/"
pred_files = [f for f in os.listdir(pred_path) if f.endswith('.npz')]

f_ids = [f.split("_")[0] for f in pred_files]

# create empty Dataframe for each channel
BP_pred = {name: pd.DataFrame() for name in relevant_channels}
HC_pred = {name: pd.DataFrame() for name in relevant_channels}
columns = ['start', 'end', 'frequency', 'amplitude ptp', 'amplitude hilbert']

# add contents of each file to dataframes
for f_id, f in zip(f_ids, pred_files):
    pred_data = np.load(pred_path + f)
    for ch in [f'EEG {c}' for c in relevant_channels]:
        df = pd.DataFrame(pred_data[f"{f_id}_{ch}_fold-0_agg"], columns=columns)
        df["file"] = f_id

        df = overlap(df, pred_sleep_stages[f_id])
        df = merge_and_filter_spindles(df)

        if cohort_mapping[f_id] == 'bp':
            BP_pred[ch.split(" ", 1)[1]] = pd.concat([BP_pred[ch.split(" ", 1)[1]], df], axis=0,
                                                     ignore_index=True)
        else:
            HC_pred[ch.split(" ", 1)[1]] = pd.concat([HC_pred[ch.split(" ", 1)[1]], df], axis=0,
                                                     ignore_index=True)

Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filtered 0 spindles
Merged 0 spindles, filter

In [7]:
# plot
calc_metrics(BP_pred, HC_pred, pred_char)

Fast spindles
| metric            | channel   | HC            | BP            |   p-value t-test |   p-value ranksum |
|-------------------|-----------|---------------|---------------|------------------|-------------------|
| Density           | F3-A2     | 3.43 (1.70)   | 2.15 (1.46)   |            0.008 |             0.006 |
|                   | F4-A1     | 3.39 (1.68)   | 2.16 (1.39)   |            0.009 |             0.009 |
|                   | C3-A2     | 4.37 (2.01)   | 2.80 (1.98)   |            0.011 |             0.011 |
|                   | C4-A1     | 4.37 (2.00)   | 2.77 (1.91)   |            0.008 |             0.01  |
| Frequency         | F3-A2     | 13.91 (0.14)  | 13.91 (0.18)  |            0.99  |             0.765 |
|                   | F4-A1     | 13.92 (0.13)  | 13.92 (0.12)  |            0.983 |             0.781 |
|                   | C3-A2     | 13.98 (0.22)  | 13.90 (0.18)  |            0.184 |             0.403 |
|                   | C4-A1     | 13.98 (