In [14]:
import os
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, Dataset, Subset
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight


# Directory containing data files
data_dir = r'C:\Users\User\Documents\Lie detect data\56M_AugmentedEEGData'
model_save_dir = r'C:\Users\User\Documents\Lie detect data\Model'
os.makedirs(model_save_dir, exist_ok=True)

# Function to load and label data
def load_data(data_dir):
    X = []
    y = []
    groups = []
    
    for idx, file_name in enumerate(os.listdir(data_dir)):
        if file_name.endswith('.pkl'):
            file_path = os.path.join(data_dir, file_name)
            data = pd.read_pickle(file_path)
            label = 0 if 'lie' in file_name else 1
            X.append(data)
            y.extend([label] * data.shape[0])
            groups.extend([idx] * data.shape[0])  # Use file index as group label
    
    X = np.vstack(X)
    y = np.array(y)
    groups = np.array(groups)
    return X, y, groups

# Load and label data
X, y, groups = load_data(data_dir)

# Define dataset class
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Define EEGNet model
class EEGNet(nn.Module):
    def __init__(self, num_classes=2):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, (1, 63), padding='same')
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.depthwiseConv2d = nn.Conv2d(16, 32, (65, 1), groups=16, padding='same')
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.activation = nn.ELU()
        self.pooling = nn.AvgPool2d((1, 4))
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 65 * 62, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.depthwiseConv2d(x)
        x = self.batchnorm2(x)
        x = self.activation(x)
        x = self.pooling(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Initialize model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_and_evaluate(train_loader, val_loader, fold_idx):
    model = EEGNet(num_classes=2).to(device)
    
    class_weights = compute_class_weight('balanced', classes=np.unique(y), y=y)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)

    num_epochs = 100
    best_val_loss = float('inf')
    patience = 20
    trigger_times = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')
        
        scheduler.step()

        
        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            trigger_times = 0
            fold_model_path = os.path.join(model_save_dir, f'56M_model_fold_{fold_idx}.pth')
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'val_loss': best_val_loss,
            }, fold_model_path)
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print('Early stopping!')
                break

    return model

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        else:
            BCE_loss = F.cross_entropy(inputs, targets, reduction='none')

        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

# K-Fold Cross Validation
kf = GroupKFold(n_splits=5)
all_labels = []
all_predictions = []
fold_idx = 1

for train_index, val_index in kf.split(X, y, groups):
    print(f'Fold {fold_idx}')

    # Split data
    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

    # Normalize data
    scaler = MinMaxScaler()
    X_train = X_train.reshape(X_train.shape[0], -1)
    X_train = scaler.fit_transform(X_train)
    X_train = X_train.reshape(-1, 65, 250)

    X_val = X_val.reshape(X_val.shape[0], -1)
    X_val = scaler.transform(X_val)
    X_val = X_val.reshape(-1, 65, 250)

    # Create datasets
    train_dataset = EEGDataset(X_train, y_train)
    val_dataset = EEGDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

    model = train_and_evaluate(train_loader, val_loader, fold_idx)

    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            all_labels.extend(y_batch.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    fold_idx += 1

# Calculate additional metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)
f1 = f1_score(all_labels, all_predictions)
auc = roc_auc_score(all_labels, all_predictions)
conf_matrix = confusion_matrix(all_labels, all_predictions)

print(f'Accuracy: {accuracy},Precision: {precision}, Recall: {recall}, F1-score: {f1}, AUC: {auc}')
print('Confusion Matrix:')
print(conf_matrix)


Fold 1
Epoch 1/100, Loss: 1.749316434065501, Validation Loss: 0.7688215970993042, Validation Accuracy: 0.4222222222222222
Epoch 2/100, Loss: 1.0399823586146038, Validation Loss: 0.6932893991470337, Validation Accuracy: 0.4444444444444444
Epoch 3/100, Loss: 0.7709147532780966, Validation Loss: 0.7402219772338867, Validation Accuracy: 0.4222222222222222
Epoch 4/100, Loss: 0.6240840554237366, Validation Loss: 0.6718201041221619, Validation Accuracy: 0.4888888888888889
Epoch 5/100, Loss: 0.5898752411206564, Validation Loss: 0.7152655124664307, Validation Accuracy: 0.4222222222222222
Epoch 6/100, Loss: 0.48006927967071533, Validation Loss: 0.7308868169784546, Validation Accuracy: 0.43333333333333335
Epoch 7/100, Loss: 0.45040900508562726, Validation Loss: 0.6785053610801697, Validation Accuracy: 0.5111111111111111
Epoch 8/100, Loss: 0.452277531226476, Validation Loss: 0.6709713339805603, Validation Accuracy: 0.5222222222222223
Epoch 9/100, Loss: 0.4125521381696065, Validation Loss: 0.670800