In [None]:
!mkdir models
!mkdir plots
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from classifier import NappaSleepNet
from utils.dataset_classes import NappaDataset
from utils.dataset_preprocess import HybridScaler
from utils.plots import plot_confusion_matrix, plot_time_series, box_plot

from sklearn.metrics import accuracy_score, matthews_corrcoef, confusion_matrix

In [10]:
device = torch.device('cpu')

scaler = HybridScaler(method='global')

mapping = {
    'N3'  :0,
    'N2'  :0,
    'N1'  :1,
    'REM' :1,
    'Wake':2,        
    }

nappa_dataset = NappaDataset('nappa_dataset.pkl').labelsToNumeric(mapping).sortById()

sleep_classes = ['N2/N3', 'N1/REM', 'Wake']

NUM_FEATURES = nappa_dataset.features.shape[1]
NUM_CLASSES = len(sleep_classes)

model = NappaSleepNet(n_features=NUM_FEATURES, n_classes=NUM_CLASSES).to(device)

In [11]:
def model_predict(model, features):
    model.eval()
    with torch.no_grad():
        features = torch.tensor(features, dtype=torch.float).unsqueeze(0)

        output = model(features, rec_lengths=[features.shape[1]]).reshape(-1, NUM_CLASSES)
        predicted_probabilities = torch.softmax(output, dim=-1)
        predicted_classes = torch.argmax(predicted_probabilities, dim=-1) + 1 # Shift sleep class labels from [0, C-1] to [1, C]

        return predicted_probabilities, predicted_classes

In [None]:
test_metrics = np.zeros((len(nappa_dataset), 2))
class_accuracies = np.zeros((len(nappa_dataset), NUM_CLASSES))

all_preds, all_targets = [], []

for test_subject in nappa_dataset:
    # Load the model specifically trained excluding the current test subject
    model_path = f'models/model_subject_{test_subject.id}.pth'
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    print(f'model loaded from {model_path}')

     # Construct training dataset by excluding the current test subject
    train_dataset = NappaDataset([subject for subject in nappa_dataset if subject.id != test_subject.id])
    
    # Normalize test subject features using training dataset statistics
    features = scaler(test_subject.features, is_testset=True,
                      trainset_mean=train_dataset.features.mean(axis=0),
                      trainset_std=train_dataset.features.std(axis=0))

    targets = torch.tensor(test_subject.labels) + 1 # Shift sleep class labels from [0, C-1] to [1, C]

    predicted_probabilities, preds = model_predict(model, features)

    acc, mcc = accuracy_score(targets, preds), matthews_corrcoef(targets, preds)
    cm = confusion_matrix(targets, preds).T
    
    # Store metrics
    test_metrics[test_subject.id - 1] = [acc, mcc]
    class_accuracies[test_subject.id - 1] = np.diag(cm / cm.sum(axis=0))
    all_preds.append(preds)
    all_targets.append(targets)

    # Generate and save plots
    ts_fig = plot_time_series(targets, predicted_probabilities.numpy(), preds, sleep_classes)
    ts_fig.savefig(f'plots/time series subject {test_subject.id}.png', bbox_inches='tight')
    plt.show(ts_fig)
    plt.close(ts_fig)

    cm_fig = plot_confusion_matrix(cm, sleep_classes, title=f'MCC: {mcc:.2f}, Accuracy: {acc:.0%}')
    cm_fig.savefig(f'plots/confusion matrix subject {test_subject.id}.png')
    plt.show(cm_fig)
    plt.close(cm_fig)

# Aggregate predictions and labels
all_targets, all_preds = np.concatenate(all_targets), np.concatenate(all_preds)

In [None]:
# Compute aggregate performance metric scores
overall_acc, overall_mcc = accuracy_score(all_targets, all_preds), matthews_corrcoef(all_targets, all_preds)

# Create and save overall plots
cm_fig_all = plot_confusion_matrix(confusion_matrix(all_targets, all_preds).T,
                                    sleep_classes, title=f'Aggregate Confusion Matrix (n={nappa_dataset.labels.shape[0]} sleep epochs)\nMCC: {overall_mcc:.2f}, Accuracy: {overall_acc:.0%}')
plt.show(cm_fig_all)
cm_fig_all.savefig('plots/confusion_matrix_all.png')

bp = box_plot(test_metrics[:, 0], test_metrics[:, 1], title='Whole dataset (n=33 subjects)')
bp.savefig('plots/box_plot_all.png')
plt.show(bp)