In [36]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, accuracy_score
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F 

# Constants
EEG_DATA_DIR = r'C:\Users\User\Documents\Lie detect data\6M_AugmentedEEGData'

# Define EEGNet model (same as in your training script)
class EEGNet(nn.Module):
    def __init__(self, num_classes=2):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, (1, 63), padding='same')
        self.batchnorm1 = nn.BatchNorm2d(16)
        self.depthwiseConv2d = nn.Conv2d(16, 32, (65, 1), groups=16, padding='same')
        self.batchnorm2 = nn.BatchNorm2d(32)
        self.activation = nn.ELU()
        self.pooling = nn.AvgPool2d((1, 4))
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(32 * 65 * 62, num_classes)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.depthwiseConv2d(x)
        x = self.batchnorm2(x)
        x = self.activation(x)
        x = self.pooling(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Function to load and label data (same as in your training script)
def load_data(data_dir):
    X = []
    y = []
    
    for file_name in os.listdir(data_dir):
        if file_name.endswith('.pkl'):
            file_path = os.path.join(data_dir, file_name)
            data = pd.read_pickle(file_path)
            label = 0 if 'lie' in file_name else 1
            X.extend(data)
            y.extend([label] * data.shape[0])
    
    X = np.array(X)
    y = np.array(y)
    return X, y

# Define dataset class
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# Main script
if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load the saved model
    model_path = r'C:\Users\User\Documents\Lie detect data\Model\model_fold_5.pth'
    model = EEGNet(num_classes=2).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval

    # Move the model to the appropriate device
    model.to(device)

    # Load and preprocess the training data to create the scaler
    X, y = load_data(EEG_DATA_DIR)

    X_scaler = StandardScaler()
    X = X_scaler.fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape)

    test_dataset = EEGDataset(X, y)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    all_labels = []
    all_predictions = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        print(f"batch shape: {X_batch.shape}")
        
        X_batch = X_batch.to(device)
        labels = y_batch.to(device)
        
        outputs = model(X_batch)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Convert lists to numpy arrays for metric calculations
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)

# Calculate metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)
f1 = f1_score(all_labels, all_predictions)
auc = roc_auc_score(all_labels, all_predictions)
conf_matrix = confusion_matrix(all_labels, all_predictions)


# Print the metrics
print(f"Accuracy: {accuracy:.4f}")
print(f'Precision: {precision}, Recall: {recall}, F1-score: {f1}, AUC: {auc}')
print('Confusion Matrix:')
print(conf_matrix)

    

batch shape: torch.Size([128, 65, 250])
batch shape: torch.Size([128, 65, 250])
batch shape: torch.Size([128, 65, 250])
batch shape: torch.Size([110, 65, 250])
Accuracy: 0.6579
Precision: 0.7031963470319634, Recall: 0.8876080691642652, F1-score: 0.7847133757961784, AUC: 0.5016271638338333
Confusion Matrix:
[[ 17 130]
 [ 39 308]]
