# Sleep Staging Models

---

Links to notebooks in this repository:

[Quickstart Tutorial](./quickstart_tutorial.ipynb) | [Introduction](../../../../../Downloads/00_introduction.ipynb) | [Services](./01_services.ipynb) | [Sleep Staging](02_sleep_staging.ipynb) | [Ensembling Sleep Staging](./03_ensembling_sleep_staging.ipynb) | [Sleep Dynamics](./04_sleep_dynamics.ipynb) | [Luna Toolbox Integration](./05_luna_integration.ipynb)

---

In this notebook, we will present how to work with our tutorial EDF files (from the NSRR), as well as how to process your own uploaded data. We will also show how to practically run multiple models in a batch of EDF files you have moved in our shared `input` volume. We’ll use the same helper functions introduced in the previous [Services](./01_services.ipynb) notebook, always following the simple workflowfrom loading data, to harmonizing signal channels, to running predictions and visualizing the results.


> Helper function for interacting with the SLEEPYLAND services. By wrapping the HTTP POST logic in one function, we can easily send data/parameters to various endpoints of the `manager-api`, simplifying the code in the rest of the notebook.

In [None]:
import requests
import shutil
import os
import pandas as pd
import numpy as np
import plotly.figure_factory as ff
import plotly.subplots as sp
import plotly.graph_objects as go
import yasa

# Define the base URL for the manager-api
MANAGER_API_BASE_URL = "http://manager-api:8989"

def make_post_request(endpoint, data=None, params=None):
    """
    Helper function to make a POST request to the specified endpoint.

    Parameters:
        endpoint (str): The API endpoint to hit.
        data (dict): The form data to send in the request.
        params (dict): The URL parameters to send in the request.

    Returns:
        dict: The JSON response if the request is successful.
    """
    url = f"{MANAGER_API_BASE_URL}/{endpoint}"
    response = requests.post(url, data=data, params=params)
    if response.status_code == 200:
        print("Success:", response.json())
        return response.json()
    else:
        print(f"Failed with status code {response.status_code}")
        return None

## Sleep staging on NSRR learn dataset

---

### Data loading


Let's first relocate the tutorial `.edf` and corresponding `.xml` files from their original exposed path into the shared `../input/learn/` folder. This ensures that all necessary files for the upcoming analyses are consolidated in one place accessible to the pipeline.

In [None]:
# Define source and destination directories
source_dir = "../lunapi-notebooks/tutorial/edfs/"
destination_dir = "../input/learn/"

# Ensure destination directory exists
os.makedirs(destination_dir, exist_ok=True)

# Get a list of all .edf files in the source directory
edf_files = [f for f in os.listdir(source_dir) if f.endswith(".edf")]

# Copy each .edf file
for file in edf_files:
    shutil.copy(os.path.join(source_dir, file), destination_dir)

xml_files = [f for f in os.listdir(source_dir) if f.endswith(".xml")]

# Copy each .edf file
for file in xml_files:
    shutil.copy(os.path.join(source_dir, file), destination_dir)

print(f"Copied {len(edf_files)} EDF files successfully!")
print(f"Copied {len(xml_files)} XML files successfully!")

### Sleep staging predictions


In the block below we first use the get_channels function to determine which signal channels (EEG AND/OR EOG) are available for our `learn` dataset. We then pass those channel selections, along with the dataset name and a list of chosen models (e.g., `yasa` and `usleep`), to the `auto_evaluate_data` function. This automatically harmonizes the data (aligning and preparing signals) and runs sleep-stage predictions for all files in the learn dataset - producing ready-to-use results in the specified output folder (i.e., always retrievable from the `output` volume).

In [None]:
def get_channels(dataset):
    """
    Retrieve the available EEG, EOG, and EMG channels for the specified dataset.

    Parameters:
    dataset (str): Name of the dataset.

    Returns:
    dict: Dictionary containing available channels.
    """
    params = {'dataset': dataset}
    return make_post_request("get_channels", params=params)


def auto_evaluate_data(folder_root_name, output_folder_name, eeg_channels, eog_channels, emg_channels, dataset, models, resolution):
    """
    Perform both harmonization and evaluation using the specified models.

    Parameters:
    folder_root_name (str): Root folder containing the input data.
    output_folder_name (str): Folder where results will be saved.
    eeg_channels (list): List of EEG channels to use.
    eog_channels (list): List of EOG channels to use.
    emg_channels (list): List of EMG channels to use.
    dataset (str): Name of the dataset.
    models (list): List of models to apply for evaluation.
    resolution (str): Time interval (in seconds) for sleep stage predictions.

    Returns:
    response (dict): Response from the evaluation request.
    """
    data = {
        'folder_root_name': folder_root_name,
        'folder_name': output_folder_name,
        'eeg_channels': eeg_channels,
        'eog_channels': eog_channels,
        'emg_channels': emg_channels,
        'dataset': dataset,
        'models': models,
        'resolution': resolution
    }
    return make_post_request("auto_evaluate", data=data)


def plot_confusion_matrix(df_metrics):
    """
    Plot a confusion matrix using Plotly.

    Parameters:
    df_metrics (pd.DataFrame): DataFrame containing model evaluation
    """
    for _, row in df_metrics.iterrows():
        cm = np.array(row['Confusion Matrix'])
        order = [0, 1, 2, 3, 4]
        cm_reordered = cm[np.ix_(order, order)]
        labels = ["Wake", "N1", "N2", "N3", "REM"]

        text = [[f"{value:.2f}%" for value in row] for row in cm_reordered]

        fig = go.Figure(data=go.Heatmap(
            z=cm_reordered,
            x=labels,
            y=labels,
            colorscale='Blues',
            showscale=True,
            text=text,
            texttemplate="%{text}",
        ))

        fig.update_layout(
            title=f"{row['Model']} - {row['File']}",
            xaxis_title="Predicted Label",
            yaxis_title="True Label",
            xaxis=dict(
                side='top'
            ),
            yaxis=dict(
                autorange='reversed'
            )
        )

        fig.show()


def extract_metrics(data):
    """
    Extract evaluation metrics from the model results.

    Parameters:
    data (list): List of dictionaries containing model evaluation results.

    Returns:
    pd.DataFrame: DataFrame with separate columns for each F1 score per class.
    """
    results = []
    for model in data:
        for model_name, files in model.items():
            for file_data in files:
                file_name = file_data['file']
                metrics = file_data['metrics']

                f1_list = metrics.get('f1_score_per_class', [])
                
                f1_wake = f1_list[0] if len(f1_list) > 0 else None
                f1_n1   = f1_list[1] if len(f1_list) > 1 else None
                f1_n2   = f1_list[2] if len(f1_list) > 2 else None
                f1_n3   = f1_list[3] if len(f1_list) > 3 else None
                f1_rem  = f1_list[4] if len(f1_list) > 4 else None

                results.append({
                    'Model': model_name,
                    'File': file_name,
                    'Accuracy': metrics['accuracy'],
                    'F1 Score': metrics['f1_score'],
                    'F1_Wake': f1_wake,
                    'F1_N1': f1_n1,
                    'F1_N2': f1_n2,
                    'F1_N3': f1_n3,
                    'F1_REM': f1_rem,
                    'Cohen Kappa': metrics['cohen_kappa'],
                    'Recall': metrics['recall'],
                    'Precision': metrics['precision'],
                    'Confusion Matrix': metrics['cm']
                })
    return pd.DataFrame(results)


def visualize_results(data):
    """
    Display evaluation results in a styled DataFrame and
    plot confusion matrices.

    Parameters:
    data (dict): Model evaluation results containing
                 accuracy, precision, recall, etc.
    """
    df_metrics = extract_metrics(data)

    display(df_metrics.drop("Confusion Matrix", axis=1))
    
    # Finally, plot confusion matrices
    plot_confusion_matrix(df_metrics)


# Retrieve the available channels for the dataset
# Channels could include EEG and/or EOG derivations depending on the dataset.
dataset = 'learn'
response = get_channels(dataset)

# Extract EEG, EOG, and EMG channels from the response
eeg_channels = response["eeg_channels"]
eog_channels = response["eog_channels"]
emg_channels = ['']  # (not supported yet)

# Define the models to be used for evaluation
models = ['usleep', 'yasa']

# Define the epoch length (in seconds) for each sleep stage prediction - string formatted
sec_per_prediction = '30'

# Perform automatic evaluation using the selected models
response = auto_evaluate_data(dataset, 'output_learn', eeg_channels, eog_channels, emg_channels, dataset, models, sec_per_prediction)

# Visualize the results from the evaluation
visualize_results(response)

### Hypnograms and hypnodensity graphs

Exploit the `create_hypnogram_predict` function to generate simple hypnograms based on the predicted sleep stages from each model. After loading the prediction files from the `output_learn` directory in the `output` volume, we convert the model outputs to stage labels and plot them over time. This helps you quickly visualize how each model classifies sleep stages throughout the night—no ground truth required.

In [None]:
def create_hypnogram_evaluate(folder_name, models_selected):
    """
    Compare predicted sleep stages with ground-truth annotations,
    plotting both on the same timeline for easy visual evaluation.

    Parameters:
    folder_name (str): The folder containing the model outputs and ground truth data.
    models_selected (list): List of models selected for evaluation (e.g., 'usleep', 'yasa').

    Returns:
    None: The function generates and displays an interactive plot comparing predicted and true sleep stages over time.
    """
    for model in models_selected:
        # Paths to the predicted outputs and true labels
        majority_folder = os.path.join('..', 'output', folder_name, model, 'majority')
        true_folder = os.path.join('..', 'output', folder_name, model, 'TRUE_files')

        # Gather the .npy files for predictions and for the ground truth
        majority_files = sorted([file for file in os.listdir(majority_folder) if file.endswith('.npy')])
        true_files = sorted(os.listdir(true_folder))

        for i, (maj_file, true_file) in enumerate(zip(majority_files, true_files)):
            # Load model predictions (argmax selects the stage with highest probability)
            sleep_stages_majority = np.load(os.path.join(majority_folder, maj_file)).argmax(-1).astype(int)
            # Load true labels (already in numeric form)
            sleep_stages_true = np.load(os.path.join(true_folder, true_file)).astype(int).ravel()

            # Remove threshold limit on printed output (for debugging if needed)
            np.set_printoptions(threshold=np.inf)

            # Each epoch is 30 seconds, so create a corresponding time axis
            time = np.arange(len(sleep_stages_majority))

            # Map numeric labels to textual sleep stage names
            sleep_stage_labels = ['Wake', 'NREM1', 'NREM2', 'NREM3', 'REM']
            sleep_stages_labels_majority = [sleep_stage_labels[stage] for stage in sleep_stages_majority]
            sleep_stages_labels_true = [sleep_stage_labels[stage] for stage in sleep_stages_true]

            # Assign colors to each stage index for a visually clear plot
            colors = {
                0: '#58e306',  # Wake
                1: '#2cf7f0',  # NREM1
                2: '#1173ef',  # NREM2
                3: '#4b4d4d',  # NREM3
                4: '#ee0e0e'   # REM
            }

            # Combine both predicted and true labels in a single figure
            fig_combined = go.Figure()

            # Plot predicted labels over time
            fig_combined.add_trace(go.Scatter(
                x=time,
                y=sleep_stages_labels_majority,
                mode='lines+markers',
                line=dict(color='#bdc2c3', width=2, shape='hv'),
                marker=dict(size=5, color=[colors[stage] for stage in sleep_stages_majority]),
                name='Pred'
            ))

            # Plot true labels over time
            fig_combined.add_trace(go.Scatter(
                x=time,
                y=sleep_stages_labels_true,
                mode='lines+markers',
                line=dict(color='#1f77b4', width=2, shape='hv'),
                marker=dict(size=5, color=[colors[stage] for stage in sleep_stages_true]),
                name='True'
            ))

            # Configure axes and title for clarity
            fig_combined.update_layout(
                title=f"{maj_file.split('.')[0].split('_')[0]} (Pred vs True) - {model}",
                xaxis=dict(title='Sleep Epoch'),
                yaxis=dict(
                    title='Sleep Stage',
                    categoryorder='array',
                    categoryarray=['NREM3', 'NREM2', 'NREM1', 'REM', 'Wake']
                ),
                yaxis_range=[-0.5, 4.5]
            )

            # Display the interactive chart
            fig_combined.show()

# Call the function to compare predictions with ground truth for both 'usleep' and 'yasa'
create_hypnogram_evaluate("output_learn", ["usleep", "yasa"])

In [None]:
def create_hypnodensity_graph(folder_name, models_selected):
    """
    Generate a hypnodensity-style graph showing cumulative probability distributions
    across all sleep stages for each epoch.

    Parameters:
    folder_name (str): Name of the folder containing the output data.
    models_selected (list): List of model names to visualize.

    Returns:
    None: The function generates and displays an interactive hypnodensity plot for each model.
    """
    for model in models_selected:
        # Path where prediction data (.npy files) is saved
        majority_folder = f'/app/output/{folder_name}/{model}/majority'

        # Collect all .npy files for the model from the majority folder
        majority_files = sorted([file for file in os.listdir(majority_folder) if file.endswith('.npy')])

        for i, maj_file in enumerate(majority_files):
            # Load the prediction probabilities (or logits) for each epoch
            sleep_probabilities_majority = np.load(os.path.join(majority_folder, maj_file))

            # Compute cumulative probabilities over the stages to create "stacked" areas
            cumulative_probs = np.cumsum(sleep_probabilities_majority, axis=1)

            colors = ['#364B9A', '#83B8D7', '#EAECCC', '#F99858', '#A50026']

            # Define names for each stage
            stage_names = ['Wake', 'N1', 'N2', 'N3', 'REM']

            fig = go.Figure()

            for j in range(cumulative_probs.shape[1]):
                fig.add_trace(go.Scatter(
                    x=np.arange(0, len(cumulative_probs)),
                    y=cumulative_probs[:, j],
                    mode='lines',
                    name=stage_names[j],
                    line=dict(width=0, color=colors[j]),  # Set line color
                    fill='tonexty',
                    fillcolor=f'rgba{tuple(int(colors[j][i:i + 2], 16) for i in (1, 3, 5)) + (0.5,)}',
                    # Convert HEX to RGBA
                    hoverinfo='none'
                ))

            # Configure the layout with titles and axis labels
            fig.update_layout(
                title=f"{maj_file.split('.')[0].split('_')[0]} - {model}",
                xaxis_title='Sleep Epoch',
                yaxis_title='Cumulative Probability',
                yaxis_range=[0, 1],
                showlegend=True
            )

            # Display the plot
            fig.show()

# Example call to create hypnodensity graphs for the specified models
create_hypnodensity_graph("output_learn", ["usleep", "yasa"])


### Sleep parameters computation

Exploit the `yasa.sleep_statistics` function to quickly compute standard AASM metrics from both predicted and true hypnograms.

Below are the key parameters it returns (with all durations in **minutes**, except for the stage percentages and efficiencies):

- Time in Bed (TIB): Total duration of the hypnogram.  
- Sleep Period Time (SPT): Duration from the first to the last period of sleep.  
- Wake After Sleep Onset (WASO): Total wake time within SPT.  
- Total Sleep Time (TST): Now calculated as the sum of all N1 + N2 + N3 + REM in SPT.  
- Sleep Efficiency (SE): TST / TIB * 100 (%).
- Sleep Maintenance Efficiency (SME): TST / SPT * 100 (%).
- W, N1, N2, N3, REM: Duration of each stage (NREM = N1 + N2 + N3).  
- Percentages % (W,..., REM): Duration of each stage expressed in % of TST.
- Latencies (e.g., `Lat_REM`, `Lat_N1`): Time from the beginning of the recording to the first epoch of each stage.  
- Sleep Onset Latency (SOL): Latency to the first epoch of any sleep.  

Note that YASA’s REM latency is measured from the start of the record, whereas the AASM definition measures it from the first epoch of sleep. To convert YASA’s `Lat_REM` to the AASM definition, compute `Lat_REM - SOL`.

Such a numeric summary provides an instant clinical overview of each model’s performance relative to the ground truth.

In [None]:
def plot_sleep_stats(stats_pred, stats_true, file_name, model_name):
    """
    Given two dictionaries of sleep statistics from YASA (predicted vs. true),
    create a grouped bar chart comparing key time-based and percentage metrics.

    Parameters:
    stats_pred (dict): Dictionary of metrics for the predicted hypnogram.
    stats_true (dict): Dictionary of metrics for the ground-truth hypnogram.
    file_name (str): Name of the file being analyzed.
    model_name (str): Name of the model.

    Returns:
    None: The function generates and displays a grouped bar chart comparing predicted and true sleep statistics.
    """
    # Define two groups of metrics:
    #  - time_metrics are in minutes
    #  - percent_metrics are in percentages
    time_metrics = ["TIB", "SPT", "WASO", "TST", "N1", "N2", "N3", "REM", "NREM", "SOL"] 
    percent_metrics = ["%N1", "%N2", "%N3", "%REM", "SE", "SME"] 
    
    # Prepare values for each group, using .get() to safely handle missing keys.
    pred_time_values = [stats_pred.get(m, np.nan) for m in time_metrics]
    true_time_values = [stats_true.get(m, np.nan) for m in time_metrics]

    pred_percent_values = [stats_pred.get(m, np.nan) for m in percent_metrics]
    true_percent_values = [stats_true.get(m, np.nan) for m in percent_metrics]

    # Create a 1-row, 2-column subplot layout
    fig = sp.make_subplots(
        rows=1, cols=2,
        subplot_titles=["Time Metrics (minutes)", "Percentage Metrics"],
        shared_yaxes=False
    )

    # --- Left subplot: Time Metrics ---
    fig.add_trace(
        go.Bar(name='Predicted', x=time_metrics, y=pred_time_values, marker_color='steelblue'),
        row=1, col=1
    )
    fig.add_trace(
        go.Bar(name='True', x=time_metrics, y=true_time_values, marker_color='darkorange'),
        row=1, col=1
    )

    # --- Right subplot: Percentage Metrics ---
    fig.add_trace(
        go.Bar(name='Predicted', x=percent_metrics, y=pred_percent_values, marker_color='steelblue'),
        row=1, col=2
    )
    fig.add_trace(
        go.Bar(name='True', x=percent_metrics, y=true_percent_values, marker_color='darkorange'),
        row=1, col=2
    )

    # Update the layout for a nice grouped bar appearance
    fig.update_layout(
        title=f"Sleep Stats Comparison: {file_name.split("_")[0]} - {model_name}",
        barmode='group',
        width=950, height=400,
    )

    fig.show()

def compute_sleep_stats(folder_name, models_selected):
    """
    For each model in `models_selected` and each file in `folder_name`,
    load the predicted hypnogram and ground-truth labels, then compute
    YASA sleep statistics and visualize them in side-by-side bar charts.

    Parameters:
    folder_name (str): Name of the folder containing the output data.
    models_selected (list): List of models to compute and visualize sleep statistics for.

    Returns:
    None: The function computes YASA sleep statistics and generates plots for each model and file in the specified folder.
    """
    for model in models_selected:
        # Paths to the predicted outputs and true labels
        majority_folder = os.path.join('..', 'output', folder_name, model, 'majority')
        true_folder = os.path.join('..', 'output', folder_name, model, 'TRUE_files')

        # Gather .npy files for predictions and ground truth
        majority_files = sorted([file for file in os.listdir(majority_folder) if file.endswith('.npy')])
        true_files = sorted(os.listdir(true_folder))

        for maj_file, true_file in zip(majority_files, true_files):
            # Load model predictions (argmax if shape=[n_epochs, n_classes])
            sleep_stages_majority = np.load(os.path.join(majority_folder, maj_file)).argmax(-1).astype(int)
            # Load true labels (already 1D, numeric)
            sleep_stages_true = np.load(os.path.join(true_folder, true_file)).astype(int).ravel()

            # Compute YASA stats for predicted & true
            sf_hyp = 1 / 30.0  # each epoch = 30 seconds
            stats_pred = yasa.sleep_statistics(sleep_stages_majority, sf_hyp)
            stats_true = yasa.sleep_statistics(sleep_stages_true, sf_hyp)

            # Plot them side-by-side for easier comparison
            plot_sleep_stats(stats_pred, stats_true, maj_file, model)

            # Optionally, you can still print them if you like:
            print("-" * 60)
            print(f"File: {maj_file} | Model: {model}")
            print("Predicted Hypnogram Stats:\n", stats_pred)
            print("True Hypnogram Stats:\n", stats_true)
            print("-" * 60)

compute_sleep_stats("output_learn", ["usleep", "yasa"])

## Sleep staging on your own EDF

---

In the second part of the notebook, we’ll show how to run SLEEPYLAND’s staging pipeline on your own uploaded `edf` files in few steps - no annotations needed. By following a similar procedure to the tutorial dataset, you can automatically harmonize your recordings, select relevant channels-type (EEG AND/OR EOG), and generate predictions using your model(s) of choice.

> NOTE: The system takes in input the channel type the user specify, then it automatically infer and extract all the recognised, e.g., EEG type, channels, forwarding them to the pre-trained models. Thus, the predictions in output are the result of all the combination of EEG AND/OR EOG channels the system recognized from the `edf` file. We suggest to use the majority vote predictions the system give in output.

In [None]:
# Let's first remove/clean all files and subdirectories inside the input volume
# Define the directory path where files and folders need to be removed
directory = "/app/input"

# Iterate through all items in the directory
for item in os.listdir(directory):
    item_path = os.path.join(directory, item)
    if os.path.isfile(item_path):
        os.remove(item_path)
    elif os.path.isdir(item_path):
        shutil.rmtree(item_path)

### Data loading

Below is an example of how to use the `predict_on_my_edf` function with a single EDF file, `learn-nsrr01.edf`. First, we move the file to the shared `input` volume, ensuring that a dataset folder, in that case named `learn` exists (create the folder if necessary). Users should follow the same approach: first, choose/create a preferred root folder name located in the shared input volume, then move all the EDF files they wish to analyze into that folder before running predictions.


In [None]:
# Define source and destination directories
source_dir = "../lunapi-notebooks/tutorial/edfs/"
destination_dir = "../input/myedf/"

# Ensure destination directory exists
os.makedirs(destination_dir, exist_ok=True)

# Get a list of all .edf files in the source directory
edf_files = [f for f in os.listdir(source_dir) if f.endswith(".edf")]

edf_files.sort()

# Copy one .edf file
file_to_copy = edf_files[0]
shutil.copy(os.path.join(source_dir, file_to_copy), destination_dir)

print(f"Copied EDF file successfully!")

### Sleep staging predictions

> **NOTE** - The exposed endpoint `predict_one` takes as input just **one** EDF file at a time.
> Below, we show how to run the prediction on a single EDF file.


In [None]:
# Function to send prediction request for an EDF file
def predict_on_my_edf(folder_root_name, output_folder_name, channels_type, models, resolution):
    """
    Sends a request to perform prediction on an EDF file.

    Parameters:
    folder_root_name (str): Root directory containing the EDF file.
    output_folder_name (str): Directory where the prediction results will be saved.
    channels_type (list): List of channel types (e.g., EEG, EOG).
    models (list): List of models to use for prediction.
    resolution (str): Time interval (in seconds) for sleep stage predictions

    """
    data = {
        'folder_root_name': folder_root_name,
        'folder_name': output_folder_name,
        'channels': channels_type,
        'models': models,
        'resolution': resolution
    }
    make_post_request("predict_one", data=data)

In [None]:
# Define the models to use for prediction
models = ['usleep']

# Define the channel types to use for prediction
channels_type = ['EEG', 'EOG']

# Define the dataset name
dataset = 'learn'

# Define the epoch length (in seconds) for each sleep stage prediction - string formatted
sec_per_prediction = '5'

# Use the predict_on_my_edf function to perform prediction on the specified EDF file
predict_on_my_edf(dataset, 'output_my_edf', channels_type, models, sec_per_prediction)

### Hypnograms and hypnodensity graphs

In [None]:
def create_hypnogram_predict(folder_name, models_selected):
    """
    Plot a single hypnogram (predicted stages) for one or more models,
    given a folder of .npy prediction files.

    Parameters:
    folder_name (str): The name of the folder containing the model outputs
                       (e.g. 'output_my_edf').
    models_selected (list): List of models (e.g. ['usleep']).
    epoch_sec (float): Duration (in seconds) of each epoch (default=30).

    Returns:
    None: This function generates and displays a plot for the predicted hypnogram of each model.
    """
    for model in models_selected:
        # Path to the predicted outputs (majority folder)
        majority_folder = os.path.join('..', 'output', folder_name, model, 'majority')
        
        # Gather all .npy files for predictions
        majority_files = sorted([file for file in os.listdir(majority_folder)
                                 if file.endswith('.npy')])

        for maj_file in majority_files:
            # Load model predictions
            # If shape=[n_epochs, n_classes], take argmax to convert to integer-coded stages
            predictions = np.load(os.path.join(majority_folder, maj_file))
            if len(predictions.shape) == 2 and predictions.shape[1] > 1:
                # Argmax across last dimension if the file contains probabilities/logits
                sleep_stages_pred = predictions.argmax(axis=-1).astype(int)
            else:
                # Already integer-coded or single-class
                sleep_stages_pred = predictions.astype(int)

            # Create the corresponding time axis
            time = np.arange(len(sleep_stages_pred))

            # Map numeric labels to textual stage names (0=Wake,1=N1,2=N2,3=N3,4=REM)
            stage_names = ['Wake', 'NREM1', 'NREM2', 'NREM3', 'REM']
            sleep_stages_labels = [stage_names[st] for st in sleep_stages_pred]

            # Define colors for each stage index
            colors = {
                0: '#58e306',  # Wake
                1: '#2cf7f0',  # NREM1
                2: '#1173ef',  # NREM2
                3: '#4b4d4d',  # NREM3
                4: '#ee0e0e'   # REM
            }

            # Create a single Plotly figure (predictions only)
            fig = go.Figure()

            fig.add_trace(go.Scatter(
                x=time,
                y=sleep_stages_labels,
                mode='lines+markers',
                line=dict(color='#bdc2c3', width=2, shape='hv'),
                marker=dict(size=5, color=[colors[s] for s in sleep_stages_pred]),
                name='Pred'
            ))

            fig.update_layout(
                title=f"Predicted Hypnogram - {maj_file.split("_")[0]} - {model}",
                xaxis=dict(title='Sleep Epoch'),
                yaxis=dict(
                    title='Sleep Stage',
                    categoryorder='array',
                    categoryarray=['NREM3', 'NREM2', 'NREM1', 'REM', 'Wake']
                ),
                yaxis_range=[-0.5, 4.5]
            )
            fig.show()


create_hypnogram_predict("output_my_edf", ["usleep"])

In [None]:
def create_hypnodensity_predict(folder_name, models_selected):
    """
    Generate a hypnodensity-style graph showing cumulative probability
    distributions across all sleep stages for each epoch (predicted only).

    Parameters:
    folder_name (str): The folder containing the .npy output data
                       (e.g. 'output_my_edf').
    models_selected (list): A list of model names (e.g. ['usleep']).

    Returns:
    None: This function generates and displays a hypnodensity graph for each model's predicted stages.
    """
    for model in models_selected:
        majority_folder = os.path.join('..', 'output', folder_name, model, 'majority')
        majority_files = sorted([f for f in os.listdir(majority_folder) if f.endswith('.npy')])

        for maj_file in majority_files:
            # Load the predicted data (probabilities or logits) for each epoch
            preds = np.load(os.path.join(majority_folder, maj_file))

            # Make sure preds is shape [n_epochs, n_stages]
            if len(preds.shape) != 2:
                print(f"Skipping {maj_file}: not in [epochs, stages] format.")
                continue

            # Create cumulative probabilities across columns
            cumulative_probs = np.cumsum(preds, axis=1)

            # Colors for each stage
            colors = ['#364B9A', '#83B8D7', '#EAECCC', '#F99858', '#A50026']
            stage_names = ['Wake', 'N1', 'N2', 'N3', 'REM']

            fig = go.Figure()

            # Add stacked areas
            for j in range(cumulative_probs.shape[1]):
                fig.add_trace(go.Scatter(
                    x=np.arange(len(cumulative_probs)),
                    y=cumulative_probs[:, j],
                    mode='lines',
                    name=stage_names[j],
                    line=dict(width=0, color=colors[j]),
                    fill='tonexty' if j > 0 else 'none',
                    hoverinfo='none'
                ))

            fig.update_layout(
                title=f"Predicted Hypnodensity Graph: {maj_file.split("_")[0]} - {model}",
                xaxis_title='Sleep Epoch',
                yaxis_title='Cumulative Probability',
                yaxis_range=[0, 1],
                showlegend=True
            )

            fig.show()

create_hypnodensity_predict("output_my_edf", ["usleep"])

### Sleep parameters computation

In [None]:
def plot_sleep_stats_pred_only(stats_pred, file_name, model_name):
    """
    Create a grouped bar chart for key time-based and percentage metrics,
    but only for one predicted dataset (no ground truth available).

    Parameters:
    stats_pred (dict): YASA metrics for the predicted hypnogram
                        (e.g., from `yasa.sleep_statistics`).
    file_name (str): Name of the file (e.g., 'myrecord_PRED.npy').
    model_name (str): The model name (e.g., 'usleep').

    Returns:
    None: This function generates and displays a bar chart for the predicted sleep statistics.
    """
    # Metrics to display
    time_metrics = ["TIB", "SPT", "WASO", "TST", "N1", "N2", "N3", "REM", "NREM", "SOL"]
    percent_metrics = ["%N1", "%N2", "%N3", "%REM", "SE", "SME"]

    # Extract values, using np.nan for any missing key
    pred_time_values = [stats_pred.get(m, np.nan) for m in time_metrics]
    pred_percent_values = [stats_pred.get(m, np.nan) for m in percent_metrics]

    # Create a 1-row, 2-column layout
    fig = sp.make_subplots(
        rows=1, cols=2,
        subplot_titles=["Time Metrics (minutes)", "Percentage Metrics"]
    )

    # 1) Left subplot: Time Metrics
    fig.add_trace(
        go.Bar(
            name='Predicted',
            x=time_metrics,
            y=pred_time_values,
            marker_color='steelblue'
        ),
        row=1, col=1
    )

    # 2) Right subplot: Percentage Metrics
    fig.add_trace(
        go.Bar(
            name='Predicted',
            x=percent_metrics,
            y=pred_percent_values,
            marker_color='steelblue'
        ),
        row=1, col=2
    )

    # Remove .npy extension if present
    base_name, _ = os.path.splitext(file_name)

    fig.update_layout(
        title=f"Sleep Stats (No Ground Truth): {base_name} - {model_name}",
        barmode='group',
        width=950,
        height=400
    )
    fig.show()

def compute_predicted_sleep_stats(folder_name, models_selected, epoch_sec=30):
    """
    For each model in `models_selected` and each .npy file in `folder_name`,
    load the predicted hypnogram, compute YASA sleep statistics, and plot them.
    
    This version handles only predicted data (no true labels).
    
    Parameters:
    folder_name (str): Name of the output folder (e.g. 'output_my_edf').
    models_selected (list): List of model names (e.g. ['usleep']).
    epoch_sec (float): Duration of each epoch in seconds (default=30).

    Returns:
    None: This function loads predicted data, computes statistics, and plots them.
    """
    for model in models_selected:
        # Path to your predicted outputs (the 'majority' folder)
        majority_folder = os.path.join('..', 'output', folder_name, model, 'majority')
        
        # Collect all .npy files in that folder
        majority_files = sorted(f for f in os.listdir(majority_folder) if f.endswith('.npy'))
        
        for maj_file in majority_files:
            # Load the prediction array
            pred_array = np.load(os.path.join(majority_folder, maj_file))

            # If shape=[epochs, classes], argmax across last dim => integer-coded stages
            if pred_array.ndim == 2 and pred_array.shape[1] > 1:
                sleep_stages_pred = pred_array.argmax(axis=-1).astype(int)
            else:
                # Already integer-coded
                sleep_stages_pred = pred_array.astype(int)

            # Compute YASA stats
            sf_hyp = 1.0 / epoch_sec  # e.g., 1/30 = 0.0333 for 30-s epochs
            stats_pred = yasa.sleep_statistics(sleep_stages_pred, sf_hyp)

            # Plot the stats (predicted only)
            plot_sleep_stats_pred_only(stats_pred, maj_file, model)

            # Optionally, you can still print them in the console for quick inspection
            print("-" * 60)
            print(f"File: {maj_file} | Model: {model}")
            print("Predicted Hypnogram Stats:\n", stats_pred)
            print("-" * 60)

compute_predicted_sleep_stats("output_my_edf", ["usleep"], epoch_sec=30)