diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 3b56981e5290..925cfbb23594 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -68,7 +68,7 @@ def get_qwen_prompt_embeds( split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -88,6 +88,7 @@ def get_qwen_prompt_embeds_edit( image: Optional[torch.Tensor] = None, prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", prompt_template_encode_start_idx: int = 64, + tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -115,7 +116,7 @@ def get_qwen_prompt_embeds_edit( split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -136,6 +137,7 @@ def get_qwen_prompt_embeds_edit_plus( prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", prompt_template_encode_start_idx: int = 64, + tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -171,7 +173,7 @@ def get_qwen_prompt_embeds_edit_plus( split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = tokenizer_max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..54b09ca6aed7 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -190,6 +190,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -200,7 +201,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -211,7 +212,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -251,7 +252,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -672,9 +675,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + [negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0] + if negative_prompt_embeds is not None + else None ) # 6. Denoising loop diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..00c50d45f70f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -254,6 +254,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -264,7 +265,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -275,7 +276,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -316,7 +317,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -909,7 +912,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0], return_dict=False, ) @@ -920,7 +923,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0], controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +938,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0], controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..ed40874fae47 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -236,6 +236,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -246,7 +247,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(self.device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -257,7 +258,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -297,7 +298,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -852,7 +855,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0], return_dict=False, ) @@ -863,7 +866,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0], controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +881,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), + txt_seq_lens=[negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0], controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..df9d64e65ec4 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -229,6 +229,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -257,8 +258,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -301,7 +304,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -793,7 +798,7 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..9b060c365195 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -240,6 +240,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -268,8 +269,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -313,7 +316,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1008,7 +1013,7 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..0f789270da08 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -232,6 +232,7 @@ def _get_qwen_prompt_embeds( image: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -270,8 +271,10 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # Truncate if longer than max_sequence_length + split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -315,7 +318,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, image, device, max_sequence_length=max_sequence_length + ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -777,7 +782,7 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..a2fea8fbb966 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -197,6 +197,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -207,7 +208,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -218,7 +219,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -294,7 +295,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -775,7 +778,7 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..9169b4551b02 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -207,6 +207,7 @@ def _get_qwen_prompt_embeds( prompt: Union[str, List[str]] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -217,7 +218,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( - txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, @@ -228,7 +229,7 @@ def _get_qwen_prompt_embeds( split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) + max_seq_len = max_sequence_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) @@ -305,7 +306,9 @@ def encode_prompt( batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, device, max_sequence_length=max_sequence_length + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -944,7 +947,7 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index 8ebfe7d08bc1..bbc95072f5be 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -160,7 +160,7 @@ def test_inference(self): self.assertEqual(generated_image.shape, (3, 32, 32)) # fmt: off - expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208]) + expected_slice = torch.tensor([0.5633, 0.6416, 0.6035, 0.5617, 0.5813, 0.5502, 0.5718, 0.6345, 0.4164, 0.3563, 0.5630, 0.4849, 0.4979, 0.5269, 0.4096, 0.5020]) # fmt: on generated_slice = generated_image.flatten() @@ -234,3 +234,61 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) + + def test_prompt_embeds_padding(self): + """Test that prompt embeddings are padded to tokenizer_max_length (1024) instead of batch max.""" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + # Test 1: Short prompt should be padded to 1024, not to its actual length + short_prompt = "test" + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=short_prompt, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=1024, + ) + + # Should be padded to 1024 (tokenizer_max_length), not to the actual token count + self.assertEqual( + prompt_embeds.shape[1], + 1024, + f"Short prompt should be padded to 1024, got {prompt_embeds.shape[1]}", + ) + self.assertEqual( + prompt_embeds_mask.shape[1], + 1024, + f"Mask should be 1024 length, got {prompt_embeds_mask.shape[1]}", + ) + + # Test 2: Batch with different lengths should all be padded to same length (1024) + batch_prompts = ["short", "a much longer prompt here"] + prompt_embeds_batch, mask_batch = pipe.encode_prompt( + prompt=batch_prompts, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=1024, + ) + + self.assertEqual(prompt_embeds_batch.shape[0], 2, "Batch size should be 2") + self.assertEqual( + prompt_embeds_batch.shape[1], + 1024, + f"All prompts in batch should be padded to 1024, got {prompt_embeds_batch.shape[1]}", + ) + + # Test 3: With default max_sequence_length (512), should still pad to 1024 internally + # then truncate to 512 + prompt_embeds_512, mask_512 = pipe.encode_prompt( + prompt=short_prompt, + device=torch_device, + num_images_per_prompt=1, + max_sequence_length=512, + ) + + self.assertEqual( + prompt_embeds_512.shape[1], + 512, + f"With max_sequence_length=512, should truncate to 512, got {prompt_embeds_512.shape[1]}", + )