# Data sets and models

In [None]:
#| default_exp models

In [None]:
#| hide 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
import numpy as np
import polars as pl
from pathlib import Path
from enum import Enum, auto
from typing import Dict, List, Tuple
from pisces.data_sets import DataSetObject, ModelInput1D, ModelInputSpectrogram, ModelOutputType, DataProcessor

## Classifier models

In [None]:
#| export
import abc
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline
import numpy as np


class SleepWakeClassifier(abc.ABC):
    """
    """
    @abc.abstractmethod
    def get_needed_X_y(self, data_set: DataSetObject, id: str) -> Tuple[np.ndarray, np.ndarray] | None:
        pass
    def train(self, examples_X: List[pl.DataFrame] = [], examples_y: List[pl.DataFrame] = [], 
              pairs_Xy: List[Tuple[pl.DataFrame, pl.DataFrame]] = [], 
              epochs: int = 10, batch_size: int = 32):
        pass
    def predict(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        pass
    def predict_probabilities(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        pass


### SGD Logistic Regression

In [None]:
#| export
class SGDLogisticRegression(SleepWakeClassifier):
    """Uses Sk-Learn's `SGDCLassifier` to train a logistic regression model. The SGD aspect allows for online learning, or custom training regimes through the `partial_fit` method.
     
    The model is trained with a balanced class weight, and uses L1 regularization. The input data is scaled with a `StandardScaler` before being passed to the model.
    """
    def __init__(self, data_processor: DataProcessor, lr: float = 0.15, ):
        self.model = SGDClassifier(loss='log_loss',
                                   learning_rate='adaptive',
                                   penalty='l1',
                                   eta0=lr,
                                   class_weight='balanced',
                                   warm_start=True)
        self.scaler = StandardScaler()
        self.pipeline = make_pipeline(self.scaler, self.model)
        if not isinstance(data_processor.model_input, ModelInput1D):
            raise ValueError("Model input must be set to 1D on the data processor")
        if not data_processor.model_output == ModelOutputType.SleepWake:
            raise ValueError("Model output must be set to SleepWake on the data processor")
        self.data_processor = data_processor

    def get_needed_X_y(self, id: str) -> Tuple[np.ndarray, np.ndarray] | None:
        return self.data_processor.get_1D_X_y(id)

    def train(self, 
              examples_X: List[np.ndarray]=[], 
              examples_y: List[np.ndarray]=[], 
              pairs_Xy: List[Tuple[np.ndarray, np.ndarray]]=[], 
              ):
        """
        Assumes data is already preprocessed using `get_needed_X_y` 
        and ready to be passed to the model.
        """
        if (examples_X and not examples_y) or (examples_y and not examples_X):
            raise ValueError("If providing examples, must provide both X and y")
        else:
            if examples_X and examples_y:
                assert len(examples_X) == len(examples_y)
        if pairs_Xy:
            assert not examples_X

        X = [self._input_preprocessing(example) for example in examples_X]

        Xs = np.concatenate(X, axis=0)
        ys = np.concatenate(examples_y, axis=0)

        selector = ys >= 0
        Xs = Xs[selector]
        ys = ys[selector]

        self.pipeline.fit(Xs, ys)
    
    def _input_preprocessing(self, X: np.ndarray) -> np.ndarray:
        return self.scaler.transform(X)
    
    def predict(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        """
        Assumes data is already preprocessed using `get_needed_X_y`
        """
        return self.model.predict(self._input_preprocessing(sample_X))
    
    def predict_probabilities(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        """
        Assumes data is already preprocessed using `get_needed_X_y`
        """
        return self.model.predict_proba(self._input_preprocessing(sample_X))

### Mads Olsen et all classifier

We have downloaded the saved model weights from a [research repository from Mads Olsen's group](https://github.com/MADSOLSEN/SleepStagePrediction), and converted those into a saved Keras model to remove the need to re-define all of the layers. This conversion process is shown in `../analyses/convert_mads_olsen_model_to_keras.ipynb`.

Thus, we have a TensorFlow model that we can run inference on, and we could train it if we wanted to.

For simplicity, we are just going to run inference. One twist of our method is that the classifier is expecting two high-resolution spectrograms for inputs:
1. 3-axis Accelerometer data
2. PPG (photoplethysmogram) data

Based on visually inspecting examples from the paper, we are going to hack together an input by flipping the accelerometer data along the frequencies axis. The paper images seem to show a similarity between high-frequency accelerometer data and low-frequency PPG data. Surprisingly, this seems to work well.

In [None]:
#| export
from functools import partial
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import warnings

from pisces.mads_olsen_support import *
from pisces.utils import split_analysis


class MOResUNetPretrained(SleepWakeClassifier):
    tf_model = load_saved_keras()
    config = MO_PREPROCESSING_CONFIG

    def __init__(
        self,
        sampling_hz: int = FS,
    ) -> None:
        """
        Initialize the MOResUNetPretrained classifier.

        Args:
            sampling_hz (int, optional): The sampling frequency in Hz. Defaults to FS.
        """
        super().__init__()
        self.sampling_hz = sampling_hz

    def prepare_set_for_training(self, 
                                 data_processor: DataProcessor, 
                                 ids: List[str],
                                 max_workers: int | None = None 
                                 ) -> List[Tuple[np.ndarray, np.ndarray] | None]:
        """
        Prepare the data set for training.

        Args:
            data_set (DataSetObject): The data set to prepare for training.
            ids (List[str], optional): The IDs to prepare. Defaults to None.
            max_workers (int, optional): The number of workers to use for parallel processing. Defaults to None, which uses all available cores. Setting to a negative number leaves that many cores unused. For example, if my machine has 4 cores and I set max_workers to -1, then 3 = 4 - 1 cores will be used; if max_workers=-3 then 1 = 4 - 3 cores are used.

        Returns:
            List[Tuple[np.ndarray, np.ndarray] | None]: A list of tuples, where each tuple is the result of `get_needed_X_y` for a given ID. An empty list indicates an error occurred during processing.
        """
        results = []
        
        processor_and_ids = [(data_processor, id) for id in ids]
        # Get the number of available CPU cores
        num_cores = multiprocessing.cpu_count()
        workers_to_use = max_workers if max_workers is not None else num_cores
        if (workers_to_use > num_cores):
            warnings.warn(f"Attempting to use {max_workers} but only have {num_cores}. Running with {num_cores} workers.")
            workers_to_use = num_cores
        if workers_to_use <= 0:
            workers_to_use = num_cores + max_workers
        if workers_to_use < 1:
            # do this check second, NOT with elif, to verify we're still in a valid state
            raise ValueError(f"With `max_workers` == {max_workers}, we end up with max_workers + num_cores ({max_workers} + {num_cores}) which is less than 1. This is an error.")

        print(f"Using {workers_to_use} of {num_cores} cores ({int(100 * workers_to_use / num_cores)}%) for parallel preprocessing.")
        print(f"This can cause memory or heat issues if  is too high; if you run into problems, call prepare_set_for_training() again with max_workers = -1, going more negative if needed. (See the docstring for more info.)")
        # Create a pool of workers
        with ProcessPoolExecutor(max_workers=workers_to_use) as executor:
            results = list(
                executor.map(
                    self.get_needed_X_y,
                    processor_and_ids,
                ))

        return results

    def get_needed_X_y(self, data_processor: DataProcessor, id: str) -> Tuple[np.ndarray, np.ndarray] | None:
        return data_processor.get_spectrogram(id)

    def train(self, 
              examples_X: List[pl.DataFrame] = [], 
              examples_y: List[pl.DataFrame] = [], 
              pairs_Xy: List[Tuple[pl.DataFrame, pl.DataFrame]] = [], 
              epochs: int = 10, batch_size: int = 32):
        """Training is not implemented yet for this model. You can run inference, though, using `predict_probabilities` and `predict`."""
        pass

    def predict(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        return np.argmax(self.predict_probabilities(sample_X), axis=1)

    def predict_probabilities(self, sample_X: np.ndarray | pl.DataFrame) -> np.ndarray:
        if isinstance(sample_X, pl.DataFrame):
            sample_X = sample_X.to_numpy()
        return self._evaluate_tf_model(sample_X)

    def _evaluate_tf_model(self, inputs: np.ndarray) -> np.ndarray:
        # set input tensor to FLOAT32
        inputs = inputs.astype(np.float32)

        # run inference
        preds = self.tf_model.predict(inputs)

        return preds

    def evaluate_data_set(self, 
                          data_processor: DataProcessor, 
                          exclude: List[str] = [], 
                          max_workers: int = None) -> Tuple[Dict[str, dict], list]:
        data_set = data_processor.data_set
        filtered_ids = [id for id in data_set.ids if id not in exclude]
        mo_preprocessed_data = [
            (d, i) 
            for (d, i) in zip(
                self.prepare_set_for_training(data_processor, filtered_ids, max_workers=max_workers),
                filtered_ids) 
            if d is not None
        ]

        evaluations: Dict[str, dict] = {}
        for i, ((X, y), id) in enumerate(mo_preprocessed_data):
            y_hat_proba = self.predict_probabilities(X)
            y_hat_sleep_proba = (1 - y_hat_proba[:, :, 0]).reshape(-1,)
            analysis = split_analysis(y, y_hat_sleep_proba)
            evaluations[id] = analysis
            print(f"Processing {i+1} of {len(mo_preprocessed_data)} ({id})... AUROC: {analysis['auc']}")
        return evaluations, mo_preprocessed_data




### Training tools

In [None]:
#| export
from typing import Type
from tqdm import tqdm
from sklearn.model_selection import LeaveOneOut


class SplitMaker:
    def split(self, ids: List[str]) -> Tuple[List[int], List[int]]:
        raise NotImplementedError
    
class LeaveOneOutSplitter(SplitMaker):
    def split(self, ids: List[str]) -> Tuple[List[int], List[int]]:
        loo = LeaveOneOut()
        return loo.split(ids)

def run_split(train_indices, 
              preprocessed_data_set: List[Tuple[np.ndarray, np.ndarray]], 
              swc: SleepWakeClassifier) -> SleepWakeClassifier:
    training_pairs = [
        preprocessed_data_set[i][0]
        for i in train_indices
        if preprocessed_data_set[i][0] is not None
    ]
    swc.train(pairs_Xy=training_pairs)

    return swc

def run_splits(split_maker: SplitMaker, w: DataSetObject, swc_class: Type[SleepWakeClassifier], exclude: List[str] = []) -> Tuple[
        List[SleepWakeClassifier], 
        List[np.ndarray],
        List[List[List[int]]]]:
    split_models: List[swc_class] = []
    test_indices = []
    splits = []

    preprocessed_data = [(swc_class().get_needed_X_y(w, i), i) for i in w.ids if i not in exclude]

    for train_index, test_index in tqdm(split_maker.split(w.ids)):
        if preprocessed_data[test_index[0]][0] is None:
            continue
        model = run_split(train_indices=train_index,
                        preprocessed_data_set=preprocessed_data,
                        swc=swc_class())
        split_models.append(model)
        test_indices.append(test_index[0])
        splits.append([train_index, test_index])
        # break
    
    return split_models, preprocessed_data, splits



In [None]:
#| hide
import nbdev
nbdev.nbdev_export()