Skip to content

Conversation

@99991
Copy link

@99991 99991 commented Feb 29, 2024

What does this PR do?

This PR changes

image_embeds.repeat(batch_size * num_images_per_prompt)

to

image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)

PyTorch expects that the length of the shape of the tensor to be repeated (in this case: len((1, 1, 768)) = 3) is equal to the number of parameters passed to the repeat function.

This PR might be totally wrong, but there is no issue tab in this repository, so this is the next best thing. Anyway, embedding images with the prior network works now. Previously, it failed with the following error message.

File "~/lib/python3.10/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py", line 238, in encode_image
    image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt)
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

Example for image embedding:

import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers.utils import load_image

device = "cuda"
num_images_per_prompt = 2

# Adjust those to your liking
prompt = "TODO_positive_prompt"
negative_prompt = "TODO_negative_prompt"
filename = "TODO_image.png"

image = load_image(filename).resize((1024, 1024))

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.float32).to(device)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade",  torch_dtype=torch.float16).to(device)

encoded_images = prior(
    images=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=1024,
    width=1024,
    num_images_per_prompt=2,
    num_inference_steps=20
)

decoder_output = decoder(
    image_embeddings=encoded_images.image_embeddings.half(),
    prompt=prompt,
    negative_prompt=negative_prompt,
    guidance_scale=0.0,
    output_type="pil",
    num_inference_steps=10
).images

Before submitting

  • [no] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [no] Did you read the contributor guideline?
  • [no] Did you read our philosophy doc (important for complex PRs)?
  • [no] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [no] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [no] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif

@kashif
Copy link
Owner

kashif commented Feb 29, 2024

thanks!

@kashif kashif merged commit 1b171b6 into kashif:wuerstchen-v3 Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants