In [1]:
# Imports
import scipy.io as sio
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score, accuracy_score, confusion_matrix
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sleepdetector_new import ImprovedSleepdetector
from tqdm import tqdm
import seaborn as sns
from sklearn.model_selection import train_test_split, KFold
from scipy.signal import welch
from imblearn.over_sampling import SMOTE
from torch.optim.lr_scheduler import ReduceLROnPlateau
import optuna
import os
import logging
import json


# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Set random seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Data loading and preprocessing
def load_data(filepath, labels_file):
    try:
        mat_file = sio.loadmat(filepath)
        x = np.stack((mat_file['sig1'], mat_file['sig2'], mat_file['sig3'], mat_file['sig4']), axis=0)
        x = torch.from_numpy(x).float()
        y_true = sio.loadmat(labels_file)['labels'].flatten() - 1
        y = torch.from_numpy(y_true).long()
        return x.permute(1, 0, 2, 3).squeeze(-1), y
    except Exception as e:
        logging.error(f"Error loading data: {e}")
        raise

def extract_spectral_features(x):
    features = []
    for channel in x:
        f, psd = welch(channel.squeeze().numpy(), fs=100, nperseg=1000)
        delta = np.sum(psd[(f >= 0.5) & (f <= 4)])
        theta = np.sum(psd[(f > 4) & (f <= 8)])
        alpha = np.sum(psd[(f > 8) & (f <= 13)])
        beta = np.sum(psd[(f > 13) & (f <= 30)])
        features.extend([delta, theta, alpha, beta])
    return np.array(features)

def prepare_data(x, y, test_size=0.2):
    X_train, X_test, y_train, y_test = train_test_split(x.numpy(), y.numpy(), test_size=test_size, stratify=y, random_state=42)
    
    X_train_spectral = np.array([extract_spectral_features(torch.from_numpy(x)) for x in X_train])
    X_test_spectral = np.array([extract_spectral_features(torch.from_numpy(x)) for x in X_test])
    
    X_train_combined = np.concatenate([X_train.reshape(X_train.shape[0], -1), X_train_spectral], axis=1)
    X_test_combined = np.concatenate([X_test.reshape(X_test.shape[0], -1), X_test_spectral], axis=1)
    
    smote = SMOTE(random_state=42)
    X_train_resampled, y_train_resampled = smote.fit_resample(X_train_combined, y_train)
    
    original_shape = list(X_train.shape)
    original_shape[0] = X_train_resampled.shape[0]
    spectral_shape = (X_train_resampled.shape[0], X_train_spectral.shape[1])
    
    X_train_final = X_train_resampled[:, :-X_train_spectral.shape[1]].reshape(original_shape)
    X_train_spectral_final = X_train_resampled[:, -X_train_spectral.shape[1]:].reshape(spectral_shape)
    
    return (torch.from_numpy(X_train_final).float(),
            torch.from_numpy(X_train_spectral_final).float(),
            torch.from_numpy(y_train_resampled).long(),
            torch.from_numpy(X_test).float(),
            torch.from_numpy(X_test_spectral).float(),
            torch.from_numpy(y_test).long())

# Model definition
class EnsembleModel(nn.Module):
    def __init__(self, model_params, n_models=3):
        super().__init__()
        self.models = nn.ModuleList([ImprovedSleepdetector(**model_params) for _ in range(n_models)])
    
    def forward(self, x, spectral_features):
        outputs = [model(x.clone(), spectral_features.clone()) for model in self.models]
        return torch.mean(torch.stack(outputs), dim=0)

# Training and evaluation functions
def train_model(model, train_loader, val_data, optimizer, scheduler, criterion, device, epochs=100):
    best_accuracy = 0
    best_model_state = None
    
    for epoch in tqdm(range(epochs), desc="Training Progress"):
        model.train()
        for batch_x, batch_x_spectral, batch_y in train_loader:
            batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_x, batch_x_spectral)
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        accuracy = evaluate_model(model, val_data, device)
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict()
        
        scheduler.step(accuracy)
    
    return best_model_state, best_accuracy

def evaluate_model(model, data, device):
    model.eval()
    X, X_spectral, y = data
    with torch.no_grad():
        outputs = model(X.to(device), X_spectral.to(device))
        _, predicted = torch.max(outputs, 1)
        accuracy = accuracy_score(y.cpu().numpy(), predicted.cpu().numpy())
    return accuracy

# Hyperparameter optimization
def objective(trial, X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test, device):
    model_params = {
        'n_filters': trial.suggest_categorical('n_filters', [[32, 64, 128], [64, 128, 256]]),
        'lstm_hidden': trial.suggest_int('lstm_hidden', 64, 512),
        'lstm_layers': trial.suggest_int('lstm_layers', 1, 3),
        'dropout': trial.suggest_float('dropout', 0.1, 0.5)
    }
    
    train_params = {
        'lr': trial.suggest_float('lr', 1e-5, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [16, 32, 64, 128])
    }
    
    model = ImprovedSleepdetector(**model_params).to(device)
    optimizer = optim.Adam(model.parameters(), lr=train_params['lr'])
    train_loader = DataLoader(TensorDataset(X_train, X_train_spectral, y_train), batch_size=train_params['batch_size'], shuffle=True)
    
    _, accuracy = train_model(model, train_loader, (X_test, X_test_spectral, y_test), optimizer, ReduceLROnPlateau(optimizer), nn.CrossEntropyLoss(), device, epochs=10)
    
    return accuracy

# Cross-validation
def cross_validate(X, X_spectral, y, model_params, train_params, n_splits=5):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    scores = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        X_spectral_train_fold, X_spectral_val_fold = X_spectral[train_idx], X_spectral[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]

        model = ImprovedSleepdetector(**model_params).to(device)
        optimizer = optim.Adam(model.parameters(), lr=train_params['lr'])
        train_loader = DataLoader(TensorDataset(X_train_fold, X_spectral_train_fold, y_train_fold), batch_size=train_params['batch_size'], shuffle=True)
        
        _, accuracy = train_model(model, train_loader, (X_val_fold, X_spectral_val_fold, y_val_fold), optimizer, ReduceLROnPlateau(optimizer), nn.CrossEntropyLoss(), device, epochs=50)
        scores.append(accuracy)
        
        logging.info(f"Fold {fold + 1} Accuracy: {accuracy:.4f}")
    
    logging.info(f"Average Accuracy: {np.mean(scores):.4f} (+/- {np.std(scores):.4f})")
    return scores

# Confusion matrix plotting
def plot_confusion_matrix(y_true, y_pred, normalize=False, title=None, cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Class names in the correct order (0 to 4)
    class_names = ['N3', 'N2', 'N1', 'REM', 'Awake']

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd',
                cmap=cmap, square=True, xticklabels=class_names, yticklabels=class_names)
    
    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    ax.set_title(title)
    
    plt.tight_layout()
    return fig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
x, y = load_data('./data/data.mat', './data/labels.mat')
print(x.shape, y.shape)

torch.Size([1027, 4, 3000]) torch.Size([1027])


In [None]:
logging.info(f"Using device: {device}")


try:
    # Create a list of steps
    steps = [
        "Load and prepare data",
        "Hyperparameter optimization",
        "Cross-validation",
        "Train final model",
        "Save best model",
        "Final evaluation"
    ]
    
    # Create the progress bar
    pbar = tqdm(total=len(steps), desc="Overall Progress")

    # Load and prepare data
    x, y = load_data('../data/data.mat', '../data/labels.mat')
    X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test = prepare_data(x, y)
    logging.info("Data loaded and prepared successfully")
    pbar.update(1)

    # Save test data
    
    # Convert tensors to lists
    X_test_list = X_test.tolist()
    X_test_spectral_list = X_test_spectral.tolist()
    y_test_list = y_test.tolist()

    # Create a dictionary to store the data
    data = {
        'X_test': X_test_list,
        'X_test_spectral': X_test_spectral_list,
        'y_test': y_test_list
    }

    # Save to a JSON file
    with open('test_data.json', 'w') as f:
        json.dump(data, f)
    
    logging.info("Test data saved successfully")
    pbar.update(1)

    # Hyperparameter optimization
    study = optuna.create_study(direction='maximize')
    study.optimize(lambda trial: objective(trial, X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test, device), n_trials=100)
    
    best_params = study.best_params
    best_model_params = {k: v for k, v in best_params.items() if k in ['n_filters', 'lstm_hidden', 'lstm_layers', 'dropout']}
    best_train_params = {k: v for k, v in best_params.items() if k in ['lr', 'batch_size']}
    logging.info(f"Best hyperparameters: {best_params}")
    
    # Save the parameters
    params_to_save = {
        'best_params': best_params,
        'best_model_params': best_model_params,
        'best_train_params': best_train_params
    }

    with open('best_params_ensemble.json', 'w') as f:
        json.dump(params_to_save, f, indent=4)

    logging.info("Best parameters saved to 'best_params.json'")
    pbar.update(1)

    # Cross-validation
    cv_scores = cross_validate(X_train, X_train_spectral, y_train, best_model_params, best_train_params)
    pbar.update(1)

    # Train final model (ensemble)
    ensemble_model = EnsembleModel(best_model_params, n_models=3).to(device)
    train_loader = DataLoader(TensorDataset(X_train, X_train_spectral, y_train), batch_size=best_train_params['batch_size'], shuffle=True)
    optimizer = optim.Adam(ensemble_model.parameters(), lr=best_train_params['lr'], weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True)
    
    best_model_state, _ = train_model(ensemble_model, train_loader, (X_test, X_test_spectral, y_test), optimizer, scheduler, nn.CrossEntropyLoss(), device)
    pbar.update(1)

    # Save best model
    torch.save(best_model_state, "best_ensemble_model.pth")
    

    logging.info("Best ensemble model saved")
    pbar.update(1)

    pbar.close()

except Exception as e:
    logging.error(f"An error occurred: {e}")
    pbar.close()

In [None]:
# Final evaluation
ensemble_model.load_state_dict(best_model_state)
final_accuracy = evaluate_model(ensemble_model, (X_test, X_test_spectral, y_test), device)

ensemble_model.eval()
with torch.no_grad():
    outputs = ensemble_model(X_test.to(device), X_test_spectral.to(device))
    _, predicted = torch.max(outputs, 1)
    final_kappa = cohen_kappa_score(y_test.cpu().numpy(), predicted.cpu().numpy())

logging.info(f"Ensemble Model - Final Test Accuracy: {final_accuracy:.4f}")
logging.info(f"Ensemble Model - Final Cohen's Kappa: {final_kappa:.4f}")

In [None]:
# Plot and save normalized confusion matrix
fig_norm = plot_confusion_matrix(y_test.cpu().numpy(), predicted.cpu().numpy(), 
                                    normalize=True, 
                                    title='Normalized Confusion Matrix')
# fig_norm.savefig('confusion_matrix_normalized.png')
logging.info("Normalized confusion matrix plot saved as 'confusion_matrix_normalized.png'")

# Plot and save non-normalized confusion matrix
fig_non_norm = plot_confusion_matrix(y_test.cpu().numpy(), predicted.cpu().numpy(), 
                                        normalize=False, 
                                        title='Confusion Matrix, without normalization')
# fig_non_norm.savefig('confusion_matrix_non_normalized.png')
logging.info("Non-normalized confusion matrix plot saved as 'confusion_matrix_non_normalized.png'")

# plt.close('all')  # Close all figures to free up memory