In [1]:
import numpy as np
import os
import mne
from mne.preprocessing import ICA
from mne import pick_types
from mne.io import read_raw_eeglab
from mne.time_frequency import psd_array_welch
from mne.time_frequency import tfr_morlet
import torch
import multiprocessing
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
import json 
from tqdm import tqdm
import time
from torch.autograd import Variable
import copy
import pandas as pd
import logging

logging.getLogger('mne').setLevel(logging.WARNING)

num_sub = 20
num_sess = 12
use_gpu = 0
use_mps = 1
cuda_device = 0

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

create_dir('prepro_data')

for i in range(1, num_sub+1):
    for j in range(1, num_sess+1):
        data_path = f'../ds003774/sub-0{i//10}{i%10}/ses-{j//10}{j%10}/eeg/sub-0{i//10}{i%10}_ses-{j//10}{j%10}_task-MusicListening_run-{j}_eeg.set'
        raw = read_raw_eeglab(data_path, preload=True)

        # High-pass filter at 0.2 Hz
        raw.filter(l_freq=0.2, h_freq=None)

        # Remove 50 Hz line noise
        raw.notch_filter(freqs=[50])

        # Downsample the data to 256 Hz
        raw.resample(256)

        # Extract EEG data and calculate PSD using Welch's method
        picks = pick_types(raw.info, eeg=True, exclude=[])
        data, times = raw.get_data(picks=picks, return_times=True)
        psds, freqs = psd_array_welch(data, sfreq=raw.info['sfreq'], fmin=2, fmax=40)

        # Calculate the mean and threshold for PSD
        psd_mean = psds.mean(axis=-1)
        psd_threshold = 3 * np.std(psds, axis=-1)  # Calculate the standard deviation along the frequency axis

        # Identify bad channels based on spectral criteria
        bad_channels = [raw.ch_names[p] for p in picks if psd_mean[p] > psd_threshold[p]]
        raw.info['bads'] += bad_channels
        raw.interpolate_bads()

        # Artifact rejection using ICA
        ica = ICA(n_components=20, random_state=99, method='fastica')
        ica.fit(raw)
        ica.apply(raw)

        # Re-reference the data to the average
        raw.set_eeg_reference('average', projection=True)

        # Save preprocessed data
        pre_path = f'prepro_data/pre_eeg_sub-0{i//10}{i%10}_ses-{j//10}{j%10}.fif'
        raw.save(pre_path, overwrite=True)

In [None]:
powers = []

for i in range(1, num_sub+1):
    for j in range(1, num_sess+1):
        pre_path = f'prepro_data/pre_eeg_sub-0{i//10}{i%10}_ses-{j//10}{j%10}.fif'
        pre = mne.io.read_raw_fif(pre_path, preload=True)

        # Define frequencies of interest (log-spaced)
        frequencies = np.logspace(np.log10(1), np.log10(40), num=40)
        n_cycles = frequencies / 2.  # Different number of cycle per frequency
        # Compute time-frequency representation with Morlet wavelets
        power = tfr_morlet(pre, freqs=frequencies, n_cycles=n_cycles, use_fft=True, return_itc=False, decim=3, n_jobs=4)
        
        powers.append(power.get_data())

powers = np.array(powers)

# Pick channel to plot
channel_index = 100  # change this to the index of the channel you want to visualize

# Plotting the spectrogram
power.plot([channel_index], baseline=(-0.5, 0), mode='logratio', title=power.ch_names[channel_index], tmin=30, tmax=40)

In [96]:
print(powers.shape)
# Flatten spectrogram (columns are freq bins * time bins; 129 rows for each of the channels)

(240, 129, 40, 427)


In [27]:
class RawDataset(Dataset):
    # Bin and hot encode our labels for our targets
    # Bins: [high familiarity & high enjoyment, 
    #        high familiarity & low enjoyment, 
    #        low familiarity & high enjoyment, 
    #        low familiarity & low enjoyment]
    # High is >= 2.5
    # Low is < 2.5
    def hot_encode_target(self, row):
        # HEHF
        if row[2] >= 2.5 and row[3] >= 2.5:
            return np.array([1, 0, 0, 0]), 'HEHF'
        # HELF
        elif row[2] >= 2.5 and row[3] < 2.5:
            return np.array([0, 1, 0, 0]), 'HELF'
        # LEHF
        elif row[2] < 2.5 and row[3] >= 2.5:
            return np.array([0, 0, 1, 0]), 'LEHF'
        # LELF
        else:
            return np.array([0, 0, 0, 1]), 'LELF'
        
    def __init__(self, data_dir, behav_file, transform=None, target_transform=None):
        self.data_dir = data_dir
        self.behav_file = behav_file
        self.transform = transform
        self.target_transform = target_transform
        self.data_dict = {}

        eeg_label_dict = {}
        class_counts = {}

        tags = ['HEHF', 'HELF', 'LEHF', 'LELF']
        
        for tag in tags:
            class_counts[tag] = 0

        df = pd.read_csv('behav.csv')
        behav_data = df.values

        total_files = 0
        for entry in os.listdir(self.data_dir):
            # Join the directory path with the entry name to get full file path
            full_path = os.path.join(self.data_dir, entry)
            if os.path.isfile(full_path):
                total_files += 1
    
        progress_bar = tqdm(total=len(behav_data))

        id = 0
        for row in behav_data:
            existing_files = set(os.listdir(self.data_dir))

            data_path = f'pre_eeg_sub-0{row[0]//10}{row[0]%10}_ses-{row[1]//10}{row[1]%10}.fif'
            if data_path in existing_files:
                data_path = os.path.join(self.data_dir, data_path)
                full_data = mne.io.read_raw_fif(data_path, preload=False)

                # Splitting full EEG recording into 5 second slices
                num_intervals = full_data.get_data().shape[1] // 1250
                for i in range(num_intervals):
                    slice = [data_path, i*1250, i*1250+1250]
                    
                    target, tag_string = self.hot_encode_target(row)

                    self.data_dict[id] = slice

                    eeg_label_dict[id] = target
                    class_counts[tag_string] += 1
                    id += 1

            progress_bar.update(1)
        
        progress_bar.close()

        self.items = list(eeg_label_dict.items())
        print('Class counts: ', class_counts)

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        label = self.items[idx][1]
        eeg_index = self.data_dict[self.items[idx][0]]

        full_data = mne.io.read_raw_fif(eeg_index[0], preload=False)
        eeg_data = full_data.get_data()[:, eeg_index[1] : eeg_index[2]]

        if self.transform:
            eeg_data = self.transform(eeg_data)
        if self.target_transform:
            label = self.target_transform(label)

        return eeg_data, label

In [28]:
if use_gpu:
    torch.cuda.set_device(cuda_device)

if use_mps:
   mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

dset = RawDataset('prepro_data', 'behav.csv', data_transforms['train'])

  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_path, preload=False)
  full_data = mne.io.read_raw_fif(data_p

Class counts:  {'HEHF': 2997, 'HELF': 339, 'LEHF': 1906, 'LELF': 978}





In [30]:
data, target = dset.__getitem__(0)
print(data.shape)
print(target)


torch.Size([1, 129, 1250])
[0 0 0 1]


  full_data = mne.io.read_raw_fif(eeg_index[0], preload=False)
