# Inference

> For you!

In [None]:
#| default_exp inference

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os, torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from pftsleep.train import PatchTFTSleepStage, PatchTFTSimpleLightning
from pftsleep.heads import RNNProbingHeadExperimental
from pftsleep.inference_tools import process_edf, CLASSIFIER_HEAD_DEFAULTS, FREQUENCY_DEFAULT, SEQUENCE_LENGTH_SECONDS_DEFAULT, FREQUENCY_FILTERS_ORDERED_DEFAULT, MEDIAN_FILTER_KERNEL_SIZE_DEFAULT

In [None]:
#| export
def load_pftsleep_models(models_dir='', # The directory of the saved models
                         encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                         classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                         classifier_head_defaults=CLASSIFIER_HEAD_DEFAULTS, # the defaults for the classifier head, DO NOT CHANGE!
                        ):
    """
    Loads the pftsleep models from the models directory

    Args:
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model
        classifier_head_defaults (dict): The defaults for the classifier head, DO NOT CHANGE!
    Returns:
        encoder (PatchTFTSimpleLightning): The encoder model
        ss_classifier (PatchTFTSleepStage): The classifier model
    """
    encoder = PatchTFTSimpleLightning.load_from_checkpoint(os.path.join(models_dir, encoder_model_name), map_location='cpu')
    lp_model = RNNProbingHeadExperimental(**classifier_head_defaults)
    ss_classifier = PatchTFTSleepStage.load_from_checkpoint(os.path.join(models_dir, classifier_model_name),
                                                            preloaded_model=encoder,
                                                            map_location='cpu',
                                                            linear_probing_head = lp_model
                                                        )
    return encoder, ss_classifier

In [None]:
#| export
def infer_on_edf(edf_file_path, # The edf file path to perform inference on
                eeg_channel, # the EEG channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. However, 
                left_eog_channel, # the left EOG channel name in the EDF. The model was trained with M2 referenced left EOG channels.
                chin_emg_channel, # the chin EMG channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels.
                ecg_channel, # the ECG channel name in the EDF. The model was trained with augmented lead 2 ecg channels
                spo2_channel, # the SpO2 channel name in the EDF.
                abdomen_rr_channel, # the abdomen RR channel name in the EDF. 
                thoracic_rr_channel, # the thoracic RR channel name in the EDF.
                eeg_reference_channel=None, # the EEG reference channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. This will reference the channels, if they havent already been referenced. 
                left_eog_reference_channel=None, # the left EOG reference channel name in the EDF. The model was trained with M2 referenced left EOG channels. This will reference the channels, if they havent already been referenced. 
                chin_emg_reference_channel=None, # the chin EMG reference channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels. This will reference the channels, if they havent already been referenced. 
                ecg_reference_channel=None, # the ECG reference channel name in the EDF. The model was trained with augmented lead 2 ecg channels. This will reference the channels, if they havent already been referenced. 
                models_dir='', # the directory of the saved models
                encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                device="cpu", # the device to run the model on
                **kwargs
                ):
    """
    Performs inference on a single edf file using the pftsleep models. 
    If you specify a channel as None or 'dummy', the channel will be passed through as a zero vector. This allows you to use the model even if some channels are not present in the edf file.

    Args:
        edf_file_path (str): The path to the edf file to perform inference on
        eeg_channel (str): The name of the EEG channel in the EDF
        left_eog_channel (str): The name of the left EOG channel in the EDF
        chin_emg_channel (str): The name of the chin EMG channel in the EDF
        ecg_channel (str): The name of the ECG channel in the EDF
        spo2_channel (str): The name of the SpO2 channel in the EDF
        abdomen_rr_channel (str): The name of the abdomen RR channel in the EDF
        thoracic_rr_channel (str): The name of the thoracic RR channel in the EDF
        eeg_reference_channel (str): The name of the EEG reference channel in the EDF
        left_eog_reference_channel (str): The name of the left EOG reference channel in the EDF
        chin_emg_reference_channel (str): The name of the chin EMG reference channel in the EDF
        ecg_reference_channel (str): The name of the ECG reference channel in the EDF
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model
        device (str): The device to run the model on
        **kwargs: Additional keyword arguments for process_edf function

    Returns:
        out (torch.Tensor): The sleep stage logit outputs of the classifier for each sleep epoch in the edf file
    """
    if not os.path.exists(os.path.join(models_dir, encoder_model_name)):
        raise ValueError(f"Encoder model not found in {models_dir}")
    if not os.path.exists(os.path.join(models_dir, classifier_model_name)):
        raise ValueError(f"Classifier model not found in {models_dir}")
    if not os.path.exists(edf_file_path):
        print(os.path.exists(edf_file_path))
        print(os.getcwd())
        raise ValueError(f"EDF file not found in {edf_file_path}")
    try:
        _, ss_classifier = load_pftsleep_models(models_dir, encoder_model_name, classifier_model_name)
    except Exception as e:
        raise ValueError(f"Trouble loading models: {e}")
    
    try:
        channels = [ecg_channel, left_eog_channel, chin_emg_channel, eeg_channel, spo2_channel, thoracic_rr_channel, abdomen_rr_channel]
        channels = [c if c is not None and c != 'dummy' else 'dummy' for c in channels]
        signals, sequence_padding_mask = process_edf(edf_file_path,
                                                    channels=channels,
                                                    reference_channels_dict={ecg_channel: ecg_reference_channel,
                                                                            left_eog_channel: left_eog_reference_channel,
                                                                            chin_emg_channel: chin_emg_reference_channel,
                                                                            eeg_channel: eeg_reference_channel},
                                                    **kwargs
                                                    )
    except Exception as e:
        raise ValueError(f"Trouble processing edf file: {e}")
    
    try:
        # fine-tuned sleep stage classifier, recommend using a GPU
        ss_classifier = ss_classifier.to(device)
        ss_classifier.eval()

        with torch.no_grad():
            signals = signals.unsqueeze(0).to(device)
            sequence_padding_mask = sequence_padding_mask.unsqueeze(0).to(device)
            out = ss_classifier(signals, sequence_padding_mask=sequence_padding_mask)
        return out.squeeze(0)
    except Exception as e:
        raise ValueError(f"Trouble inferring on edf file: {e}")

In [None]:
# output = infer_on_edf(edf_file_path='/Users/benfox/Downloads/CROWD_133_RR_170927.edf', 
#              eeg_channel='C4-M1', 
#              left_eog_channel='E1-M2', 
#              chin_emg_channel='EMG1-EMG2', 
#              ecg_channel='ECG1-ECG2', 
#              spo2_channel='SPO2', 
#              abdomen_rr_channel='ABDO', 
#              thoracic_rr_channel='dummy',
#              device='mps'
#              )

EDF file: ~/Downloads/CROWD_133_RR_170927.edf
Duration: 26393
Number of channels: 31
Available channels: ['F3-M2', 'F4-M1', 'C3-M2', 'C4-M1', 'O1-M2', 'O2-M1', 'ECG1-ECG2', 'EMG1-EMG2', 'LLEG1 - LLEG2', 'RLEG1 - RLEG2', 'E1-M2', 'E2-M2', 'SNORE1', 'AIRFLOW', 'ABDO', 'THERM', 'SNORE', 'THOR EFFORT', 'ABDO EFFORT', 'POSITION', 'OX STATUS', 'PULSE', 'SPO2', 'NASAL PRESSURE', 'CPAP FLOW', 'PRO-TECH', 'CPAP PRESS', 'PRO-TECH DC', 'PLETH', 'HRATE', 'LIGHT']
Signal headers: [{'label': 'SpO2', 'dimension': '%', 'sample_rate': 125, 'sample_frequency': 125, 'physical_max': 100.0, 'physical_min': 0.0, 'digital_max': 32767, 'digital_min': -32768, 'prefilter': '', 'transducer': ''}, {'label': 'E1-M2', 'dimension': 'uV', 'sample_rate': 125, 'sample_frequency': 125, 'physical_max': 500.0, 'physical_min': -500.0, 'digital_max': 32767, 'digital_min': -32768, 'prefilter': '', 'transducer': 'EEG E1'}, {'label': 'EMG1-EMG2', 'dimension': 'mV', 'sample_rate': 125, 'sample_frequency': 125, 'physical_max': 0

In [None]:
#| export
def infer_on_edf_dataset(edf_dataset, # the edf dataset to perform inference on
                        device='cpu', # the device to run the model on
                        num_workers=1, # the number of workers to use for inference
                        batch_size=1, # the batch size to use for inference
                        models_dir='', # the directory of the saved models
                        encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                        classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                        ):
    """
    Performs inference on an EDFDataset.

    Args:
        edf_dataset (EDFDataset): The dataset to perform inference on
        device (str): The device to run the model on
        batch_size (int): The batch size to use for inference
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model

    Returns:
        preds (list): The predicted sleep stage logits for each edf file
    """
    _, ss_classifier = load_pftsleep_models(models_dir, encoder_model_name, classifier_model_name)
    data_loader = DataLoader(edf_dataset, batch_size=batch_size, pin_memory=False, persistent_workers=False, num_workers=num_workers)
    preds = []
    ss_classifier = ss_classifier.to(device)
    ss_classifier.eval()
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Predicting"):
            x, sequence_padding_mask = batch
            x = x.to(device)
            sequence_padding_mask = sequence_padding_mask.to(device)
            pred = ss_classifier(x, sequence_padding_mask=sequence_padding_mask)
            preds.append(pred.cpu())
    return preds

In [None]:
# dataset = EDFDataset(edf_file_paths=['/Users/benfox/Downloads/CROWD_133_RR_170927.edf'], 
#              eeg_channel='C4-M1', 
#              left_eog_channel='E1-M2', 
#              chin_emg_channel='EMG1-EMG2', 
#              ecg_channel='ECG1-ECG2', 
#              spo2_channel='SPO2', 
#              abdomen_rr_channel='ABDO', 
#              thoracic_rr_channel='dummy')

In [None]:
# data_loader = DataLoader(dataset, batch_size=6, pin_memory=False, persistent_workers=False, num_workers=8)

In [None]:
# next(iter(data_loader))

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'EDFDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'EDFDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/p

RuntimeError: DataLoader worker (pid(s) 74931) exited unexpectedly

In [None]:
# preds = infer_on_edf_dataset(dataset, device='mps')

Predicting:   0%|          | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/miniconda3/envs/pftsleep/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'EDFDataset' on <module '__main__' (built-in)>
Predicting:   0%|          | 0/1 [00:01<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 73888) exited unexpectedly

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()