diff --git a/parti_pytorch/parti_pytorch.py b/parti_pytorch/parti_pytorch.py index 4b8ca2d..a1bae04 100644 --- a/parti_pytorch/parti_pytorch.py +++ b/parti_pytorch/parti_pytorch.py @@ -225,7 +225,8 @@ def __init__( vae_codebook_size = None, t5_name = DEFAULT_T5_NAME, text_embed_dim = None, - cond_drop_prob = 0.25 + cond_drop_prob = 0.25, + max_text_len = 128 ): super().__init__() @@ -233,6 +234,7 @@ def __init__( text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name)) self.encode_texts = partial(t5_encode_text, name = t5_name) + self.max_text_len = max_text_len assert cond_drop_prob > 0. self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb @@ -394,6 +396,10 @@ def forward( text_token_embeds.to(device) text_mask.to(device) + # enforce max text len + + text_token_embeds, text_mask = map(lambda t: t[:, :self.max_text_len], (text_token_embeds, text_mask)) + # classifier free guidance conditional dropout if cond_drop_prob > 0: diff --git a/setup.py b/setup.py index fb6b7df..afbb678 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'parti-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Parti - Pathways Autoregressive Text-to-Image Model - Pytorch', author = 'Phil Wang',