From 37e5db2d0a7562d1ef75b0aa5235c329410afc0e Mon Sep 17 00:00:00 2001 From: Penn Date: Mon, 14 Nov 2022 16:03:18 -0800 Subject: [PATCH 1/3] fix non square images with UNet2DModel and DDIM/DDPM pipelines --- src/diffusers/models/unet_2d.py | 6 +++--- src/diffusers/pipelines/ddim/pipeline_ddim.py | 6 +++++- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 6 +++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 641c253c86f8..6f98e769cfee 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): - Input sample size. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 6db6298329a7..b9e590dea646 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,7 +89,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b7194664f4c4..8a18c1ca2eb8 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,7 +94,11 @@ def __call__( generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) From 95587af47347ad712ba7cc89411c5baf13fdf3d5 Mon Sep 17 00:00:00 2001 From: Penn Date: Wed, 16 Nov 2022 12:13:52 -0800 Subject: [PATCH 2/3] fix unet_2d `sample_size` docstring --- src/diffusers/models/unet_2d_condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7f7f3ecd4435..67986dd38c0d 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): implements for all the models (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optional*): The size of the input sample. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. From 4377e6b1138b4caf3a38813c1609121503f3ce4e Mon Sep 17 00:00:00 2001 From: Penn Date: Wed, 16 Nov 2022 12:15:52 -0800 Subject: [PATCH 3/3] update pipeline tests for unet uncond --- tests/test_pipelines.py | 62 ++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4559d713ed81..a35e116f96c6 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -18,6 +18,7 @@ import random import tempfile import unittest +from functools import partial import numpy as np import torch @@ -42,6 +43,7 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu +from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -229,7 +231,6 @@ def test_load_pipeline_from_git(self): class PipelineFastTests(unittest.TestCase): - @property def dummy_image(self): batch_size = 1 num_channels = 3 @@ -238,13 +239,12 @@ def dummy_image(self): image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) return image - @property - def dummy_uncond_unet(self): + def dummy_uncond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -252,13 +252,12 @@ def dummy_uncond_unet(self): ) return model - @property - def dummy_cond_unet(self): + def dummy_cond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -267,13 +266,12 @@ def dummy_cond_unet(self): ) return model - @property - def dummy_cond_unet_inpaint(self): + def dummy_cond_unet_inpaint(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=9, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -282,7 +280,6 @@ def dummy_cond_unet_inpaint(self): ) return model - @property def dummy_vq_model(self): torch.manual_seed(0) model = VQModel( @@ -295,7 +292,6 @@ def dummy_vq_model(self): ) return model - @property def dummy_vae(self): torch.manual_seed(0) model = AutoencoderKL( @@ -308,7 +304,6 @@ def dummy_vae(self): ) return model - @property def dummy_text_encoder(self): torch.manual_seed(0) config = CLIPTextConfig( @@ -324,7 +319,6 @@ def dummy_text_encoder(self): ) return CLIPTextModel(config) - @property def dummy_extractor(self): def extract(*args, **kwargs): class Out: @@ -339,15 +333,43 @@ def to(self, device): return extract - def test_components(self): + @parameterized.expand( + [ + [DDIMScheduler, DDIMPipeline, 32], + [partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32], + [DDIMScheduler, DDIMPipeline, (32, 64)], + [partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)], + ] + ) + def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32): + unet = self.dummy_uncond_unet(sample_size) + # DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator + scheduler = scheduler_fn() + pipeline = pipeline_fn(unet, scheduler).to(torch_device) + + # Device type MPS is not supported for torch.Generator() api. + if torch_device == "mps": + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_image = pipeline( + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size + assert out_image.shape == (1, *sample_size, 3) + + def test_stable_diffusion_components(self): """Test that components property works correctly""" - unet = self.dummy_cond_unet + unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True) - vae = self.dummy_vae - bert = self.dummy_text_encoder + vae = self.dummy_vae() + bert = self.dummy_text_encoder() tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) @@ -359,7 +381,7 @@ def test_components(self): text_encoder=bert, tokenizer=tokenizer, safety_checker=None, - feature_extractor=self.dummy_extractor, + feature_extractor=self.dummy_extractor(), ).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)