Skip to content

Commit

Permalink
fix classifier free guidance for image hiddens summed to time hiddens…
Browse files Browse the repository at this point in the history
…, thanks to @xvjiarui for finding this bug
  • Loading branch information
lucidrains committed Jun 14, 2022
1 parent 0f31980 commit 5d95871
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,7 @@ def __init__(
# for classifier free guidance

self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))

self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
Expand Down Expand Up @@ -1559,31 +1560,41 @@ def forward(
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)

# conditional dropout

image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)

text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')

# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper

if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)

# conditional dropout

image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
image_hiddens = torch.where(
image_keep_mask_hidden,
image_hiddens,
null_image_hiddens
)

image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
t = t + image_hiddens

# mask out image embedding depending on condition dropout
# for classifier free guidance

image_tokens = None

if self.cond_on_image_embeds:
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working

image_tokens = torch.where(
image_keep_mask,
image_keep_mask_embed,
image_tokens,
null_image_embed
)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.6.16'
__version__ = '0.7.0'

0 comments on commit 5d95871

Please sign in to comment.