Skip to content

kandinsky model/pipeline review #13597

@hlky

Description

@hlky

kandinsky model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search performed with gh search issues and gh search prs on huggingface/diffusers for kandinsky, affected class names, image_embeds, num_images_per_prompt, combined PIL inputs, transformer_kandinsky, bfloat16, device_map, _no_split_modules, and slow-test coverage. No likely duplicates were found.

Issue 1: Decoder image embeddings are only expanded in the CFG path

Affected code:

if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)

if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)

Problem:
image_embeds is repeated to num_images_per_prompt and moved to the execution dtype/device only when guidance_scale > 1. With guidance_scale <= 1, prompt_embeds and latents are expanded but image_embeds stays at the original batch size. KandinskyPipeline and KandinskyInpaintPipeline then fail for num_images_per_prompt > 1 unless the caller manually pre-expands image_embeds.

Impact:
Valid no-CFG calls cannot generate multiple images per prompt from one prior embedding. The no-CFG path also skips dtype/device normalization for image_embeds.

Reproduction:

import torch
from types import SimpleNamespace
from diffusers import DDIMScheduler, KandinskyPipeline

class Module(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dummy = torch.nn.Parameter(torch.zeros(()))
    @property
    def dtype(self): return self.dummy.dtype
    @property
    def device(self): return self.dummy.device

class Tokenizer:
    model_max_length = 77
    def __call__(self, prompt, **kwargs):
        b = len(prompt) if isinstance(prompt, list) else 1
        return SimpleNamespace(input_ids=torch.ones(b, 4, dtype=torch.long), attention_mask=torch.ones(b, 4, dtype=torch.long))
    def batch_decode(self, ids): return [""]

class TextEncoder(Module):
    def forward(self, input_ids, attention_mask):
        b, s = input_ids.shape
        return torch.zeros(b, 32), torch.zeros(b, s, 32)

class UNet(Module):
    config = SimpleNamespace(in_channels=4)
    def forward(self, sample, timestep, encoder_hidden_states, added_cond_kwargs, return_dict=False):
        assert added_cond_kwargs["image_embeds"].shape[0] == sample.shape[0], (added_cond_kwargs["image_embeds"].shape, sample.shape)
        return (torch.zeros(sample.shape[0], 8, sample.shape[2], sample.shape[3]),)

class Movq(Module):
    config = SimpleNamespace(block_out_channels=[1])
    def decode(self, latents, force_not_quantize=True): return {"sample": latents[:, :3]}

pipe = KandinskyPipeline(TextEncoder(), Tokenizer(), UNet(), DDIMScheduler(num_train_timesteps=2), Movq()).to("cpu")
pipe.set_progress_bar_config(disable=True)
pipe(prompt="p", image_embeds=torch.zeros(1, 32), negative_image_embeds=torch.zeros(1, 32), guidance_scale=1.0, num_images_per_prompt=2, num_inference_steps=1, output_type="pt")

Relevant precedent:
prompt_embeds is expanded before the CFG branch in the same methods:

prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)

Suggested fix:

image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
    dtype=prompt_embeds.dtype, device=device
)

if do_classifier_free_guidance:
    negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to(
        dtype=prompt_embeds.dtype, device=device
    )
    image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)

Issue 2: Combined img2img/inpaint wrappers check prompt instead of image for PIL wrapping

Affected code:

prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
image = [image] if isinstance(prompt, PIL.Image.Image) else image
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
if (
isinstance(image, (list, tuple))
and len(image) < image_embeds.shape[0]
and image_embeds.shape[0] % len(image) == 0
):
image = (image_embeds.shape[0] // len(image)) * image

prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
image = [image] if isinstance(prompt, PIL.Image.Image) else image
mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
if (
isinstance(image, (list, tuple))
and len(image) < image_embeds.shape[0]
and image_embeds.shape[0] % len(image) == 0
):
image = (image_embeds.shape[0] // len(image)) * image
if (
isinstance(mask_image, (list, tuple))
and len(mask_image) < image_embeds.shape[0]
and image_embeds.shape[0] % len(mask_image) == 0
):
mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image

Problem:
KandinskyImg2ImgCombinedPipeline and KandinskyInpaintCombinedPipeline use isinstance(prompt, PIL.Image.Image) when they mean to check image. A single PIL image is not wrapped into a list, so the subsequent expansion to match num_images_per_prompt is skipped.

Impact:
Combined img2img/inpaint calls with a single PIL image and num_images_per_prompt > 1 fail in the decoder with latent batch shape mismatches. Passing [image] works, which makes this a wrapper bug rather than a decoder limitation.

Reproduction:

import numpy as np
import torch
from PIL import Image
from types import SimpleNamespace
from diffusers import KandinskyImg2ImgCombinedPipeline

class Prior:
    def __call__(self, **kwargs):
        return (torch.zeros(2, 32), torch.zeros(2, 32))

class Decoder:
    def __call__(self, prompt, image, **kwargs):
        assert isinstance(image, list) and len(image) == 2, (type(image), image if isinstance(image, list) else None)
        return SimpleNamespace(images="ok")

pipe = object.__new__(KandinskyImg2ImgCombinedPipeline)
pipe.prior_pipe = Prior()
pipe.decoder_pipe = Decoder()
pipe.maybe_free_model_hooks = lambda: None

image = Image.fromarray(np.zeros((8, 8, 3), dtype=np.uint8))
pipe(prompt="horse", image=image, num_images_per_prompt=2)

Relevant precedent:
The branch immediately below already handles list/tuple inputs by length; the missing step is wrapping the single PIL image first.

Suggested fix:

prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
image = [image] if isinstance(image, PIL.Image.Image) else image

For inpaint, apply the same correction to image and keep the existing mask_image PIL wrapping.

Issue 3: Kandinsky5 transformer fails after direct half/bfloat16 casting

Affected code:

def forward(self, time):
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))

def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)

Problem:
Kandinsky5TimeEmbeddings.forward() always builds a float32 embedding and feeds it into nn.Linear. If users call model.to(torch.float16) or model.to(torch.bfloat16), the linear weights are converted but the input remains float32, causing a dtype mismatch. The rotary helper also hard-casts through torch.bfloat16, which silently quantizes float32/float16 attention states.

Impact:
Direct dtype conversion, a common Diffusers usage pattern, breaks Kandinsky5 transformer inference. The hard-coded rotary bfloat16 cast also prevents clean dtype behavior and can degrade fp32 parity.

Reproduction:

import torch
from diffusers import Kandinsky5Transformer3DModel

model = Kandinsky5Transformer3DModel(
    in_visual_dim=4, in_text_dim=8, in_text_dim2=4, time_dim=8, out_visual_dim=4,
    patch_size=(1, 1, 1), model_dim=8, ff_dim=16, num_text_blocks=1,
    num_visual_blocks=1, axes_dims=(2, 2, 4), visual_cond=False,
).eval().to(dtype=torch.bfloat16)

model(
    hidden_states=torch.randn(1, 1, 2, 2, 4, dtype=torch.bfloat16),
    encoder_hidden_states=torch.randn(1, 3, 8, dtype=torch.bfloat16),
    timestep=torch.tensor([1], dtype=torch.bfloat16),
    pooled_projections=torch.randn(1, 4, dtype=torch.bfloat16),
    visual_rope_pos=[torch.arange(1), torch.arange(2), torch.arange(2)],
    text_rope_pos=torch.arange(3),
    return_dict=False,
)

Relevant precedent:
Wan casts timestep projections to the time embedder parameter dtype before the linear and casts the result back to the hidden-state dtype:

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)

Suggested fix:

time_embedder_dtype = self.in_layer.weight.dtype
if time_embed.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
    time_embed = time_embed.to(time_embedder_dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))

And for rotary:

orig_dtype = x.dtype
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).float()
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(orig_dtype)

Issue 4: Kandinsky5 transformer declares repeated blocks but no _no_split_modules

Affected code:

_repeated_blocks = [
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
_supports_gradient_checkpointing = True

Problem:
Kandinsky5Transformer3DModel sets _repeated_blocks but does not set _no_split_modules. Related transformer families mark block classes as no-split for device-map/offload placement.

Impact:
Automatic device maps can split transformer blocks across devices. That is inconsistent with related transformer models and risks poor placement, extra transfers, or correctness issues in block-local attention/modulation paths.

Reproduction:

from diffusers import Kandinsky5Transformer3DModel

print(getattr(Kandinsky5Transformer3DModel, "_no_split_modules", None))
print(Kandinsky5Transformer3DModel._repeated_blocks)
assert Kandinsky5Transformer3DModel._no_split_modules is not None

Relevant precedent:

_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]

_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]

_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]

Suggested fix:

