Skip to content

Commit

Permalink
Enable TF32 for massive training and sampling speedups on Ampere GPUs…
Browse files Browse the repository at this point in the history
… (A100s etc.)
  • Loading branch information
wpeebles committed Feb 22, 2023
1 parent 70b9f00 commit 58fe9c2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Expand Up @@ -100,6 +100,11 @@ similar (and sometimes slightly better) results compared to the JAX-trained mode
These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID
here is computed with 250 DDPM sampling steps, with the `mse` VAE decoder and without guidance (`cfg-scale=1`).

**TF32 Note (important for A100 users).** When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults.
We've enabled them at the top of `train.py` and `sample.py` because it makes training and sampling way way way faster on
A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to
the above results.

### Enhancements
Training (and sampling) could likely be sped-up significantly by:
- [ ] using [Flash Attention](https://github.com/HazyResearch/flash-attention) in the DiT model
Expand Down
2 changes: 2 additions & 0 deletions sample.py
Expand Up @@ -7,6 +7,8 @@
Sample new images from a pre-trained DiT.
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
Expand Down
3 changes: 3 additions & 0 deletions train.py
Expand Up @@ -8,6 +8,9 @@
A minimal training script for DiT using PyTorch DDP.
"""
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
Expand Down

0 comments on commit 58fe9c2

Please sign in to comment.