https://github.com/CompVis/tread
dit.py: Core implementation of the Diffusion Transformer modelsrouting_module.py: Implementation of the TREAD token routing mechanismedm.py: Diffusion process and sampling methodstrain.py: Main training loop and infrastructurefid.py: Evaluation metrics calculationautoencoder.py: VAE implementation for latent space encoding/decoding
In order to train a diffusion model, we offer a minimalistic training script in train.py. In its simplest form it can be started using:
accelerate launch train.pywith configs/config.yaml having all the relevant information and settings for the actual training run. Please adjust this as needed before training.
Note: We expect precomputed latents in this version.
Under model one can decide between dit and tread which are the preconfigured versions here with the former being the standard dit and the latter being supported by TREAD. How these changes are implemented can be seen in dit.py and routing_module.py.
In our paper, we show that TREAD can also work on other architectures. In practice, one needs to be more careful with the routing process in order to adhere to the characteristics of the specific architecture as some have a spatial bias (RWKV, Mamba, etc.). For simplicity, we only provide code for the Transformer architecture as it is the most widely used while being robust and easy to work with.
For sampling, we use the EDM sampling, and the FID calculation is done via the ADM evaluation suite. We provide a fid.py to evaluate our models during training using the same reference batches as ADM.
If you use this codebase or otherwise found our work valuable, please cite our paper:
@misc{krause2025treadtokenroutingefficient,
title={TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training},
author={Felix Krause and Timy Phan and Vincent Tao Hu and Björn Ommer},
year={2025},
eprint={2501.04765},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2501.04765},
}Thanks to the open source codebases such as DiT, MaskDiT, ADM, and EDM. Our codebase is built on them.