# Inference

> For you!

In [None]:
#| default_exp inference

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os, torch, numpy as np, warnings, edfio, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from pftsleep.train import PatchTFTSleepStage, PatchTFTSimpleLightning
from pftsleep.heads import RNNProbingHeadExperimental

from scipy.signal import resample
from huggingface_hub import hf_hub_download
from scipy.ndimage import median_filter

from pftsleep.signal import butterworth
from pathlib import Path
import json

ENCODER_DEFAULTS = dict(c_in=7,
            win_length=750,
            hop_length=750,
            max_seq_len=8*3600*125,
            use_revin=True,
            dim1reduce = False,
            use_flash_attn=False, 
            affine=True, # need to test with both true and false
            augmentations=['jitter_zero_mask'],#jitter_zero_mask', 'reverse_sequence', 'shuffle_channels'],
            mask_ratio=0.1,
            n_layers=3,
            d_model=512,
            n_heads=4,
            shared_embedding=False,
            d_ff=2048,
            norm='BatchNorm',
            attn_dropout=0.,
            dropout=0.1,
            act="gelu", 
            res_attention=True,
            pre_norm=False,
            store_attn=False,
            pretrain_head=True,
            pretrain_head_n_layers=1,
            pretrain_head_dropout=0.
            )

CLASSIFIER_HEAD_DEFAULTS = dict(c_in=7, 
                input_size=512,
                hidden_size=1024,
                predict_every_n_patches=5, 
                n_classes=5, 
                num_rnn_layers=2,
                contrastive=False,
                rnn_dropout=0.1, 
                module='GRU',
                bidirectional=True,
                affine=True,
                pool='average', 
                pre_norm=False, 
                mlp_final_head=True,
                linear_dropout=0.1,
                temperature=2)

FREQUENCY_DEFAULT = 125
HYPNOGRAM_FREQUENCY_DEFAULT = 1
HYPNOGRAM_EPOCH_SECONDS_DEFAULT = 30
SEQUENCE_LENGTH_SECONDS_DEFAULT = (8*3600) # 8 hrs
HYPNOGRAM_PADDING_DEFAULT = -100


FREQUENCY_FILTERS_ORDERED_DEFAULT = [[0.5,40], # ECG
            [0.3,30], # EOG L
            [0.3,30], # Chin EMG
            [0.3,30], # EEG
            [0.4, None], # SpO2
            [0.5, None], # Thoracic RR
            [0.5, None], # Abdominal RR
            ]

MEDIAN_FILTER_KERNEL_SIZE_DEFAULT = 3

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
def load_pftsleep_models(models_dir='', # The directory of the saved models
                         encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                         classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                         classifier_head_defaults=CLASSIFIER_HEAD_DEFAULTS, # the defaults for the classifier head, DO NOT CHANGE!
                        ):
    """
    Loads the pftsleep models from the models directory

    Args:
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model
        classifier_head_defaults (dict): The defaults for the classifier head, DO NOT CHANGE!
    Returns:
        encoder (PatchTFTSimpleLightning): The encoder model
        ss_classifier (PatchTFTSleepStage): The classifier model
    """
    encoder = PatchTFTSimpleLightning.load_from_checkpoint(os.path.join(models_dir, encoder_model_name), map_location='cpu')
    lp_model = RNNProbingHeadExperimental(**classifier_head_defaults)
    ss_classifier = PatchTFTSleepStage.load_from_checkpoint(os.path.join(models_dir, classifier_model_name),
                                                            preloaded_model=encoder,
                                                            map_location='cpu',
                                                            linear_probing_head = lp_model
                                                        )
    return encoder, ss_classifier

In [None]:
#| export
def download_pftsleep_models(write_dir='', # The directory to write the models to
                             token=None # Your hugging face token to use to download the models
                             ):
    """
    Function to download pftsleep models from hugging face

    Args:
        write_dir (str): The directory to write the models to
        token (str): Your hugging face token to use to download the models
    """
    hf_hub_download(repo_id="benmfox/PFTSleep", local_dir=write_dir, filename="pft_sleep_encoder.ckpt", token=token)
    hf_hub_download(repo_id="benmfox/PFTSleep", local_dir=write_dir, filename="pft_sleep_classifier.ckpt", token=token)

