diff --git a/.gitignore b/.gitignore index 1774f101..99230b14 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ dist/ *.egg-info/ .DS_Store/ .pytest_cache/ -.ruff_cache/ \ No newline at end of file +.ruff_cache/ +CLAUDE.md \ No newline at end of file diff --git a/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json b/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json index 9f8e4ff0..cccbcb6b 100644 --- a/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +++ b/diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json @@ -21,5 +21,6 @@ "vision_start_token_id": 151652, "vision_end_token_id": 151653, "image_token_id": 151655, - "video_token_id": 151656 + "video_token_id": 151656, + "attn_impl": "sdpa" } \ No newline at end of file diff --git a/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json b/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json new file mode 100644 index 00000000..40adabb9 --- /dev/null +++ b/diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json @@ -0,0 +1,29 @@ +{ + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "Qwen2VLImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "max_pixels": 12845056, + "merge_size": 2, + "min_pixels": 3136, + "patch_size": 14, + "processor_class": "Qwen2_5_VLProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "longest_edge": 12845056, + "shortest_edge": 3136 + }, + "temporal_patch_size": 2 +} \ No newline at end of file diff --git a/diffsynth_engine/models/basic/attention.py b/diffsynth_engine/models/basic/attention.py index 25ddc72c..a18ce6e2 100644 --- a/diffsynth_engine/models/basic/attention.py +++ b/diffsynth_engine/models/basic/attention.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange, repeat from typing import Optional -import torch.nn.functional as F from diffsynth_engine.utils import logging from diffsynth_engine.utils.flag import ( FLASH_ATTN_3_AVAILABLE, @@ -42,11 +42,11 @@ def xformers_attn(q, k, v, attn_mask=None, scale=None): if SDPA_AVAILABLE: - def sdpa_attn(q, k, v, attn_mask=None, scale=None): + def sdpa_attn(q, k, v, attn_mask=None, is_causal=False, scale=None): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal, scale=scale) return out.transpose(1, 2) diff --git a/diffsynth_engine/models/qwen_image/qwen2_5_vl.py b/diffsynth_engine/models/qwen_image/qwen2_5_vl.py index d1e12562..15a830ff 100644 --- a/diffsynth_engine/models/qwen_image/qwen2_5_vl.py +++ b/diffsynth_engine/models/qwen_image/qwen2_5_vl.py @@ -7,7 +7,7 @@ from diffsynth_engine.models.base import PreTrainedModel from diffsynth_engine.models.basic.transformer_helper import RMSNorm -from diffsynth_engine.models.basic.attention import attention +from diffsynth_engine.models.basic import attention as attention_ops from diffsynth_engine.models.utils import no_init_weights from diffsynth_engine.utils.cache import Cache, DynamicCache from diffsynth_engine.utils import logging @@ -152,17 +152,15 @@ def __init__( self, dim: int = 80, theta: float = 10000.0, - device: str = "cuda:0", - dtype: torch.dtype = torch.bfloat16, ): super().__init__() - with torch.device(device): - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + with torch.device("cpu"): + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) + def forward(self, seqlen: int, device: str) -> torch.Tensor: + inv_freq = self.inv_freq.to(device=device) + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) return freqs @@ -222,7 +220,7 @@ def forward( q = rearrange(q, "s n d -> 1 s n d") k = rearrange(k, "s n d -> 1 s n d") v = rearrange(v, "s n d -> 1 s n d") - out = attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask) + out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask) out = rearrange(out, "1 s n d -> s (n d)") out = self.proj(out) return out @@ -301,7 +299,7 @@ def __init__(self, config: Qwen2_5_VLVisionConfig, device: str = "cuda:0", dtype dtype=dtype, ) head_dim = config.hidden_size // config.num_heads - self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device, dtype=dtype) + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ Qwen2_5_VisionBlock( @@ -348,7 +346,7 @@ def rot_pos_emb(self, grid_thw): pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @@ -488,7 +486,6 @@ def __init__( hidden_size: int = 3584, num_attention_heads: int = 28, num_key_value_heads: int = 4, - # dropout: float = 0.0, mrope_section: List[int] = [16, 24, 24], attn_impl: Optional[str] = None, device: str = "cuda:0", @@ -501,7 +498,6 @@ def __init__( self.head_dim = hidden_size // num_attention_heads self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = num_attention_heads // num_key_value_heads - # self.dropout = dropout self.mrope_section = mrope_section self.attn_impl = attn_impl @@ -521,8 +517,6 @@ def __init__( self.num_attention_heads * self.head_dim, self.hidden_size, bias=False, device=device, dtype=dtype ) - self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=self.head_dim, device=device, dtype=dtype) - def forward( self, hidden_states: torch.Tensor, @@ -556,14 +550,18 @@ def forward( if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[1]] - # TODO: attention_mask for flash attention 2 - out = attention( - query_states, - key_states, - value_states, - attn_impl=self.attn_impl, - attn_mask=causal_mask, - ) + # TODO: use is_causal when attention mask is causal + if self.attn_impl == "sdpa": + out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True) + else: + # TODO: attention_mask for flash attention 2 + out = attention_ops.attention( + query_states, + key_states, + value_states, + attn_impl=self.attn_impl, + attn_mask=causal_mask, + ) out = rearrange(out, "b s n d -> b s (n d)") out = self.o_proj(out) return out, past_key_values @@ -647,29 +645,29 @@ def forward( class Qwen2_5_VLRotaryEmbedding(nn.Module): - def __init__(self, dim: int = 128, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16): + def __init__(self, dim: int = 128): super().__init__() - with torch.device(device): - inv_freq = self.compute_rope(dim) # default rope without dynamic frequency - self.register_buffer("inv_freq", inv_freq, persistent=False) + with torch.device("cpu"): + self.inv_freq = self.compute_rope(dim) # default rope without dynamic frequency def compute_rope(self, dim: int, theta: float = 1000000.0): inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) return inv_freq @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, position_ids: torch.LongTensor, device: str, dtype: torch.dtype): # In contrast to other models, Qwen2_5_VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + inv_freq = self.inv_freq.to(device=device) + inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - return cos.to(device=x.device, dtype=x.dtype), sin.to(device=x.device, dtype=x.dtype) + return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype) class Qwen2_5_VLModel(nn.Module): @@ -702,7 +700,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device: str = "cuda:0", dtype: torc ) self.norm = Qwen2_5_RMSNorm(config.hidden_size, config.rms_norm_eps, device=device, dtype=dtype) head_dim = config.hidden_size // config.num_attention_heads - self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim, device=device, dtype=dtype) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim) def get_input_embeddings(self): return self.embed_tokens @@ -749,7 +747,7 @@ def forward( hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(position_ids, device=hidden_states.device, dtype=hidden_states.dtype) # decoder layers for decoder_layer in self.layers: @@ -940,8 +938,7 @@ def from_state_dict( with torch.device("meta"), no_init_weights(): model = cls(vision_config=vision_config, config=config, device=device, dtype=dtype) model.load_state_dict(state_dict, assign=True) - for param in model.parameters(): # skip buffers - param.data = param.data.to(device=device, dtype=dtype, non_blocking=True) + model.to(device=device, dtype=dtype, non_blocking=True) return model def get_input_embeddings(self): @@ -1202,27 +1199,14 @@ def forward( if position_ids is None: assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D" # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas hidden_states, present_key_values = self.model( input_ids=None, diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit.py b/diffsynth_engine/models/qwen_image/qwen_image_dit.py index 2c0d2df4..0c1dd2cd 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit.py @@ -81,41 +81,47 @@ def rope_params(self, index, dim, theta=10000): def forward(self, video_fhw, txt_length, device): """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video - Args: txt_length: an integer representing the length of text + Args: + video_fhw (List[Tuple[int, int, int]]): A list of (frame, height, width) tuples for each video/image + txt_length (int): The maximum length of the text sequences """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) - frame, height, width = video_fhw - rope_key = f"{frame}_{height}_{width}" - - if rope_key not in self.rope_cache: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + rope_key = f"{idx}_{height}_{width}" + + if rope_key not in self.rope_cache: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + self.rope_cache[rope_key] = freqs.clone().contiguous() + vid_freqs.append(self.rope_cache[rope_key]) if self.scale_rope: - freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) - + max_vid_index = max(height // 2, width // 2, max_vid_index) else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) - self.rope_cache[rope_key] = freqs.clone().contiguous() - vid_freqs = self.rope_cache[rope_key] - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) + max_vid_index = max(height, width, max_vid_index) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -364,6 +370,7 @@ def unpatchify(self, hidden_states, height, width): def forward( self, image: torch.Tensor, + edit: torch.Tensor = None, text: torch.Tensor = None, timestep: torch.LongTensor = None, txt_seq_lens: torch.LongTensor = None, @@ -377,6 +384,7 @@ def forward( cfg_parallel( ( image, + edit, text, timestep, txt_seq_lens, @@ -385,11 +393,18 @@ def forward( ), ): conditioning = self.time_text_embed(timestep, image.dtype) - video_fhw = (1, h // 2, w // 2) # frame, height, width + video_fhw = [(1, h // 2, w // 2)] # frame, height, width max_length = txt_seq_lens.max().item() + image = self.patchify(image) + image_seq_len = image.shape[1] + if edit is not None: + edit = edit.to(dtype=image.dtype) + edit = self.patchify(edit) + image = torch.cat([image, edit], dim=1) + video_fhw += video_fhw + image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device) - image = self.patchify(image) image = self.img_in(image) text = self.txt_in(self.txt_norm(text[:, :max_length])) @@ -397,6 +412,8 @@ def forward( text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb) image = self.norm_out(image, conditioning) image = self.proj_out(image) + if edit is not None: + image = image[:, :image_seq_len] image = self.unpatchify(image, h, w) diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 43d16f31..37fbccbf 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -164,7 +164,7 @@ def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image | List[Image.Im @staticmethod def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16): generator = None if seed is None else torch.Generator(device).manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=device).to(dtype) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) return noise def encode_image( diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 9eb1762a..d2217c8f 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -1,10 +1,11 @@ import json import torch +import torch.distributed as dist import math from typing import Callable, List, Tuple, Optional, Union, Dict from tqdm import tqdm from einops import rearrange -import torch.distributed as dist +from PIL import Image from diffsynth_engine.configs import QwenImagePipelineConfig, QwenImageStateDicts from diffsynth_engine.models.basic.lora import LoRAContext @@ -16,13 +17,14 @@ Qwen2_5_VLConfig, ) from diffsynth_engine.models.qwen_image import QwenImageVAE -from diffsynth_engine.tokenizers import Qwen2TokenizerFast +from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter from diffsynth_engine.pipelines.utils import calculate_shift from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler from diffsynth_engine.utils.constants import ( QWEN_IMAGE_TOKENIZER_CONF_PATH, + QWEN_IMAGE_PROCESSOR_CONFIG_FILE, QWEN_IMAGE_CONFIG_FILE, QWEN_IMAGE_VISION_CONFIG_FILE, QWEN_IMAGE_VAE_CONFIG_FILE, @@ -44,20 +46,23 @@ def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, lora_a_suffix = None if "lora_A.default.weight" in key: lora_a_suffix = "lora_A.default.weight" + lora_b_suffix = "lora_B.default.weight" elif "lora_A.weight" in key: lora_a_suffix = "lora_A.weight" + lora_b_suffix = "lora_B.weight" + elif "lora_down.weight" in key: + lora_a_suffix = "lora_down.weight" + lora_b_suffix = "lora_up.weight" if lora_a_suffix is None: continue lora_args = {} lora_args["down"] = param - - lora_b_suffix = lora_a_suffix.replace("lora_A", "lora_B") lora_args["up"] = lora_state_dict[origin_key.replace(lora_a_suffix, lora_b_suffix)] lora_args["rank"] = lora_args["up"].shape[1] - alpha_key = origin_key.replace("lora_up", "lora_A").replace(lora_a_suffix, "alpha") + alpha_key = origin_key.replace(lora_a_suffix, "alpha") if alpha_key in lora_state_dict: alpha = lora_state_dict[alpha_key] @@ -83,6 +88,7 @@ def __init__( self, config: QwenImagePipelineConfig, tokenizer: Qwen2TokenizerFast, + processor: Qwen2VLProcessor, encoder: Qwen2_5_VLForConditionalGeneration, dit: QwenImageDiT, vae: QwenImageVAE, @@ -97,11 +103,15 @@ def __init__( self.config = config self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode_start_idx = 34 + + self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.edit_prompt_template_encode_start_idx = 64 # sampler self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True) self.sampler = FlowMatchEulerSampler() # models self.tokenizer = tokenizer + self.processor = processor self.encoder = encoder self.dit = dit self.vae = vae @@ -155,6 +165,10 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH) + processor = Qwen2VLProcessor.from_pretrained( + tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH, + image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE, + ) with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r") as f: vision_config = Qwen2_5_VLVisionConfig(**json.load(f)) with open(QWEN_IMAGE_CONFIG_FILE, "r") as f: @@ -201,6 +215,7 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe pipe = cls( config=config, tokenizer=tokenizer, + processor=processor, encoder=encoder, dit=dit, vae=vae, @@ -209,7 +224,7 @@ def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipe if config.offload_mode is not None: pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk) - + if config.model_dtype == torch.float8_e4m3fn: pipe.dtype = torch.bfloat16 # compute dtype pipe.enable_fp8_autocast( @@ -302,9 +317,51 @@ def encode_prompt( return prompt_embeds, prompt_embeds_mask + def encode_prompt_with_image( + self, + prompt: Union[str, List[str]], + image: torch.Tensor, + num_images_per_prompt: int = 1, + max_sequence_length: int = 1024, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + batch_size = len(prompt) + template = self.edit_prompt_template_encode + drop_idx = self.edit_prompt_template_encode_start_idx + texts = [template.format(txt) for txt in prompt] + + model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx) + input_ids, attention_mask, pixel_values, image_grid_thw = ( + model_inputs["input_ids"].to(self.device), + model_inputs["attention_mask"].to(self.device), + model_inputs["pixel_values"].to(self.device), + model_inputs["image_grid_thw"].to(self.device), + ) + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + hidden_states = outputs["hidden_states"] + prompt_embeds = hidden_states[:, drop_idx:] + prompt_embeds_mask = attention_mask[:, drop_idx:] + seq_len = prompt_embeds.shape[1] + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + def predict_noise_with_cfg( self, latents: torch.Tensor, + image_latents: torch.Tensor, timestep: torch.Tensor, prompt_emb: torch.Tensor, negative_prompt_emb: torch.Tensor, @@ -316,6 +373,7 @@ def predict_noise_with_cfg( if cfg_scale <= 1.0 or negative_prompt_emb is None: return self.predict_noise( latents, + image_latents, timestep, prompt_emb, prompt_embeds_mask, @@ -325,12 +383,14 @@ def predict_noise_with_cfg( h, w = latents.shape[-2:] positive_noise_pred = self.predict_noise( latents, + image_latents, timestep, prompt_emb, prompt_embeds_mask, ) negative_noise_pred = self.predict_noise( latents, + image_latents, timestep, negative_prompt_emb, negative_prompt_embeds_mask, @@ -346,9 +406,11 @@ def predict_noise_with_cfg( prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0) prompt_embeds_mask = torch.cat([prompt_embeds_mask, negative_prompt_embeds_mask], dim=0) latents = torch.cat([latents, latents], dim=0) + image_latents = torch.cat([image_latents, image_latents], dim=0) timestep = torch.cat([timestep, timestep], dim=0) noise_pred = self.predict_noise( latents, + image_latents, timestep, prompt_emb, prompt_embeds_mask, @@ -363,25 +425,49 @@ def predict_noise_with_cfg( def predict_noise( self, latents: torch.Tensor, + image_latents: torch.Tensor, timestep: torch.Tensor, prompt_emb: torch.Tensor, prompt_embeds_mask: torch.Tensor, ): self.load_models_to_device(["dit"]) - noise_pred = self.dit( image=latents, + edit=image_latents, text=prompt_emb, timestep=timestep, txt_seq_lens=prompt_embeds_mask.sum(dim=1), ) return noise_pred + def prepare_image_latents(self, input_image: Image.Image): + image = self.preprocess_image(input_image).to( + device=self.device, dtype=self.vae.model.encoder.conv1.weight.dtype + ) + image = image.unsqueeze(2) + image_latents = self.vae.encode( + image, + device=self.device, + tiled=self.vae_tiled, + tile_size=self.vae_tile_size, + tile_stride=self.vae_tile_stride, + ) + image_latents = image_latents.squeeze(2) + return image_latents + + def calculate_dimensions(self, target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + @torch.no_grad() def __call__( self, prompt: str, negative_prompt: str = "", + input_image: Image.Image | None = None, # use for img2img cfg_scale: float = 4.0, # true cfg height: int = 1328, width: int = 1328, @@ -389,29 +475,51 @@ def __call__( seed: int | None = None, progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status) ): + if input_image is not None: + width, height = input_image.size + width, height = self.calculate_dimensions(1024 * 1024, width / height) + input_image = input_image.resize((width, height), Image.LANCZOS) + + self.validate_image_size(height, width, minimum=64, multiple_of=16) + noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to( device=self.device ) # dynamic shift image_seq_len = math.ceil(height // 16) * math.ceil(width // 16) mu = calculate_shift(image_seq_len, max_shift=0.9, max_seq_len=8192) + if input_image: + image_latents = self.prepare_image_latents(input_image) + else: + image_latents = None init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps, mu) # Initialize sampler self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas) self.load_models_to_device(["encoder"]) - prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096) - if cfg_scale > 1.0 and negative_prompt != "": - negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096) + if image_latents is not None: + prompt_embeds, prompt_embeds_mask = self.encode_prompt_with_image(prompt, input_image, 1, 4096) + if cfg_scale > 1.0 and negative_prompt != "": + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_with_image( + negative_prompt, input_image, 1, 4096 + ) + else: + negative_prompt_embeds, negative_prompt_embeds_mask = None, None else: - negative_prompt_embeds, negative_prompt_embeds_mask = None, None + prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096) + if cfg_scale > 1.0 and negative_prompt != "": + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096) + else: + negative_prompt_embeds, negative_prompt_embeds_mask = None, None self.model_lifecycle_finish(["encoder"]) hide_progress = dist.is_initialized() and dist.get_rank() != 0 + for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)): timestep = timestep.unsqueeze(0).to(dtype=self.dtype) noise_pred = self.predict_noise_with_cfg( latents=latents, + image_latents=image_latents, timestep=timestep, prompt_emb=prompt_embeds, negative_prompt_emb=negative_prompt_embeds, @@ -431,12 +539,16 @@ def __call__( latents = rearrange(latents, "B C H W -> B C 1 H W") vae_output = rearrange( self.vae.decode( - latents.to(self.vae.model.encoder.conv1.weight.dtype), device=self.vae.model.encoder.conv1.weight.device + latents.to(self.vae.model.encoder.conv1.weight.dtype), + device=self.vae.model.encoder.conv1.weight.device, + tiled=self.vae_tiled, + tile_size=self.vae_tile_size, + tile_stride=self.vae_tile_stride, )[0], "C B H W -> B C H W", ) image = self.vae_output_to_image(vae_output) # Offload all models - self.model_lifecycle_finish(["vae"]) + self.model_lifecycle_finish(["vae"]) self.load_models_to_device([]) return image diff --git a/diffsynth_engine/pipelines/sd_image.py b/diffsynth_engine/pipelines/sd_image.py index e1a79ed0..ded8ce77 100644 --- a/diffsynth_engine/pipelines/sd_image.py +++ b/diffsynth_engine/pipelines/sd_image.py @@ -181,21 +181,21 @@ def from_state_dict(cls, state_dicts: SDStateDicts, config: SDPipelineConfig) -> raise ValueError("`model_path` cannot be empty") logger.info(f"loading state dict from {config.model_path} ...") state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype) - + if state_dicts.vae is None: if config.vae_path is None: state_dicts.vae = state_dicts.model else: logger.info(f"loading state dict from {config.vae_path} ...") state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype) - + if state_dicts.clip is None: if config.clip_path is None: state_dicts.clip = state_dicts.model else: logger.info(f"loading state dict from {config.clip_path} ...") state_dicts.clip = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype) - + init_device = "cpu" if config.offload_mode is not None else config.device tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH) with LoRAContext(): diff --git a/diffsynth_engine/pipelines/sdxl_image.py b/diffsynth_engine/pipelines/sdxl_image.py index ab233c31..48ecfb4d 100644 --- a/diffsynth_engine/pipelines/sdxl_image.py +++ b/diffsynth_engine/pipelines/sdxl_image.py @@ -159,28 +159,32 @@ def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig raise ValueError("`model_path` cannot be empty") logger.info(f"loading state dict from {config.model_path} ...") state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype) - + if state_dicts.vae is None: if config.vae_path is None: state_dicts.vae = state_dicts.model else: logger.info(f"loading state dict from {config.vae_path} ...") state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype) - + if state_dicts.clip_l is None: if config.clip_l_path is None: state_dicts.clip_l = state_dicts.model else: logger.info(f"loading state dict from {config.clip_l_path} ...") - state_dicts.clip_l = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype) - + state_dicts.clip_l = cls.load_model_checkpoint( + config.clip_l_path, device="cpu", dtype=config.clip_l_dtype + ) + if state_dicts.clip_g is None: if config.clip_g_path is None: state_dicts.clip_g = state_dicts.model else: logger.info(f"loading state dict from {config.clip_g_path} ...") - state_dicts.clip_g = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype) - + state_dicts.clip_g = cls.load_model_checkpoint( + config.clip_g_path, device="cpu", dtype=config.clip_g_dtype + ) + init_device = "cpu" if config.offload_mode else config.device tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH) tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH) diff --git a/diffsynth_engine/tokenizers/__init__.py b/diffsynth_engine/tokenizers/__init__.py index 6ad86b10..5e8e3d44 100644 --- a/diffsynth_engine/tokenizers/__init__.py +++ b/diffsynth_engine/tokenizers/__init__.py @@ -3,6 +3,8 @@ from .t5 import T5TokenizerFast from .wan import WanT5Tokenizer from .qwen2 import Qwen2TokenizerFast +from .qwen2_vl_image_processor import Qwen2VLImageProcessor +from .qwen2_vl_processor import Qwen2VLProcessor __all__ = [ "BaseTokenizer", @@ -10,4 +12,6 @@ "T5TokenizerFast", "WanT5Tokenizer", "Qwen2TokenizerFast", + "Qwen2VLImageProcessor", + "Qwen2VLProcessor", ] diff --git a/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py b/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py new file mode 100644 index 00000000..57eab6a7 --- /dev/null +++ b/diffsynth_engine/tokenizers/qwen2_vl_image_processor.py @@ -0,0 +1,157 @@ +# modified from transformers.models.qwen2_vl.image_processing_qwen2_vl +import os +import json +import logging +import numpy as np +from typing import List, Optional +from PIL import Image + +from diffsynth_engine.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from diffsynth_engine.utils.image import ( + ChannelDimension, + convert_to_rgb, + get_image_size, + infer_channel_dimension_format, + rescale_image, + resize_image, + smart_resize, + normalize_image, + to_channel_dimension_format, +) + +logger = logging.getLogger(__name__) + + +class Qwen2VLImageProcessor: + def __init__( + self, + do_resize: bool = True, + resample: Image.Resampling = Image.Resampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: float = 1.0 / 255, + do_normalize: bool = True, + image_mean: List[float] = OPENAI_CLIP_MEAN, + image_std: List[float] = OPENAI_CLIP_STD, + do_convert_rgb: bool = True, + min_pixels: int = 56 * 56, + max_pixels: int = 28 * 28 * 1280, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ): + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} + self.image_mean = image_mean + self.image_std = image_std + self.patch_size = patch_size + self.merge_size = merge_size + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, config_file_path: str | os.PathLike, **kwargs): + init_kwargs = {} + if not os.path.exists(config_file_path): + logger.warning(f"Cannot find {config_file_path}, init processor with default parameters") + else: + with open(config_file_path, "r", encoding="utf-8") as kwargs_handler: + init_kwargs = json.load(kwargs_handler) + + init_kwargs.update(**kwargs) + return cls(**init_kwargs) + + def __call__( + self, + images: Image.Image | List[Image.Image], + videos: Optional[List[List[Image.Image]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + ): + pixel_values, image_grid_thws = None, None + if images is not None: + if isinstance(images, Image.Image): + images = [images] + pixel_values, image_grid_thws = [], [] + for image in images: + flatten_patches, image_grid_thw = self._preprocess([image], data_format) + pixel_values.extend(flatten_patches) + image_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + image_grid_thws = np.array(image_grid_thws) + + vision_pixel_values, vision_grid_thws = None, None + if videos is not None: + vision_pixel_values, vision_grid_thws = [], [] + for images in videos: + flatten_patches, video_grid_thw = self._preprocess(images, data_format) + vision_pixel_values.append(flatten_patches) + vision_grid_thws.append(video_grid_thw) + vision_pixel_values = np.array(vision_pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + + return pixel_values, image_grid_thws, vision_pixel_values, vision_grid_thws + + def _preprocess(self, images: List[Image.Image], data_format: Optional[ChannelDimension] = ChannelDimension.FIRST): + images = [convert_to_rgb(image) for image in images] + image_nps = [np.array(image) for image in images] + input_data_format = infer_channel_dimension_format(image_nps[0]) + height, width = get_image_size(image_nps[0], input_data_format) + resized_height, resized_width = height, width + + processed_image_nps = [] + for image_np in image_nps: + if self.do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image_np = resize_image( + image_np, resized_height, resized_width, self.resample, input_data_format=input_data_format + ) + + if self.do_rescale: + image_np = rescale_image(image_np, self.rescale_factor) + + if self.do_normalize: + image_np = normalize_image( + image_np, self.image_mean, self.image_std, input_data_format=input_data_format + ) + image_np = to_channel_dimension_format(image_np, data_format, input_data_format) + processed_image_nps.append(image_np) + + patches = np.array(processed_image_nps) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] % self.temporal_patch_size != 0: + repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0) + patches = np.concatenate([patches, repeats], axis=0) + num_channel = patches.shape[1] + grid_t = patches.shape[0] // self.temporal_patch_size + grid_h = resized_height // self.patch_size + grid_w = resized_width // self.patch_size + patches = patches.reshape( + grid_t, + self.temporal_patch_size, + num_channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, num_channel * self.temporal_patch_size * self.patch_size * self.patch_size + ) + + return flatten_patches, (grid_t, grid_h, grid_w) diff --git a/diffsynth_engine/tokenizers/qwen2_vl_processor.py b/diffsynth_engine/tokenizers/qwen2_vl_processor.py new file mode 100644 index 00000000..eba88239 --- /dev/null +++ b/diffsynth_engine/tokenizers/qwen2_vl_processor.py @@ -0,0 +1,100 @@ +import os +import re +import torch +import logging +from PIL import Image +from typing import List, Dict, Optional + +from diffsynth_engine.tokenizers.qwen2_vl_image_processor import Qwen2VLImageProcessor +from diffsynth_engine.tokenizers.qwen2 import Qwen2TokenizerFast + +logger = logging.getLogger(__name__) + + +class Qwen2VLProcessor: + def __init__( + self, + tokenizer: Qwen2TokenizerFast, + image_processor: Qwen2VLImageProcessor, + image_token: str = "<|image_pad|>", + **kwargs, + ): + self.tokenizer = tokenizer + self.image_processor = image_processor + self.image_token = image_token + + @classmethod + def from_pretrained( + cls, + tokenizer_config_path: str | os.PathLike, + image_processor_config_path: str | os.PathLike, + **kwargs, + ): + tokenizer = Qwen2TokenizerFast.from_pretrained(tokenizer_config_path) + image_processor = Qwen2VLImageProcessor.from_pretrained(image_processor_config_path) + return cls(tokenizer=tokenizer, image_processor=image_processor, **kwargs) + + def batch_decode( + self, + ids: List[List[int]] | List[torch.Tensor], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + ): + if isinstance(ids[0], torch.Tensor): + ids = [id_.tolist() for id_ in ids] + decoded = self.tokenizer.batch_decode(ids, skip_special_tokens, clean_up_tokenization_spaces) + pattern = r"<\|vision_start\|>.*?<\|vision_end\|>" + decoded_with_image_tag = [re.sub(pattern, "", d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r"<\|im_end\|>", "", d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def __call__( + self, + text: str | List[str], + images: Optional[List[Image.Image]] = None, + videos: Optional[List[List[Image.Image]]] = None, + max_length: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + text (`List[str]`): + The sequence or batch of sequences to be encoded. + images (`List[PIL.Image.Image]`): + The batch of images to be prepared. + videos (`List[List[PIL.Image.Image]]`): + The batch of videos to be prepared. + """ + images_pixel_values, images_grid_thws, video_pixels_values, video_grid_thws = self.image_processor( + images, videos + ) + + if not isinstance(text, list): + text = [text] + if images_grid_thws is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, "<|placeholder|>" * (images_grid_thws[index].prod() // merge_length), 1 + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + text_inputs = self.tokenizer(text, max_length=max_length) + + processed_inputs = text_inputs + if images_pixel_values is not None: + processed_inputs["pixel_values"] = torch.from_numpy(images_pixel_values) + if images_grid_thws is not None: + processed_inputs["image_grid_thw"] = torch.from_numpy(images_grid_thws) + if video_pixels_values is not None: + processed_inputs["pixel_values_videos"] = video_pixels_values + if video_grid_thws is not None: + processed_inputs["video_grid_thw"] = video_grid_thws + + return processed_inputs diff --git a/diffsynth_engine/utils/constants.py b/diffsynth_engine/utils/constants.py index 80bc3ca9..3c67449e 100644 --- a/diffsynth_engine/utils/constants.py +++ b/diffsynth_engine/utils/constants.py @@ -5,6 +5,7 @@ # conf CONF_PATH = os.path.join(PACKAGE_ROOT, "conf") + # tokenizers FLUX_TOKENIZER_1_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_1") FLUX_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_2") @@ -12,6 +13,8 @@ SDXL_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer_2") WAN_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "wan", "umt5-xxl") QWEN_IMAGE_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "tokenizer") +QWEN_IMAGE_PROCESSOR_CONFIG_FILE = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "qwen2_vl_image_processor.json") + # models VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json") FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json") @@ -46,3 +49,6 @@ MB = 1024 * KB GB = 1024 * MB TB = 1024 * GB + +OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] diff --git a/diffsynth_engine/utils/image.py b/diffsynth_engine/utils/image.py index be39b22b..e2185241 100644 --- a/diffsynth_engine/utils/image.py +++ b/diffsynth_engine/utils/image.py @@ -1,6 +1,13 @@ import torch import numpy as np +import math from PIL import Image +from enum import Enum +from typing import List, Tuple, Optional + +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) def tensor_to_image(t: torch.Tensor, denormalize: bool = True) -> Image.Image: @@ -23,3 +30,209 @@ def tensor_to_image(t: torch.Tensor, denormalize: bool = True) -> Image.Image: else: mode = "RGB" return Image.fromarray(t, mode=mode) + + +class ChannelDimension(Enum): + FIRST = "channels_first" + LAST = "channels_last" + + +def convert_to_rgb(image: Image.Image) -> Image.Image: + if not isinstance(image, Image.Image): + raise TypeError(f"image must be a PIL.Image.Image, but got {type(image)}") + if image.mode == "RGB": + return image + image = image.convert(mode="RGB") + return image + + +def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: + num_channels = (1, 3) + if image.ndim == 3: + first_dim, last_dim = 0, 2 + elif image.ndim == 4: + first_dim, last_dim = 1, 3 + else: + raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") + + if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: + logger.warning("Image has both first and last dimensions as channels. This may lead to unexpected behavior.") + return ChannelDimension.FIRST + elif image.shape[first_dim] in num_channels: + return ChannelDimension.FIRST + elif image.shape[last_dim] in num_channels: + return ChannelDimension.LAST + raise ValueError("Unable to infer channel dimension format") + + +def get_image_size(image: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> Tuple[int, int]: + """ + Returns the (height, width) dimensions of the image. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(image) + if channel_dim == ChannelDimension.FIRST: + return image.shape[-2], image.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return image.shape[-3], image.shape[-2] + else: + raise ValueError(f"Unsupported channel dimension format: {channel_dim}") + + +def smart_resize( + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +) -> Tuple[int, int]: + """Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + abs_aspect_ratio = max(height, width) / min(height, width) + if height < factor or width < factor: + raise ValueError(f"Image height: {height} and width: {width} must be greater than or equal to factor: {factor}") + elif abs_aspect_ratio > 200: + raise ValueError(f"absolute aspect ratio must be smaller than 200, got {abs_aspect_ratio}") + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt(height * width / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def to_channel_dimension_format( + image: np.ndarray, channel_dim: ChannelDimension, input_channel_dim: Optional[ChannelDimension] = None +) -> np.ndarray: + if not isinstance(image, np.ndarray): + raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}") + if input_channel_dim is None: + input_channel_dim = infer_channel_dimension_format(image) + if input_channel_dim == channel_dim: + return image + if channel_dim == ChannelDimension.FIRST: + image = image.transpose((2, 0, 1)) + elif channel_dim == ChannelDimension.LAST: + image = image.transpose((1, 2, 0)) + else: + raise ValueError(f"Unsupported channel dimension format: {channel_dim}") + return image + + +def get_channel_dimension_axis(image: np.ndarray, input_data_format: Optional[ChannelDimension] = None) -> int: + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + if input_data_format == ChannelDimension.FIRST: + return image.ndim - 3 + elif input_data_format == ChannelDimension.LAST: + return image.ndim - 1 + raise ValueError(f"Unsupported channel dimension format: {input_data_format}") + + +def rescale_image( + image: np.ndarray, + rescale_factor: float, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[ChannelDimension] = None, +) -> np.ndarray: + rescaled_image = image.astype(np.float64) * rescale_factor + if data_format is not None: + rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format) + rescaled_image = rescaled_image.astype(np.float32) + return rescaled_image + + +def normalize_image( + image: np.ndarray, + mean: List[float], + std: List[float], + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[ChannelDimension] = None, +) -> np.ndarray: + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + channel_axis = get_channel_dimension_axis(image, input_data_format) + num_channels = image.shape[channel_axis] + if len(mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements, but got {len(mean)}") + if len(std) != num_channels: + raise ValueError(f"std must have {num_channels} elements, but got {len(std)}") + if not np.issubdtype(image.dtype, np.floating): + image = image.astype(np.float32) + mean = np.array(mean, dtype=image.dtype) + std = np.array(std, dtype=image.dtype) + if input_data_format == ChannelDimension.LAST: + image = (image - mean) / std + else: + image = ((image.T - mean) / std).T + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_data_format) + return image + + +def to_pil_image( + image: np.ndarray, + do_rescale: Optional[bool] = None, + input_data_format: Optional[ChannelDimension] = None, + image_mode: Optional[str] = None, +) -> Image.Image: + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) + image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image + do_rescale = do_rescale if do_rescale is not None else _need_rescale_pil_conversion(image) + if do_rescale: + image = rescale_image(image, 255) + image = image.astype(np.uint8) + return Image.fromarray(image, mode=image_mode) + + +def resize_image( + image: np.ndarray, + height: int, + width: int, + resample: Image.Resampling = Image.Resampling.BILINEAR, + reducing_gap: Optional[int] = None, + input_data_format: Optional[ChannelDimension] = None, + data_format: Optional[ChannelDimension] = None, +) -> np.ndarray: + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + data_format = data_format if data_format is not None else input_data_format + do_rescale = _need_rescale_pil_conversion(image) + pil_image = to_pil_image(image, do_rescale, input_data_format) + resized_image = pil_image.resize((width, height), resample=resample, reducing_gap=reducing_gap) + resized_image = np.array(resized_image) + resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image + resized_image = to_channel_dimension_format(resized_image, data_format, ChannelDimension.LAST) + resized_image = rescale_image(resized_image, 1 / 255) if do_rescale else resized_image + return resized_image + + +def _need_rescale_pil_conversion(image: np.ndarray) -> bool: + """ + Detects whether or not the image needs to be rescaled before being converted to a PIL image. + The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be + rescaled. + """ + if image.dtype == np.uint8: + do_rescale = False + elif np.allclose(image, image.astype(int)): + if np.all(0 <= image) and np.all(image <= 255): + do_rescale = False + else: + raise ValueError( + "The image to be converted to a PIL image contains value outside the range [0, 255], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) + elif np.all(0 <= image) and np.all(image <= 1): + do_rescale = True + else: + raise ValueError( + "The image to be converted to PIL image contains values outside the range [0, 1]" + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) + return do_rescale diff --git a/diffsynth_engine/utils/offload.py b/diffsynth_engine/utils/offload.py index c379311a..dde9cb15 100644 --- a/diffsynth_engine/utils/offload.py +++ b/diffsynth_engine/utils/offload.py @@ -3,6 +3,7 @@ from typing import Dict import platform + def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"): module = module.to("cpu") if len(list(module.children())) == 0: @@ -26,13 +27,13 @@ def _forward_pre_hook(module: nn.Module, input_): for name, buffer in module.named_buffers(recurse=recurse): buffer.data = buffer.data.to(device=device) return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input_) - for name, param in module.named_parameters(recurse=recurse): - if platform.system() == 'Linux': + for name, param in module.named_parameters(recurse=recurse): + if platform.system() == "Linux": param.data = param.data.pin_memory() offload_param_dict[name] = param.data param.data = param.data.to(device=device) for name, buffer in module.named_buffers(recurse=recurse): - if platform.system() == 'Linux': + if platform.system() == "Linux": buffer.data = buffer.data.pin_memory() offload_param_dict[name] = buffer.data buffer.data = buffer.data.to(device=device) @@ -59,11 +60,11 @@ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]: module = module.to("cpu") offload_param_dict = {} for name, param in module.named_parameters(recurse=True): - if platform.system() == 'Linux': + if platform.system() == "Linux": param.data = param.data.pin_memory() offload_param_dict[name] = param.data for name, buffer in module.named_buffers(recurse=True): - if platform.system() == 'Linux': + if platform.system() == "Linux": buffer.data = buffer.data.pin_memory() offload_param_dict[name] = buffer.data return offload_param_dict diff --git a/examples/input/qwen_image_edit_input.png b/examples/input/qwen_image_edit_input.png new file mode 100644 index 00000000..dbb94c86 Binary files /dev/null and b/examples/input/qwen_image_edit_input.png differ diff --git a/examples/qwen_image_edit.py b/examples/qwen_image_edit.py new file mode 100644 index 00000000..6940926f --- /dev/null +++ b/examples/qwen_image_edit.py @@ -0,0 +1,21 @@ +from diffsynth_engine import QwenImagePipeline, QwenImagePipelineConfig, fetch_model +from PIL import Image + +if __name__ == "__main__": + config = QwenImagePipelineConfig.basic_config( + model_path=fetch_model("Qwen/Qwen-Image-Edit", revision="v1", path="transformer/*.safetensors"), + encoder_path=fetch_model("Qwen/Qwen-Image-Edit", revision="v1", path="text_encoder/*.safetensors"), + vae_path=fetch_model("Qwen/Qwen-Image-Edit", revision="v1", path="vae/*.safetensors"), + parallelism=1, + ) + + pipe = QwenImagePipeline.from_pretrained(config) + + prompt = "把'通义千问'替换成'muse平台'" + image = pipe( + prompt=prompt, + input_image=Image.open("input/qwen_image_edit_inpput.png"), + seed=42, + ) + image.save("image.png") + del pipe diff --git a/pyproject.toml b/pyproject.toml index 6bd8af72..d12b39de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "diffusers == 0.31.0", - "transformers == 4.45.2", + "transformers == 4.52.4", "accelerate", "build", "ruff", diff --git a/tests/data/input/capybara.jpg b/tests/data/input/capybara.jpg new file mode 100644 index 00000000..831312dc Binary files /dev/null and b/tests/data/input/capybara.jpg differ diff --git a/tests/test_models/qwen_image/test_qwen2_5_vl.py b/tests/test_models/qwen_image/test_qwen2_5_vl.py index fc1f6638..accc7075 100644 --- a/tests/test_models/qwen_image/test_qwen2_5_vl.py +++ b/tests/test_models/qwen_image/test_qwen2_5_vl.py @@ -8,22 +8,27 @@ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig, ) -from diffsynth_engine.tokenizers import Qwen2TokenizerFast +from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor from diffsynth_engine.utils.constants import ( QWEN_IMAGE_TOKENIZER_CONF_PATH, + QWEN_IMAGE_PROCESSOR_CONFIG_FILE, QWEN_IMAGE_CONFIG_FILE, QWEN_IMAGE_VISION_CONFIG_FILE, ) from diffsynth_engine.utils.download import ensure_directory_exists, fetch_model from diffsynth_engine.utils.loader import save_file -from tests.common.test_case import TestCase, RUN_EXTRA_TEST +from tests.common.test_case import ImageTestCase, RUN_EXTRA_TEST from tests.common.utils import load_model_checkpoint -class TestQwen2_5_VL(TestCase): +class TestQwen2_5_VL(ImageTestCase): @classmethod def setUpClass(cls): cls.tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH) + cls.processor = Qwen2VLProcessor.from_pretrained( + tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH, + image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE, + ) cls._model_path = fetch_model("Qwen/Qwen2.5-VL-7B-Instruct", fetch_safetensors=False) ckpt_path = [ @@ -42,14 +47,34 @@ def setUpClass(cls): dtype=torch.bfloat16, ).eval() cls.texts = ["Hello, World!", "DiffSynth-Engine developed by Muse AI+Modelscope"] + cls.prompt = "<|vision_start|><|image_pad|><|vision_end|> the capybara is swimming in the pool" + cls.input_image = cls.get_input_image("capybara.jpg").convert("RGB") - def test_encoder(self): + def test_encode_text(self): outputs = self.tokenizer(self.texts) text_ids, attention_mask = outputs["input_ids"].to("cuda:0"), outputs["attention_mask"].to("cuda:0") with torch.no_grad(): logits = self.encoder(input_ids=text_ids, attention_mask=attention_mask)["logits"].cpu() expected_tensors = self.get_expect_tensor("qwen_image/qwen2_5_vl.safetensors") - self.assertTensorEqual(logits, expected_tensors["logits"]) + self.assertTensorEqual(logits, expected_tensors["text_logits"]) + + def test_encode_text_image(self): + outputs = self.processor(text=self.prompt, images=[self.input_image]) + input_ids, attention_mask, pixel_values, image_grid_thw = ( + outputs["input_ids"].to("cuda:0"), + outputs["attention_mask"].to("cuda:0"), + outputs["pixel_values"].to("cuda:0"), + outputs["image_grid_thw"].to("cuda:0"), + ) + with torch.no_grad(): + logits = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + )["logits"].cpu() + expected_tensors = self.get_expect_tensor("qwen_image/qwen2_5_vl.safetensors") + self.assertTensorEqual(logits, expected_tensors["text_image_logits"]) @unittest.skipUnless(RUN_EXTRA_TEST, "RUN_EXTRA_TEST is not set") def test_and_save_tensors(self): @@ -58,14 +83,37 @@ def test_and_save_tensors(self): vlm = Qwen2_5_VLForConditionalGeneration.from_pretrained( self._model_path, device_map="cuda:0", torch_dtype=torch.bfloat16 ).eval() + outputs = self.tokenizer(self.texts) - text_ids, attention_mask = outputs["input_ids"].to("cuda:0"), outputs["attention_mask"].to("cuda:0") + input_ids, attention_mask = outputs["input_ids"].to("cuda:0"), outputs["attention_mask"].to("cuda:0") with torch.no_grad(): - expected = vlm(input_ids=text_ids, attention_mask=attention_mask).logits.cpu() - logits = self.encoder(input_ids=text_ids, attention_mask=attention_mask)["logits"].cpu() - self.assertTensorEqual(logits, expected) + expected = vlm(input_ids=input_ids, attention_mask=attention_mask).logits.cpu() + text_logits = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["logits"].cpu() + self.assertTensorEqual(text_logits, expected) + + outputs = self.processor(text=self.prompt, images=[self.input_image]) + input_ids, attention_mask, pixel_values, image_grid_thw = ( + outputs["input_ids"].to("cuda:0"), + outputs["attention_mask"].to("cuda:0"), + outputs["pixel_values"].to("cuda:0"), + outputs["image_grid_thw"].to("cuda:0"), + ) + with torch.no_grad(): + expected = vlm( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ).logits.cpu() + text_image_logits = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + )["logits"].cpu() + self.assertTensorEqual(text_image_logits, expected) - excepted_tensors = {"logits": logits} + excepted_tensors = {"text_logits": text_logits, "text_image_logits": text_image_logits} save_path = self.testdata_dir / "expect/qwen_image/qwen2_5_vl.safetensors" ensure_directory_exists(save_path) save_file(excepted_tensors, save_path)