In [51]:
import os
import mne
import numpy as np
from mne.time_frequency import tfr_multitaper
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn

In [52]:
chan_IDs = ['FP1', 'FP2' , 'F7' ,'F8']

In [53]:
pipeline_visualizations_semiautomated = True
vis_freq_min = 2
vis_freq_max = 57
freq_to_plot = [6, 10, 20, 30, 55]

In [54]:
pipeline_visualizations_semiautomated = True
vis_freq_min = 2
vis_freq_max = 57
freq_to_plot = [6, 10, 20, 30, 55]
task_EEG_processing = False
segment_data = True
segment_length = 2 
segment_interpolation = True
segment_rejection = False
average_rereference = True

In [55]:
src_folder_name='eeg_data'
file_names = [f for f in os.listdir(src_folder_name) if f.endswith('.edf')]

In [56]:
metrics = {
    "Number_ICs_Rejected": [],
    "File_Length_In_Secs": [],
    "Percent_Variance_Kept_of_Post_Waveleted_Data": []
}

In [57]:
import matplotlib
matplotlib.use('TkAgg')  # Switch to the TkAgg backend

In [58]:
def load_and_extract_intervals(file_path):
    # Load the raw data
    raw = mne.io.read_raw_edf(file_path, preload=True)

    montage_channels = ['T4', 'A2', 'C3', 'C4', 'A1', 'T3', 'O1', 'T5', 'O2', 
                    'F8', 'T6', 'FZ', 'FP2', 'F7', 'FP1', 'CZ']
    
    raw.pick_channels(montage_channels)
    
    channel_positions = np.array([
    [5.18e-15, -84.5, -8.85],  # T4
    [3.68e-15, -60.1, -60.1],  # A2
    [3.87e-15, 63.2, 56.9],    # C3
    [-3.87e-15, -63.2, 56.9],  # C4
    [3.68e-15, 60.1, -60.1],   # A1
    [5.18e-15, 84.5, -8.85],   # T3
    [-80.8, 26.1, -4],         # O1
    [-49.9, 68.4, -7.49],      # T5
    [-80.8, -26.1, -4],        # O2
    [49.9, -68.4, -7.49],      # F8
    [-49.9, -68.4, -7.49],     # T6
    [60.7, 0, 59.5],           # FZ
    [80.8, -26.1, -4],         # FP2
    [49.9, 68.4, -7.49],       # F7
    [80.8, 26.1, -4],          # FP1
    [5.2e-15, 0, 85],          # CZ
    [57.6, 48.2, 39.9],        # F3
    [57.6, -48.1, 39.9],       # F4
    [32.9, 0, 78.4]            # FCz
    ])
    available_channels = [ch for ch in montage_channels if ch in raw.info['ch_names']]

    # Adjust positions array to match available channels
    positions_dict = dict(zip(available_channels, channel_positions))
    positions_dict = {ch: positions_dict[ch] for ch in available_channels if ch in positions_dict}
    channel_positions = np.array([positions_dict[ch] for ch in available_channels])

    # Create the montage with the updated channel positions
    montage = mne.channels.make_dig_montage(
        ch_pos=dict(zip(available_channels, channel_positions)), 
        coord_frame='head'
    )

    raw = raw.pick_channels(chan_IDs)

    
    # Filter the data (1 Hz highpass, 50 Hz notch for powerline noise)
    raw.filter(1, 50, fir_design='firwin')

    # Plot the power spectrum if visualization is enabled
    if pipeline_visualizations_semiautomated:
        raw.plot_psd(fmin=vis_freq_min, fmax=vis_freq_max)

    # Re-reference data to average if specified
    if average_rereference:
        raw.set_eeg_reference('average', projection=True)

    # # Re-reference data to the specified channel
    # reference_channel='FCz'
    # if reference_channel in raw.info['ch_names']:
    #     raw.set_eeg_reference(ref_channels=[reference_channel], projection=True)
    # else:
    #     raise ValueError(f"The reference channel '{reference_channel}' is not present in the data.")
    

    # Segmentation (epoching the data)
    if segment_data:
        if not task_EEG_processing:
            events = mne.make_fixed_length_events(raw, duration=segment_length)
        else:
            # If task-related EEG, use predefined conditions (not applicable here)
            raise NotImplementedError("Task EEG not implemented yet.")

        # Epoch the data
        # Set the baseline to a reasonable range (e.g., 200 ms before the event to the event itself)
        # epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=segment_length, baseline=(None, 0), preload=True)
        sfreq = raw.info['sfreq']
        max_time = raw.times[-1]
        interval_1=(300, 600)
        interval_2=(1500, 1860)
            
        # Convert intervals to seconds
        interval_1_sec = (interval_1[0] / sfreq, interval_1[1] / sfreq)
        interval_2_sec = (interval_2[0] / sfreq, interval_2[1] / sfreq)
            
        # Adjust intervals if they exceed the recording duration
        interval_1_sec = (max(0, interval_1_sec[0]), min(max_time, interval_1_sec[1]))
        interval_2_sec = (max(0, interval_2_sec[0]), min(max_time, interval_2_sec[1]))

        no_stress_interval = raw.copy().crop(tmin=interval_1_sec[0], tmax=interval_1_sec[1])
        stressed_interval = raw.copy().crop(tmin=interval_2_sec[0], tmax=interval_2_sec[1])
        epochs_no_stress=mne.make_fixed_length_epochs(no_stress_interval,duration=1,overlap=0.5)
        array_no_streesed=epochs_no_stress.get_data()
                
        epochs_stressed=mne.make_fixed_length_epochs(stressed_interval,duration=1,overlap=0.5)
        array_streesed=epochs_stressed.get_data()

    
    # Record metrics for each file (example)
    metrics["File_Length_In_Secs"].append(raw.times[-1])
    return array_no_streesed,array_streesed

In [59]:
# Initialize lists to hold data and labels
all_data = []
all_labels = []

for file in file_names:
    file_path=os.path.join(src_folder_name, file)
    no_stressed,stressed = load_and_extract_intervals(file_path)    
    all_data.append(stressed[0])
    all_data.append(no_stressed[0])
    all_labels.append(1)
    all_labels.append(0)

all_data = np.array(all_data)
all_labels = np.array(all_labels)

Extracting EDF parameters from /Users/ashishupadhyay/Desktop/E Club Secy/Decison Lab/eeg_data/0_20170726_102501.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 356599  =      0.000 ...  1426.396 secs...
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 50 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 50.00 Hz
- Upper transition bandwidth: 12.50 Hz (-6 dB cutoff frequency: 56.25 Hz)
- Filter length: 825 samples (3.300 s)

NOTE: plot_psd() is 

  raw.plot_psd(fmin=vis_freq_min, fmax=vis_freq_max)
  raw.plot_psd(fmin=vis_freq_min, fmax=vis_freq_max)


In [11]:
scaler = StandardScaler()
all_data = scaler.fit_transform(all_data.reshape(-1, all_data.shape[-1])).reshape(all_data.shape)

In [12]:
indices = np.arange(all_data.shape[0])
np.random.shuffle(indices)

all_data = all_data[indices]
all_labels = all_labels[indices]

In [13]:
X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, test_size=0.2, random_state=42)

In [14]:
train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
class EEGNet(nn.Module):
    def __init__(self, num_classes=2, chans=4, samples=128):
        super(EEGNet, self).__init__()

        # First Conv2D layer (Temporal Convolutions)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 64), padding=(0, 32), bias=False)
        self.batchnorm1 = nn.BatchNorm2d(16)

        # Depthwise Convolution (Spatio-Temporal filtering)
        self.depthwiseConv = nn.Conv2d(16, 32, kernel_size=(chans, 1), groups=16, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.elu = nn.ELU()
        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4))
        self.dropout1 = nn.Dropout(0.25)

        # Separable Convolutions
        self.separableConv = nn.Conv2d(32, 32, kernel_size=(1, 16), padding=(0, 8), bias=False)
        self.batchnorm3 = nn.BatchNorm2d(32)
        self.pool2 = nn.AvgPool2d(kernel_size=(1, 8))
        self.dropout2 = nn.Dropout(0.25)

        # Fully Connected Layer
        self.fc1 = nn.Linear(32 * (samples // 32), num_classes)

    def forward(self, x):
        # First convolutional block (Temporal)
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.depthwiseConv(x)
        x = self.batchnorm2(x)
        x = self.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        # Separable convolutions
        x = self.separableConv(x)
        x = self.batchnorm3(x)
        x = self.elu(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        # Flatten and fully connected layer
        x = x.view(x.size(0), -1)
        x = self.fc1(x)

        return x

# Define model, loss function, and optimizer
model = EEGNet(num_classes=2, chans=4, samples=X_train.shape[2])  # Update with your actual input shape
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [16]:
# Set device (use GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

EEGNet(
  (conv1): Conv2d(1, 16, kernel_size=(1, 64), stride=(1, 1), padding=(0, 32), bias=False)
  (batchnorm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (depthwiseConv): Conv2d(16, 32, kernel_size=(4, 1), stride=(1, 1), groups=16, bias=False)
  (batchnorm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu): ELU(alpha=1.0)
  (pool1): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
  (dropout1): Dropout(p=0.25, inplace=False)
  (separableConv): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 1), padding=(0, 8), bias=False)
  (batchnorm3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
  (dropout2): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=224, out_features=2, bias=True)
)

In [49]:
num_epochs = 20  # Number of epochs to train for

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs.unsqueeze(1))  # Add a channel dimension
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Print statistics
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')
    
print("Training complete!")

In [50]:
# Set the model to evaluation mode
model.eval()

# Initialize counters for correct predictions and total samples
correct = 0
total = 0

# Disable gradient computation during testing
with torch.no_grad():
    for inputs, labels in test_loader:
        # Move inputs and labels to the appropriate device (CPU or GPU)
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Add a channel dimension to the inputs if necessary
        outputs = model(inputs.unsqueeze(1))  # Add the channel dimension

        # Get the predicted class (the index of the max log-probability)
        _, predicted = torch.max(outputs.data, 1)
        
        # Update the total number of samples
        total += labels.size(0)
        
        # Update the number of correct predictions
        correct += (predicted == labels).sum().item()

        # Optionally print the predicted and true labels for inspection
        print(f"Predicted: {predicted.cpu().numpy()}")
        print(f"Labels: {labels.cpu().numpy()}")

# Calculate and print the accuracy as a percentage
accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')
