Reference repository for the paper Unbalanced Diffusion Schrödinger Bridge, cite as:
@misc{pariset2023unbalanced,
title={Unbalanced Diffusion Schr\"odinger Bridge},
author={Matteo Pariset and Ya-Ping Hsieh and Charlotte Bunne and Andreas Krause and Valentin De Bortoli},
year={2023},
eprint={2306.09099},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Create and activate a dedicated conda
environment:
conda env create -n udsb -f udsb.yml
conda activate udsb
Install jax
:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
If you do not have hardware acceleration, you can find more information here: https://github.com/google/jax#installation.
Install remaning dependencies:
pip install --upgrade dm-haiku==0.0.9
conda install -c conda-forge optax
conda install -c conda-forge ott-jax
Below, we detail the organization of this repository.
Raw data and preprocessing pipelines are contained in the data
folder.
data/
:create_datasets.ipynb
: generate toy datasetsprepare_cells.ipynb
: process cell drug response datasetprepare_flights.ipynb
: generate country embedding based on flightsprepare_covid_variants.ipynb
: process covid 19 variant spread dataset2d/
: folder containing 2d representations of empirical distributions and killing zones4i/
: folder with raw and processed data belonging to the cell experimentflights/
: folder with flight data and country embeddingcovid/
: folder with raw and processed data belonging to the covid experiment
Datasets DOI: https://doi.org/10.3929/ethz-b-000631091.
The two algorithms presented in our paper are contained in:
udsb_td/
: code and experiments involving our UDSB-TD algorithm.udsb_f/
: code and experiments involving our UDSB-F algorithm.
Both algorithms can be used with a similar interface. To initialize an experiment, use:
Experiment.create(config, ...)
and to reload a trained model with name tag
, call:
Experiment.load(dataset_name, tag)
Path sampling and plots can be obtained using the
Viewer(key, experiment)
object, while training is perfomed via the snippet:
trainer = Trainer(key, experiment)
trainer.train(...)
For additional information, please refer to the documentation of UDSB-F. We highlight here some invariants respected throughout our codebase:
-
Direction-dependent entities:
- Many entities exist in pair: one instance per SDE direction. When this happens, the pair is regrouped in a dictionary, indexed by the direction (
FORWARD
orBACKWARD
). - In the context of training, the direction assigned is the one for which the network is updated: e.g. forward corresponds to the IPF pass in which the forward score is learned.
- The names of direction-indexed dictionaries are singular. Examples:
model
,ipf_loss
,optimizer
, .... broadcast()
provides a shortcut to execute a computation that is indexed by the direction.
- Many entities exist in pair: one instance per SDE direction. When this happens, the pair is regrouped in a dictionary, indexed by the direction (
-
Random operations:
- Functions that use randomness take a
key
argument. - If a function returns a
key
, no splitting needs to be done by the caller: the function internally splits the key and returns a fresh state (e.g. forkey, ... = train(key, ...)
). Otherwise, the caller is in change of managing the state of the PRNG (as withSDE.sample_f_trajectory()
).
- Functions that use randomness take a
-
We offer the following jit-compiled functions:
training_step
: which is used bytrain()
to execute one IPF iteration.fast_sample_trajectory_evo
: which is the faster alternative to the standard (slower)sample_trajectory()
call (note that all functions inViewer
use the latter).
The reproducibility/
folder contains some pre-computed baseline predictions.