Skip to content

jpeaceau/HVRTDiffuser

Repository files navigation

HVRTDiffuser — Proof of Concept

License: AGPL v3

A single-pass generative image model built on FastHVRT, the Hierarchical Variance-Retaining Transformer. Instead of iterative denoising, it samples synthetic latent vectors non-parametrically from a distribution learned by a convolutional autoencoder, then decodes them in one forward pass.


Motivation

Diffusion models produce high-quality images but require hundreds of denoising steps at inference time. VAE reparameterisation is fast but destroys the latent-space structure that would allow faithful resampling — even a small KL weight (β ≥ 0.01) measurably degrades generation quality.

FastHVRT addresses this directly. It partitions an encoded dataset into hierarchical regions that preserve local variance, then samples from per-partition non-parametric kernels. Fitting is O(n · d) and generation is a single expand() call — no stochastic inference chain, no gradient descent at inference time.

This repository tests the hypothesis:

A zero-KL autoencoder + FastHVRT latent sampler can match VAE-class generation quality at a fraction of the compute, while preserving enough class structure for conditional generation.


How It Works

  Real images
       |
  [Encoder]  (ConvAE, beta=0 — no KL penalty)
       |
  Latent Z  (n x d float32 array)
       |
  [FastHVRT.fit(Z, labels=y)]
       |                    |
  Unconditional          Conditional
  .expand(n)          per-class .expand(n)
       |                    |
  Z_synth  ----------------+
       |
  [Decoder]
       |
  Synthetic images

The conditional path fits a dedicated FastHVRT per class on the class-specific training latents. No label embedding, no classifier guidance — just non-parametric density estimation within each class cluster.


Key Results (MNIST)

All runs: 60k training samples, NVIDIA RTX 4090 Laptop GPU.

Generation quality — full HPO optimisation

Config FID Latent KS IS
VAE reparam (β=1.0, baseline) 325.4 0.504 1.03
HVRT(AE, bootstrap, dim=16, 100ep) 18.4 0.064 2.06
HVRT(AE, epanechnikov, dim=16, 100ep) 15.5 0.066 2.04
HVRT(AE, epanechnikov, dim=64, 20ep) 19.7 0.033 1.99
HVRT(AE, epanechnikov, dim=64, min_leaf=5) 19.5 0.031 1.95

HVRT(AE) achieves 21× lower FID than VAE reparameterisation at equal epochs.

Downstream utility — TSTR (Train-on-Synthetic, Test-on-Real)

Classifier trained only on synthetic images, tested on 10k real MNIST images. Optimal config: ConvAE β=0, epanechnikov, latent_dim=64, min_samples_leaf=5.

Training data N Test accuracy
Real (full 60k) 60 000 99.21%
Real (equal-N) 10 000 98.28%
HVRT_AE_cond (synthetic) 10 000 98.01%
VAE reparam (β=0.05) 10 000 — (collapsed at dim=64)
Random N(0,1) decoded 10 000 13.7%

0.27 percentage points below equal-N real data. A classifier trained exclusively on HVRT-generated digits transfers to real images with 98% accuracy.

Generation speed

Method Time for 5 000 images
DDPM (reference) ~800 s
VAE reparameterisation 0.37 s
FastHVRT expand + decode 0.02–0.04 s

Single-pass generation is approximately 100× faster than VAE reparameterisation for equal sample counts, and orders of magnitude faster than diffusion.


HPO Findings Summary

Four sweeps were run to characterise the design space. All on MNIST 60k, 20 training epochs, 2 000 generated samples.

Generation strategy (ConvAE β=0, dim=16)

Strategy FID TSTR
epanechnikov 21.68 96.65%
multivariate_kde 22.16 96.67%
bootstrap_noise 23.80 96.83%
univariate_kde_copula 24.07 95.96%

Epanechnikov gives the best FID; TSTR is flat across all strategies (~96–97%).

Encoder beta (epanechnikov, dim=16)

FID doubles by β=0.025 and quintuples by β=0.2. TSTR collapses to 85.6% at β=0.2. Use β=0 (pure AE).

Latent dimension (ConvAE β=0, epanechnikov)

dim FID TSTR
8 38.3 94.75%
16 21.6 96.50%
32 20.0 96.45%
64 19.7 97.23%

Diminishing returns above 32. Note: beta values tuned at one dimension do not transfer — β=0.05 that works at dim=16 completely collapses at dim=64 (KL penalty scales with d).

Ensemble K (EnsembleHVRT, bootstrap_noise)

K=1 through K=7 span less than 1 FID point. Single-model generation is sufficient.

Tree parameters (ConvAE β=0, epanechnikov, dim=64)

Parameter Finding
max_depth Auto-tune resolves to depth ~10; deeper adds nothing
min_samples_leaf Auto-tune is too conservative (187 partitions); min_samples_leaf=5 gives 1 500 partitions and FID 19.45 vs 21.29
bandwidth No effect with epanechnikov (applies only to Gaussian-kernel strategies)

