Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_qwen_prompt_embeds(
split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand All @@ -88,6 +88,7 @@ def get_qwen_prompt_embeds_edit(
image: Optional[torch.Tensor] = None,
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
prompt_template_encode_start_idx: int = 64,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
Expand Down Expand Up @@ -115,7 +116,7 @@ def get_qwen_prompt_embeds_edit(
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand All @@ -136,6 +137,7 @@ def get_qwen_prompt_embeds_edit_plus(
prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
prompt_template_encode_start_idx: int = 64,
tokenizer_max_length: int = 1024,
device: Optional[torch.device] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
Expand Down Expand Up @@ -171,7 +173,7 @@ def get_qwen_prompt_embeds_edit_plus(
split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = tokenizer_max_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down
15 changes: 10 additions & 5 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -200,7 +201,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -211,7 +212,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -251,7 +252,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
Expand Down Expand Up @@ -672,9 +675,11 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
[negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0]
if negative_prompt_embeds is not None
else None
)

# 6. Denoising loop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -264,7 +265,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -275,7 +276,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -316,7 +317,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -909,7 +912,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0],
return_dict=False,
)

Expand All @@ -920,7 +923,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0],
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand All @@ -935,7 +938,7 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0],
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def _get_qwen_prompt_embeds(
prompt: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand All @@ -246,7 +247,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
txt, max_length=max_sequence_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
Expand All @@ -257,7 +258,7 @@ def _get_qwen_prompt_embeds(
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -297,7 +298,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -852,7 +855,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0],
return_dict=False,
)

Expand All @@ -863,7 +866,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[prompt_embeds.shape[1]] * prompt_embeds.shape[0],
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand All @@ -878,7 +881,7 @@ def __call__(
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=[negative_prompt_embeds.shape[1]] * negative_prompt_embeds.shape[0],
controlnet_block_samples=controlnet_block_samples,
attention_kwargs=self.attention_kwargs,
return_dict=False,
Expand Down
11 changes: 8 additions & 3 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -257,8 +258,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -301,7 +304,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -793,7 +798,7 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -268,8 +269,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -313,7 +316,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -1008,7 +1013,7 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _get_qwen_prompt_embeds(
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 1024,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
Expand Down Expand Up @@ -270,8 +271,10 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
# Truncate if longer than max_sequence_length
split_hidden_states = [e[:max_sequence_length] if e.size(0) > max_sequence_length else e for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max_sequence_length
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
Expand Down Expand Up @@ -315,7 +318,9 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
prompt, image, device, max_sequence_length=max_sequence_length
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand Down Expand Up @@ -777,7 +782,7 @@ def __call__(
if self.attention_kwargs is None:
self._attention_kwargs = {}

txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
txt_seq_lens = [prompt_embeds.shape[1]] * prompt_embeds.shape[0] if prompt_embeds is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
Expand Down
Loading