In [103]:
import os
import mne
from mne.preprocessing import ICA
import numpy as np
from mne.time_frequency import tfr_multitaper
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn

In [119]:
chan_IDs = ['FP1', 'FP2' , 'F7' ,'F8']

In [120]:
pipeline_visualizations_semiautomated = False
vis_freq_min = 2
vis_freq_max = 57
freq_to_plot = [6, 10, 20, 30, 55]

In [121]:
pipeline_visualizations_semiautomated = False
vis_freq_min = 2
vis_freq_max = 57
freq_to_plot = [6, 10, 20, 30, 55]
task_EEG_processing = False
segment_data = True
segment_length = 2 
segment_interpolation = True
segment_rejection = False
average_rereference = True

In [122]:
src_folder_name='eeg_data'
file_names = [f for f in os.listdir(src_folder_name) if f.endswith('.edf')]

In [123]:
metrics = {
    "Number_ICs_Rejected": [],
    "File_Length_In_Secs": [],
    "Percent_Variance_Kept_of_Post_Waveleted_Data": []
}

In [125]:
def load_and_extract_intervals(file_path):
    # Load the raw data
    raw = mne.io.read_raw_edf(file_path, preload=True)

    raw_picked = raw.pick_channels(chan_IDs)

    # Set the montage and ignore missing channels
    # raw.set_montage(montage, on_missing='ignore')
    
    # Filter the data (1 Hz highpass, 50 Hz notch for powerline noise)
    raw.filter(1, 50, fir_design='firwin')

    # Plot the power spectrum if visualization is enabled
    if pipeline_visualizations_semiautomated:
        raw.plot_psd(fmin=vis_freq_min, fmax=vis_freq_max)

    # ICA for artifact removal
    ica = ICA(n_components=2, random_state=97, max_iter=800)
    ica.fit(raw)

    # Re-reference data to average if specified
    if average_rereference:
        raw.set_eeg_reference('average', projection=True)

    # Segmentation (epoching the data)
    if segment_data:
        if not task_EEG_processing:
            events = mne.make_fixed_length_events(raw, duration=segment_length)
        else:
            # If task-related EEG, use predefined conditions (not applicable here)
            raise NotImplementedError("Task EEG not implemented yet.")

        # Epoch the data
        # Set the baseline to a reasonable range (e.g., 200 ms before the event to the event itself)
        # epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=segment_length, baseline=(None, 0), preload=True)
        sfreq = raw.info['sfreq']
        max_time = raw.times[-1]
        interval_1=(300, 600)
        interval_2=(1500, 1860)
            
        # Convert intervals to seconds
        interval_1_sec = (interval_1[0] / sfreq, interval_1[1] / sfreq)
        interval_2_sec = (interval_2[0] / sfreq, interval_2[1] / sfreq)
            
        # Adjust intervals if they exceed the recording duration
        interval_1_sec = (max(0, interval_1_sec[0]), min(max_time, interval_1_sec[1]))
        interval_2_sec = (max(0, interval_2_sec[0]), min(max_time, interval_2_sec[1]))

        no_stress_interval = raw.copy().crop(tmin=interval_1_sec[0], tmax=interval_1_sec[1])
        stressed_interval = raw.copy().crop(tmin=interval_2_sec[0], tmax=interval_2_sec[1])
        epochs_no_stress=mne.make_fixed_length_epochs(no_stress_interval,duration=1,overlap=0.5)
        array_no_streesed=epochs_no_stress.get_data()
                
        epochs_stressed=mne.make_fixed_length_epochs(stressed_interval,duration=1,overlap=0.5)
        array_streesed=epochs_stressed.get_data()

    
    # Record metrics for each file (example)
    metrics["Number_ICs_Rejected"].append(len(ica.exclude))
    metrics["File_Length_In_Secs"].append(raw.times[-1])
    return array_no_streesed,array_streesed

In [136]:
# Initialize lists to hold data and labels
all_data = []
all_labels = []

for file in file_names:
    file_path=os.path.join(src_folder_name, file)
    no_stressed,stressed = load_and_extract_intervals(file_path)    
    all_data.append(stressed[0])
    all_data.append(no_stressed[0])
    all_labels.append(1)
    all_labels.append(0)

all_data = np.array(all_data)
all_labels = np.array(all_labels)

In [127]:
scaler = StandardScaler()
all_data = scaler.fit_transform(all_data.reshape(-1, all_data.shape[-1])).reshape(all_data.shape)

In [128]:
indices = np.arange(all_data.shape[0])
np.random.shuffle(indices)

all_data = all_data[indices]
all_labels = all_labels[indices]

In [129]:
X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, test_size=0.2, random_state=42)

In [130]:
train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [131]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(16 * (X_train.shape[2] // 2), 64)
        self.fc2 = nn.Linear(64, 2)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = x.view(-1, 16 * (X_train.shape[2] // 2))
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

model = EEGNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [135]:
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')

In [134]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        print(predicted)
        print(labels)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total}%')