Skip to content

milmor/diffusion-transformer-keras

Repository files navigation

Diffusion Transformer

Implementation of the Diffusion Transformer model in the paper:

Scalable Diffusion Models with Transformers.

See here for the official Pytorch implementation.

Dependencies

  • Python 3.8
  • TensorFlow 2.12

Training AutoencoderKL

Use --train_file_pattern=<file_pattern> and --test_file_pattern=<file_pattern> to specify the train and test dataset path.

python ae_train.py --train_file_pattern='./train_dataset_path/*.png' --test_file_pattern='./test_dataset_path/*.png' 

Training Diffusion Transformer

Use --file_pattern=<file_pattern> to specify the dataset path.

python ldt_train.py --file_pattern='./dataset_path/*.png'

*Training DiT requires the pretrained AutoencoderKL. Use ae_dir and ae_name to specify the AutoencoderKL path in the ldt_config.py file.

Sampling

Use --model_dir=<model_dir> and --ldt_name=<ldt_name> to specify the pre-trained model. For example:

python sample.py --model_dir=ldt --ldt_name=model_1 --diffusion_steps=40

Hparams setting

Adjust hyperparameters in the ae_config.py and ldt_config.py files.

Implementation notes:

  • LDT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
  • LDT largely follows the original DiT model.
  • DiT Block with adaLN-Zero.
  • Diffusion Transformer with Linformer attention.
  • Cosine schedule.
  • DDIM sampler.
  • FID evaluation.
  • AutoencoderKL with PatchGAN discriminator and hinge loss.
  • This implementation uses code from the beresandras repo. Under MIT Licence.

Samples

Curated samples from FFHQ

Licence

MIT

About

Implementation of Latent Diffusion Transformer Model in Tensorflow / Keras

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages