Skip to content

visualcloze model/pipeline review #13623

@hlky

Description

@hlky

visualcloze model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate check: searched GitHub issues and PRs for VisualCloze, VisualClozeProcessor, prompt_embeds, latents, generator, resolution, return_dict, get_layout_prompt, _resize_and_crop, upsampling_strength, and slow-test coverage. No open duplicate found. Related merged PR: #12121 fixed a prior multi-image VisualClozeProcessor AttributeError, but not the remaining resize issue below.

Test status: attempted .venv\Scripts\python.exe -m pytest tests/pipelines/visualcloze -q; collection failed because this .venv torch build lacks torch._C._distributed_c10d.

Issue 1: Default generator=None crashes

Affected code:

for i in range(len(input_image)):
_image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents(
input_image[i],
input_mask[i],
generator if isinstance(generator, torch.Generator) else generator[i],
vae_scale_factor,
device,
dtype,
)
masked_image_latents.append(_masked_image_latent)
image_latents.append(_image_latent)
masks.append(_mask)
latent_image_ids.append(_latent_image_ids)
# Concatenate all batches
masked_image_latents = torch.cat(masked_image_latents, dim=0)
image_latents = torch.cat(image_latents, dim=0)
masks = torch.cat(masks, dim=0)
# Handle batch size expansion
if batch_size > masked_image_latents.shape[0]:
if batch_size % masked_image_latents.shape[0] == 0:
# Expand batches by repeating
additional_image_per_prompt = batch_size // masked_image_latents.shape[0]
masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0)
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
masks = torch.cat([masks] * additional_image_per_prompt, dim=0)
else:
raise ValueError(
f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. "
"Batch sizes must be multiples of each other."
)
# Add noise to latents
noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype)
# Combine masked latents with masks
masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype)
return latents, masked_image_latents, latent_image_ids[0]

Problem:
generator defaults to None, but prepare_latents() treats every non-torch.Generator value as an indexable list and evaluates generator[i]. Calling the pipeline without an explicit generator crashes before latent prep.

Impact:
The documented default API is broken. Users must pass a generator even though the signature says it is optional.

Reproduction:

import torch
from diffusers import VisualClozeGenerationPipeline

pipe = object.__new__(VisualClozeGenerationPipeline)
try:
    pipe.prepare_latents([[torch.zeros(1, 3, 16, 16)]], [[torch.zeros(1, 1, 16, 16)]],
                         torch.tensor([1.0]), 1, torch.float32, torch.device("cpu"),
                         None, vae_scale_factor=8)
except Exception as e:
    print(type(e).__name__, e)

Relevant precedent:
FluxFillPipeline.prepare_latents accepts generator=None.

def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

Suggested fix:

sample_generator = generator[i] if isinstance(generator, list) else generator

Issue 2: Precomputed embeddings cannot be used, and latents is ignored

Affected code:

if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
if task_prompt is None and content_prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
# Validate prompt types and consistency
if task_prompt is None:
raise ValueError("`task_prompt` is missing.")

prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
layout_prompt=processor_output["layout_prompt"],
task_prompt=processor_output["task_prompt"],
content_prompt=processor_output["content_prompt"],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare timesteps
# Calculate sequence length and shift factor
image_seq_len = sum(
(size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2)
for sample in processor_output["image_size"][0]
for size in sample
)
# Calculate noise schedule parameters
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
# Get timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
# 5. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, masked_image_latents, latent_image_ids = self.prepare_latents(
processor_output["init_image"],
processor_output["mask"],
latent_timestep,
batch_size * num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
vae_scale_factor=self.vae_scale_factor,
)

Problem:
check_inputs() rejects text prompts with prompt_embeds, but then also raises when task_prompt is missing, making prompt_embeds unusable. Separately, __call__ accepts latents but never forwards it to latent preparation.

Impact:
Documented pipeline controls for prompt reuse and deterministic latent reuse do not work.

Reproduction:

import inspect
import torch
from diffusers import VisualClozeGenerationPipeline

pipe = object.__new__(VisualClozeGenerationPipeline)
pipe._callback_tensor_inputs = ["latents", "prompt_embeds"]

try:
    pipe.check_inputs(None, None, None,
                      prompt_embeds=torch.zeros(1, 4, 8),
                      pooled_prompt_embeds=torch.zeros(1, 8))
except Exception as e:
    print(type(e).__name__, e)

print("prepare_latents accepts latents:",
      "latents" in inspect.signature(VisualClozeGenerationPipeline.prepare_latents).parameters)

Relevant precedent:
FluxPipeline.encode_prompt only encodes text when prompt_embeds is None.

if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids

Suggested fix:
Implement the Flux-style prompt-embed branch, decouple image preprocessing from required text prompts, and add a latents=None parameter to prepare_latents() that returns supplied latents after dtype/device conversion.

Issue 3: resolution is not serialized

Affected code:

def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.resolution = resolution
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VisualClozeProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 128

def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.generation_pipe = VisualClozeGenerationPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
resolution=resolution,
)
self.upsampling_pipe = VisualClozeUpsamplingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)

Problem:
resolution controls preprocessing, but neither pipeline registers it in config. The fast tests work around this by manually passing resolution=32 after reload.

Impact:
Saved 512-resolution or tiny-test pipelines reload with the default 384 preprocessing resolution, changing behavior and potentially increasing memory use.

Reproduction:

import json, tempfile
from diffusers import VisualClozeGenerationPipeline

pipe = VisualClozeGenerationPipeline(None, None, None, None, None, None, None, resolution=32)
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as d:
    pipe.save_pretrained(d)
    print(json.load(open(f"{d}/model_index.json")).get("resolution"))

Relevant precedent:
Scalar pipeline config values are registered with register_to_config.

self.register_to_config(is_distilled=is_distilled)

Suggested fix:

self.register_to_config(resolution=resolution)
self.resolution = resolution

Issue 4: Layout prompt is a tuple, not a string

Affected code:

def get_layout_prompt(self, size: tuple[int, int]) -> str:
layout_instruction = (
f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
)
return layout_instruction

Problem:
A trailing comma makes layout_instruction a one-element tuple. It is later interpolated into the text prompt as ("A grid layout ...",).

Impact:
The text encoder receives tuple punctuation instead of the intended layout instruction string.

Reproduction:

from diffusers.pipelines.visualcloze.visualcloze_utils import VisualClozeProcessor

processor = VisualClozeProcessor(resolution=64)
layout_prompt = processor.get_layout_prompt((2, 3))
print(type(layout_prompt).__name__, layout_prompt)

Relevant precedent:
Normal prompt assembly expects plain strings.

Suggested fix:

layout_instruction = (
    f"A grid layout with {size[0]} rows and {size[1]} columns, "
    f"displaying {size[0] * size[1]} images arranged side by side."
)
return layout_instruction

Issue 5: Multi-target preprocessing swaps resize width and height

Affected code:

# Ensure consistent width for multiple target images when there are multiple target images
if len(target_position) > 1 and sum(target_position) > 1:
new_w = resize_size[n_samples - 1][0] or 384
for i in range(len(processed_images)):
for j in range(len(processed_images[i])):
if processed_images[i][j] is not None:
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
new_w = int(new_w / 16) * 16
new_h = int(new_h / 16) * 16
processed_images[i][j] = self._resize_and_crop(processed_images[i][j], new_h, new_w)

Problem:
The multi-target branch computes new_w and new_h, then calls _resize_and_crop(image, new_h, new_w). _resize_and_crop expects (width, height).

Impact:
For non-square inputs with more than one target, target crops are transposed to the wrong aspect/size. PR #12121 touched this block for a previous AttributeError but did not fix this width/height swap.

Reproduction:

from PIL import Image
from diffusers.pipelines.visualcloze.visualcloze_utils import VisualClozeProcessor

p = VisualClozeProcessor(resolution=64)
imgs = [
    [Image.new("RGB", (128, 64)) for _ in range(3)],
    [None, None, Image.new("RGB", (128, 64))],
]
_, sizes, pos = p.preprocess_image(imgs, vae_scale_factor=8)
print(sizes, pos)

Relevant precedent:
VaeImageProcessor._resize_and_crop(image, width, height) defines the required order.

def _resize_and_crop(
self,
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess.
Args:
image (`PIL.Image.Image`):
The image to resize and crop.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and cropped image.
"""
ratio = width / height
src_ratio = image.width / image.height
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

Suggested fix:

processed_images[i][j] = self._resize_and_crop(processed_images[i][j], new_w, new_h)

Issue 6: Combined pipeline returns the wrong tuple when upsampling is disabled

Affected code:

generation_output = self.generation_pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
joint_attention_kwargs=joint_attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
output_type=output_type if upsampling_strength == 0 else "pil",
)
if upsampling_strength == 0:
if not return_dict:
return (generation_output,)
return FluxPipelineOutput(images=generation_output)

Problem:
For upsampling_strength == 0 and return_dict=False, the method returns (generation_output,) instead of (generation_output.images,).

Impact:
Tuple-output users receive a nested FluxPipelineOutput, unlike other pipelines and unlike the method docstring.

Reproduction:

from diffusers import VisualClozePipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput

class FakeGenerationPipe:
    def __call__(self, **kwargs):
        return FluxPipelineOutput(images=[["generated"]])

pipe = object.__new__(VisualClozePipeline)
pipe.generation_pipe = FakeGenerationPipe()
out = VisualClozePipeline.__call__(pipe, "task", "content", [[None]],
                                   upsampling_strength=0, return_dict=False)
print(type(out[0]).__name__)

Relevant precedent:
Pipeline tuple returns normally expose the payload field directly.

Suggested fix:

if upsampling_strength == 0:
    if not return_dict:
        return (generation_output.images,)
    return generation_output

Issue 7: No slow tests exist for VisualCloze

Affected code:

class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = VisualClozeGenerationPipeline
params = frozenset(
[
"task_prompt",
"content_prompt",
"guidance_scale",
"prompt_embeds",
"pooled_prompt_embeds",
]
)
batch_params = frozenset(["task_prompt", "content_prompt", "image"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=12,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=6,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[2, 2, 2],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"resolution": 32,
}
def get_dummy_inputs(self, device, seed=0):
# Create example images to simulate the input format required by VisualCloze
context_image = [
Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
for _ in range(2)
]
query_image = [
Image.fromarray(
floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
),
None,
]
# Create an image list that conforms to the VisualCloze input format
image = [
context_image, # In-Context example
query_image, # Query image
]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
"content_prompt": "A beautiful landscape with mountains and a lake",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"max_sequence_length": 77,
"output_type": "np",
}
return inputs
def test_visualcloze_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["task_prompt"] = "A different task to perform."
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different
assert max_diff > 1e-6
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
def test_different_task_prompts(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_original = pipe(**inputs).images[0]
inputs["task_prompt"] = "A different task description for image generation"
output_different_task = pipe(**inputs).images[0]
# Different task prompts should produce different outputs
max_diff = np.abs(output_original - output_different_task).max()
assert max_diff > expected_min_diff
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
pass
@unittest.skip("Needs to be revisited later.")
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=0.0001):
pass

class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = VisualClozePipeline
params = frozenset(
[
"task_prompt",
"content_prompt",
"upsampling_height",
"upsampling_width",
"guidance_scale",
"prompt_embeds",
"pooled_prompt_embeds",
"upsampling_strength",
]
)
batch_params = frozenset(["task_prompt", "content_prompt", "image"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=12,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=6,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[2, 2, 2],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder_2 = T5EncoderModel(config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"resolution": 32,
}
def get_dummy_inputs(self, device, seed=0):
# Create example images to simulate the input format required by VisualCloze
context_image = [
Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
for _ in range(2)
]
query_image = [
Image.fromarray(
floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
),
None,
]
# Create an image list that conforms to the VisualCloze input format
image = [
context_image, # In-Context example
query_image, # Query image
]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
"content_prompt": "A beautiful landscape with mountains and a lake",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"upsampling_height": 32,
"upsampling_width": 32,
"max_sequence_length": 77,
"output_type": "np",
"upsampling_strength": 0.4,
}
return inputs
def test_visualcloze_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["task_prompt"] = "A different task to perform."
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different
assert max_diff > 1e-6
def test_visualcloze_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.generation_pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.generation_pipe.vae_scale_factor * 2)
inputs.update({"upsampling_height": height, "upsampling_width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
def test_upsampling_strength(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
# Test different upsampling strengths
inputs["upsampling_strength"] = 0.2
output_no_upsampling = pipe(**inputs).images[0]
inputs["upsampling_strength"] = 0.8
output_full_upsampling = pipe(**inputs).images[0]
# Different upsampling strengths should produce different outputs
max_diff = np.abs(output_no_upsampling - output_full_upsampling).max()
assert max_diff > expected_min_diff
def test_different_task_prompts(self, expected_min_diff=1e-1):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_original = pipe(**inputs).images[0]
inputs["task_prompt"] = "A different task description for image generation"
output_different_task = pipe(**inputs).images[0]
# Different task prompts should produce different outputs
max_diff = np.abs(output_original - output_different_task).max()
assert max_diff > expected_min_diff
@unittest.skip(
"Test not applicable because the pipeline being tested is a wrapper pipeline. CFG tests should be done on the inner pipelines."
)
def test_callback_cfg(self):
pass
def test_save_load_local(self, expected_max_difference=1e-3):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
components = self.get_dummy_components()
for key in components:
if "text_encoder" in key and hasattr(components[key], "eval"):
components[key].eval()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
# NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
# This attribute is not serialized in the config of the pipeline
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for name, component in pipe_loaded.components.items():
if hasattr(component, "dtype"):
self.assertTrue(
component.dtype == torch.float16,
f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
@unittest.skip("Test not supported.")
def test_pipeline_with_accelerator_device_map(self):
pass

Problem:
Only fast tests are present; there is no @slow coverage for the published checkpoints.

Impact:
Checkpoint-specific behavior is untested, including real 384/512 resolution handling, default generator behavior, multi-target tasks, and the two-stage combined pipeline.

Reproduction:

from pathlib import Path

text = "\n".join(p.read_text(encoding="utf-8") for p in Path("tests/pipelines/visualcloze").glob("test_*.py"))
print("@slow" in text)

Relevant precedent:
Flux has slow pipeline tests for real checkpoints.

@slow
@require_big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
weight_name = "ip_adapter.safetensors"
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)

Suggested fix:
Add slow tests for VisualClozeGenerationPipeline and VisualClozePipeline using VisualCloze/VisualClozePipeline-384, covering generator=None, upsampling_strength=0, upsampling_strength>0, multi-target inputs, and save/load resolution preservation.

Issue 8: Docs link the 512 checkpoint to the 384 repo

Affected code:

VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://huggingface.co/papers/2108.01073) to enable high-resolution image synthesis.

Problem:
The VisualClozePipeline-512 link points to VisualClozePipeline-384.

Impact:
Users trying to load the 512-resolution checkpoint are sent to the wrong model page.

Reproduction:

from pathlib import Path

for line in Path("docs/source/en/api/pipelines/visualcloze.md").read_text(encoding="utf-8").splitlines():
    if "VisualClozePipeline-512" in line:
        print(line)

Relevant precedent:
N/A.

Suggested fix:

[VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-512)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions