@@ -333,21 +333,25 @@ def __init__(
333
333
self .proj_out = operations .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True , dtype = dtype , device = device )
334
334
self .gradient_checkpointing = False
335
335
336
- def pos_embeds (self , x , context ):
336
+ def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
337
337
bs , c , t , h , w = x .shape
338
338
patch_size = self .patch_size
339
+ hidden_states = comfy .ldm .common_dit .pad_to_patch_size (x , (1 , self .patch_size , self .patch_size ))
340
+ orig_shape = hidden_states .shape
341
+ hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
342
+ hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 )
343
+ hidden_states = hidden_states .reshape (orig_shape [0 ], (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
339
344
h_len = ((h + (patch_size // 2 )) // patch_size )
340
345
w_len = ((w + (patch_size // 2 )) // patch_size )
341
346
342
- img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
343
- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
344
- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
345
- img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
347
+ h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
348
+ w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
346
349
347
- txt_start = round (max (h_len , w_len ))
348
- txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context .shape [1 ], device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 ).repeat (bs , 1 , 3 )
349
- ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
350
- return self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
350
+ img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
351
+ img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
352
+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
353
+ img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
354
+ return hidden_states , repeat (img_ids , "h w c -> b (h w) c" , b = bs ), orig_shape
351
355
352
356
def forward (
353
357
self ,
@@ -363,13 +367,13 @@ def forward(
363
367
encoder_hidden_states = context
364
368
encoder_hidden_states_mask = attention_mask
365
369
366
- image_rotary_emb = self .pos_embeds (x , context )
370
+ hidden_states , img_ids , orig_shape = self .process_img (x )
371
+ num_embeds = hidden_states .shape [1 ]
367
372
368
- hidden_states = comfy .ldm .common_dit .pad_to_patch_size (x , (1 , self .patch_size , self .patch_size ))
369
- orig_shape = hidden_states .shape
370
- hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
371
- hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 )
372
- hidden_states = hidden_states .reshape (orig_shape [0 ], (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
373
+ txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ), ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size )))
374
+ txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context .shape [1 ], device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
375
+ ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
376
+ image_rotary_emb = self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
373
377
374
378
hidden_states = self .img_in (hidden_states )
375
379
encoder_hidden_states = self .txt_norm (encoder_hidden_states )
@@ -408,6 +412,6 @@ def block_wrap(args):
408
412
hidden_states = self .norm_out (hidden_states , temb )
409
413
hidden_states = self .proj_out (hidden_states )
410
414
411
- hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
415
+ hidden_states = hidden_states [:, : num_embeds ] .view (orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
412
416
hidden_states = hidden_states .permute (0 , 3 , 1 , 4 , 2 , 5 )
413
417
return hidden_states .reshape (orig_shape )[:, :, :, :x .shape [- 2 ], :x .shape [- 1 ]]
0 commit comments