In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt

In [None]:
hand = 0
subject = 1
data_dir = "Z:\Desktop\\"

In [None]:
if hand == 0:
    file_name = "sub_{}_left.npz".format(str(subject))
    print("processing file :", file_name)

else:
    file_name = "sub_{}_right.npz".format(str(subject))
    print("processing file :", file_name)


dataset = np.load(os.path.join(data_dir, file_name))

X = dataset["X_train"]
X = np.swapaxes(X, 2, -1).squeeze()
del dataset

In [None]:
sf = 250
data = X[100, 20, :]
print(data.shape)
time = np.arange(data.size) / sf

In [None]:
 # Plot the signal
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
plt.plot(time, data, lw=1.5, color='k')
plt.xlabel('Time (seconds)')
plt.ylabel('Voltage')
plt.xlim([time.min(), time.max()])
plt.title('N3 sleep EEG data (F3)')
plt.show()

In [None]:
from scipy.signal import welch

nperseg = 250/2
noverlap = None # nperseg // 2
freqs, psd = welch(data, sf, nperseg=nperseg, noverlap=noverlap)  # win = Hann  

print(freqs)
print(psd)



In [None]:
plt.figure(figsize=(8, 4))
plt.plot(freqs, psd, color='k', lw=2)
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power spectral density (V^2 / Hz)')
plt.ylim([0, psd.max() * 1.1])
plt.title("Welch's periodogram")
plt.xlim([0, freqs.max()])
plt.show()

In [None]:
for low, high in [(1, 4), (4, 8), (8, 10), (10, 13), (13, 30),
                  (30, 70)]:
    # Find intersecting values in frequency vector
    idx_delta = np.logical_and(freqs >= low, freqs <= high)
    print("logical map for bands ({},{}):\n {}".format(low, high, idx_delta))

    plt.figure(figsize=(7, 4))
    plt.plot(freqs, psd, lw=2, color='k')
    plt.fill_between(freqs, psd, where=idx_delta, color='skyblue')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power spectral density (uV^2 / Hz)')
    plt.xlim([0, 70])
    plt.ylim([0, psd.max() * 1.1])
    plt.title("Welch's periodogram")
    plt.show()


In [None]:
from scipy.integrate import simps
for low, high in [(1, 4), (4, 8), (8, 10), (10, 13), (13, 30),
                  (30, 70)]:
    print("processing bands (low, high) : ({},{})".format(low, high))
    # Find intersecting values in frequency vector
    idx_delta = np.logical_and(freqs >= low, freqs <= high)
      # Frequency resolution
    freq_res = freqs[1] - freqs[0]  # = 1 / 4 = 0.25

    # Compute the absolute power by approximating the area under the curve
    power = simps(psd[idx_delta], dx=freq_res)
    print('Absolute power: {:.4f} uV^2'.format(power))
    
    total_power = simps(psd, dx=freq_res)
    rel_power = power / total_power
    
    print('Relative power: {:.4f}'.format(rel_power))
    
    

In [None]:
def bandpower_1d(data, sf, band, nperseg=250, relative=False):
    """
        Compute the average power of the signal x in a specific frequency band.
        https://raphaelvallat.com/bandpower.html
    Args:
        data (1d-array):
            Input signal in the time-domain.
        sf (float):
            Sampling frequency of the data.
        band (list):
            Lower and upper frequencies of the band of interest.
        window_sec (float):
            Length of each window in seconds.
            If None, window_sec = (1 / min(band)) * 2
        relative (boolean):
            If True, return the relative power (= divided by the total power of the signal).
            If False (default), return the absolute power.

    Returns:
        bp (float):
            Absolute or relative band power.
    """

    # band = np.asarray(band)
    low, high = band

    # Compute the modified periodogram (Welch)
    freqs, psd = welch(data, sf, nperseg=nperseg)

    # Frequency resolution
    freq_res = freqs[1] - freqs[0]

    # Find closest indices of band in frequency vector
    idx_band = np.logical_and(freqs >= low, freqs <= high)

    # Integral approximation of the spectrum using Simpson's rule.
    bp = simps(psd[idx_band], dx=freq_res)

    if relative:
        bp /= simps(psd, dx=freq_res)
    return bp


# def bandpower(x, fs, bands, window_sec=None, relative=True):
#         # x shape [n_epoch, n_ channle, n_times]
#         # bandpower [n_epoch, n_channel, 1]
#         n_epoch, n_channel, _ = x.shape
#         bp = np.zeros((n_epoch, n_channel, 1))
#         for idx, b in enumerate(bands):
#             print(b)
#             print(idx)
#             for epoch in range(n_epoch):
#                 for channel in range(n_channel):
#                     bp[epoch, channel] = bandpower_1d(x[epoch, channel, idx], fs, [fmax, fmin],
#                                                         window_sec=window_sec, relative=relative)
#
#         return bp


def bandpower(x, fs, fmin, fmax, nperseg=250, relative=True):
    """
    Compute the average power of the multi-channel signal x in a specific frequency band.
    Args:
        x (nd-array): [n_epoch, n_channel, n_times]
           The epoched input data.
        fs (float):
            Sampling frequency of the data.
        fmin (int): Low-band frequency.
        fmax (int): High-band frequency.
        window_sec (float):
            Length of each window in seconds.
            If None, window_sec = (1 / min(band)) * 2
        relative (boolean):
            If True, return the relative power (= divided by the total power of the signal).
            If False (default), return the absolute power.

    Returns:
        bp (nd-array): [n_epoch, n_channel, 1]
            Absolute or relative band power.
    """
    n_epoch, n_channel, _ = x.shape

    bp = np.zeros((n_epoch, n_channel, 1))
    for epoch in range(n_epoch):
        for channel in range(n_channel):
            bp[epoch, channel] = bandpower_1d(
                x[epoch, channel, :],
                fs,
                [fmin, fmax],
                nperseg=nperseg,
                relative=relative,
            )

    return bp


def bandpower_multi(x, fs, bands,  nperseg=250, relative=True):
    """
    Compute the average power of the multi-channel signal x in multiple frequency bands.
    Args:
        x (nd-array): [n_epoch, n_channel, n_times]
           The epoched input data.
        fs (float):
            Sampling frequency of the data.
        bands (list): list of bands to compute the bandpower. echa band is a tuple of fmin and fmax.
        window_sec (float):
            Length of each window in seconds.
            If None, window_sec = (1 / min(band)) * 2
        relative (boolean):
            If True, return the relative power (= divided by the total power of the signal).
            If False (default), return the absolute power.

    Returns:
        bp (nd-array): [n_epoch, n_channel, n_bands]
            Absolute or relative bands power.
    """
    n_epoch, n_channel, _ = x.shape
    bp_list = []
    for idx, band in enumerate(bands):
        fmin, fmax = band
        bp_list.append(
            bandpower(
                x, fs, fmin, fmax,  nperseg=nperseg, relative=relative
            )
        )

    bp = np.concatenate(bp_list, -1)

    return bp


In [None]:
### all channels for 1 epoch
fs = 250
nperseg = 250/2
data = X[1000, :, :].squeeze()
print(data.shape)
data = np.expand_dims(data, 0)
bands = [(1, 4), (4, 8), (8, 10), (10, 13), (13, 30), (30, 70)]
bp = bandpower_multi(data, fs, bands,  nperseg=nperseg, relative=True).squeeze()
print(bp.shape)

In [None]:
fig = plt.figure(figsize=[10, 4])
cmap = "RdBu_r"
im = plt.pcolormesh(bp, cmap=cmap)
fig.colorbar(im)
plt.title("RPS")
plt.ylabel("Channels")
plt.xlabel("Bands")
plt.locator_params(axis="y", nbins=5)
plt.xticks(
    [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
    ["delta", "tetha", "low-alpha", "high-alpha", "beta", "low-gamma"],
)
plt.show()