Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)

Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Expand All @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)

Parameters:
sample_size (`int`, *optional*): The size of the input sample.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def __call__(
generator = None

# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
Expand Down
62 changes: 42 additions & 20 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import random
import tempfile
import unittest
from functools import partial

import numpy as np
import torch
Expand Down Expand Up @@ -46,6 +47,7 @@
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -247,7 +249,6 @@ def test_load_pipeline_from_git(self):


class PipelineFastTests(unittest.TestCase):
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
Expand All @@ -256,27 +257,25 @@ def dummy_image(self):
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image

@property
def dummy_uncond_unet(self):
def dummy_uncond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model

@property
def dummy_cond_unet(self):
def dummy_cond_unet(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
Expand All @@ -285,13 +284,12 @@ def dummy_cond_unet(self):
)
return model

@property
def dummy_cond_unet_inpaint(self):
def dummy_cond_unet_inpaint(self, sample_size=32):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
sample_size=sample_size,
in_channels=9,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
Expand All @@ -300,7 +298,6 @@ def dummy_cond_unet_inpaint(self):
)
return model

@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
Expand All @@ -313,7 +310,6 @@ def dummy_vq_model(self):
)
return model

@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
Expand All @@ -326,7 +322,6 @@ def dummy_vae(self):
)
return model

@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
Expand All @@ -342,7 +337,6 @@ def dummy_text_encoder(self):
)
return CLIPTextModel(config)

@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
Expand All @@ -357,15 +351,43 @@ def to(self, device):

return extract

def test_components(self):
@parameterized.expand(
[
[DDIMScheduler, DDIMPipeline, 32],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, 32],
[DDIMScheduler, DDIMPipeline, (32, 64)],
[partial(DDPMScheduler, predict_epsilon=True), DDPMPipeline, (64, 32)],
]
)
def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32):
unet = self.dummy_uncond_unet(sample_size)
# DDIM doesn't take `predict_epsilon`, and DDPM requires it -- so using partial in parameterized decorator
scheduler = scheduler_fn()
pipeline = pipeline_fn(unet, scheduler).to(torch_device)

# Device type MPS is not supported for torch.Generator() api.
if torch_device == "mps":
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)

out_image = pipeline(
generator=generator,
num_inference_steps=2,
output_type="np",
).images
sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size
assert out_image.shape == (1, *sample_size, 3)

def test_stable_diffusion_components(self):
"""Test that components property works correctly"""
unet = self.dummy_cond_unet
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
vae = self.dummy_vae()
bert = self.dummy_text_encoder()
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

Expand All @@ -377,7 +399,7 @@ def test_components(self):
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
feature_extractor=self.dummy_extractor(),
).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
Expand Down