Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3e2ff83
finalize
patrickvonplaten Nov 23, 2023
69702a6
finalize
patrickvonplaten Nov 23, 2023
233970e
finalize
patrickvonplaten Nov 23, 2023
ed509c8
add slow test
patrickvonplaten Nov 23, 2023
b75ca86
add slow test
patrickvonplaten Nov 23, 2023
2c2dac8
add slow test
patrickvonplaten Nov 23, 2023
368e70e
Fix more
patrickvonplaten Nov 23, 2023
446a9d4
add slow test
patrickvonplaten Nov 23, 2023
b89a387
fix more
patrickvonplaten Nov 23, 2023
97f621e
fix more
patrickvonplaten Nov 23, 2023
fe0a4ed
fix more
patrickvonplaten Nov 23, 2023
92711d5
fix more
patrickvonplaten Nov 23, 2023
22dfa36
fix more
patrickvonplaten Nov 23, 2023
5781f73
fix more
patrickvonplaten Nov 23, 2023
20c78cf
fix more
patrickvonplaten Nov 23, 2023
efe9a7e
fix more
patrickvonplaten Nov 23, 2023
fb37208
fix more
patrickvonplaten Nov 23, 2023
37684a9
Better
patrickvonplaten Nov 24, 2023
f16f1c3
Fix more
patrickvonplaten Nov 24, 2023
15c6e85
Fix more
patrickvonplaten Nov 24, 2023
33febd4
add slow test
patrickvonplaten Nov 24, 2023
d810bb8
Add auto pipelines
patrickvonplaten Nov 24, 2023
fa33cce
add slow test
patrickvonplaten Nov 24, 2023
3fef9ea
Add all
patrickvonplaten Nov 24, 2023
18a542c
add slow test
patrickvonplaten Nov 24, 2023
f561034
add slow test
patrickvonplaten Nov 24, 2023
c3417ac
add slow test
patrickvonplaten Nov 24, 2023
b8c0c13
add slow test
patrickvonplaten Nov 24, 2023
52ec729
add slow test
patrickvonplaten Nov 24, 2023
6a6fd2a
Apply suggestions from code review
patrickvonplaten Nov 24, 2023
b1ca9ac
add slow test
patrickvonplaten Nov 24, 2023
0e276a4
add slow test
patrickvonplaten Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
Expand Down
24 changes: 24 additions & 0 deletions docs/source/en/api/pipelines/kandinsky3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Kandinsky 3

TODO

## Kandinsky3Pipeline

[[autodoc]] Kandinsky3Pipeline
- all
- __call__

## Kandinsky3Img2ImgPipeline

[[autodoc]] Kandinsky3Img2ImgPipeline
- all
- __call__
98 changes: 98 additions & 0 deletions scripts/convert_kandinsky3_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one's new!


from safetensors.torch import load_file

from diffusers import Kandinsky3UNet


MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}

DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}


def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.

Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)

for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])

if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star

pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))

new_key = new_key.replace(pattern, new_pattern)
has_matched = True

converted_state_dict[new_key] = unet_state_dict[key]

return converted_state_dict


def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)

# Initialize your Kandinsky3UNet model
config = {}

# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)

unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)

unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")

args = parser.parse_args()
main(args.model_path, args.output_path)
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
Expand Down Expand Up @@ -214,6 +215,8 @@
"IFPipeline",
"IFSuperResolutionPipeline",
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
Expand Down Expand Up @@ -446,6 +449,7 @@
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
MultiAdapter,
Expand Down Expand Up @@ -560,6 +564,8 @@
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]

Expand Down Expand Up @@ -63,6 +64,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

Expand Down
41 changes: 40 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.nn.functional as F
from torch import nn
from torch import einsum, nn

from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -2219,6 +2219,44 @@ def __call__(
return hidden_states


# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""

@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)

def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)

attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)

if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand All @@ -2244,6 +2282,7 @@ def __call__(
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)

AttentionProcessor = Union[
Expand Down
Loading