Skip to content

Commit

Permalink
also offer l2norm clamping in diffusion prior during training, if one…
Browse files Browse the repository at this point in the history
… were using predict x0 objective
  • Loading branch information
lucidrains committed May 6, 2022
1 parent 09e9eaa commit 14e63a3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions dalle2_pytorch/dalle2_pytorch.py
Expand Up @@ -805,6 +805,7 @@ def __init__(
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
Expand Down Expand Up @@ -842,6 +843,7 @@ def __init__(

# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm

def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
Expand Down Expand Up @@ -894,6 +896,9 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
**text_cond
)

if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale

target = noise if not self.predict_x_start else image_embed

loss = self.loss_fn(pred, target)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 14e63a3

Please sign in to comment.