Skip to content

benmfox/PhysioJEPA

Repository files navigation

PhysioJEPA

Overview

PhysioJEPA is a Python package for physiological signal modeling with joint embedding predicting architectures. The repository is built using nbdev, which means the package is developed in Jupyter notebooks.

PhysioJEPA was published in the ML4H Conference! See the paper here.

Install

pip install physiojepa

Repository Structure

This is an nbdev repository, which means the package is developed in Jupyter notebooks located in the nbs/ directory. Any modifications or additions to the PhysioJEPA package should be made by editing these notebooks.

To build the package, run nbdev_prepare in the terminal. This will generate the PhysioJEPA package in the PhysioJEPA/ directory and all python modules, which can be imported and used in other Python projects.

To add new functionality, create a new notebook or add to exisitng in the nbs/ directory and follow the instructions in the nbdev documentation to add the new functionality. Then, run nbdev_prepare to generate the PhysioJEPA package with the new functionality.

Directory Structure:

  • nbs/: Contains the source notebooks that generate the Python package
  • jobs/: Contains processing and training scripts
    • convert_to_zarr.py: Converts the MIMIC-III dataset to zarr files
    • label_processing/
      • create_hypotension_outcome_df.py: Creates minute level hypotension and shock labels from Zarr file ABP waveforms
      • create_hypotension_shock_labels.ipynb: Creates training labels (5-min continuous minutes of shock or hypotension) from minute level labels
    • jepa/:
      • train_patch_jepa.py: Trains the initial JEPA foundational transformer model (3 channel, 30 min segments) using the train_patch_jepa.yaml config
      • train_hypotension.py: Trains the Attentive classifier to predict hypotension using the trained JEPA encoder with the train_hypotension.yaml config
      • train_shock_index.py: Trains the Attentive classifier to predict shock index using the trained JEPA encoder with the train_shock_index.yaml config
    • patchtst/:
      • train_patchtst.py: Trains the initial PatchTST foundational transformer model (3 channel, 30 min segments) using the train_patchtst.yaml config
      • train_hypotension.py: Trains the Attentive classifier to predict hypotension using the trained PatchTST encoder with the train_hypotension.yaml config
      • train_shock_index.py: Trains the Attentive classifier to predict shock index using the trained PatchTST encoder with the train_shock_index.yaml config
    • ecgjepa/ (based on https://arxiv.org/abs/2410.08559):
      • train_ecgjepa.py: Trains the initial ECG-JEPA foundational transformer model (3 channel, 30 min segments) using the train_ecgjepa.yaml config
      • train_hypotension_ecgjepa.py: Trains the Attentive classifier (without batch/channel melting) to predict hypotension using the trained ECG-JEPA encoder with the train_hypotension_ecgjepa.yaml config
      • train_shock_index_ecgjepa.py: Trains the Attentive classifier (without batch/channel melting) to predict shock index using the trained ECG-JEPA encoder with the train_shock_index_ecgjepa.yaml config
    • baselines/:
      • fcn_baseline_hypotension.py: Trains a supervised FCN model to predict hypotension using the fcn_baseline_hypotension.yaml config
      • fcn_baseline_si.py: Trains a supervised FCN model to predict hypotension using the fcn_baseline_si.yaml config

About

PhysioJEPA: Joint Embedding Representations of Physiological Signals for Real Time Risk Estimation in the Intensive Care Unit

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors