Skip to content

Commit

Permalink
enforce a maximum text length
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 24, 2022
1 parent 84216fd commit b301b74
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion parti_pytorch/parti_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,16 @@ def __init__(
vae_codebook_size = None,
t5_name = DEFAULT_T5_NAME,
text_embed_dim = None,
cond_drop_prob = 0.25
cond_drop_prob = 0.25,
max_text_len = 128
):
super().__init__()

# text conditioning

text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
self.encode_texts = partial(t5_encode_text, name = t5_name)
self.max_text_len = max_text_len

assert cond_drop_prob > 0.
self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb
Expand Down Expand Up @@ -394,6 +396,10 @@ def forward(
text_token_embeds.to(device)
text_mask.to(device)

# enforce max text len

text_token_embeds, text_mask = map(lambda t: t[:, :self.max_text_len], (text_token_embeds, text_mask))

# classifier free guidance conditional dropout

if cond_drop_prob > 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'parti-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit b301b74

Please sign in to comment.