Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make T2I-Adapter downscale padding match the UNet #5377

Closed
RyanJDick opened this issue Oct 12, 2023 · 7 comments · Fixed by #5435
Closed

Make T2I-Adapter downscale padding match the UNet #5377

RyanJDick opened this issue Oct 12, 2023 · 7 comments · Fixed by #5435

Comments

@RyanJDick
Copy link
Contributor

Background

T2I-Adapter models produce residual maps that must match the corresponding UNet hidden state dimensions at each downscaling level. This works as expected when the T2I-Adapter's input image is evenly divisible by the T2I-Adapter's total_downscale_factor.

For example, a FullAdapter used with an SD1.5 UNet and a 1280 x 640 image produces feature maps with the following dimensions:

  • [1, 320, 80, 160]
  • [1, 640, 40, 80]
  • [1. 1280, 20, 40]
  • [1. 1280, 10, 20]

Problem

When an input image is not evenly divisible by the T2I-Adapter's total_downscale_factor, the T2I-Adapter produces residual maps with different dimensions than the UNet model's hidden states.

For example, consider a FullAdapter used with an SD1.5 UNet and a 680 x 384 image. The UNet produces hidden states with the following dimensions:

  • [1, 320, 48, 85]
  • [1, 640, 24, 43]
  • [1, 1280, 12, 22]
  • [1, 1280, 6, 11]

Meanwhile, the T2I-Adapter produces residual maps with the following dimensions:

  • [1, 320, 48, 85]
  • [1, 640, 24, 42]
  • [1, 1280, 12, 21]
  • [1, 1280, 6, 10]

The UNet and T2I-Adapter dimensions do not match, resulting in a torch exception.

Why should we solve this?

It would be nice for T2I-Adapters to work with any input latent noise dimension that is supported by the UNet. For context, I am an InvokeAI contributor and would rather not restrict output image dimensions to multiples of 64 when a T2I-Adapter is applied.

Root cause

The differences in downscaled feature map dimensions between the UNet and T2I-Adapter are caused by differences in padding behavior between their downscaling operations.

The UNet downscales with a torch.nn.Conv2d layer (inside Downsample2D), initialized as follows:

torch.nn.Conv2d(..., kernel_size=3, stride=2, padding=1)

The T2I-Adapter downscales with a torch.nn.AvgPool2d layer (also inside Downsample2D), initialized as follows:

torch.nn.AvgPool2d(kernel_size=2, stride=2)

Proposed Solution

The T2I-Adapter can be modified to match the UNet's downscaling padding behavior by using ceil_mode=True:

torch.nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)

This change would have no impact on input images with dimensions that already work.

Questions for the diffusers team:

  1. Do you agree that this is worth fixing? If yes, I'm happy to put up a PR.
  2. Are there any objections to the proposed solution?
  3. Should this be fixed within Downsample2D or just for T2I-Adapters? I'm leaning towards just fixing it for T2I-Adapters, because Downsample2D is used in many places and I'm wary of unintended side effects from changing it.

Aside: Alignment offset

(This issue is more relevant to the original model creators.)

The padding discrepancy points to an underlying alignment issue in the model that might be worth fixing.

In the current configuration, the T2I Adapter feature maps and the UNet2D feature maps are always offset by half a pixel with respect to the previous feature map. It's difficult to estimate how much impact this would actually have in the trained model, but it may be hindering the T2I Adapter's ability to strictly adhere to the conditioning image's structure. As a rough calculation, in an SD1.5 T2I-Adapter the 0.5 pixel offset in the last downscale operation corresponds to a 16-pixel offset in the original image (0.5 pixel offset x downscale factor at the second-lowest resolution). Not huge, but might have an impact.

This alignment issue could be fixed in the T2I-Adapter by downscaling with a torch.nn.AvgPool2d layer initialized as follows (or by downscaling with a conv layer):

torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1)

This change would alter the characteristics of the T2I-Adapter feature maps, so would warrant re-training the T2I-Adapters. To be clear, I'm just saying that this might be an interesting experiment. This configuration would have a blurring effect on the feature maps and may actually worsen performance.

@sayakpaul
Copy link
Member

Cc @MC-E here as well.

@patrickvonplaten
Copy link
Contributor

@sayakpaul @yiyixuxu can we try to look into this ourselves and try to solve it?

@sayakpaul
Copy link
Member

Sorry for getting to this a bit late.

Excellent thread here @RyanJDick! Thanks for detailing about the problem!

From my side:

  • It's definitely worth fixing! And then it would be interesting to see the impact of this fix on the generation quality of the outputs. For example, if we make the input image divisible w.r.t the current form -- is that better / worse than having the proposed fix.
  • re

Should this be fixed within Downsample2D or just for T2I-Adapters? I'm leaning towards just fixing it for T2I-Adapters, because Downsample2D is used in many places and I'm wary of unintended side effects from changing it.

💯 agree here.

Regarding "Aside: Alignment offset", you probably could build on top of our SDXL T2I Adapter training script and study the consequences. Let me know if that makes sense.

@RyanJDick
Copy link
Contributor Author

It's definitely worth fixing! And then it would be interesting to see the impact of this fix on the generation quality of the outputs. For example, if we make the input image divisible w.r.t the current form -- is that better / worse than having the proposed fix.

Sounds good! I'll put up a PR, including some experiment results to show how it affects generation.

Regarding "Aside: Alignment offset", you probably could build on top of our SDXL T2I Adapter training script and study the consequences. Let me know if that makes sense.

Cool. I won't get to this right away, but would be interested in giving it a try at some point. (I'd also be happy to see someone else give it a shot!)

@MC-E
Copy link
Contributor

MC-E commented Oct 16, 2023

Hi @RyanJDick, I tried your input but couldn't reproduce the error on the diffusers (T2I-Adapter-1.4/1.5). Can you provide more details?

sketch (1)
sketch_image_out (13)

import torch
from PIL import Image
from controlnet_aux import PidiNetDetector
import numpy as np
from diffusers import (
    T2IAdapter,
    StableDiffusionAdapterPipeline
)
image = Image.open('dog.png')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
sketch_image = processor(image).resize((680, 384)).convert('L')
tensor_img = (torch.from_numpy(np.array(sketch_image).astype(np.float32)/255.)>0.5).float()
sketch_image = tensor_img.numpy()
sketch_image = (sketch_image * 255).astype(np.uint8)
sketch_image = Image.fromarray(sketch_image)
sketch_image.save('sketch.png')
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_sketch_sd15v2", torch_dtype=torch.float16)
pipe = StableDiffusionAdapterPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16"
)
[pipe.to](http://pipe.to/)('cuda')
generator = torch.Generator().manual_seed(0)
sketch_image_out = pipe(prompt="a dog", image=sketch_image, generator=generator).images[0]
sketch_image_out.save('sketch_image_out.png')

@RyanJDick
Copy link
Contributor Author

Hi @RyanJDick, I tried your input but couldn't reproduce the error on the diffusers (T2I-Adapter-1.4/1.5). Can you provide more details?

The StableDiffusionAdapterPipeline has a workaround that adjusts the height/width to work:

def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[-2]
# round down to nearest multiple of `self.adapter.total_downscale_factor`
height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[-1]
# round down to nearest multiple of `self.adapter.total_downscale_factor`
width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
return height, width

So, this is only a problem when calling the model directly. I'll include a unit test that triggers this case in my PR.

@MC-E
Copy link
Contributor

MC-E commented Oct 17, 2023

Okay:) thanks for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants