In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import KFold
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import pickle
import numpy as np

# Define constants
DATA_DIR = 'C:/Users/User/Documents/Lie detect data/EEGData'
BATCH_SIZE = 32
EPOCHS = 100  # Increased to allow early stopping
LEARNING_RATE = 0.001
NUM_FOLDS = 5
PATIENCE = 20  # Early stopping patience

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Custom Dataset class for EEG data
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data = []
        self.labels = []
        self.load_data(data_dir)
        self.normalize_data()

    def load_data(self, data_dir):
        max_length = 0
        temp_data = []
        
        for file_name in os.listdir(data_dir):
            file_path = os.path.join(data_dir, file_name)
            with open(file_path, 'rb') as f:
                eeg_data = pickle.load(f)
                label = 1 if 'lie' in file_name else 0  # Assuming file names contain 'lie' or 'truth'
                temp_data.append((eeg_data, label))
                max_length = max(max_length, eeg_data.shape[1])

        for eeg_data, label in temp_data:
            padded_data = np.pad(eeg_data, ((0, 0), (0, max_length - eeg_data.shape[1])), mode='constant')
            self.data.append(padded_data)
            self.labels.append(label)
        
        self.data = [torch.tensor(d, dtype=torch.float32, device=device) for d in self.data]
        self.labels = torch.tensor(self.labels, dtype=torch.long, device=device)
    
    def normalize_data(self):
        all_data = torch.cat([d.unsqueeze(0) for d in self.data], dim=0)
        mean = all_data.mean()
        std = all_data.std()
        self.data = [(d - mean) / std for d in self.data]

    def augment_data(self, data):
        # Advanced augmentations: Gaussian noise, time shift, scaling
        noise = torch.randn_like(data, device=device) * 0.01
        shift = torch.roll(data, shifts=int(data.shape[1] * 0.1), dims=1)
        scale = data * (1 + 0.1 * torch.randn(1, device=device))
        return noise + shift + scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data, label = self.data[idx], self.labels[idx]
        data = self.augment_data(data)  # Apply augmentation
        return data, label

# Define the EEGNet model
class EEGNet(nn.Module):
    def __init__(self, output_size):
        super(EEGNet, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, 16, (1, 51), padding=(0, 25)),
            nn.BatchNorm2d(16)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, (65, 1), groups=16),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(0.5)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, (1, 15), padding=(0, 7)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(0.5)
        )
        self.classify = nn.Sequential(
            nn.Flatten(),
            nn.Linear(output_size, 2)
        )

    def forward(self, x):
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)
        return self.classify(x)

# Function to determine the output size of the EEGNet model before the linear layer
def get_output_size(model, shape):
    with torch.no_grad():
        x = torch.zeros(shape, device=device)
        x = model.firstconv(x)
        x = model.depthwiseConv(x)
        x = model.separableConv(x)
        return x.view(x.size(0), -1).size(1)

# Load data
dataset = EEGDataset(DATA_DIR)

# Determine the correct input size for the linear layer
dummy_input_shape = (1, 1, 65, max([d.shape[1] for d in dataset.data]))  # (batch_size, channels, height, width)
output_size = get_output_size(EEGNet(output_size=0).to(device), dummy_input_shape)

# K-Fold Cross Validation
kf = KFold(n_splits=NUM_FOLDS, shuffle=True)

accuracies = []
precisions = []
recalls = []
f1s = []
aucs = []
confusion_matrices = []

for fold, (train_index, val_index) in enumerate(kf.split(dataset)):
    print(f'Fold {fold+1}')

    # Creating train and validation samplers
    train_sampler = SubsetRandomSampler(train_index)
    val_sampler = SubsetRandomSampler(val_index)

    # Creating DataLoaders
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=val_sampler)

    # Calculate class weights based on training data
    train_labels = [dataset[i][1].cpu().numpy() for i in train_index]  # Move to CPU and convert to NumPy array
    class_counts = np.bincount(train_labels)
    class_weights = 1. / class_counts
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    # Initialize model, loss function, and optimizer
    model = EEGNet(output_size=output_size).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

    best_val_accuracy = 0
    patience_counter = 0

    # Training loop
    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.unsqueeze(1))  # Adding channel dimension
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Scheduler step based on running loss
        scheduler.step(running_loss)

        # Intermediate validation
        model.eval()
        correct = 0
        total = 0
        all_labels = []
        all_predictions = []
        all_probs = []
        val_running_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs.unsqueeze(1))  # Adding channel dimension
                loss = criterion(outputs, labels)
                val_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())
                all_probs.extend(torch.softmax(outputs, dim=1)[:, 1].cpu().numpy())  # Get the probability of class 1

        val_accuracy = 100 * correct / total
        val_precision = precision_score(all_labels, all_predictions, average='macro')
        val_recall = recall_score(all_labels, all_predictions, average='macro')
        val_f1 = f1_score(all_labels, all_predictions, average='macro')
        val_auc = roc_auc_score(all_labels, all_probs)  # Calculate AUC score
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {running_loss/len(train_loader)}, Validation Loss: {val_running_loss/len(val_loader)}, Validation Accuracy: {val_accuracy:.2f}%, AUC: {val_auc:.2f}")

        # Early stopping
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
            # Save the best model
            torch.save(model.state_dict(), f"best_model_fold_{fold+1}.pth")
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Final validation metrics for the fold
    accuracies.append(best_val_accuracy)
    precisions.append(val_precision)
    recalls.append(val_recall)
    f1s.append(val_f1)
    aucs.append(val_auc)
    print(f'Final Accuracy for fold {fold+1}: {best_val_accuracy:.2f}%')
    print(f'Final Precision for fold {fold+1}: {val_precision:.2f}')
    print(f'Final Recall for fold {fold+1}: {val_recall:.2f}')
    print(f'Final F1-Score for fold {fold+1}: {val_f1:.2f}')
    print(f'Final AUC for fold {fold+1}: {val_auc:.2f}\n')

# Report average performance across all folds
print(f'Average Accuracy: {np.mean(accuracies):.2f}%')
print(f'Average Precision: {np.mean(precisions):.2f}')
print(f'Average Recall: {np.mean(recalls):.2f}')
print(f'Average F1-Score: {np.mean(f1s):.2f}')
print(f'Average AUC: {np.mean(aucs):.2f}')


# Print confusion matrices for each fold
for i, cm in enumerate(confusion_matrices):
    print(f'Confusion Matrix for fold {i+1}:')
    print(cm)

Using device: cuda




Fold 1




Epoch 1/100, Loss: 0.8394130865732828, Validation Loss: 0.6580803990364075, Validation Accuracy: 77.78%, AUC: 0.64
Epoch 2/100, Loss: 0.6368493636449178, Validation Loss: 0.6487430930137634, Validation Accuracy: 55.56%, AUC: 0.84
Epoch 3/100, Loss: 0.5582949817180634, Validation Loss: 0.6363502740859985, Validation Accuracy: 38.89%, AUC: 0.79
Epoch 4/100, Loss: 0.4406478901704152, Validation Loss: 0.5951131582260132, Validation Accuracy: 61.11%, AUC: 0.77
Epoch 5/100, Loss: 0.45410295327504474, Validation Loss: 0.5822038650512695, Validation Accuracy: 61.11%, AUC: 0.77
Epoch 6/100, Loss: 0.5262102087338766, Validation Loss: 0.574840247631073, Validation Accuracy: 66.67%, AUC: 0.88
Epoch 7/100, Loss: 0.40798385938008624, Validation Loss: 0.5974298119544983, Validation Accuracy: 66.67%, AUC: 0.91
Epoch 8/100, Loss: 0.3562670548756917, Validation Loss: 0.5751838684082031, Validation Accuracy: 66.67%, AUC: 0.91
Epoch 9/100, Loss: 0.25946616133054096, Validation Loss: 0.5293633341789246, Va



Epoch 4/100, Loss: 0.41754483183224994, Validation Loss: 0.6461873054504395, Validation Accuracy: 50.00%, AUC: 0.66
Epoch 5/100, Loss: 0.34874128301938373, Validation Loss: 0.6759926080703735, Validation Accuracy: 66.67%, AUC: 0.68
Epoch 6/100, Loss: 0.3288443088531494, Validation Loss: 0.6626077890396118, Validation Accuracy: 72.22%, AUC: 0.72
Epoch 7/100, Loss: 0.43387940526008606, Validation Loss: 0.5666847229003906, Validation Accuracy: 72.22%, AUC: 0.84
Epoch 8/100, Loss: 0.3355404684940974, Validation Loss: 0.488547682762146, Validation Accuracy: 77.78%, AUC: 0.80
Epoch 9/100, Loss: 0.2894390920797984, Validation Loss: 0.4291003346443176, Validation Accuracy: 83.33%, AUC: 0.80
Epoch 10/100, Loss: 0.261380930741628, Validation Loss: 0.4704250693321228, Validation Accuracy: 88.89%, AUC: 0.78
Epoch 11/100, Loss: 0.23963954548041025, Validation Loss: 0.525185227394104, Validation Accuracy: 88.89%, AUC: 0.76
Epoch 12/100, Loss: 0.15991431723038355, Validation Loss: 0.561599850654602, 



Epoch 4/100, Loss: 0.5184783438841502, Validation Loss: 0.4995233118534088, Validation Accuracy: 94.44%, AUC: 0.97
Epoch 5/100, Loss: 0.4785986344019572, Validation Loss: 0.41728702187538147, Validation Accuracy: 94.44%, AUC: 0.99
Epoch 6/100, Loss: 0.4291989902655284, Validation Loss: 0.32324132323265076, Validation Accuracy: 94.44%, AUC: 0.99
Epoch 7/100, Loss: 0.3695099949836731, Validation Loss: 0.29678136110305786, Validation Accuracy: 94.44%, AUC: 0.99
Epoch 8/100, Loss: 0.3941962917645772, Validation Loss: 0.27829065918922424, Validation Accuracy: 88.89%, AUC: 0.99
Epoch 9/100, Loss: 0.3949030538400014, Validation Loss: 0.23695942759513855, Validation Accuracy: 94.44%, AUC: 0.99
Epoch 10/100, Loss: 0.28111350536346436, Validation Loss: 0.2014644742012024, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 11/100, Loss: 0.2911706864833832, Validation Loss: 0.17360498011112213, Validation Accuracy: 100.00%, AUC: 1.00
Epoch 12/100, Loss: 0.2537191982070605, Validation Loss: 0.1685986518



Epoch 4/100, Loss: 0.5683049658934275, Validation Loss: 0.652737021446228, Validation Accuracy: 61.11%, AUC: 0.69
Epoch 5/100, Loss: 0.3951188127199809, Validation Loss: 0.5815105438232422, Validation Accuracy: 77.78%, AUC: 0.75
Epoch 6/100, Loss: 0.3656119604905446, Validation Loss: 0.5720276832580566, Validation Accuracy: 77.78%, AUC: 0.81
Epoch 7/100, Loss: 0.29989171028137207, Validation Loss: 0.5994699597358704, Validation Accuracy: 88.89%, AUC: 0.79
Epoch 8/100, Loss: 0.36214444041252136, Validation Loss: 0.6168983578681946, Validation Accuracy: 72.22%, AUC: 0.80
Epoch 9/100, Loss: 0.29865515728791553, Validation Loss: 0.5246726870536804, Validation Accuracy: 88.89%, AUC: 0.80
Epoch 10/100, Loss: 0.2010005588332812, Validation Loss: 0.47496169805526733, Validation Accuracy: 83.33%, AUC: 0.81
Epoch 11/100, Loss: 0.2372224728266398, Validation Loss: 0.4699316918849945, Validation Accuracy: 83.33%, AUC: 0.81
Epoch 12/100, Loss: 0.1903054416179657, Validation Loss: 0.5026645660400391



Epoch 1/100, Loss: 0.8577163418134054, Validation Loss: 0.6934369802474976, Validation Accuracy: 50.00%, AUC: 0.34
Epoch 2/100, Loss: 0.6622242331504822, Validation Loss: 0.7093409895896912, Validation Accuracy: 33.33%, AUC: 0.73
Epoch 3/100, Loss: 0.5073525110880533, Validation Loss: 0.6601888537406921, Validation Accuracy: 66.67%, AUC: 0.71
Epoch 4/100, Loss: 0.5666225751241049, Validation Loss: 0.6272764801979065, Validation Accuracy: 72.22%, AUC: 0.68
Epoch 5/100, Loss: 0.4282403488953908, Validation Loss: 0.6139181852340698, Validation Accuracy: 66.67%, AUC: 0.80
Epoch 6/100, Loss: 0.36879361669222516, Validation Loss: 0.5994619727134705, Validation Accuracy: 72.22%, AUC: 0.86
Epoch 7/100, Loss: 0.30099841952323914, Validation Loss: 0.5856037139892578, Validation Accuracy: 72.22%, AUC: 0.86
Epoch 8/100, Loss: 0.284267783164978, Validation Loss: 0.5881832242012024, Validation Accuracy: 72.22%, AUC: 0.88
Epoch 9/100, Loss: 0.2867511808872223, Validation Loss: 0.5895050764083862, Val