In [None]:
#| export
def process_edf(edf_file_path, # The edf file path to perform inference on
                channels, # the channels to read from the edf file
                reference_channels_dict={}, # the reference channels to subtract from the channels. The keys are the channels to subtract from, and the values are the reference channels.
                frequency=FREQUENCY_DEFAULT, # the frequency to resample the channels to. Do not change this!
                sample_length=SEQUENCE_LENGTH_SECONDS_DEFAULT, # the length of the sequence to pad the channels to, expected by the model. Do not change this!
                frequency_filters_ordered=FREQUENCY_FILTERS_ORDERED_DEFAULT, # the frequency filters to apply to the channels. Do not change this!
                median_filter_kernel_size=3, # the kernel size for the median filter. Do not change this!
                overwrite_edf_duration=False, # whether to overwrite the duration of the edf file to the sample length, if the edf file duration key is corrupted. 
                verbose=True, # whether to print the verbose output.
                ):
    """
    Process the edf file to prepare it for inference. 
    This function is used to prepare the edf file for inference by reading the channels, resampling them to the correct frequency, filtering them, and padding them to the correct length.
    Do not change the default parameters (frequency, sample_length, frequency_filters_ordered, median_filter_kernel_size) of this function!

    Args:
        edf_file_path (str): The path to the edf file to perform inference on
        channels (list): The channels to read from the edf file
        reference_channels_dict (dict): The reference channels to subtract from the channels. The keys are the channels to subtract from, and the values are the reference channels.
        frequency (int): The frequency to resample the channels to. Do not change this!
        sample_length (int): The length of the sequence to pad the channels to, expected by the model. Do not change this!
        frequency_filters_ordered (list): The frequency filters to apply to the channels. Do not change this!
        median_filter_kernel_size (int): The kernel size for the median filter. Do not change this!
        overwrite_edf_duration (bool): Whether to overwrite the duration of the edf file to the sample length, if the edf file duration key is corrupted. 
        verbose (bool): Whether to print the verbose output.

    Returns:
        signals (torch.Tensor): The processed signals
        sequence_padding_mask (torch.Tensor): The sequence padding mask
    """
    assert set(reference_channels_dict.keys()).issubset(set(channels)), 'The reference channels must be a subset of the channels'
    assert len(channels) == len(frequency_filters_ordered), 'The number of channels and the number of frequency filters must be the same'
    if sample_length != SEQUENCE_LENGTH_SECONDS_DEFAULT:
        warnings.warn(f'Sample length is not set to the default of {SEQUENCE_LENGTH_SECONDS_DEFAULT} seconds. This will likely cause issues with the model.')
    if frequency != FREQUENCY_DEFAULT:
        warnings.warn(f'Frequency is not set to the default of {FREQUENCY_DEFAULT} Hz. This will likely cause issues with the model.')
    f = edfio.read_edf(edf_file_path, lazy_load_data=True)
    num_channels = f.num_signals
    available_channels = [channel.upper() for channel in f.labels]
    channels_all = [channel.upper() for channel in channels]
    channels = [channel.upper() for channel in channels if channel.lower() != 'dummy']
    if len(set(channels) - set(available_channels)) != 0:
        raise ValueError(f'Missing channels in edf file: {set(channels) - set(available_channels)}')
    channels_to_get = list(set(channels) & set(available_channels))
    channels_idxs = [available_channels.index(c) for c in channels_to_get]
    if reference_channels_dict:
        reference_channels = [channel.upper() for channel in reference_channels_dict.values() if channel is not None]
        if len(set(reference_channels) - set(available_channels)) != 0:
            raise ValueError(f'Missing reference channels in edf file: {set(reference_channels) - set(available_channels)}')
        reference_channels_to_get = list(set(reference_channels) & set(available_channels))
        reference_channels_idxs = [available_channels.index(c) for c in reference_channels_to_get]
    if overwrite_edf_duration:
        duration = sample_length
    else:
        duration = int(f.duration)
    if duration < sample_length / 2:
        warnings.warn(f'Duration of edf file is less than half of the expected (~ 8 Hrs) sample length: {duration} < {sample_length / 2}')
    signal_headers = [{'label':f.signals[i].label, 
                        'dimension':f.signals[i].physical_dimension,
                        'sample_rate':f.signals[i].sampling_frequency if frequency is None else frequency,
                        'sample_frequency':f.signals[i].sampling_frequency if frequency is None else frequency,
                        'physical_max': f.signals[i].physical_max,
                        'physical_min': f.signals[i].physical_min,
                        'digital_max': f.signals[i].digital_max,
                        'digital_min': f.signals[i].digital_min,
                        'prefilter': f.signals[i].prefiltering,
                        'transducer': f.signals[i].transducer_type} for i in channels_idxs]
    if verbose:
        print(f'EDF file: {edf_file_path}')
        print(f'Duration: {duration}')
        print(f'Number of channels: {num_channels}')
        print(f'Available channels: {available_channels}')
        print(f'Signal headers: {signal_headers}')
    signals = []
    required_length = duration * int(frequency)
    for c_idx, channel_name in zip(channels_idxs, channels):
        signal = f.signals[c_idx].data
        if reference_channels_dict and channel_name in reference_channels_dict and len(reference_channels_idxs) > 0:
            reference_signal = f.signals[reference_channels_idxs[reference_channels.index(channel_name)]].data
            signal = signal - reference_signal
        if len(signal) != required_length:
            signal = resample(signal, required_length)
        signals.append(signal)
    channel_order = [i['label'].upper() for i in signal_headers if i['label'].upper() in channels] # get channel order of signals
    signals = np.array(signals, dtype=np.float32)
    signals = signals[[channel_order.index(c) for c in channels]]
    filtered_signals = []
    for s, freq_range in zip(signals, frequency_filters_ordered):
        if median_filter_kernel_size is not None:
            s = median_filter(s, size=median_filter_kernel_size, mode='nearest')
        btype = 'highpass' if freq_range[0] is None else 'lowpass' if freq_range[1] is None else 'bandpass'
        freq_range = freq_range[1] if freq_range[0] is None else freq_range[0] if freq_range[1] is None else freq_range
        s = butterworth(s, freq_range=freq_range, btype=btype, fs=frequency, order=2)
        filtered_signals.append(s)
    dummy_idxs = [i for i, c in enumerate(channels_all) if c.lower() == 'dummy']
    for i in dummy_idxs:
        filtered_signals.insert(i, np.zeros(len(filtered_signals[0])))
    assert signals.shape[0] == len(channels), f"Signals shape is not equal to the number of channels: {signals.shape[0]} != {len(channels)}"
    signals = torch.from_numpy(np.array(filtered_signals, dtype=np.float32))
    sample_length_idx = int(sample_length * frequency)
    signals = signals[:sample_length_idx]
    sequence_padding_mask = torch.zeros([signals.shape[-1]])
    if duration < sample_length:
        if signals.shape[-1] < sample_length_idx:
            sequence_padding_mask = F.pad(sequence_padding_mask, (0, sample_length_idx - sequence_padding_mask.shape[-1]), 'constant', value=1) # constant pad with 1
            signals = F.pad(signals, (0, sample_length_idx - signals.shape[-1]), 'constant', value=0) # replicate pad
    return signals, sequence_padding_mask

In [None]:
#| export
def view_edf_channels(edf_file_path, # The path to the edf file to view the channels of
                      uppercase=True # Whether to return the channels in uppercase
                      ):
    """
    View the channels of an edf file.

    Args:
        edf_file_path (str): The path to the edf file to view the channels of
        uppercase (bool): Whether to return the channels in uppercase

    Returns:
        channels (list): The channels in the edf file
    """
    f = edfio.read_edf(edf_file_path, lazy_load_data=True)
    if uppercase:
        return [channel.upper() for channel in f.labels]
    else:
        return f.labels

In [None]:
#| export
class EDFDataset(Dataset):
    def __init__(self, 
                edf_file_paths, # The paths to the edf files to perform inference on
                eeg_channel, # the EEG channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. However, 
                left_eog_channel, # the left EOG channel name in the EDF. The model was trained with M2 referenced left EOG channels.
                chin_emg_channel, # the chin EMG channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels.
                ecg_channel, # the ECG channel name in the EDF. The model was trained with augmented lead 2 ecg channels
                spo2_channel, # the SpO2 channel name in the EDF.
                abdomen_rr_channel, # the abdomen RR channel name in the EDF. 
                thoracic_rr_channel, # the thoracic RR channel name in the EDF.
                eeg_reference_channel=None, # the EEG reference channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. This will reference the channels, if they havent already been referenced. 
                left_eog_reference_channel=None, # the left EOG reference channel name in the EDF. The model was trained with M2 referenced left EOG channels. This will reference the channels, if they havent already been referenced. 
                chin_emg_reference_channel=None, # the chin EMG reference channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels. This will reference the channels, if they havent already been referenced. 
                ecg_reference_channel=None, # the ECG reference channel name in the EDF. The model was trained with augmented lead 2 ecg channels. This will reference the channels, if they havent already been referenced. 
                **kwargs
                ):
        """
        A dataset class for performing inference on multiple edf files.

        Args:
            edf_file_paths (list): The paths to the edf files to perform inference on
            eeg_channel (str): The name of the EEG channel in the EDF
            left_eog_channel (str): The name of the left EOG channel in the EDF
            chin_emg_channel (str): The name of the chin EMG channel in the EDF
            ecg_channel (str): The name of the ECG channel in the EDF
            spo2_channel (str): The name of the SpO2 channel in the EDF
            abdomen_rr_channel (str): The name of the abdomen RR channel in the EDF
            thoracic_rr_channel (str): The name of the thoracic RR channel in the EDF
            eeg_reference_channel (str): The name of the EEG reference channel in the EDF
            left_eog_reference_channel (str): The name of the left EOG reference channel in the EDF
            chin_emg_reference_channel (str): The name of the chin EMG reference channel in the EDF
            ecg_reference_channel (str): The name of the ECG reference channel in the EDF
            **kwargs: Additional keyword arguments for process_edf function
        """
        self.edf_file_paths = edf_file_paths
        self.channels = [ecg_channel, left_eog_channel, chin_emg_channel, eeg_channel, spo2_channel, thoracic_rr_channel, abdomen_rr_channel]
        self.channels = [c if c is not None and c != 'dummy' else 'dummy' for c in self.channels]
        self.reference_channels_dict = {ecg_channel: ecg_reference_channel,
                                        left_eog_channel: left_eog_reference_channel,
                                        chin_emg_channel: chin_emg_reference_channel,
                                        eeg_channel: eeg_reference_channel}
        self.kwargs = kwargs

    def __len__(self):
        return len(self.edf_file_paths)

    def __getitem__(self, idx):
        signals, sequence_padding_mask = process_edf(self.edf_file_paths[idx],
                    channels=self.channels,
                    reference_channels_dict=self.reference_channels_dict,
                    verbose=False,
                    **self.kwargs)
        return signals, sequence_padding_mask


In [None]:
#| export
def map_stage(stage):
    if stage == 4:
        return 5  # REM
    elif stage in (0, 1, 2, 3):
        return stage
    else:
        return -1  # Undefined
 
def create_hypjson(epochs):
    return {
        "header": {
            "study_date": "",
            "study_time": "",
            "study_id": "",
            "version": "6.0.1.24"
        },
        "Data": {
            "10sEpochs": epochs
        },
        "Legend": {
            "undefined": -1,
            "awake": 0,
            "stage1": 1,
            "stage2": 2,
            "stage3": 3,
            "stage4": 4,
            "REM": 5
        }
    }

def write_pred_to_hypjson(predictions, hypjson_path):
    """
    Function to write the predictions to a hypjson file.
    """
    out = torch.softmax(predictions, dim=0) # apply softmax to get probabilities
    out = out.argmax(0).cpu().numpy().astype(int).repeat(3) # repeat each epoch 3 times to get 10s epochs from 30s epochs
    hypjson_epochs = list(map(map_stage, out.tolist())) # map stages
    hyp_json = create_hypjson(hypjson_epochs)
    with open(hypjson_path, 'w') as out_file:
        json.dump(hyp_json, out_file, indent=4)
    print(f'Saved hypjson file to {hypjson_path}')

In [None]:
#| export
def infer_on_edf(edf_file_path, # The edf file path to perform inference on
                eeg_channel, # the EEG channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. However, 
                left_eog_channel, # the left EOG channel name in the EDF. The model was trained with M2 referenced left EOG channels.
                chin_emg_channel, # the chin EMG channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels.
                ecg_channel, # the ECG channel name in the EDF. The model was trained with augmented lead 2 ecg channels
                spo2_channel, # the SpO2 channel name in the EDF.
                abdomen_rr_channel, # the abdomen RR channel name in the EDF. 
                thoracic_rr_channel, # the thoracic RR channel name in the EDF.
                eeg_reference_channel=None, # the EEG reference channel name in the EDF. The model was trained with C4-M1 and C3-M2 referenced EEG channels. This will reference the channels, if they havent already been referenced. 
                left_eog_reference_channel=None, # the left EOG reference channel name in the EDF. The model was trained with M2 referenced left EOG channels. This will reference the channels, if they havent already been referenced. 
                chin_emg_reference_channel=None, # the chin EMG reference channel name in the EDF. The model was trained with chin refenced (chin 2 or chin 3) EMG channels. This will reference the channels, if they havent already been referenced. 
                ecg_reference_channel=None, # the ECG reference channel name in the EDF. The model was trained with augmented lead 2 ecg channels. This will reference the channels, if they havent already been referenced. 
                models_dir='', # the directory of the saved models
                encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                device="cpu", # the device to run the model on
                **kwargs
                ):
    """
    Performs inference on a single edf file using the pftsleep models. 
    If you specify a channel as None or 'dummy', the channel will be passed through as a zero vector. This allows you to use the model even if some channels are not present in the edf file.

    Args:
        edf_file_path (str): The path to the edf file to perform inference on
        eeg_channel (str): The name of the EEG channel in the EDF
        left_eog_channel (str): The name of the left EOG channel in the EDF
        chin_emg_channel (str): The name of the chin EMG channel in the EDF
        ecg_channel (str): The name of the ECG channel in the EDF
        spo2_channel (str): The name of the SpO2 channel in the EDF
        abdomen_rr_channel (str): The name of the abdomen RR channel in the EDF
        thoracic_rr_channel (str): The name of the thoracic RR channel in the EDF
        eeg_reference_channel (str): The name of the EEG reference channel in the EDF
        left_eog_reference_channel (str): The name of the left EOG reference channel in the EDF
        chin_emg_reference_channel (str): The name of the chin EMG reference channel in the EDF
        ecg_reference_channel (str): The name of the ECG reference channel in the EDF
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model
        device (str): The device to run the model on
        **kwargs: Additional keyword arguments for process_edf function

    Returns:
        out (torch.Tensor): The sleep stage logit outputs of the classifier for each sleep epoch in the edf file
    """
    if not os.path.exists(os.path.join(models_dir, encoder_model_name)):
        raise ValueError(f"Encoder model not found in {models_dir}")
    if not os.path.exists(os.path.join(models_dir, classifier_model_name)):
        raise ValueError(f"Classifier model not found in {models_dir}")
    if not os.path.exists(edf_file_path):
        raise ValueError(f"EDF file not found in {edf_file_path}")
    try:
        _, ss_classifier = load_pftsleep_models(models_dir, encoder_model_name, classifier_model_name)
    except Exception as e:
        raise ValueError(f"Trouble loading models: {e}")
    
    try:
        channels = [ecg_channel, left_eog_channel, chin_emg_channel, eeg_channel, spo2_channel, thoracic_rr_channel, abdomen_rr_channel]
        channels = [c if c is not None and c != 'dummy' else 'dummy' for c in channels]
        signals, sequence_padding_mask = process_edf(edf_file_path,
                                                    channels=channels,
                                                    reference_channels_dict={ecg_channel: ecg_reference_channel,
                                                                            left_eog_channel: left_eog_reference_channel,
                                                                            chin_emg_channel: chin_emg_reference_channel,
                                                                            eeg_channel: eeg_reference_channel},
                                                    **kwargs
                                                    )
    except Exception as e:
        raise ValueError(f"Trouble processing edf file: {e}")
    
    try:
        # fine-tuned sleep stage classifier, recommend using a GPU
        ss_classifier = ss_classifier.to(device)
        ss_classifier.eval()

        with torch.no_grad():
            signals = signals.unsqueeze(0).to(device)
            sequence_padding_mask = sequence_padding_mask.unsqueeze(0).to(device)
            out = ss_classifier(signals, sequence_padding_mask=sequence_padding_mask)
            out = out.squeeze(0).cpu()
            # trim the output to the original signal length (if it was shorter)
            #padding_begins = (sequence_padding_mask == 1).nonzero(as_tuple=False)[0]
            #cutoff = int(padding_begins // (30 * FREQUENCY_DEFAULT))
        return out
    except Exception as e:
        raise ValueError(f"Trouble inferring on edf file: {e}")

