Skip to content

Commit 71e8049

Browse files
committed
Replace padding with pad_sequence; Add gradient checkpointing.
1 parent 2bb39f4 commit 71e8049

File tree

1 file changed

+38
-69
lines changed

1 file changed

+38
-69
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 38 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121
from einops import rearrange
22+
from torch.nn.utils.rnn import pad_sequence
2223

2324
from ...configuration_utils import ConfigMixin, register_to_config
2425
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
@@ -355,6 +356,7 @@ def __init__(
355356

356357
self.rope_theta = rope_theta
357358
self.t_scale = t_scale
359+
self.gradient_checkpointing = False
358360

359361
assert len(all_patch_size) == len(all_f_patch_size)
360362

@@ -579,29 +581,18 @@ def forward(
579581
x = list(x.split(x_item_seqlens, dim=0))
580582
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
581583

582-
pad_tensor = torch.zeros((1, self.dim), dtype=x[0].dtype, device=device)
583-
freqs_pad_tensor = torch.zeros(
584-
(1, self.dim // self.n_heads // 2),
585-
dtype=x_freqs_cis[0].dtype,
586-
device=device,
587-
)
588-
x_attn_mask = torch.ones((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
589-
for i, (item, freqs_item) in enumerate(zip(x, x_freqs_cis)):
590-
seq_len = x_item_seqlens[i]
591-
pad_len = x_max_item_seqlen - seq_len
592-
x[i] = torch.cat([item, pad_tensor.repeat(pad_len, 1)])
593-
x_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
594-
x_attn_mask[i, seq_len:] = 0
595-
x = torch.stack(x)
596-
x_freqs_cis = torch.stack(x_freqs_cis)
597-
598-
for layer in self.noise_refiner:
599-
x = layer(
600-
x,
601-
x_attn_mask,
602-
x_freqs_cis,
603-
adaln_input,
604-
)
584+
x = pad_sequence(x, batch_first=True, padding_value=0.0)
585+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
586+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
587+
for i, seq_len in enumerate(x_item_seqlens):
588+
x_attn_mask[i, :seq_len] = 1
589+
590+
if torch.is_grad_enabled() and self.gradient_checkpointing:
591+
for layer in self.noise_refiner:
592+
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
593+
else:
594+
for layer in self.noise_refiner:
595+
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
605596

606597
# cap embed & refine
607598
cap_item_seqlens = [len(_) for _ in cap_feats]
@@ -614,29 +605,18 @@ def forward(
614605
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
615606
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
616607

617-
# Reuse padding tensors (convert dtype if needed)
618-
cap_pad_tensor = pad_tensor.to(cap_feats[0].dtype) if pad_tensor.dtype != cap_feats[0].dtype else pad_tensor
619-
cap_freqs_pad_tensor = (
620-
freqs_pad_tensor.to(cap_freqs_cis[0].dtype)
621-
if freqs_pad_tensor.dtype != cap_freqs_cis[0].dtype
622-
else freqs_pad_tensor
623-
)
624-
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
625-
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
626-
seq_len = cap_item_seqlens[i]
627-
pad_len = cap_max_item_seqlen - seq_len
628-
cap_feats[i] = torch.cat([item, cap_pad_tensor.repeat(pad_len, 1)])
629-
cap_freqs_cis[i] = torch.cat([freqs_item, cap_freqs_pad_tensor.repeat(pad_len, 1)])
630-
cap_attn_mask[i, seq_len:] = 0
631-
cap_feats = torch.stack(cap_feats)
632-
cap_freqs_cis = torch.stack(cap_freqs_cis)
633-
634-
for layer in self.context_refiner:
635-
cap_feats = layer(
636-
cap_feats,
637-
cap_attn_mask,
638-
cap_freqs_cis,
639-
)
608+
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
609+
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
610+
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
611+
for i, seq_len in enumerate(cap_item_seqlens):
612+
cap_attn_mask[i, :seq_len] = 1
613+
614+
if torch.is_grad_enabled() and self.gradient_checkpointing:
615+
for layer in self.context_refiner:
616+
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
617+
else:
618+
for layer in self.context_refiner:
619+
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
640620

641621
# unified
642622
unified = []
@@ -650,29 +630,18 @@ def forward(
650630
assert unified_item_seqlens == [len(_) for _ in unified]
651631
unified_max_item_seqlen = max(unified_item_seqlens)
652632

653-
unified_pad_tensor = pad_tensor.to(unified[0].dtype) if pad_tensor.dtype != unified[0].dtype else pad_tensor
654-
unified_freqs_pad_tensor = (
655-
freqs_pad_tensor.to(unified_freqs_cis[0].dtype)
656-
if freqs_pad_tensor.dtype != unified_freqs_cis[0].dtype
657-
else freqs_pad_tensor
658-
)
659-
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
660-
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
661-
seq_len = unified_item_seqlens[i]
662-
pad_len = unified_max_item_seqlen - seq_len
663-
unified[i] = torch.cat([item, unified_pad_tensor.repeat(pad_len, 1)])
664-
unified_freqs_cis[i] = torch.cat([freqs_item, unified_freqs_pad_tensor.repeat(pad_len, 1)])
665-
unified_attn_mask[i, seq_len:] = 0
666-
unified = torch.stack(unified)
667-
unified_freqs_cis = torch.stack(unified_freqs_cis)
668-
669-
for layer in self.layers:
670-
unified = layer(
671-
unified,
672-
unified_attn_mask,
673-
unified_freqs_cis,
674-
adaln_input,
675-
)
633+
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
634+
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
635+
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
636+
for i, seq_len in enumerate(unified_item_seqlens):
637+
unified_attn_mask[i, :seq_len] = 1
638+
639+
if torch.is_grad_enabled() and self.gradient_checkpointing:
640+
for layer in self.layers:
641+
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
642+
else:
643+
for layer in self.layers:
644+
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
676645

677646
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
678647
unified = list(unified.unbind(dim=0))

0 commit comments

Comments
 (0)