Skip to content

BiomedSciAI/fuse-med-ml

Repository files navigation

Open Source PyPI version Python version Slack channel Downloads DOI

drawing

Effective Code Reuse across ML projects!

A python framework accelerating ML based discovery in the medical field by encouraging code reuse. Batteries included :)

FuseMedML is part of the PyTorch Ecosystem.

Jump to:

Motivation - "Oh, the pain!"

Analyzing many ML research projects we discovered that

  • Projects bring up is taking far too long, even when very similar projects were already done in the past by the same lab!
  • Porting individual components across projects was painful - resulting in "reinventing the wheel" time after time

How the magic happens

1. A simple yet super effective design concept

Data is kept in a nested (hierarchical) dictionary

This is a key aspect in FuseMedML (shortly named as "fuse"). It's a key driver of flexibility, and allows to easily deal with multi modality information.

from fuse.utils import NDict

sample_ndict = NDict()
sample_ndict['input.mri'] = # ...
sample_ndict['input.ct_view_a'] = # ...
sample_ndict['input.ct_view_b'] = # ...
sample_ndict['groundtruth.disease_level_label'] = # ...

This data can be a single sample, it can be for a minibatch, for an entire epoch, or anything that is desired. The "nested key" ("a.b.c.d.etc') is called "path key", as it can be seen as a path inside the nested dictionary.

Components are written in a way that allows to define input and output keys, to be read and written from the nested dict See a short introduction video (3 minutes) to how FuseMedML components work:

fuse-intro.mp4

Examples - using FuseMedML-style components

A multi head model FuseMedML style component, allows easy reuse across projects:

ModelMultiHead(
    conv_inputs=(('data.input.img', 1),),                                       # input to the backbone model
    backbone=BackboneResnet3D(in_channels=1),                                   # PyTorch nn Module
    heads=[                                                                     # list of heads - gives the option to support multi task / multi head approach
               Head3D(head_name='classification',
                                mode="classification",
                                conv_inputs=[("model.backbone_features", 512)]  # Input to the classification head
                                ,),
          ]
)

Our default loss implementation - creates an easy wrap around a callable function, while being FuseMedML style

LossDefault(
    pred='model.logits.classification',          # input - model prediction scores
    target='data.label',                         # input - ground truth labels
    callable=torch.nn.functional.cross_entropy   # callable - function that will get the prediction scores and labels extracted from batch_dict and compute the loss
)

An example metric that can be used

MetricAUCROC(
    pred='model.output', # input - model prediction scores
    target='data.label'  # input - ground truth labels
)

Note that several components return answers directly and not write it into the nested dictionary. This is perfectly fine, and to allow maximum flexibility we do not require any usage of output path keys.

Creating a custom FuseMedML component

Creating custom FuseMedML components is easy - in the following example we add a new data processing operator:

A data pipeline operator

class OpPad(OpBase):
    def __call__(self, sample_dict: NDict,
        key_in: str,
        padding: List[int], fill: int = 0, mode: str = 'constant',
        key_out:Optional[str]=None,
        ):

        # we extract the element in the defined key location (for example 'input.xray_img')
        img = sample_dict[key_in]
        assert isinstance(img, np.ndarray), f'Expected np.ndarray but got {type(img)}'
        processed_img = np.pad(img, pad_width=padding, mode=mode, constant_values=fill)

        # store the result in the requested output key (or in key_in if no key_out is provided)
        key_out = key_in if key_out is None
        sample_dict[key_out] = processed_img

        # returned the modified nested dict
        return sample_dict

Since the key location isn't hardcoded, this module can be easily reused across different research projects with very different data sample structures. More code reuse - Hooray!

FuseMedML-style components in general are any classes or functions that define which key paths will be written and which will be read. Arguments can be freely named, and you don't even have to write anything to the nested dict. Some FuseMedML components return a value directly - for example, loss functions.

2. "Batteries included" key components, built using the same design concept

fuse.data - A declarative super flexible data processing pipeline

  • Easy dealing with complex multi modality scenario
  • Advanced caching, including periodic audits to automatically detect stale caches
  • Default ready-to-use Dataset and Sampler classes
  • See detailed introduction here

fuse.eval - a standalone library for evaluating ML models (not necessarily trained with FuseMedML)

The package includes collection of off-the-shelf metrics and utilities such as statistical significance tests, calibration, thresholding, model comparison and more. See detailed introduction here

fuse.dl - reusable dl (deep learning) model architecture components, loss functions, etc.

Supported DL libraries

Some components depend on pytorch. For example, fuse.data is oriented towards pytorch DataSet, DataLoader, DataSampler etc. fuse.dl makes heavy usage of pytorch models. Some components do not depend on any specific DL library - for example fuse.eval.

Broadly speaking, the supported DL libraries are:

Before you ask - pytorch-lightning and FuseMedML play along very nicely and have in practice orthogonal and additive benefits :) See Simple FuseMedML + PytorchLightning Example for simple supervised learning cases, and this example for completely custom usage of pytorch-lightning and FuseMedML - useful for advanced scenarios such as Reinforcement Learning and generative models.

Domain Extensions

fuse-med-ml, the core library, is completely domain agnostic! Domain extensions are optionally installable packages that deal with specific (sub) domains. For example:

  • fuseimg which was battle-tested in many medical imaging related projects (different organs, imaging modalities, tasks, etc.)
  • fusedrug (to be released soon) which focuses on molecular biology and chemistry - prediction, generation and more

Domain extensions contain concrete implementation of components and components parts within the relevant domain, for example:

The recommended directory structure mimics fuse-med-ml core structure

your_package
    data #everything related to datasets, samplers, data processing pipeline Ops, etc.
    dl #everything related to deep learning architectures, optimizers, loss functions etc.
    eval #evaluation metrics
    utils #any utilities

You are highly encouraged to create additional domain extensions and/or contribute to the existing ones! There's no need to wait for any approval, you can create domain extensions on your own repos right away

Note - in general, we find it helpful to follow the same directory structure shown above even in small and specific research projects that use FuseMedML for consistency and easy landing for newcomers into your project :)

Installation

FuseMedML is tested on Python >= 3.7 and PyTorch >= 1.5

We recommend using a Conda environment

Create a conda environment using the following command (you can replace FUSEMEDML with your preferred enviornment name)

conda create -n FUSEMEDML python=3.9
conda activate FUSEMEDML

Now one shall install PyTorch and it's corresponding cudatoolkit. See here for the exact command that will suit your local environment. For example:

conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia

and then do Option 1 or Option 2 below inside the activated conda env

Option 1: Install from source (recommended)

The best way to install FuseMedML is to clone the repository and install it in an editable mode using pip:

$ pip install -e .[all]

This mode installs all the currently publicly available domain extensions - fuseimg as of now, fusedrug will be added soon.

To install FuseMedML with an included collection of examples install it using:

$ pip install -e .[all,examples]

Option 2: Install from PyPI

$ pip install fuse-med-ml[all]

or with examples:

$ pip install fuse-med-ml[all,examples]

Examples

Walkthrough template

  • Walkthrough Template - includes several TODO notes, marking the minimal scope of code required to get your pipeline up and running. The template also includes useful explanations and tips.

Community support - join the discussion!

Citation

If you use FuseMedML in scientific context, please consider citing our JOSS paper:

@article{Golts2023,
        doi = {10.21105/joss.04943},
        url = {https://doi.org/10.21105/joss.04943},
        year = {2023},
        publisher = {The Open Journal},
        volume = {8},
        number = {81},
        pages = {4943},
        author = {Alex Golts and Moshe Raboh and Yoel Shoshan and Sagi Polaczek and Simona Rabinovici-Cohen and Efrat Hexter},
        title = {FuseMedML: a framework for accelerated discovery in machine learning based biomedicine},
        journal = {Journal of Open Source Software}
}