Skip to content

DataCTE/Decentralized-Diffusion-Models

Repository files navigation

Decentralized Diffusion Model (DDM) Training

Author: Datacte (GitHub: Datavoid)
Date: March 15, 2025
License: MIT License

Overview

This implementation provides a streamlined approach to Decentralized Diffusion Models (DDM) training without requiring data clustering or DINO feature extraction. The system trains multiple expert models on random data partitions and uses a learned router for dynamic expert selection during inference.

Key Features:

  • Automatic Clustering - Integrated two-stage clustering for optimal expert specialization
  • Dynamic Expert Selection - Router learns optimal expert combinations per-input
  • Efficient Inference - Top-k expert selection reduces compute costs
  • Precomputed Latents - GPU-optimized latent caching for faster training

Prerequisites

  • Python 3.8+
  • PyTorch 2.0+
  • Basic dependencies:
    pip install torch torchvision numpy tqdm

Quick Start

  1. Prepare Dataset

    # 1. Run feature extraction and clustering
    python run_clustering.py \
      --dataset_path /path/to/images \
      --feature_path ./cache/features
      
    # 2. Precompute VAE latents
    python precompute_latents.py \
      --dataset_path /path/to/images \
      --output_dir ./cache/latents
      
    # 3. Generate cluster assignments
    python generate_clusters.py --config config.py
  2. Start Training

    torchrun --nproc-per-node=4 train.py \
      --num_experts 8 \
      --batch_size 32 \
      --top_k 2
  3. Generate Samples

    from trainers.sampling import ddm_sample
    
    images = ddm_sample(
        router, 
        experts,
        shape=(4, 3, 256, 256),
        steps=50,
        top_k=2
    )

Key Components

Component Description
Experts Specialized diffusion models trained on clustered data partitions
Router Lightweight network predicting expert relevance scores
Coordinator Manages distributed training and model checkpoints

Training Configuration

class TrainingConfig:
    num_experts = 8           # Number of parallel experts
    expert_dim = 1024         # Model hidden dimension  
    router_dim = 256          # Router network size
    batch_size = 32           # Per-expert batch size
    top_k = 1                 # Active experts during inference
    steps = 1000000           # Total training steps
    lr = 1e-4                 # Learning rate
    warmup = 5000             # LR warmup steps

Implementation Highlights

  • Decentralized Flow Matching - Implements paper's equations 4-7
  • Expert Isolation - Each expert trains on separate data partition
  • Cluster-Aware Routing - Router uses precomputed cluster assignments
  • Efficient Sampling - Top-k expert selection reduces compute by 4-8x

Benchmarks (256x256 Images)

Experts Top-k FID ↓ Sampling Time ⏱️ VRAM Usage 💾
4 1 12.7 1.2s/img 18GB
8 2 9.8 1.8s/img 22GB
16 3 7.4 2.4s/img 28GB

Limitations

  • Requires careful expert count vs quality tradeoff
  • Router training needs sufficient diversity in data
  • Larger models benefit from distributed training
  • Requires precomputation of clusters and latents (~30 mins for 1M images)

Acknowledgments

Builds on foundational work from:

  • Peebles & Xie (2023) - Diffusion Transformers
  • McAllister et al. (2025) - Decentralized Diffusion
  • Ho et al. (2020) - DDPM framework

License

This project is licensed under the MIT License. See the LICENSE file for details.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages