-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[@cene555][Kandinsky 3.0] Add Kandinsky 3.0 #5913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
TODOs for next week (cc @yiyixuxu):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for moving quickly on this. I don't have the appropriate context knowledge to judge the implementation and how well it conforms to existing implementations, so my comments are more about details.
new_height = height // scale_factor**2 | ||
if height % scale_factor**2 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Division result and remainder can be obtained in a single operation by using divmod
.
|
||
def process_embeds(self, embeddings, attention_mask, cut_context): | ||
if cut_context: | ||
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be made more efficient with torch.where
.
Encodes the prompt into text encoder hidden states. | ||
|
||
Args: | ||
prompt (`str` or `List[str]`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One space too many
Args: | ||
prompt (`str` or `List[str]`, *optional*): | ||
prompt to be encoded | ||
device: (`torch.device`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order here doesn't correspond to the order of the arguments.
if prompt is not None and negative_prompt is not None: | ||
if type(prompt) is not type(negative_prompt): | ||
raise TypeError( | ||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"type to" => "type as", remove "="
num_inference_steps (`int`, *optional*, defaults to 50): | ||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
expense of slower inference. | ||
timesteps (`List[int]`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argument doesn't exist.
The height in pixels of the generated image. | ||
width (`int`, *optional*, defaults to self.unet.config.sample_size): | ||
The width in pixels of the generated image. | ||
eta (`float`, *optional*, defaults to 0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argument doesn't exist.
callback_steps (`int`, *optional*, defaults to 1): | ||
The frequency at which the `callback` function will be called. If not specified, the callback will be | ||
called at every step. | ||
clean_caption (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and next argument don't exist, but latents
is missing
Args: | ||
prompt (`str` or `List[str]`, *optional*): | ||
prompt to be encoded | ||
device: (`torch.device`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not the same order as arguments.
image = [image] | ||
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): | ||
raise ValueError( | ||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a double space. [type(i) for i in image]
could be a set instead to avoid potentially very long error message.
@@ -0,0 +1,98 @@ | |||
#!/usr/bin/env python3 | |||
import argparse | |||
import fnmatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one's new!
@@ -0,0 +1,589 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming should (read must) be changed: unet_2d_model_for_kandinsky3.py
.
@@ -0,0 +1,589 @@ | |||
import math | |||
from dataclasses import dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Todo: missing licensing.
#!/usr/bin/env python3 | ||
import argparse | ||
import fnmatch | ||
|
||
from safetensors.torch import load_file | ||
|
||
from diffusers import Kandinsky3UNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be removed from here?
attention_mask = attention_mask[:, :max_seq_length] | ||
return embeddings, attention_mask | ||
|
||
@torch.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we introducing this decorator here? Apart from IF, no other pipeline does it.
negative_attention_mask = None | ||
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask | ||
|
||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No "copy"?
return new_height * scale_factor, new_width * scale_factor | ||
|
||
|
||
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: documentation.
`self.processor` in | ||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | ||
""" | ||
cut_context = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be an arg and be defaulted to True.
latents = latents * scheduler.init_noise_sigma | ||
return latents | ||
|
||
def check_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No "copy"?
if output_type not in ["pt", "np", "pil"]: | ||
raise ValueError( | ||
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" | ||
) | ||
|
||
if output_type in ["np", "pil"]: | ||
image = image * 0.5 + 0.5 | ||
image = image.clamp(0, 1) | ||
image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
|
||
if output_type == "pil": | ||
image = self.numpy_to_pil(image) | ||
|
||
if not return_dict: | ||
return (image,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might make sense to use the image processor module here. Why are we not using it?
* finalize * finalize * finalize * add slow test * add slow test * add slow test * Fix more * add slow test * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * Better * Fix more * Fix more * add slow test * Add auto pipelines * add slow test * Add all * add slow test * add slow test * add slow test * add slow test * add slow test * Apply suggestions from code review * add slow test * add slow test
* finalize * finalize * finalize * add slow test * add slow test * add slow test * Fix more * add slow test * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * Better * Fix more * Fix more * add slow test * Add auto pipelines * add slow test * Add all * add slow test * add slow test * add slow test * add slow test * add slow test * Apply suggestions from code review * add slow test * add slow test
What does this PR do?
@yiyixuxu given the limited time, I've made sure that the public API and that the weights are correctly named. I've left a lot of TODOs in the code that should be completed next week ideally.