In [1]:
import numpy as np
import os
import mne
from mne.preprocessing import ICA
from mne import pick_types
from mne.io import read_raw_eeglab
from mne.time_frequency import psd_array_welch
from mne.time_frequency import tfr_morlet
import torch
import multiprocessing
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, WeightedRandomSampler
import json 
from tqdm import tqdm
import time
from torch.autograd import Variable
import copy
import pandas as pd
import logging

logging.getLogger('mne').setLevel(logging.WARNING)

num_sub = 20
num_sess = 12
use_gpu = 0
use_mps = 1
cuda_device = 0
train_dir = '../prepro_data/train'
val_dir = '../prepro_data/val'
train_behav_file = 'train_behav.csv'
val_behav_file = 'val_behav.csv'
base_lr = 0.0001
decay_weight = 0.1 
epoch_decay = 5 
b_size = 3
n_epochs = 10

### Preprocessing the data

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

create_dir(train_dir)
create_dir(val_dir)

for i in range(1, num_sub+1):
    for j in range(1, num_sess+1):
        data_path = f'../ds003774/sub-0{i//10}{i%10}/ses-{j//10}{j%10}/eeg/sub-0{i//10}{i%10}_ses-{j//10}{j%10}_task-MusicListening_run-{j}_eeg.set'
        raw = read_raw_eeglab(data_path, preload=True)

        # High-pass filter at 0.2 Hz
        raw.filter(l_freq=0.2, h_freq=None)

        # Remove 50 Hz line noise
        raw.notch_filter(freqs=[50])

        # Downsample the data to 256 Hz
        raw.resample(256)

        # Extract EEG data and calculate PSD using Welch's method
        picks = pick_types(raw.info, eeg=True, exclude=[])
        data, times = raw.get_data(picks=picks, return_times=True)
        psds, freqs = psd_array_welch(data, sfreq=raw.info['sfreq'], fmin=2, fmax=40)

        # Calculate the mean and threshold for PSD
        psd_mean = psds.mean(axis=-1)
        psd_threshold = 3 * np.std(psds, axis=-1)  # Calculate the standard deviation along the frequency axis

        # Identify bad channels based on spectral criteria
        bad_channels = [raw.ch_names[p] for p in picks if psd_mean[p] > psd_threshold[p]]
        raw.info['bads'] += bad_channels
        raw.interpolate_bads()

        # Artifact rejection using ICA
        ica = ICA(n_components=20, random_state=99, method='fastica')
        ica.fit(raw)
        ica.apply(raw)

        # Re-reference the data to the average
        raw.set_eeg_reference('average', projection=True)

        # Save preprocessed data
        pre_path = f'pre_eeg_sub-0{i//10}{i%10}_ses-{j//10}{j%10}_eeg.fif'
        if i <= 16:
            pre_path = os.path.join(train_dir, pre_path)
        else:
            pre_path = os.path.join(val_dir, pre_path)
        raw.save(pre_path, overwrite=True)

### Dataset class for the V vs. Time data

In [2]:
class RawDataset(Dataset):
    # Bin and hot encode our labels for our targets
    # Bins: [high familiarity & high enjoyment, 
    #        high familiarity & low enjoyment, 
    #        low familiarity & high enjoyment, 
    #        low familiarity & low enjoyment]
    # High is >= 2.5
    # Low is < 2.5
    def get_target(self, row):
        # HEHF
        if row[2] >= 2.5 and row[3] >= 2.5:
            return 0, 'HEHF'
        # HELF
        elif row[2] >= 2.5 and row[3] < 2.5:
            return 1, 'HELF'
        # LEHF
        elif row[2] < 2.5 and row[3] >= 2.5:
            return 2, 'LEHF'
        # LELF
        else:
            return 3, 'LELF'
        
    def __init__(self, data_dir, behav_file, transform=None, target_transform=None):
        self.data_dir = data_dir
        self.behav_file = behav_file
        self.transform = transform
        self.target_transform = target_transform
        self.data_dict = {}

        eeg_label_dict = {}
        self.class_counts = {}

        tags = ['HEHF', 'HELF', 'LEHF', 'LELF']
        
        for tag in tags:
            self.class_counts[tag] = 0

        df = pd.read_csv(self.behav_file)
        behav_data = df.values

        total_files = 0
        for entry in os.listdir(self.data_dir):
            # Join the directory path with the entry name to get full file path
            full_path = os.path.join(self.data_dir, entry)
            if os.path.isfile(full_path):
                total_files += 1
    
        progress_bar = tqdm(total=len(behav_data))

        id = 0
        for row in behav_data:
            existing_files = set(os.listdir(self.data_dir))

            data_path = f'pre_eeg_sub-0{row[0]//10}{row[0]%10}_ses-{row[1]//10}{row[1]%10}_eeg.fif'
            if data_path in existing_files:
                data_path = os.path.join(self.data_dir, data_path)
                full_data = mne.io.read_raw_fif(data_path, preload=False)

                # Splitting full EEG recording into 5 second slices
                num_intervals = full_data.get_data().shape[1] // 1250
                for i in range(num_intervals):
                    slice = [data_path, i*1250, i*1250+1250]
                    
                    target, tag_string = self.get_target(row)

                    self.data_dict[id] = slice

                    eeg_label_dict[id] = target
                    self.class_counts[tag_string] += 1
                    id += 1

            progress_bar.update(1)
        
        progress_bar.close()

        self.items = list(eeg_label_dict.items())
        print('Class counts: ', self.class_counts)

    def get_class_counts(self):
        enum_class_count = {}
        i = 0
        for _, count in self.class_counts.items():
            enum_class_count[i] = count
            i += 1
        return enum_class_count
    
    def get_label(self, idx):
        return self.items[idx][1]

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        label = self.items[idx][1]
        eeg_index = self.data_dict[self.items[idx][0]]

        full_data = mne.io.read_raw_fif(eeg_index[0], preload=False)
        eeg_data = full_data.get_data()[:, eeg_index[1] : eeg_index[2]]

        if self.transform:
            eeg_data = self.transform(eeg_data)
        if self.target_transform:
            label = self.target_transform(label)

        return eeg_data[0], label

In [3]:
if use_gpu:
    torch.cuda.set_device(cuda_device)

if use_mps:
   mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

In [4]:
def exp_lr_scheduler(optimizer, epoch, init_lr=base_lr, lr_decay_epoch=epoch_decay):
    """Decay learning rate by a factor of DECAY_WEIGHT every lr_decay_epoch epochs."""
    lr = init_lr * (decay_weight**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

### CNN for the V vs. Time data

In [5]:
class EEGCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(EEGCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(3, 40), padding=(1, 20))  # Preserves time dimension
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(3, 25), padding=(1, 12))  # Preserves time dimension
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(3, 10), padding=(1, 5))  # Preserves time dimension
        self.bn3 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d((2, 4))  # Reduces height by 2, width by 4
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Calculating the output dimensions after pooling:
        # Initial input shape: (1, 129, 1250)
        # After first pool: (16, 65, 312)
        # After second pool: (32, 33, 78)
        # After third pool: (64, 17, 20)
        self.fc1 = nn.Linear(19456, 100)  # 64 channels, height 17, width 20
        self.fc2 = nn.Linear(100, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.leaky_relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

In [6]:
def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=10):
    since = time.time()

    best_model = model
    best_acc = 0.0

    accuracies = {'train': [], 'val': []}
    losses = {'train': [], 'val': []}

    for epoch in range(num_epochs):
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                optimizer = lr_scheduler(optimizer, epoch)
                model.train()
            else:
                model.eval()
        
            running_loss = 0.0
            running_corrects = 0

            counter = 0

            
            for data in dset_loaders[phase]:
                inputs, labels = data

                if use_gpu:
                    try:
                        inputs, labels = Variable(inputs.float().cuda()), Variable(labels.long().cuda())

                    except Exception as e:
                        print("ERROR! here are the inputs and labels before we print the full stack trace:")
                        print(inputs, labels)
                        raise e
                    
                elif use_mps:
                   try:
                      inputs, labels = Variable(inputs.float().to(mps_device)), Variable(labels.long().to(mps_device))

                   except Exception as e:
                      print("ERROR! here are the inputs and labels before we print the full stack trace:")
                      print(inputs, labels)
                      raise e
                
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)

                loss = criterion(outputs, labels)

                print(counter)
                
                if counter%100 == 0:
                    print('Reached batch iteration', counter)

                counter += 1

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                try:
                    running_loss += loss.item()
                    running_corrects += torch.sum(preds == labels.data)
                except:
                    print('unexpected error, could not calculate loss or do a sum.')

            epoch_loss = running_loss / dset_sizes[phase]
            epoch_acc = running_corrects.item() / float(dset_sizes[phase])
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            accuracies[phase].append(epoch_acc)
            losses[phase].append(epoch_loss)

            # deep copy the model
            if phase == 'val':
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model)
                    print('new best accuracy =', best_acc)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('returning and looping back')

    return best_model, accuracies, losses

### Training loop for CNN & V vs. Time data

In [None]:
dsets = {}
dsets['train'] = RawDataset(train_dir, train_behav_file, data_transforms['train'])
dsets['val'] = RawDataset(val_dir, val_behav_file, data_transforms['val'])

dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

dset_loaders = {}
for split in ['train', 'val']:
    targets = np.zeros(dsets[split].__len__())
    for i in range(len(targets)):
        label = dsets[split].get_label(i)
        targets[i] = label
    class_counts = dsets[split].get_class_counts()
    class_weights = {tag: 1.0 / count if count > 0 else 0 for tag, count in class_counts.items()}
    weights = np.array([class_weights[tag] for tag in targets])
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader: ', split)
model_ft = EEGCNN(num_classes=4)

criterion = nn.CrossEntropyLoss()

if use_gpu:
    criterion.cuda()
    model_ft.cuda()

if use_mps:
    criterion.to(mps_device)
    model_ft.to(mps_device)

optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)

model_ft, accuracies, losses = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=n_epochs)

for split in ['train', 'val']:
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])


### Dataset class for spectrogram data

In [7]:
class SpectroDataset(Dataset):
    # Bin and hot encode our labels for our targets
    # Bins: [high familiarity & high enjoyment, 
    #        high familiarity & low enjoyment, 
    #        low familiarity & high enjoyment, 
    #        low familiarity & low enjoyment]
    # High is >= 2.5
    # Low is < 2.5
    def get_target(self, row):
        # HEHF
        if row[2] >= 2.5 and row[3] >= 2.5:
            return 0, 'HEHF'
        # HELF
        elif row[2] >= 2.5 and row[3] < 2.5:
            return 1, 'HELF'
        # LEHF
        elif row[2] < 2.5 and row[3] >= 2.5:
            return 2, 'LEHF'
        # LELF
        else:
            return 3, 'LELF'
        
    def __init__(self, data_dir, behav_file, transform=None, target_transform=None):
        self.data_dir = data_dir
        self.behav_file = behav_file
        self.transform = transform
        self.target_transform = target_transform
        self.data_dict = {}

        eeg_label_dict = {}
        self.class_counts = {}

        tags = ['HEHF', 'HELF', 'LEHF', 'LELF']
        
        for tag in tags:
            self.class_counts[tag] = 0

        df = pd.read_csv(self.behav_file)
        behav_data = df.values

        total_files = 0
        for entry in os.listdir(self.data_dir):
            # Join the directory path with the entry name to get full file path
            full_path = os.path.join(self.data_dir, entry)
            if os.path.isfile(full_path):
                total_files += 1
    
        progress_bar = tqdm(total=len(behav_data))

        id = 0
        for row in behav_data:
            existing_files = set(os.listdir(self.data_dir))

            data_path = f'pre_eeg_sub-0{row[0]//10}{row[0]%10}_ses-{row[1]//10}{row[1]%10}_eeg.fif'
            if data_path in existing_files:
                data_path = os.path.join(self.data_dir, data_path)
                full_data = mne.io.read_raw_fif(data_path, preload=False)

                # Splitting full EEG recording into 5 second slices
                num_intervals = (full_data.get_data().shape[1] // 3) // (84 * 5)
                for i in range(num_intervals):
                    slice = [data_path, i*84, i*84+84]
                    
                    target, tag_string = self.get_target(row)

                    self.data_dict[id] = slice

                    eeg_label_dict[id] = target
                    self.class_counts[tag_string] += 1
                    id += 1

            progress_bar.update(1)
        
        progress_bar.close()

        self.items = list(eeg_label_dict.items())
        print('Class counts: ', self.class_counts)

    def get_class_counts(self):
        enum_class_count = {}
        i = 0
        for _, count in self.class_counts.items():
            enum_class_count[i] = count
            i += 1
        return enum_class_count
    
    def get_label(self, idx):
        return self.items[idx][1]

    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        label = self.items[idx][1]
        eeg_index = self.data_dict[self.items[idx][0]]

        full_data = mne.io.read_raw_fif(eeg_index[0], preload=False)
        # Define frequencies of interest (log-spaced)
        frequencies = np.logspace(np.log10(1), np.log10(40), num=40)
        n_cycles = frequencies / 2.  # Different number of cycle per frequency
        # Compute time-frequency representation with Morlet wavelets
        power = tfr_morlet(full_data, freqs=frequencies, n_cycles=n_cycles, use_fft=True, return_itc=False, decim=3, n_jobs=1)
        # power has shape (129, 40, 11609)

        eeg_data = power.get_data()[:, :, eeg_index[1] : eeg_index[2]]

        if self.transform:
            eeg_data = self.transform(eeg_data)
        if self.target_transform:
            label = self.target_transform(label)

        new_shape = (eeg_data.shape[0], eeg_data.shape[1] * eeg_data.shape[2])  # (129, 40*84)
        eeg_data = eeg_data.reshape(new_shape)

        return eeg_data, label
        # eeg_data has shape (129, 3360)

In [14]:
class SpectrogramCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(SpectrogramCNN, self).__init__()
        # Reduce the number of convolutional layers and channels
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 10), padding=(1, 5))  # One layer, more channels
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d((2, 4))  # Pooling to reduce dimensions

        # Simplified dropout and fewer fully connected layers
        self.dropout = nn.Dropout(0.25)
        
        # The fully connected layer sizes need to be adjusted based on the actual output dimensions
        # Placeholder dimensions for illustration; you will need to calculate the exact number based on your input size after pooling
        self.fc1 = nn.Linear(1733760, num_classes)  # Direct connection to output

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x


### Training loop for CNN & spectrogram data

In [15]:
dsets = {}
dsets['train'] = SpectroDataset(train_dir, train_behav_file, data_transforms['train'])
dsets['val'] = SpectroDataset(val_dir, val_behav_file, data_transforms['val'])

dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

dset_loaders = {}
for split in ['train', 'val']:
    targets = np.zeros(dsets[split].__len__())
    for i in range(len(targets)):
        label = dsets[split].get_label(i)
        targets[i] = label
    class_counts = dsets[split].get_class_counts()
    class_weights = {tag: 1.0 / count if count > 0 else 0 for tag, count in class_counts.items()}
    weights = np.array([class_weights[tag] for tag in targets])
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader: ', split)
model_ft = SpectrogramCNN(num_classes=4)

criterion = nn.CrossEntropyLoss()

if use_gpu:
    criterion.cuda()
    model_ft.cuda()

if use_mps:
    criterion.to(mps_device)
    model_ft.to(mps_device)

optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)

model_ft, accuracies, losses = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=n_epochs)

for split in ['train', 'val']:
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])


100%|██████████| 192/192 [00:10<00:00, 19.09it/s]


Class counts:  {'HEHF': 2482, 'HELF': 238, 'LEHF': 1366, 'LELF': 874}


100%|██████████| 48/48 [00:01<00:00, 25.21it/s]


Class counts:  {'HEHF': 507, 'HELF': 100, 'LEHF': 533, 'LELF': 100}
done making loader:  train
done making loader:  val
----------
Epoch 0/9
----------
LR is set to 0.0001
0
Reached batch iteration 0
1
2