_no_split_modules = ["Kandinsky5TransformerEncoderBlock", "Kandinsky5TransformerDecoderBlock"]
_repeated_blocks = ["Kandinsky5TransformerEncoderBlock", "Kandinsky5TransformerDecoderBlock"]

Issue 5: Slow and enforced fast coverage is incomplete

Affected code:

@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"),
reason="Latest transformers changes the slices",
strict=False,
)
def test_kandinsky(self):

@slow
@require_torch_accelerator
class KandinskyPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinsky_text2img(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/kandinsky/kandinsky_text2img_cat_fp16.npy"
)
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to(torch_device)
pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipeline.to(torch_device)

@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"),
reason="Latest transformers changes the slices",
strict=False,
)
def test_kandinsky_inpaint(self):

@nightly
@require_torch_accelerator
class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinsky_inpaint(self):
expected_image = load_numpy(

@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.56.2"),
reason="Latest transformers changes the slices",
strict=False,
)
def test_kandinsky(self):

class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyPriorPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"generator",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
dummy = Dummies()
return dummy.get_dummy_components()
def get_dummy_inputs(self, device, seed=0):
dummy = Dummies()
return dummy.get_dummy_inputs(device=device, seed=seed)
def test_kandinsky_prior(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.image_embeds
image_from_tuple = pipe(
**self.get_dummy_inputs(device),
return_dict=False,
)[0]
image_slice = image[0, -10:]
image_from_tuple_slice = image_from_tuple[0, -10:]
assert image.shape == (1, 32)
expected_slice = np.array(
[-0.5948, 0.1875, -0.1523, -1.1995, -1.4061, -0.6367, -1.4607, -0.6406, 0.8793, -0.3891]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@skip_mps
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)

Problem:
The core fast image-slice tests are xfailed for Transformers >=4.56.2, so they are not enforced on the current test stack. Slow coverage exists for text2img and img2img, but not for KandinskyPriorPipeline or the combined pipelines. Inpaint has a nightly integration test but no @slow test. There is no dedicated tests/models coverage for Kandinsky5Transformer3DModel.

Impact:
The regressions above are not caught: no-CFG multi-image decoder behavior, combined PIL expansion, Kandinsky5 dtype conversion, and device-map metadata all slip through the current suite.

Reproduction:

from pathlib import Path

for path in sorted(Path("tests/pipelines/kandinsky").glob("test_*.py")):
    text = path.read_text()
    print(path.as_posix(), "slow=", "@slow" in text, "nightly=", "@nightly" in text, "xfail=", "xfail" in text)

has_model_test = any(
    "Kandinsky5Transformer3DModel" in path.read_text(errors="ignore")
    for path in Path("tests/models").rglob("test_*.py")
)
print("Kandinsky5 model test:", has_model_test)

Relevant precedent:
The existing text2img and img2img files show the expected slow integration-test shape:

@slow
@require_torch_accelerator
class KandinskyPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinsky_text2img(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/kandinsky/kandinsky_text2img_cat_fp16.npy"
)
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to(torch_device)
pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipeline.to(torch_device)

@slow
@require_torch_accelerator
class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinsky_img2img(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/kandinsky/kandinsky_img2img_frog.npy"
)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to(torch_device)
pipeline = KandinskyImg2ImgPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
)
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
image_emb, zero_image_emb = pipe_prior(
prompt,
generator=generator,
num_inference_steps=5,
negative_prompt="",
).to_tuple()
output = pipeline(
prompt,
image=init_image,
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
generator=generator,
num_inference_steps=100,
height=768,
width=768,
strength=0.2,
output_type="np",
)
image = output.images[0]
assert image.shape == (768, 768, 3)
assert_mean_pixel_difference(image, expected_image)

Suggested fix:
Add or re-enable coverage for:

# Fast
# - no-CFG num_images_per_prompt > 1 for KandinskyPipeline and KandinskyInpaintPipeline
# - single PIL image + num_images_per_prompt > 1 for combined img2img/inpaint
# - Kandinsky5Transformer3DModel .to(torch.bfloat16) forward

# Slow
# - KandinskyPriorPipeline
# - KandinskyCombinedPipeline
# - KandinskyImg2ImgCombinedPipeline
# - KandinskyInpaintCombinedPipeline
# - @slow inpaint coverage, or mark the existing nightly integration as slow too

Update the Transformers >=4.56.2 expected slices instead of leaving the main fast output tests xfailed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions