Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make T2I-Adapter downscale padding match the UNet #5435

Merged
merged 9 commits into from
Oct 23, 2023
42 changes: 26 additions & 16 deletions src/diffusers/models/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
from .resnet import Downsample2D


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -51,24 +50,28 @@ def __init__(self, adapters: List["T2IAdapter"]):
if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")

# The outputs from each adapter are added together with a weight
# This means that the change in dimenstions from downsampling must
# be the same for all adapters. Inductively, it also means the total
# downscale factor must also be the same for all adapters.

# The outputs from each adapter are added together with a weight.
# This means that the change in dimensions from downsampling must
# be the same for all adapters. Inductively, it also means the
# downscale_factor and total_downscale_factor must be the same for all
# adapters.
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor

first_adapter_downscale_factor = adapters[0].downscale_factor
for idx in range(1, len(adapters)):
adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor

if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
if (
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
or adapters[idx].downscale_factor != first_adapter_downscale_factor
):
raise ValueError(
f"Expecting all adapters to have the same total_downscale_factor, "
f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
f"Expecting all adapters to have the same downscaling behavior, but got:\n"
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
)

self.total_downscale_factor = adapters[0].total_downscale_factor
self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor

def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Expand Down Expand Up @@ -274,6 +277,13 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def total_downscale_factor(self):
return self.adapter.total_downscale_factor

@property
def downscale_factor(self):
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
not evenly divisible by the downscale_factor then an exception will be raised.
"""
return self.adapter.unshuffle.downscale_factor


# full adapter

Expand Down Expand Up @@ -399,7 +409,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers:

Downsample2D(in_channels) is really just a nn.AvgPool2d(kernel_size=2, stride=2) layer. The only real change here is setting ceil_mode=True.


self.in_conv = None
if in_channels != out_channels:
Expand Down Expand Up @@ -526,7 +536,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow

self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,17 +568,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,17 +622,17 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor

if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]

# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
# round down to nearest multiple of `self.adapter.downscale_factor`
width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor

return height, width

Expand Down
140 changes: 135 additions & 5 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np
import torch
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand Down Expand Up @@ -137,11 +138,100 @@ def get_dummy_components(self, adapter_type):
}
return components

def get_dummy_inputs(self, device, seed=0, num_images=1):
def get_dummy_components_with_full_downscaling(self, adapter_type):
"""Get dummy components with x8 VAE downscaling and 4 UNet down blocks.
These dummy components are intended to fully-exercise the T2I-Adapter
downscaling behavior.
"""
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 32, 32, 64),
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
cross_attention_dim=32,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 32, 32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

torch.manual_seed(0)

if adapter_type == "full_adapter" or adapter_type == "light_adapter":
adapter = T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type=adapter_type,
)
elif adapter_type == "multi_adapter":
adapter = MultiAdapter(
[
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
T2IAdapter(
in_channels=3,
channels=[32, 32, 32, 64],
num_res_blocks=2,
downscale_factor=8,
adapter_type="full_adapter",
),
]
)
else:
raise ValueError(
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''"
)

components = {
"adapter": adapter,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components

def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
if num_images == 1:
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
else:
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
image = [
floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
]

if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
Expand Down Expand Up @@ -170,11 +260,45 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)

@parameterized.expand(
[
# (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8).
((4 * 8 + 1) * 8),
# (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16).
((4 * 4 + 1) * 16),
# (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32).
((4 * 2 + 1) * 32),
# (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64).
((4 * 1 + 1) * 64),
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
]
)
def test_multiple_image_dimensions(self, dim):
"""Test that the T2I-Adapter pipeline supports any input dimension that
is divisible by the adapter's `downscale_factor`. This test was added in
response to an issue where the T2I Adapter's downscaling padding
behavior did not match the UNet's behavior.

Note that we have selected `dim` values to produce odd resolutions at
each downscaling level.
"""
components = self.get_dummy_components_with_full_downscaling()
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
image = sd_pipe(**inputs).images

assert image.shape == (1, dim, dim, 3)


class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
return super().get_dummy_components("full_adapter")

def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("full_adapter")

def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand All @@ -195,6 +319,9 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self):
return super().get_dummy_components("light_adapter")

def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("light_adapter")

def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand All @@ -215,8 +342,11 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
def get_dummy_components(self):
return super().get_dummy_components("multi_adapter")

def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed, num_images=2)
def get_dummy_components_with_full_downscaling(self):
return super().get_dummy_components_with_full_downscaling("multi_adapter")

def get_dummy_inputs(self, device, height=64, width=64, seed=0):
inputs = super().get_dummy_inputs(device, seed, height=height, width=width, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs

Expand Down
Loading
Loading