Skip to content

Commit

Permalink
Make T2I-Adapter downscale padding match the UNet (#5435)
Browse files Browse the repository at this point in the history
* Update get_dummy_inputs(...) in T2I-Adapter tests to take image height and width as params.

* Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised.

* Update the T2I-Adapter down blocks to better match the padding behavior of the UNet.

* Revert "Update the T2I-Adapter unit tests to run with the standard number of UNet down blocks so that all T2I-Adapter down blocks get exercised."

This reverts commit 6d4a060.

* Create  utility functions for testing the T2I-Adapter downscaling bahevior.

* (minor) Improve readability with an intermediate named variable.

* Statically parameterize  T2I-Adapter test dimensions rather than generating them dynamically.

* Fix static checks.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
RyanJDick and sayakpaul committed Oct 23, 2023
1 parent bc7a4d4 commit 0eac9cd
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 34 deletions.
42 changes: 26 additions & 16 deletions src/diffusers/models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
from .resnet import Downsample2D


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -51,24 +50,28 @@ def __init__(self, adapters: List["T2IAdapter"]):
if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

# The outputs from each adapter are added together with a weight
# This means that the change in dimenstions from downsampling must
# be the same for all adapters. Inductively, it also means the total
# downscale factor must also be the same for all adapters.

# The outputs from each adapter are added together with a weight.
# This means that the change in dimensions from downsampling must
# be the same for all adapters. Inductively, it also means the
# downscale_factor and total_downscale_factor must be the same for all
# adapters.
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor

first_adapter_downscale_factor = adapters[0].downscale_factor
for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor

if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
if (
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
or adapters[idx].downscale_factor != first_adapter_downscale_factor
):
raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, "
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
f"Expecting all adapters to have the same downscaling behavior, but got:\n"
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
)

self.total_downscale_factor = adapters[0].total_downscale_factor
self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor

def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Expand Down Expand Up @@ -274,6 +277,13 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def total_downscale_factor(self):
return self.adapter.total_downscale_factor

@property
def downscale_factor(self):
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
not evenly divisible by the downscale_factor then an exception will be raised.
"""
return self.adapter.unshuffle.downscale_factor


# full adapter

Expand Down Expand Up @@ -399,7 +409,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.in_conv = None
if in_channels != out_channels:
Expand Down Expand Up @@ -526,7 +536,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,17 +568,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,17 +622,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down
140 changes: 135 additions & 5 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand Down Expand Up @@ -137,11 +138,100 @@ def get_dummy_components(self, adapter_type):
}
return components

def get_dummy_inputs(self, device, seed=0, num_images=1):
def get_dummy_components_with_full_downscaling(self, adapter_type):
"""Get dummy components with x8 VAE downscaling and 4 UNet down blocks.
These dummy components are intended to fully-exercise the T2I-Adapter
downscaling behavior.
"""
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 32, 32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
cross_attention_dim=32,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 32, 32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

torch.manual_seed(0)

if adapter_type == "full_adapter" or adapter_type == "light_adapter":
adapter = T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type=adapter_type,
)
elif adapter_type == "multi_adapter":
adapter = MultiAdapter(
[
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
]
)
else:
raise ValueError(
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''"
)

components = {
"adapter": adapter,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components

def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
if num_images == 1:
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
else:
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
image = [
floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
]

if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
Expand Down Expand Up @@ -170,11 +260,45 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)

@parameterized.expand(
[
# (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8).
((4 * 8 + 1) * 8),
# (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16).
((4 * 4 + 1) * 16),
# (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32).
((4 * 2 + 1) * 32),
# (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64).
((4 * 1 + 1) * 64),
]
)
def test_multiple_image_dimensions(self, dim):
"""Test that the T2I-Adapter pipeline supports any input dimension that
is divisible by the adapter's `downscale_factor`. This test was added in
response to an issue where the T2I Adapter's downscaling padding
behavior did not match the UNet's behavior.
Note that we have selected `dim` values to produce odd resolutions at
each downscaling level.
"""
components = self.get_dummy_components_with_full_downscaling()
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
image = sd_pipe(**inputs).images

assert image.shape == (1, dim, dim, 3)


class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
return super().get_dummy_components("full_adapter")

def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("full_adapter")

def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand All @@ -195,6 +319,9 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self):
return super().get_dummy_components("light_adapter")

def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("light_adapter")

def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand All @@ -215,8 +342,11 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self):
return super().get_dummy_components("multi_adapter")

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed, num_images=2)
def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("multi_adapter")

def get_dummy_inputs(self, device, height=64, width=64, seed=0):
inputs = super().get_dummy_inputs(device, seed, height=height, width=width, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs

Expand Down
Loading

0 comments on commit 0eac9cd

Please sign in to comment.