In [1]:
import numpy as np
import matplotlib.pyplot as plt
from utils import print_graphs
from sklearn.preprocessing import LabelEncoder



dataset1 = np.load("../../datasets/beta/data1.npy")
# dataset1 = np.load("data1SNR.npy")
dataset2 = np.load("../../datasets/beta/data2.npy")
print(dataset1)
print(dataset1.shape)



labels1 = np.load("../../datasets/beta/labels1.npy")
labels2 = np.load("../../datasets/beta/labels2.npy")
print(labels1)
print(labels1.shape)

def print_graphs(data1, data2, FS=250):
    fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 10))

    # Plot time domain for dataset1
    for i in range(data1.shape[0]):
        axes[0, 0].plot(data1[i, :])
    axes[0, 0].set_title('Time Domain - Dataset 1')

    # Plot time domain for dataset2
    for i in range(data2.shape[0]):
        axes[0, 1].plot(data2[i, :])
    axes[0, 1].set_title('Time Domain - Dataset 2')

    # Plot frequency domain for dataset1
    for i in range(data1.shape[0]):
        axes[1, 0].psd(data1[i, :], Fs=FS)
    axes[1, 0].set_title('Frequency Domain - Dataset 1')

    # Plot frequency domain for dataset2
    for i in range(data2.shape[0]):
        axes[1, 1].psd(data2[i, :], Fs=FS)
    axes[1, 1].set_title('Frequency Domain - Dataset 2')

    # Plot spectrogram for dataset1
    for i in range(data1.shape[0]):
        axes[2, 0].specgram(data1[i, :], Fs=FS)
    axes[2, 0].set_title('Spectrogram - Dataset 1')

    # Plot spectrogram for dataset2
    for i in range(data2.shape[0]):
        axes[2, 1].specgram(data2[i, :], Fs=FS)
    axes[2, 1].set_title('Spectrogram - Dataset 2')

    plt.tight_layout()
    plt.show()


# print_graphs(dataset1, dataset2)

    


[[[ -4.69486829  -6.52983055  -3.95334362 ...   3.72489206   4.25573282
     5.82455553]
  [ -5.33445041  -6.76139056  -4.07902495 ...   3.18477989   4.20659904
     6.26154253]
  [ -4.77436454  -6.48497251  -3.54051562 ...   3.52458628   3.78675351
     5.95949217]
  ...
  [ -0.46235387   0.77656142  -0.48785122 ...   7.12372488  -1.62361477
    -4.63128411]
  [ -7.28200609  -1.29611109   0.12011583 ...  12.52936475  -2.0870292
    -6.04624113]
  [ -7.89620338  -1.50433314  -0.05607805 ...  12.51303487  -1.61776013
    -4.76391951]]

 [[ -0.96306876   4.9064513    8.69057364 ...   3.41482295   2.21397367
     1.88819077]
  [ -0.89004462   5.18518717   9.84095634 ...   4.14281332   2.2715642
     1.55857415]
  [ -0.81564931   4.89107417   9.81066398 ...   4.64164113   2.49749588
     1.67278824]
  ...
  [  0.43148382  -5.46596552  -4.85513581 ...  -7.51398549  -8.20400771
    -6.37305748]
  [ -4.64719463  -6.41642149  -1.78199546 ...  -4.65240932  -9.67715177
    -6.9097699 ]
  [ -4.66

In [2]:
import mne

n_channels = 64

sfreq = 250
## ch names geradas
ch_ideal = ["PZ", "PO3", "PO5", "PO4", "PO6", "POZ", "O1", "OZ", "O2"]
ch_names = ["FP1", "FPZ", "FP2", "AF3", "AF4", "F7", "F5", "F3", "F1", "FZ", "F2", "F4", "F6", "F8", "FT7", "FC5", "FC3", "FC1", "FCZ", "FC2", "FC4", "FC6", "FT8", "T7", "C5", "C3", "C1", "CZ", "C2", "C4", "C6", "T8", "M1", "TP7", "CP5", "CP3", "CP1", "CPZ", "CP2", "CP4", "CP6", "TP8", "M2", "CB1", "CB2","P7", "P5", "P3", "P1", "PZ", "P2", "P4", "P6", "P8", "PO7", "PO5", "PO3", "POZ", "PO4", "PO6", "PO8", "O1", "OZ", "O2"]
ch_types = ['eeg'] * n_channels
info = mne.create_info(ch_names, sfreq=sfreq, ch_types=ch_types)


def create_mne_epochs(data, labels, ch_names, filename=""):
    sfreq=250
    n_channels = len(ch_names)
    ch_types = ['eeg'] * n_channels
    info = mne.create_info(ch_names, sfreq=sfreq, ch_types=ch_types)

    le = LabelEncoder()
    events = np.column_stack((
        np.arange(len(labels)),
        np.zeros(len(labels), dtype=int),
        le.fit_transform(labels)
    ))

    event_dict = {str(value): index for index, value in enumerate(sorted(set(labels)))}

    mne_data = mne.EpochsArray(data, info, events, event_id=event_dict)


    drop_channels = [ch for ch in mne_data.info['ch_names'] if ch not in ch_ideal]

    mne_data = mne_data.drop_channels(drop_channels)

    freq_low=7
    freq_high=17
    filtered_mne_data = mne_data.filter(freq_low, freq_high)
    
    filtered_mne_data.save(filename, overwrite=True)

    return mne_data



mne_data1 = create_mne_epochs(dataset1, labels1, ch_names, "../beta_epo1.fif")

mne_data2 = create_mne_epochs(dataset2, labels2, ch_names,"../beta_epo2.fif")



#convert mnes to np and plot
mne_data1_np = mne_data1.get_data()
mne_data2_np = mne_data2.get_data()

# print_graphs(mne_data1_np, mne_data2_np)





Not setting metadata
160 matching events found
No baseline correction applied
0 projection items activated
Setting up band-pass filter from 7 - 17 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 17.00 Hz
- Upper transition bandwidth: 4.25 Hz (-6 dB cutoff frequency: 19.12 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 881 tasks      | elapsed:    0.2s


Overwriting existing file.
Not setting metadata
160 matching events found
No baseline correction applied
0 projection items activated


[Parallel(n_jobs=1)]: Done 1151 tasks      | elapsed:    0.2s
  filtered_mne_data.save(filename, overwrite=True)


Setting up band-pass filter from 7 - 17 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 7.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 6.00 Hz)
- Upper passband edge: 17.00 Hz
- Upper transition bandwidth: 4.25 Hz (-6 dB cutoff frequency: 19.12 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 881 tasks      | elapsed:    0.2s


Overwriting existing file.


[Parallel(n_jobs=1)]: Done 1151 tasks      | elapsed:    0.2s
  filtered_mne_data.save(filename, overwrite=True)


In [3]:


# def plot_psd(mne_data, ncols=8):
#     n_epochs = len(mne_data)
#     nrows = (n_epochs + ncols - 1) // ncols  

#     fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 3 * nrows))

#     for i, ax in enumerate(axes.flatten()):
#         if i < n_epochs:
#             epoch = mne_data[i]
#             view = epoch.compute_psd(method='welch', fmin=7, fmax=17, verbose=False)
#             view.plot(show=False, axes=ax)
#             ax.set_title(f'Epoch {i+1}')
#             ax.axvline(x=float(list(epoch.event_id)[0]), linestyle='--', color='green')

#     plt.tight_layout()
#     plt.show()



### Epochs Dataset 3

In [4]:
# plot_psd(mne_data1)

### Epochs Dataset 4

In [5]:
# plot_psd(mne_data2)