In [None]:
# out = infer_on_edf(edf_file_path=''
#              eeg_channel='C4-M1', 
#              left_eog_channel='E1-M2', 
#              chin_emg_channel='EMG1-EMG2', 
#              ecg_channel='ECG1-ECG2', 
#              spo2_channel='SPO2', 
#              abdomen_rr_channel='ABDO', 
#              thoracic_rr_channel='dummy',
#              device='mps',
#              models_dir='../'
#              )

In [None]:
#| export
def infer_on_edf_dataset(edf_dataloader, # the edf dataset to perform inference on
                        device='cpu', # the device to run the model on
                        models_dir='', # the directory of the saved models
                        encoder_model_name='pft_sleep_encoder.ckpt', # the name of the encoder model
                        classifier_model_name='pft_sleep_classifier.ckpt', # the name of the classifier model
                        ):
    """
    Performs inference on an EDFDataset.

    Args:
        edf_dataloader (DataLoader): The dataloader (from EDFDataset) to perform inference on
        device (str): The device to run the model on
        batch_size (int): The batch size to use for inference
        models_dir (str): The directory of the saved models
        encoder_model_name (str): The name of the encoder model
        classifier_model_name (str): The name of the classifier model

    Returns:
        preds (list): The predicted sleep stage logits for each edf file
    """
    _, ss_classifier = load_pftsleep_models(models_dir, encoder_model_name, classifier_model_name)
    preds = []
    ss_classifier = ss_classifier.to(device)
    ss_classifier.eval()
    with torch.no_grad():
        for batch in tqdm(edf_dataloader, desc="Predicting"):
            x, sequence_padding_mask = batch
            x = x.to(device)
            sequence_padding_mask = sequence_padding_mask.to(device)
            pred = ss_classifier(x, sequence_padding_mask=sequence_padding_mask) # [bs, n_classes, pred_len_seconds]
            preds.append(pred.cpu())
    return torch.cat(preds)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()