Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jul 4, 2023

🚨🚨🚨 Note: The main author of this PR is @cene555 and the Kandinsky team. For simplicity the original PR was continued here. Thanks a mille for the contribution @cene555 🚨🚨🚨

Authors of this PR:
Arseniy Shakhmatov
Anton Razzhigaev
Aleksandr Nikolich
Igor Pavlov
Andrey Kuznetsov
Denis Dimitrov

finishing up #3903

To-do:

  • add tests for text2img, img2img, inpaint, prior
  • test controlnet + prior_emb2emb
  • add doc
import torch
import numpy as np

from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
from transformers import pipeline
from diffusers.utils import load_image

def make_hint(image, depth_estimator):
  image = depth_estimator(image)['depth']
  image = np.array(image)
  image = image[:, :, None]
  image = np.concatenate([image, image, image], axis=2)
  detected_map = torch.from_numpy(image).float() / 255.0
  hint = detected_map.permute(2, 0, 1)
  return hint

depth_estimator = pipeline('depth-estimation')

pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained('kandinsky-community/kandinsky-2-2-prior',torch_dtype=torch.float16)
pipe_prior = pipe_prior.to("cuda")

pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained('kandinsky-community/kandinsky-2-2-controlnet-depth', torch_dtype=torch.float16)
pipe = pipe.to("cuda")


img = load_image(
             "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
            "/kandinsky/cat.png"
        ).resize((768, 768))


hint = make_hint(img, depth_estimator).unsqueeze(0).half().to('cuda')

prompt = 'A robot, 4k photo'
negative_prior_prompt ='lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature'

generator = torch.Generator(device='cuda').manual_seed(43)

# run prior pipeline  

img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)

# run controlnet img2img pipeline
images = pipe(
    image=img, 
    strength=0.5, 
    image_embeds=img_emb.image_embeds, 
    negative_image_embeds=negative_emb.image_embeds, 
    hint=hint, 
    num_inference_steps=50, 
    generator=generator,
    height=768, 
    width=768).images

images[0].save("robot_cat.png")

cat

robot_cat

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jul 6, 2023

@patrickvonplaten

finished all the to-dos from you

  • Remove the _decoder suffix from all file names to shorten the file name
  • Rename self.vae to self.movq since a MoVQ is used here again
  • Add copied from statements whenever it makes sense and rename the get_new_h_w better as explained above
  • Give Kandinsky 2.2 its own section in the docs

will send PR to the repo to change vae -> movq and then model cards maybe

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

@patrickvonplaten patrickvonplaten merged commit 7462156 into main Jul 6, 2023
@patrickvonplaten patrickvonplaten deleted the kandinsky22-yiyi branch July 6, 2023 13:17
@Lime-Cakes
Copy link
Contributor

Thanks for the great work! Though, it seems that model weights for v2.2 isn't released? "kandinsky-community/kandinsky-2-2-controlnet-depth" can't be found. So the examples can't be run atm.

@patrickvonplaten
Copy link
Contributor

See: https://huggingface.co/docs/diffusers/v0.18.2/en/api/pipelines/kandinsky#kandinsky-22 - it was open-sourced today!

@Lime-Cakes
Copy link
Contributor

See: https://huggingface.co/docs/diffusers/v0.18.2/en/api/pipelines/kandinsky#kandinsky-22 - it was open-sourced today!

Thanks! I see it now! Looks great.

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Kandinsky2_2

* fix init kandinsky2_2

* kandinsky2_2 fix inpainting

* rename pipelines: remove decoder + 2_2 -> V22

* Update scheduling_unclip.py

* remove text_encoder and tokenizer arguments from doc string

* add test for text2img

* add tests for text2img & img2img

* fix

* add test for inpaint

* add prior tests

* style

* copies

* add controlnet test

* style

* add a test for controlnet_img2img

* update prior_emb2emb api to accept image_embedding or image

* add a test for prior_emb2emb

* style

* remove try except

* example

* fix

* add doc string examples to all kandinsky pipelines

* style

* update doc

* style

* add a top about 2.2

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* vae -> movq

* vae -> movq

* style

* fix the #copied from

* remove decoder from file name

* update doc: add a section for kandinsky 2.2

* fix

* fix-copies

* add coped from

* add copies from for prior

* add copies from for prior emb2emb

* copy from for img2img

* copied from for inpaint

* more copied from

* more copies from

* more copies

* remove the yiyi comments

* Apply suggestions from code review

* Self-contained example, pipeline order

* Import prior output instead of redefining.

* Style

* Make VQModel compatible with model offload.

* Fix copies

---------

Co-authored-by: Shahmatov Arseniy <62886550+cene555@users.noreply.github.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Kandinsky2_2

* fix init kandinsky2_2

* kandinsky2_2 fix inpainting

* rename pipelines: remove decoder + 2_2 -> V22

* Update scheduling_unclip.py

* remove text_encoder and tokenizer arguments from doc string

* add test for text2img

* add tests for text2img & img2img

* fix

* add test for inpaint

* add prior tests

* style

* copies

* add controlnet test

* style

* add a test for controlnet_img2img

* update prior_emb2emb api to accept image_embedding or image

* add a test for prior_emb2emb

* style

* remove try except

* example

* fix

* add doc string examples to all kandinsky pipelines

* style

* update doc

* style

* add a top about 2.2

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* vae -> movq

* vae -> movq

* style

* fix the #copied from

* remove decoder from file name

* update doc: add a section for kandinsky 2.2

* fix

* fix-copies

* add coped from

* add copies from for prior

* add copies from for prior emb2emb

* copy from for img2img

* copied from for inpaint

* more copied from

* more copies from

* more copies

* remove the yiyi comments

* Apply suggestions from code review

* Self-contained example, pipeline order

* Import prior output instead of redefining.

* Style

* Make VQModel compatible with model offload.

* Fix copies

---------

Co-authored-by: Shahmatov Arseniy <62886550+cene555@users.noreply.github.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants