# Introduction

Sleep is arguably one of the most mysterious yet essential aspects of our life. Sleep deprivation can lead to problems such as impaired cognitive function. Unfortunately, sleep disorders affect many people, disrupting both the quantity and quality of sleep. One method for identifying sleep disorders is by **analyzing sleep patterns**.

Currently, sleep stage annotation is typically performed by trained experts, a process that can take several hours to annotate a single sleep study session. This creates an opportunity for automation, where machine learning models could significantly reduce the time and effort required for this task.

However, one major challenge with using machine learning models in medical environments is their "black-box" nature, meaning their decision-making processes are often **not easily interpretable**. Recent research has aimed to address this issue by developing **interpretable models**, such as in this [study](https://arxiv.org/pdf/2105.11043.pdf), which uses transformers to enhance both performance and interpretability in sleep stage classification.

Building on this progress, the objective of this project is to employ a transformer-based model as the backbone and further enhance it by:
- exploring a **lightweight transformer** architecture ([Linformer](https://arxiv.org/abs/2006.04768)) for faster training and inference.
- including **personalization** using subject-specific information, such as gender and age, as sleep patterns are known to vary across different demographics, as mentioned [here](https://www.sleepfoundation.org/stages-of-sleep).

## Dataset

## Definitions

Since this project contains a lot of domain-specific terminology, this section is dedicated to listing all the terms in the sleep-staging and transformer modeling domains.

### Sleep-related terminologies

- **Polysomnography (PSG)**: a comprehensive diagnostic tool used in sleep studies to record various physiological functions during sleep.
- **PSG epoch**: a 30-second frame PSG recording.
- **PSG sequence**: several PSG epochs that last one recording session.
- **Electroencephalography (EEG)**: a measurement of brain wave activity, which helps to determine different sleep stages. This is one of the component of PSG.
- **Fpz-Cz EEG**: the *electrode placement* on the scalp for recording EEG signals, **Fpz** indicates to the midpoint of the forehead, while **Cz** is center of the scalp. See the picture below.

<img src="https://www.orimtec.com/images/illustration/pop/eeg/electrode_placement_diagram.jpg" alt="EEG 10.20 placement" width="200" height="200">

In general, sleeping stages are divided into eight categories:


1. **Wake (W)**: the stage when a person is awake. It is characterized by *low-voltage, mixed-frequency brain waves* on the EEG.
2. **Stage 1 Non-REM Sleep (N1)**: the *lightest stage of non-REM sleep*. It marks the transition from wakefulness to sleep.
3. **Stage 2 Non-REM Sleep (N2)**: a deeper stage of sleep characterized by *sleep spindles* (sudden bursts of brain activity) and *K-complexes* (large waves that occur in response to external stimuli).
4. **Stage 3 Non-REM Sleep (N3)**: the deepest stage of non-REM sleep.
5. **Stage 4 Non-REM Sleep (N4)**: historically, N3 and N4 were distinguished as separate stages, but they are now both considered as one stage.
6. **Rapid Eye Movement (REM)**: a unique sleep stage characterized by rapid eye movements, paradoxical brain activity (similar to being awake), and muscle atonia (loss of muscle tone).
7. **Movement**: refers to periods of muscle activity or movement during sleep.
8. **Unknown**: when the sleep stage is not clear or cannot be classified due to poor signal quality, artifacts, or insufficient data.




# Exploratory Data Analysis

In [36]:
from google.cloud import storage

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import numpy as np
import io
import os

In [23]:
# Helper functions

def download_data_from_gcs(bucket, file_name):
    """
        A function to download data from GCS.
        The original data is of NumPy's .npz type.
    """

    blob = bucket.blob(file_name)
    npz_bytes = blob.download_as_bytes()
    data = np.load(io.BytesIO(npz_bytes), allow_pickle=True)

    return data

In [4]:
# Create a client GCS and get the specified bucket
client = storage.Client(project=PROJECT_ID)
bucket = client.get_bucket(BUCKET_NAME)

In [16]:
# List the data files.
# Note that each file contains one PSG recording for each of the 78 subjects,
# with each subject having a 2-day recording.

data_files = !gsutil ls gs://{BUCKET_NAME}
data_files = [os.path.basename(f) for f in data_files if f[-4:]]
print(f"Number of files: {len(data_files)}")

Number of files: 2


In [24]:
# Load the data

data = []
for f in data_files:
    data.append(download_data_from_gcs(bucket, f))

In [35]:
# Sneak peak into the structure of the data
sample = data[0]

print(f"EEG signal shape   : {sample['x'].shape}")
print(f"EEG label shape    : {sample['y'].shape}")
print(f"EEG channel labels : {sample['ch_label']}")
print(f"Sampling frequency : {sample['fs']} Hz")
print(f"Subject details    : {sample['header_raw'].item()['local_subject_id']}")

EEG signal shape   : (841, 3000, 1)
EEG label shape    : (841,)
EEG channel labels : EEG Fpz-Cz
Sampling frequency : 100.0 Hz
Subject details    : X F X Female_33yr


# References

- https://arxiv.org/pdf/2105.11043
- https://researchdata.ntu.edu.sg/dataverse/attnSleep
- https://www.youtube.com/watch?v=ISNdQcPhsts&t=5553s