The single most impactful tuning decision after encoder architecture is overriding the default min_samples_leaf.


Optimal Configuration

from hvrt_diffuser.config import HVRTConfig, TrainingConfig
from hvrt_diffuser.models.beta_vae import BetaVAE
from hvrt_diffuser.hvrt_wrapper import HVRTDiffuser
from hvrt_diffuser.training import train_model, encode_dataset

# 1. Train encoder — no KL regularisation
ae = BetaVAE(latent_dim=64, beta=0.0, dataset="mnist")
train_model(ae, train_loader, TrainingConfig(latent_dim=64, epochs=20))

# 2. Encode training set with labels
Z_train, y_train = encode_dataset(ae, train_loader, device)

# 3. Fit HVRT with finer partitioning
cfg = HVRTConfig(strategy="epanechnikov", min_samples_leaf=5)
diffuser = HVRTDiffuser(ae, cfg)
diffuser.fit(Z_train, labels=y_train)

# 4. Generate — unconditional or per-class
images = diffuser.generate(n=1000, device="cuda")
images_3 = diffuser.generate_conditional(n=200, class_id=3, device="cuda")

Installation

python -m venv .venv
.venv\Scripts\activate          # Windows
pip install -e .

Requires Python ≥ 3.9, PyTorch ≥ 2.0 (CUDA build recommended for HVRT fitting on large datasets), and the hvrt package.


Entry Points

Command Description
hvrt-phase1 Phase 1: ConvAE + LowBetaVAE + StdVAE on MNIST; FID/IS/latent-KS table
hvrt-phase2 Phase 2: EnsembleHVRT on CIFAR-10
hvrt-phase3 Phase 3: Conditional generation + warm-start on CIFAR-10
hvrt-tstr TSTR: classifier trained on synthetic, tested on real MNIST
hvrt-hpo HPO sweeps: strategy / beta / dim / ensemble / tree

Quick smoke test (strategy sweep, 5k samples, 5 epochs):

python -m hvrt_diffuser.experiments.hpo_mnist \
  --sweeps strategy --gen-epochs 5 --n-gen 500 \
  --n-per-class 100 --clf-epochs 3 --n-train 5000 --device cuda

Full TSTR with optimal config:

python -m hvrt_diffuser.experiments.tstr_mnist \
  --latent-dim 64 --gen-epochs 20 --strategy epanechnikov \
  --min-samples-leaf 5 --device cuda

Project Structure

hvrt_diffuser/
  config.py              -- DatasetConfig, TrainingConfig, HVRTConfig, ...
  models/
    beta_vae.py          -- BetaVAE(latent_dim, beta, dataset); beta=0 -> AE
    _encoder_mnist.py / _decoder_mnist.py
    _encoder_cifar.py / _decoder_cifar.py
  hvrt_wrapper.py        -- HVRTDiffuser, EnsembleHVRT
  training.py            -- train_model, encode_dataset, decode_to_tensor
  metrics.py             -- fid_score, inception_score, latent_ks, ...
  datasets.py            -- get_dataset, get_dataloader (MNIST + CIFAR-10)
  checkpointing.py       -- save/load checkpoint
  experiments/
    phase1_mnist.py      -- Phase 1 benchmark
    phase2_cifar.py      -- Phase 2 CIFAR ensemble
    phase3_conditional.py -- Conditional + warm-start
    tstr_mnist.py        -- TSTR utility evaluation
    hpo_mnist.py         -- HPO sweep runner (5 sweeps)
results/
  results_phase1.json
  results_tstr.json
  results_hpo.json

Current Scope

This is a proof of concept on MNIST and CIFAR-10.

  • Encoder architectures are fixed for 28×28 (MNIST) and 32×32 (CIFAR). Higher-resolution domains require new encoder/decoder pairs.
  • Conditional generation requires integer class labels at training time. Attribute-conditioned or text-conditioned generation is not implemented.
  • FID ~19 on MNIST is strong for a single-pass method but not comparable to large-scale diffusion models on natural images. The proposition is speed and inspectability, not state-of-the-art fidelity at scale.
  • The CIFAR-10 experiments (Phase 2/3) are present but the primary validated results are on MNIST.

Dependency

Generation quality depends on the hvrt package (pip install hvrt). FastHVRT is the Hierarchical Variance-Retaining Transformer — a decision-tree-based non-parametric sampler that partitions the latent space hierarchically and samples from per-leaf kernels. The key property exploited here is that FastHVRT preserves the marginal and joint distribution of its training data far more faithfully than VAE reparameterisation (KS statistic 0.031 vs 0.504 at equal sample counts).

About

Proof of concept Diffuser application of HVRT, with multiple benchmarks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages