# EEGNet

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import KFold
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
import pickle
import numpy as np

# Define constants
DATA_DIR = 'C:/Users/User/Documents/Lie detect data/EEGData'
BATCH_SIZE = 32
EPOCHS = 100  # Increased to allow early stopping
LEARNING_RATE = 0.001
NUM_FOLDS = 5
PATIENCE = 20  # Early stopping patience

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Custom Dataset class for EEG data
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data = []
        self.labels = []
        self.load_data(data_dir)
        self.normalize_data()

    def load_data(self, data_dir):
        max_length = 0
        temp_data = []
        
        for file_name in os.listdir(data_dir):
            file_path = os.path.join(data_dir, file_name)
            with open(file_path, 'rb') as f:
                eeg_data = pickle.load(f)
                label = 1 if 'lie' in file_name else 0  # Assuming file names contain 'lie' or 'truth'
                temp_data.append((eeg_data, label))
                max_length = max(max_length, eeg_data.shape[1])

        for eeg_data, label in temp_data:
            padded_data = np.pad(eeg_data, ((0, 0), (0, max_length - eeg_data.shape[1])), mode='constant')
            self.data.append(padded_data)
            self.labels.append(label)
        
        self.data = [torch.tensor(d, dtype=torch.float32, device=device) for d in self.data]
        self.labels = torch.tensor(self.labels, dtype=torch.long, device=device)
    
    def normalize_data(self):
        all_data = torch.cat([d.unsqueeze(0) for d in self.data], dim=0)
        mean = all_data.mean()
        std = all_data.std()
        self.data = [(d - mean) / std for d in self.data]

    def augment_data(self, data):
        # Advanced augmentations: Gaussian noise, time shift, scaling
        noise = torch.randn_like(data, device=device) * 0.01
        shift = torch.roll(data, shifts=int(data.shape[1] * 0.1), dims=1)
        scale = data * (1 + 0.1 * torch.randn(1, device=device))
        return noise + shift + scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx], self.labels[idx]
        data = self.augment_data(data)  # Apply augmentation
        return data, label

# Define the EEGNet model
class EEGNet(nn.Module):
    def __init__(self, output_size):
        super(EEGNet, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, (1, 51), padding=(0, 25)),
            nn.BatchNorm2d(16)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, (65, 1), groups=16),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.5)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, (1, 15), padding=(0, 7)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(0.5)
        )
        self.classify = nn.Sequential(
            nn.Flatten(),
            nn.Linear(output_size, 2)
        )

    def forward(self, x):
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        return self.classify(x)

# Function to determine the output size of the EEGNet model before the linear layer
def get_output_size(model, shape):
    with torch.no_grad():
        x = torch.zeros(shape, device=device)
        x = model.firstconv(x)
        x = model.depthwiseConv(x)
        x = model.separableConv(x)
        return x.view(x.size(0), -1).size(1)

# Load data
dataset = EEGDataset(DATA_DIR)

# Determine the correct input size for the linear layer
dummy_input_shape = (1, 1, 65, max([d.shape[1] for d in dataset.data]))  # (batch_size, channels, height, width)
output_size = get_output_size(EEGNet(output_size=0).to(device), dummy_input_shape)

# K-Fold Cross Validation
kf = KFold(n_splits=NUM_FOLDS, shuffle=True)

accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []
val_losses = []  


for fold, (train_index, val_index) in enumerate(kf.split(dataset)):
    print(f'Fold {fold+1}')
    
    train_subset = Subset(dataset, train_index)
    val_subset = Subset(dataset, val_index)

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize model, loss function, and optimizer
    model = EEGNet(output_size=output_size).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_val_accuracy = 0
    patience_counter = 0

    # Training loop
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.unsqueeze(1))  # Adding channel dimension
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        scheduler.step()

        # Intermediate validation
        model.eval()
        correct = 0
        total = 0
        all_labels = []
        all_predictions = []
        all_probs = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs.unsqueeze(1))  # Adding channel dimension
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
                all_probs.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())  # Get the probability of class 1

        val_accuracy = 100 * correct / total
        val_precision = precision_score(all_labels, all_predictions, average='macro')
        val_recall = recall_score(all_labels, all_predictions, average='macro')
        val_f1 = f1_score(all_labels, all_predictions, average='macro')
        val_auc = roc_auc_score(all_labels, all_probs)  # Calculate AUC score
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader)}, Validation Accuracy: {val_accuracy:.2f}%, AUC: {val_auc:.2f}")

        # Early stopping
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Final validation metrics for the fold
    accuracies.append(best_val_accuracy)
    precisions.append(val_precision)
    recalls.append(val_recall)
    f1s.append(val_f1)
    aucs.append(val_auc)
    print(f'Final Accuracy for fold {fold+1}: {best_val_accuracy:.2f}%')
    print(f'Final Precision for fold {fold+1}: {val_precision:.2f}')
    print(f'Final Recall for fold {fold+1}: {val_recall:.2f}')
    print(f'Final F1-Score for fold {fold+1}: {val_f1:.2f}')
    print(f'Final AUC for fold {fold+1}: {val_auc:.2f}\n')

# Report average performance across all folds
print(f'Average Accuracy: {np.mean(accuracies):.2f}%')
print(f'Average Precision: {np.mean(precisions):.2f}')
print(f'Average Recall: {np.mean(recalls):.2f}')
print(f'Average F1-Score: {np.mean(f1s):.2f}')
print(f'Average AUC: {np.mean(aucs):.2f}')


Using device: cuda
Fold 1


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/100, Loss: 0.5831657946109772, Validation Accuracy: 50.00%, AUC: 0.46
Epoch 2/100, Loss: 0.5716206034024557, Validation Accuracy: 50.00%, AUC: 0.63
Epoch 3/100, Loss: 0.4606303970019023, Validation Accuracy: 44.44%, AUC: 0.63
Epoch 4/100, Loss: 0.42014991243680316, Validation Accuracy: 72.22%, AUC: 0.73
Epoch 5/100, Loss: 0.39527322848637897, Validation Accuracy: 77.78%, AUC: 0.78
Epoch 6/100, Loss: 0.41237648328145343, Validation Accuracy: 83.33%, AUC: 0.81
Epoch 7/100, Loss: 0.3233715345462163, Validation Accuracy: 88.89%, AUC: 0.83
Epoch 8/100, Loss: 0.2921660716334979, Validation Accuracy: 88.89%, AUC: 0.84
Epoch 9/100, Loss: 0.2981623113155365, Validation Accuracy: 88.89%, AUC: 0.85
Epoch 10/100, Loss: 0.29900231460730237, Validation Accuracy: 83.33%, AUC: 0.85
Epoch 11/100, Loss: 0.2427449276049932, Validation Accuracy: 83.33%, AUC: 0.84
Epoch 12/100, Loss: 0.250828559199969, Validation Accuracy: 83.33%, AUC: 0.83
Epoch 13/100, Loss: 0.21404471496740976, Validation Accura

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 3/100, Loss: 0.46156957745552063, Validation Accuracy: 88.89%, AUC: 1.00
Epoch 4/100, Loss: 0.41970967253049213, Validation Accuracy: 94.44%, AUC: 1.00
Epoch 5/100, Loss: 0.33507755398750305, Validation Accuracy: 83.33%, AUC: 1.00
Epoch 6/100, Loss: 0.586733857790629, Validation Accuracy: 88.89%, AUC: 1.00
Epoch 7/100, Loss: 0.2819896588722865, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 8/100, Loss: 0.2882344424724579, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 9/100, Loss: 0.34858495990435284, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 10/100, Loss: 0.2477224717537562, Validation Accuracy: 94.44%, AUC: 1.00
Epoch 11/100, Loss: 0.3191291193167369, Validation Accuracy: 94.44%, AUC: 1.00
Epoch 12/100, Loss: 0.27952421208222705, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 13/100, Loss: 0.241177166501681, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 14/100, Loss: 0.21709581216176352, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 15/100, Loss: 0.23138651251792908, Validati

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/100, Loss: 0.7455568512280782, Validation Accuracy: 61.11%, AUC: 0.61
Epoch 2/100, Loss: 0.6485156118869781, Validation Accuracy: 61.11%, AUC: 0.64
Epoch 3/100, Loss: 0.5286488433678945, Validation Accuracy: 61.11%, AUC: 0.77
Epoch 4/100, Loss: 0.4335412383079529, Validation Accuracy: 72.22%, AUC: 0.81
Epoch 5/100, Loss: 0.42256633440653485, Validation Accuracy: 77.78%, AUC: 0.86
Epoch 6/100, Loss: 0.36623485883076984, Validation Accuracy: 77.78%, AUC: 0.86
Epoch 7/100, Loss: 0.34683312972386676, Validation Accuracy: 77.78%, AUC: 0.92
Epoch 8/100, Loss: 0.3000007172425588, Validation Accuracy: 72.22%, AUC: 0.86
Epoch 9/100, Loss: 0.32396040360132855, Validation Accuracy: 77.78%, AUC: 0.87
Epoch 10/100, Loss: 0.2100615600744883, Validation Accuracy: 83.33%, AUC: 0.88
Epoch 11/100, Loss: 0.18650387724240622, Validation Accuracy: 83.33%, AUC: 0.90
Epoch 12/100, Loss: 0.32408642768859863, Validation Accuracy: 83.33%, AUC: 0.87
Epoch 13/100, Loss: 0.21269037326176962, Validation Acc

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1/100, Loss: 0.8180278738339742, Validation Accuracy: 55.56%, AUC: 0.89
Epoch 2/100, Loss: 0.6523479024569193, Validation Accuracy: 61.11%, AUC: 0.95
Epoch 3/100, Loss: 0.5604995290438334, Validation Accuracy: 83.33%, AUC: 0.94
Epoch 4/100, Loss: 0.5014299750328064, Validation Accuracy: 83.33%, AUC: 0.86
Epoch 5/100, Loss: 0.4648139476776123, Validation Accuracy: 83.33%, AUC: 0.85
Epoch 6/100, Loss: 0.3899699052174886, Validation Accuracy: 83.33%, AUC: 0.86
Epoch 7/100, Loss: 0.32616376380125683, Validation Accuracy: 88.89%, AUC: 0.88
Epoch 8/100, Loss: 0.3362744053204854, Validation Accuracy: 88.89%, AUC: 0.85
Epoch 9/100, Loss: 0.3619592587153117, Validation Accuracy: 88.89%, AUC: 0.85
Epoch 10/100, Loss: 0.3005932966868083, Validation Accuracy: 88.89%, AUC: 0.85
Epoch 11/100, Loss: 0.2872677693764369, Validation Accuracy: 88.89%, AUC: 0.84
Epoch 12/100, Loss: 0.2578426351149877, Validation Accuracy: 88.89%, AUC: 0.85
Epoch 13/100, Loss: 0.31096895039081573, Validation Accuracy