Skip to content

Commit

Permalink
add diffusion prior trainer, which automatically takes care of the ex…
Browse files Browse the repository at this point in the history
…ponential moving average (training and sampling), as well as mixed precision, gradient clipping
  • Loading branch information
lucidrains committed May 6, 2022
1 parent 878b555 commit 740d644
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 4 deletions.
62 changes: 62 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,68 @@ mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```

### Diffusion Prior Training

Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.

```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()

# mock data

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)

loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior

# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior

image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```

### Decoder Dataloaders

In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer

from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP
12 changes: 12 additions & 0 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,18 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
loss = self.loss_fn(pred, target)
return loss

@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)

img = torch.randn(shape, device = device)

for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img

@torch.inference_mode()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
Expand Down
85 changes: 83 additions & 2 deletions dalle2_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch.cuda.amp import autocast, GradScaler

from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer

# helper functions
Expand Down Expand Up @@ -89,7 +89,88 @@ def calculate_ema(beta, old, new):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)

# trainers
# diffusion prior trainer

class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)

self.diffusion_prior = diffusion_prior

# exponential moving average

self.use_ema = use_ema

if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'

if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)

# optimizer and mixed precision stuff

self.amp = amp

self.scaler = GradScaler(enabled = amp)

self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)

# gradient clipping if needed

self.max_grad_norm = max_grad_norm

def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)

self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()

if self.use_ema:
self.ema_diffusion_prior.update()

@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)

@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)

@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)

def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)

# decoder trainer

class DecoderTrainer(nn.Module):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.107',
version = '0.0.108',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 740d644

Please sign in to comment.