Skip to content

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 23, 2023

What does this PR do?

@yiyixuxu given the limited time, I've made sure that the public API and that the weights are correctly named. I've left a lot of TODOs in the code that should be completed next week ideally.

@patrickvonplaten patrickvonplaten changed the title finalize [WIP][Kadinsky 3.0] Nov 23, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 23, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [WIP][Kadinsky 3.0] [Kandinsky 3.0] Add Kandinsky 3.0 Nov 24, 2023
@patrickvonplaten
Copy link
Contributor Author

This PR superseeds #5899 . All credit goes to the original author @cene555 . The PR was created because I couldn't adapt #5899 and given that the model has already been published, I wanted to make sure things were merged quickly. Hopefully that was ok 🙏

@patrickvonplaten patrickvonplaten changed the title [Kandinsky 3.0] Add Kandinsky 3.0 [@ Contributor cene555][Kandinsky 3.0] Add Kandinsky 3.0 Nov 24, 2023
@patrickvonplaten patrickvonplaten changed the title [@ Contributor cene555][Kandinsky 3.0] Add Kandinsky 3.0 [@cene555][Kandinsky 3.0] Add Kandinsky 3.0 Nov 24, 2023
@patrickvonplaten
Copy link
Contributor Author

TODOs for next week (cc @yiyixuxu):

  • Treat all the TODO statemens in the code
  • Add better docs
  • Publish on discord etc...
  • Add tests for img2img
  • Rename the pipeline and model files
  • Clean up all the unet blocks and get rid of hard to read code
  • ...

@patrickvonplaten patrickvonplaten merged commit b978334 into main Nov 24, 2023
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for moving quickly on this. I don't have the appropriate context knowledge to judge the implementation and how well it conforms to existing implementations, so my comments are more about details.

Comment on lines +21 to +22
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division result and remainder can be obtained in a single operation by using divmod.


def process_embeds(self, embeddings, attention_mask, cut_context):
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be made more efficient with torch.where.

Encodes the prompt into text encoder hidden states.

Args:
prompt (`str` or `List[str]`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One space too many

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order here doesn't correspond to the order of the arguments.

if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"type to" => "type as", remove "="

num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument doesn't exist.

The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument doesn't exist.

callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and next argument don't exist, but latents is missing

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not the same order as arguments.

image = [image]
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a double space. [type(i) for i in image] could be a set instead to avoid potentially very long error message.

@shauray8 shauray8 mentioned this pull request Nov 25, 2023
2 tasks
@sayakpaul sayakpaul mentioned this pull request Nov 27, 2023
6 tasks
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one's new!

@@ -0,0 +1,589 @@
import math
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming should (read must) be changed: unet_2d_model_for_kandinsky3.py.

@@ -0,0 +1,589 @@
import math
from dataclasses import dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Todo: missing licensing.

Comment on lines +1 to +7
#!/usr/bin/env python3
import argparse
import fnmatch

from safetensors.torch import load_file

from diffusers import Kandinsky3UNet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed from here?

attention_mask = attention_mask[:, :max_seq_length]
return embeddings, attention_mask

@torch.no_grad()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we introducing this decorator here? Apart from IF, no other pipeline does it.

negative_attention_mask = None
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask

def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No "copy"?

return new_height * scale_factor, new_width * scale_factor


class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: documentation.

`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
"""
cut_context = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be an arg and be defaulted to True.

latents = latents * scheduler.init_noise_sigma
return latents

def check_inputs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No "copy"?

Comment on lines +436 to +450
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
)

if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return (image,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to use the image processor module here. Why are we not using it?

@yiyixuxu yiyixuxu mentioned this pull request Nov 27, 2023
6 tasks
@kashif kashif deleted the kandinsky_30 branch December 5, 2023 08:59
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants