In [1]:
from glob import glob
from os.path import join, basename

import pandas as pd
import numpy as np
from biosppy.signals.tools import welch_spectrum
import seaborn as sns
import matplotlib.pyplot as plt

In [6]:
common_path = "/Volumes/MMIS-Saraiv/Datasets/Izmir/EC/as_given_denoised_2s-epoched"
#common_path = "/Volumes/MMIS-Saraiv/Datasets/Sapienza/denoised_txt_epochs_matlab"
out_path = '/Volumes/MMIS-Saraiv/Datasets/Izmir/iaf_tf_my-boudaries.csv'
sf = 128.0

In [7]:
all_subjects = glob(join(common_path, '*'))
all_subjects

In [8]:
res = pd.DataFrame(columns=['Subject', 'IAF', 'TF'])

In [9]:
MULTIPLIER = 2
for subject_path in all_subjects:
    subject = basename(subject_path).split('-')[0].split(' ')[1]
    #subject = basename(subject_path).split('PARTICIPANT')[1]
    print("Subject Code", subject)
    all_files = glob(join(subject_path, '*.txt'))
    power = []
    freqs = None
    # go by epoch
    for file in all_files:
        data = pd.read_csv(file, sep='\t', dtype=float)  # time x channels
        # select electrodes O1 (18), Oz (19), O2 (20)
        freqs1, power1 = welch_spectrum(data.iloc[:, 17], sampling_rate=sf, decibel=True, size=int(sf*MULTIPLIER-1))
        freqs2, power2 = welch_spectrum(data.iloc[:, 18], sampling_rate=sf, decibel=True, size=int(sf*MULTIPLIER-1))
        freqs3, power3 = welch_spectrum(data.iloc[:, 19], sampling_rate=sf, decibel=True, size=int(sf*MULTIPLIER-1))
        assert np.all(freqs1 == freqs2) and np.all(freqs2 == freqs3)
        # Average power of the three electrodes
        power.append((power1 + power2 + power3) / 3)
        freqs = freqs1
    
    # average power across epochs
    power_avg = np.mean(np.array(power), axis=0)
    
    # find IAF between 8 and 14 Hz
    iaf = freqs[np.argmax(power_avg[8*MULTIPLIER:12*MULTIPLIER]) + 8*MULTIPLIER]
    iaf = round(iaf, 1)
    print("IAF", iaf)
    # find TF between 3 and 8 Hz
    tf = freqs[np.argmin(power_avg[4*MULTIPLIER:8*MULTIPLIER]) + 4*MULTIPLIER]
    tf = round(tf, 1)
    print("TF", tf)
    
    # plot all psds
    """
    sns.set_palette("husl")
    for p in power:
        sns.lineplot(x=freqs[:46], y=p[:46], alpha=0.05, color='orange')
    sns.lineplot(x=freqs[:46], y=power_avg[:46], linewidth=3, color='black')
    plt.axvline(iaf, color='red', linestyle='--')
    plt.axvline(tf, color='blue', linestyle='--')
    sns.despine()
    plt.title("Subject Code {}".format(subject))
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Power (dB)")
    plt.show()
    """
    
    res = res.append({'Subject': subject, 'IAF': iaf, 'TF': tf}, ignore_index=True)

In [10]:
res.index = res['Subject']
res = res.drop(columns=['Subject'])
res

In [11]:
# Save to CSV
res.to_csv(out_path)