Skip to content

Commit 0f2b852

Browse files
Qwen image model refactor. (#9375)
1 parent 20a8416 commit 0f2b852

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -333,21 +333,25 @@ def __init__(
333333
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
334334
self.gradient_checkpointing = False
335335

336-
def pos_embeds(self, x, context):
336+
def process_img(self, x, index=0, h_offset=0, w_offset=0):
337337
bs, c, t, h, w = x.shape
338338
patch_size = self.patch_size
339+
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
340+
orig_shape = hidden_states.shape
341+
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
342+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
343+
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
339344
h_len = ((h + (patch_size // 2)) // patch_size)
340345
w_len = ((w + (patch_size // 2)) // patch_size)
341346

342-
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
343-
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
344-
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
345-
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
347+
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
348+
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
346349

347-
txt_start = round(max(h_len, w_len))
348-
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
349-
ids = torch.cat((txt_ids, img_ids), dim=1)
350-
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
350+
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
351+
img_ids[:, :, 0] = img_ids[:, :, 1] + index
352+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
353+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
354+
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
351355

352356
def forward(
353357
self,
@@ -363,13 +367,13 @@ def forward(
363367
encoder_hidden_states = context
364368
encoder_hidden_states_mask = attention_mask
365369

366-
image_rotary_emb = self.pos_embeds(x, context)
370+
hidden_states, img_ids, orig_shape = self.process_img(x)
371+
num_embeds = hidden_states.shape[1]
367372

368-
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
369-
orig_shape = hidden_states.shape
370-
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
371-
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
372-
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
373+
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
374+
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
375+
ids = torch.cat((txt_ids, img_ids), dim=1)
376+
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
373377

374378
hidden_states = self.img_in(hidden_states)
375379
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -408,6 +412,6 @@ def block_wrap(args):
408412
hidden_states = self.norm_out(hidden_states, temb)
409413
hidden_states = self.proj_out(hidden_states)
410414

411-
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
415+
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
412416
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
413417
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

0 commit comments

Comments
 (0)