In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import mne
import numpy as np

## 1. Custom Pytorch dataset

In [6]:
class NeuronAgeDataset(Dataset):
    def __init__(self, csv_file, data_dir, segment_duration=5, overlap=0.5):
        self.df = pd.read_csv(csv_file, sep='\t')
        self.data_dir = data_dir
        self.segment_duration = segment_duration
        self.overlap = overlap
        self.sfreq = None
        self.segments = []
        self.labels = []
        self._prepare_data()

    def _prepare_data(self):
        for idx, row in self.df.iterrows():
            age = row['age']
            rawdata_path_prefix = row['participant_id']
            rawdata_path = os.path.join(self.data_dir, f'{rawdata_path_prefix}_sflip_parc-raw.fif')
            try:
                rawdata = mne.io.read_raw_fif(rawdata_path, preload=True)
            except Exception as e:
                print(f"An error occurred while loading the file: {e}")
                continue
            
            if self.sfreq is None:
                self.sfreq = rawdata.info['sfreq']
            
            bad_segments = []
            for annot in rawdata.annotations:
                if "bad_segment" in annot['description']:
                    start = annot['onset']
                    end = start + annot['duration']
                    bad_segments.append((start, end))
            
            segment_samples = int(self.segment_duration * self.sfreq) # number of timesteps in a segment
            overlap_samples = int(self.overlap * segment_samples) # number of overlapping timesteps between segments
            total_samples = len(rawdata.times) # total number of timesteps in the raw data
            print(f'segment samples: {segment_samples}')
            print(f"Total samples: {total_samples}")


            start_sample = 0
            while start_sample + segment_samples <= total_samples:
                end_sample = start_sample + segment_samples
                segment_times = rawdata.times[start_sample:end_sample]
                
                if not any(start <= t <= end for t in segment_times for start, end in bad_segments):
                    data, _ = rawdata[:52, start_sample:end_sample]
                    self.segments.append(data)
                    self.labels.append(age)
                
                start_sample += segment_samples - overlap_samples

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        segment = self.segments[idx]
        label = self.labels[idx]
        return torch.tensor(segment, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

## 2. EEG net

In [3]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.T = 120
        
        # Layer 1
        self.conv1 = nn.Conv2d(1, 16, (1, 52), padding = 0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # Layer 2
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # Layer 3
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        
        # FC Layer
        # NOTE: This dimension will depend on the number of timestamps per sample in your data.
        # I have 120 timepoints. 
        self.fc1 = nn.Linear(376, 1)
        

    def forward(self, x):
        """
        (batch_size, 1, num_timepoints, num_channels) -> (batch_size, 1)
        """
        batch_size = x.size(0)
        # Layer 1
        x = F.elu(self.conv1(x)) # -> (batch_size, 16, num_timepoints, 1)
        x = self.batchnorm1(x) # -> (batch_size, 16, num_timepoints, 1)
        x = F.dropout(x, 0.25) # -> (batch_size, 16, num_timepoints, 1)
        x = x.permute(0, 3, 1, 2) # -> (batch_size, 1, num_timepoints, 16)
        
        # Layer 2
        x = self.padding1(x) # -> (batch_size, 1, num_timepoints, 33)
        x = F.elu(self.conv2(x)) # -> (batch_size, 4, num_timepoints, 1)
        x = self.batchnorm2(x) # -> (batch_size, 4, num_timepoints, 1)
        x = F.dropout(x, 0.25) # -> (batch_size, 4, num_timepoints, 1)
        x = self.pooling2(x) # -> (batch_size, 4, num_timepoints//2, 1)
        
        # Layer 3
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = F.dropout(x, 0.25)
        x = self.pooling3(x)

        # FC Layer
        x = x.reshape(batch_size, 376)
        x = F.sigmoid(self.fc1(x))
        return x # -> (batch_size, 1)

## 3. Run baseline

In [7]:
data_dir = 'data/raw'
train_csv_file = 'data/train_set.csv'
test_csv_file = 'data/test_set.csv'
train_dataset = NeuronAgeDataset(train_csv_file, data_dir, segment_duration=3, overlap=0.5)
test_dataset = NeuronAgeDataset(test_csv_file, data_dir, segment_duration=3, overlap=0.5)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=True)

Opening raw data file data/raw/sub-CC320089_sflip_parc-raw.fif...
    Range : 23000 ... 163499 =     92.000 ...   653.996 secs
Ready.
Reading 0 ... 140499  =      0.000 ...   561.996 secs...


segment samples: 750
Total samples: 140500
Opening raw data file data/raw/sub-CC520209_sflip_parc-raw.fif...
    Range : 37750 ... 178249 =    151.000 ...   712.996 secs
Ready.
Reading 0 ... 140499  =      0.000 ...   561.996 secs...
segment samples: 750
Total samples: 140500
Opening raw data file data/raw/sub-CC110045_sflip_parc-raw.fif...
    Range : 24000 ... 164499 =     96.000 ...   657.996 secs
Ready.
Reading 0 ... 140499  =      0.000 ...   561.996 secs...
segment samples: 750
Total samples: 140500
Opening raw data file data/raw/sub-CC610052_sflip_parc-raw.fif...
    Range : 23000 ... 163999 =     92.000 ...   655.996 secs
Ready.
Reading 0 ... 140999  =      0.000 ...   563.996 secs...
segment samples: 750
Total samples: 141000
Opening raw data file data/raw/sub-CC221031_sflip_parc-raw.fif...
    Range : 4500 ... 145249 =     18.000 ...   580.996 secs
Ready.
Reading 0 ... 140749  =      0.000 ...   562.996 secs...
segment samples: 750
Total samples: 140750
Opening raw data file 

  rawdata = mne.io.read_raw_fif(rawdata_path, preload=True)


segment samples: 750
Total samples: 141000
Opening raw data file data/raw/sub-CC420198_sflip_parc-raw.fif...
    Range : 29750 ... 170249 =    119.000 ...   680.996 secs
Ready.
Reading 0 ... 140499  =      0.000 ...   561.996 secs...
segment samples: 750
Total samples: 140500
Opening raw data file data/raw/sub-CC721891_sflip_parc-raw.fif...
    Range : 14250 ... 155249 =     57.000 ...   620.996 secs
Ready.
Reading 0 ... 140999  =      0.000 ...   563.996 secs...
segment samples: 750
Total samples: 141000
Opening raw data file data/raw/sub-CC721704_sflip_parc-raw.fif...
    Range : 14500 ... 156499 =     58.000 ...   625.996 secs
Ready.
Reading 0 ... 141999  =      0.000 ...   567.996 secs...
segment samples: 750
Total samples: 142000
Opening raw data file data/raw/sub-CC620619_sflip_parc-raw.fif...
    Range : 17000 ... 157249 =     68.000 ...   628.996 secs
Ready.
Reading 0 ... 140249  =      0.000 ...   560.996 secs...
segment samples: 750
Total samples: 140250
Opening raw data file

In [11]:
num_epochs = 20
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EEGNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_losses = []
validation_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.unsqueeze(1).permute(0, 1, 3, 2).to(device)
        labels = labels.unsqueeze(1).to(device)
    
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    model.eval()
    correct = 0
    total = 0
    validation_loss = 0.0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs = inputs.unsqueeze(1).permute(0, 1, 3, 2).to(device)
            labels = labels.unsqueeze(1).to(device)


            outputs = model(inputs)
            loss = criterion(outputs, labels)
            validation_loss += loss.item()
            total += labels.size(0)

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss/len(train_loader)}, Validation Loss: {validation_loss/len(test_loader)}')
    train_losses.append(train_loss/len(train_loader))
    validation_losses.append(validation_loss/len(test_loader))

    model.train()

print('Finished Training')

Epoch 1/20, Training Loss: 50.06726532408645, Validation Loss: 49.05195320977105
Epoch 2/20, Training Loss: 50.164938921368986, Validation Loss: 49.05733557807075
Epoch 3/20, Training Loss: 50.14719420704762, Validation Loss: 49.05204806857639
Epoch 4/20, Training Loss: 50.11443326726306, Validation Loss: 49.05568449232313
Epoch 5/20, Training Loss: 50.128919590784854, Validation Loss: 49.06312950981988
Epoch 6/20, Training Loss: 50.110126559294805, Validation Loss: 49.06112255520291
Epoch 7/20, Training Loss: 50.0897216796875, Validation Loss: 49.06508517795139
Epoch 8/20, Training Loss: 50.094015153426696, Validation Loss: 49.05911560058594
Epoch 9/20, Training Loss: 50.170245740666736, Validation Loss: 49.0503414577908
Epoch 10/20, Training Loss: 50.09830764685263, Validation Loss: 49.056429375542535
Epoch 11/20, Training Loss: 50.13641941347602, Validation Loss: 49.04831127590603
Epoch 12/20, Training Loss: 50.06662920733404, Validation Loss: 49.05882746378581
Epoch 13/20, Training