In [None]:
import numpy as np


def compute_psd_de(data, window, fs, f_bands=None):
    """
    compute  DE (differential entropy) and PSD (power spectral density) features

    input: 
	data-[n, m] n channels, m points of each time course,  
	window-integer, window lens of each segment in seconds, such as 1s
	fs-integer, frequency of singal sampling rate, such as 200Hz
	optional  f_bands, default delta, theta, aplha, beta, gamma

    output:
        psd,de  [bands, channels, samples]
    """
    # segment the data
    channels, lens = np.shape(data)
    segment_lens = window * fs
    samples = lens // segment_lens
    data = data[:, :samples*segment_lens]
    data = data.reshape(channels, samples, -1)  

    if f_bands == None:
        f_bands = [(1,4), (4,7), (8,13), (14,29), (30, 47)] # delta, theta, aplha, beta, gamma
    
    # compute the magnitudes 
    fxx = np.fft.fft(data)
    timestep = 1 / fs
    f = np.fft.fftfreq(segment_lens, timestep)[:segment_lens//2]  # only use the positive frequency
    fxx = np.abs(fxx[:,:,:segment_lens//2])

    psd_bands = []
    de_bands = []
    for f_band1, f_band2 in f_bands:
        f_mask = (f >= f_band1) & (f <= f_band2)
        data_bands = fxx[:, :, f_mask]

        # psd = np.sum(data_bands**2 / (segment_lens//2), axis=-1)  # same with scipy.signal.periodogram * fs, divide the number of total frequency bands like 100
        psd = np.mean(data_bands**2, axis=-1)  # only divide the number of frequency band1-band2 like 1-4, maybe 4 points with window==1s or 7 points with window==2s
        de = np.log2(2*np.pi*np.exp(1)*data_bands.var(axis=-1)) / 2
        
        psd_bands.append(psd)
        de_bands.append(de)
    psd = np.stack(psd_bands)
    de = np.stack(de_bands)
    return psd, de

# make examples and test
data = np.random.rand(1, 4000)
fs = 200
window = 2
psd, de = compute_psd_de(data, window, fs)
print(psd.shape, de.shape)

# # test with periodogram 
# from scipy.signal import periodogram
# channels, lens = np.shape(data)
# segment_lens = window * fs
# samples = lens // segment_lens
# data = data[:, :samples*segment_lens]
# data = data.reshape(channels, samples, -1) 
# f, psd_scipy = periodogram(data, fs)
# f_mask = (f >= 1) & (f <= 4)
# psd_scipy = np.sum(psd_scipy[:, :, f_mask], axis=-1)
# print(psd[0], psd_scipy*fs)

(5, 1, 10) (5, 1, 10)
