From 7f1c8ee1478cdf7910c4b72d9757a91952a05b07 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 18 Aug 2022 16:26:21 +0200 Subject: [PATCH 1/2] Support one-string prompts in LDM --- .../latent_diffusion/pipeline_latent_diffusion.py | 8 +++++++- tests/test_modeling_utils.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index fef211333b23..4ec38970e38e 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -37,7 +37,13 @@ def __call__( if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - batch_size = len(prompt) + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") self.unet.to(torch_device) self.vqvae.to(torch_device) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 65597ceb157d..4195f99c66b3 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -854,7 +854,7 @@ def test_ldm_text2img_fast(self): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] + image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"] image_slice = image[0, -3:, -3:, -1] From d0f14ec09d4fd906dfd7de2a56c344ee5ca86e64 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 18 Aug 2022 17:03:14 +0200 Subject: [PATCH 2/2] Add other features from SD too --- .../pipeline_latent_diffusion.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 4ec38970e38e..1edcdadb227f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -1,5 +1,5 @@ import inspect -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,14 +24,15 @@ def __init__(self, vqvae, bert, tokenizer, unet, scheduler): @torch.no_grad() def __call__( self, - prompt, - batch_size=1, - generator=None, - torch_device=None, - eta=0.0, - guidance_scale=1.0, - num_inference_steps=50, - output_type="pil", + prompt: Union[str, List[str]], + height: Optional[int] = 256, + width: Optional[int] = 256, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 1.0, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + torch_device: Optional[Union[str, torch.device]] = None, + output_type: Optional[str] = "pil", ): # eta corresponds to η in paper and should be between [0, 1] @@ -45,6 +46,9 @@ def __call__( else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + self.unet.to(torch_device) self.vqvae.to(torch_device) self.bert.to(torch_device) @@ -59,7 +63,7 @@ def __call__( text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0] latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(torch_device)