Skip to content

Unet sample_size not fully considered in pipeline init #10139

@Foundsheep

Description

@Foundsheep

Describe the bug

Background

  • I'm trying to train a small unet from scratch
  • For this, I used a modified version of train_text_to_image.py
  • unet model's sample_size is not symmetric. I passed [60, 80] when instantiating UNet2DConditionModel, which is normally an int value with 512

What happens

  • An error occurs in the validation step, using log_validation()
  • It's because the pipeline(here, StableDiffusionPipeline) loaded in the method doesn't expect unet.config.sample_size to be a list, but an int when checking for deprecation_message.
  • But I think it should be able to accept list which is already one of the properly accepted ways to pass a value for UNet2DConditionModel's sample_size parameter.
  • The exact line where the error occurs is the code below
# in `pipeline_stable_diffusion.py` line 258
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64

Other things to mention

  • It seems like that line of code is spread across many other files, such as other pipeline_stable...py.
  • If you want, I can spend some time on this and make a PR
  • If so, I'm thinking of making that line of code with other relevant lines into a function and put that function somewhere such as in utils(the location needs to be specified by other members of HF staff)

Reproduction

from diffusers import StableDiffusionPipeline, UNet2DConditionModel

# load pre-trained model
sd_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(sd_repo, use_safetensors=True)

# prepare components
small_unet = UNet2DConditionModel(sample_size=[60, 80], block_out_channels=[32, 64, 128, 256])
vae = pipe.components["vae"]
text_encoder = pipe.components["text_encoder"]
tokenizer = pipe.components["tokenizer"]
scheduler = pipe.components["scheduler"]
safety_checker = None
feature_extractor = pipe.components["feature_extractor"]
image_encoder = pipe.components["image_encoder"]

# error occurs(this would be the same when loading with .from_pretrained())
new_pipe = StableDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=small_unet,
    scheduler=scheduler,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder
)

Logs

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/workspace/03_gen/generative_02/diffusers/examples/text_to_image/train_text_to_image_with_new_unet_from_scratch.py", line 1208, in <module>
[rank0]:     main()
[rank0]:   File "/root/workspace/03_gen/generative_02/diffusers/examples/text_to_image/train_text_to_image_with_new_unet_from_scratch.py", line 1143, in main
[rank0]:     log_validation(
[rank0]:   File "/root/workspace/03_gen/generative_02/diffusers/examples/text_to_image/train_text_to_image_with_new_unet_from_scratch.py", line 144, in log_validation
[rank0]:     pipeline = StableDiffusionPipeline.from_pretrained(
[rank0]:   File "/root/workspace/03_gen/generative_02/.venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/root/workspace/03_gen/generative_02/.venv/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 948, in from_pretrained
[rank0]:     model = pipeline_class(**init_kwargs)
[rank0]:   File "/root/workspace/03_gen/generative_02/.venv/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 258, in __init__
[rank0]:     is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
[rank0]: TypeError: '<' not supported between instances of 'list' and 'int'


### System Info

- 🤗 Diffusers version: 0.32.0.dev0
- Platform: Linux-5.4.0-167-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.10
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.2
- Transformers version: 4.46.3
- Accelerate version: 1.1.1
- PEFT version: 0.7.0
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: GRID A100X-10C, 10240 MiB
GRID A100X-10C, 10240 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

### Who can help?

@yiyixuxu 

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions