To set up the environment, you need to install the required dependencies. You can do this by using the requirements.txt file.
conda create --name myenv python=3.11
conda activate myenv
pip install -r requirements.txtPlease download the following datasets to run the experiments.
- Lidar: Link to Lidar dataset
- Single Cell:
- CITE and Multi: Link to CITE and Multi datasets
- EB: Link to EB dataset
- Animal Faces HQ (AFHQ): Link to AFHQ dataset
All hyperparameters used for the experiments in the paper are located in the config folder, with specific definitions in mfm/train/parsers.py. To specify the data location, use the --working_dir flag.
To specify the experiment to run use --config_path flag, for example:
python -m mfm.train.main --config_path ./configs/arch/ot-mfm.yamlFor the arch, sphere, single cell, and images experiments, evaluation metrics will be logged after training. Plots for arch, lidar, and sphere will also be saved at the end of training in the --working_dir folder.
Model checkpoints are saved within the checkpoints folder under --working_dir. The geopath model can be loaded using the --load_geopath_model_ckpt <checkpoint_path> flag. Training and evaluation can be resumed from a flow model checkpoint using the --resume_flow_model_ckpt <checkpoint_path> flag.
If you find this repository helpful for your publications, please consider citing our paper:
@article{kapusniak2024metric,
title={Metric Flow Matching for Smooth Interpolations on the Data Manifold},
author={Kapusniak, Kacper and Potaptchik, Peter and Reu, Teodora and Zhang, Leo and Tong, Alexander and Bronstein, Michael and Bose, Avishek Joey and Di Giovanni, Francesco},
journal={arXiv preprint arXiv:2405.14780},
year={2024}
}
mfm
├── dataloaders
│ ├── image_data.py
│ ├── lidar_data.py
│ └── trajectory_data.py
├── flow_matchers
│ ├── ema.py
│ ├── eval_utils.py
│ ├── flow_net_train.py
│ ├── geopath_net_train.py
│ └── models
│ └── mfm.py
├── geo_metrics
│ ├── land.py
│ ├── metric_factory.py
│ └── rbf.py
├── networks
│ ├── flow_networks
│ │ └── mlp.py
│ ├── geopath_networks
│ │ ├── mlp.py
│ │ └── unet.py
│ ├── mlp_base.py
│ ├── unet_base.py
│ └── utils.py
├── train
│ ├── main.py
│ ├── parsers.py
│ └── train_utils.py
└── utils.py
