From 58fe9c286baa9fcd2d08278a57b178b0650c3eff Mon Sep 17 00:00:00 2001 From: Bill Peebles Date: Tue, 21 Feb 2023 23:51:21 -0800 Subject: [PATCH] Enable TF32 for massive training and sampling speedups on Ampere GPUs (A100s etc.) --- README.md | 5 +++++ sample.py | 2 ++ train.py | 3 +++ 3 files changed, 10 insertions(+) diff --git a/README.md b/README.md index f84ef81..edbc1a2 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,11 @@ similar (and sometimes slightly better) results compared to the JAX-trained mode These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`). +**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. +We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on +A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to +the above results. + ### Enhancements Training (and sampling) could likely be sped-up significantly by: - [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model diff --git a/sample.py b/sample.py index 568634b..82238f1 100644 --- a/sample.py +++ b/sample.py @@ -7,6 +7,8 @@ Sample new images from a pre-trained DiT. """ import torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True from torchvision.utils import save_image from diffusion import create_diffusion from diffusers.models import AutoencoderKL diff --git a/train.py b/train.py index c7eef12..7cfee80 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,9 @@ A minimal training script for DiT using PyTorch DDP. """ import torch +# the first flag below was False when we tested this script but True makes A100 training a lot faster: +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader