Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion diffsynth_engine/tools/qwen_image_upscaler_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from diffsynth_engine.pipelines.qwen_image import QwenImagePipeline
from diffsynth_engine.models.qwen_image import QwenImageVAE
from diffsynth_engine.models.basic.lora import LoRALinear
from diffsynth_engine.models.qwen_image.qwen_image_dit import QwenImageTransformerBlock
from diffsynth_engine.models.qwen_image.qwen_image_dit import QwenImageTransformerBlock, QwenEmbedRope
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.loader import load_file
from diffsynth_engine.utils.download import fetch_model
Expand All @@ -32,6 +32,7 @@ def odtsr_forward():
"""
original_lora_forward = LoRALinear.forward
original_modulate = QwenImageTransformerBlock._modulate
original_rope_forward = QwenEmbedRope.forward

def lora_batch_cfg_forward(self, x):
y = nn.Linear.forward(self, x)
Expand All @@ -50,6 +51,49 @@ def lora_batch_cfg_forward(self, x):
y[:, L:] += lora(x2)
return y

def optimized_rope_forward(self, video_fhw, txt_length, device):
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

vid_freqs = []
max_vid_index = 0
idx = 0
for fhw in video_fhw:
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"

if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)

else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)

txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + txt_length, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs


def optimized_modulate(self, x, mod_params, index=None):
if mod_params.ndim == 2:
shift, scale, gate = mod_params.chunk(3, dim=-1)
Expand All @@ -72,12 +116,14 @@ def optimized_modulate(self, x, mod_params, index=None):

LoRALinear.forward = lora_batch_cfg_forward
QwenImageTransformerBlock._modulate = optimized_modulate
QwenEmbedRope.forward = optimized_rope_forward

try:
yield
finally:
LoRALinear.forward = original_lora_forward
QwenImageTransformerBlock._modulate = original_modulate
QwenEmbedRope.forward = original_rope_forward


class QwenImageUpscalerTool:
Expand Down