Skip to content

Commit 71f24e3

Browse files
committed
up
1 parent 8abc7ae commit 71f24e3

15 files changed

+63
-350
lines changed

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import RMSNorm, get_normalization
2929
from ..transformers.sana_transformer import GLUMBConv
30-
from .vae import DecoderOutput, EncoderOutput
30+
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
3131

3232

3333
class ResBlock(nn.Module):
@@ -378,7 +378,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
378378
return hidden_states
379379

380380

381-
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
381+
class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
382382
r"""
383383
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
384384
[SANA](https://huggingface.co/papers/2410.10629).
@@ -536,27 +536,6 @@ def enable_tiling(
536536
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
537537
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
538538

539-
def disable_tiling(self) -> None:
540-
r"""
541-
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
542-
decoding in one step.
543-
"""
544-
self.use_tiling = False
545-
546-
def enable_slicing(self) -> None:
547-
r"""
548-
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
549-
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
550-
"""
551-
self.use_slicing = True
552-
553-
def disable_slicing(self) -> None:
554-
r"""
555-
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
556-
decoding in one step.
557-
"""
558-
self.use_slicing = False
559-
560539
def _encode(self, x: torch.Tensor) -> torch.Tensor:
561540
batch_size, num_channels, height, width = x.shape
562541

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
)
3333
from ..modeling_outputs import AutoencoderKLOutput
3434
from ..modeling_utils import ModelMixin
35-
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
35+
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3636

3737

38-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
38+
class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3939
r"""
4040
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4141
@@ -138,35 +138,6 @@ def __init__(
138138
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
139139
self.tile_overlap_factor = 0.25
140140

141-
def enable_tiling(self, use_tiling: bool = True):
142-
r"""
143-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
144-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
145-
processing larger images.
146-
"""
147-
self.use_tiling = use_tiling
148-
149-
def disable_tiling(self):
150-
r"""
151-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
152-
decoding in one step.
153-
"""
154-
self.enable_tiling(False)
155-
156-
def enable_slicing(self):
157-
r"""
158-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
159-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
160-
"""
161-
self.use_slicing = True
162-
163-
def disable_slicing(self):
164-
r"""
165-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
166-
decoding in one step.
167-
"""
168-
self.use_slicing = False
169-
170141
@property
171142
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
172143
def attn_processors(self) -> Dict[str, AttentionProcessor]:

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..modeling_utils import ModelMixin
2929
from ..resnet import ResnetBlock2D
3030
from ..upsampling import Upsample2D
31+
from .vae import AutoencoderMixin
3132

3233

3334
class AllegroTemporalConvLayer(nn.Module):
@@ -673,7 +674,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
673674
return sample
674675

675676

676-
class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
677+
class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
677678
r"""
678679
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
679680
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -795,35 +796,6 @@ def __init__(
795796
sample_size - self.tile_overlap_w,
796797
)
797798

798-
def enable_tiling(self) -> None:
799-
r"""
800-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
801-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
802-
processing larger images.
803-
"""
804-
self.use_tiling = True
805-
806-
def disable_tiling(self) -> None:
807-
r"""
808-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
809-
decoding in one step.
810-
"""
811-
self.use_tiling = False
812-
813-
def enable_slicing(self) -> None:
814-
r"""
815-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
816-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
817-
"""
818-
self.use_slicing = True
819-
820-
def disable_slicing(self) -> None:
821-
r"""
822-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
823-
decoding in one step.
824-
"""
825-
self.use_slicing = False
826-
827799
def _encode(self, x: torch.Tensor) -> torch.Tensor:
828800
# TODO(aryan)
829801
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..modeling_outputs import AutoencoderKLOutput
3030
from ..modeling_utils import ModelMixin
3131
from ..upsampling import CogVideoXUpsample3D
32-
from .vae import DecoderOutput, DiagonalGaussianDistribution
32+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3333

3434

3535
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -955,7 +955,7 @@ def forward(
955955
return hidden_states, new_conv_cache
956956

957957

958-
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
958+
class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
959959
r"""
960960
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
961961
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -1124,27 +1124,6 @@ def enable_tiling(
11241124
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
11251125
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
11261126

1127-
def disable_tiling(self) -> None:
1128-
r"""
1129-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1130-
decoding in one step.
1131-
"""
1132-
self.use_tiling = False
1133-
1134-
def enable_slicing(self) -> None:
1135-
r"""
1136-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1137-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1138-
"""
1139-
self.use_slicing = True
1140-
1141-
def disable_slicing(self) -> None:
1142-
r"""
1143-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1144-
decoding in one step.
1145-
"""
1146-
self.use_slicing = False
1147-
11481127
def _encode(self, x: torch.Tensor) -> torch.Tensor:
11491128
batch_size, num_channels, num_frames, height, width = x.shape
11501129

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils.accelerate_utils import apply_forward_hook
2525
from ..modeling_outputs import AutoencoderKLOutput
2626
from ..modeling_utils import ModelMixin
27-
from .vae import DecoderOutput, IdentityDistribution
27+
from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
2828

2929

3030
logger = get_logger(__name__)
@@ -875,7 +875,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
875875
return hidden_states
876876

877877

878-
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
878+
class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
879879
r"""
880880
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
881881
@@ -1031,27 +1031,6 @@ def enable_tiling(
10311031
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
10321032
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
10331033

1034-
def disable_tiling(self) -> None:
1035-
r"""
1036-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1037-
decoding in one step.
1038-
"""
1039-
self.use_tiling = False
1040-
1041-
def enable_slicing(self) -> None:
1042-
r"""
1043-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1044-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1045-
"""
1046-
self.use_slicing = True
1047-
1048-
def disable_slicing(self) -> None:
1049-
r"""
1050-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1051-
decoding in one step.
1052-
"""
1053-
self.use_slicing = False
1054-
10551034
def _encode(self, x: torch.Tensor) -> torch.Tensor:
10561035
x = self.encoder(x)
10571036
enc = self.quant_conv(x)

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
import torch.utils.checkpoint
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ...utils import logging
@@ -27,7 +26,7 @@
2726
from ..attention_processor import Attention
2827
from ..modeling_outputs import AutoencoderKLOutput
2928
from ..modeling_utils import ModelMixin
30-
from .vae import DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3130

3231

3332
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +624,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
625624
return hidden_states
626625

627626

628-
class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
627+
class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
629628
r"""
630629
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
631630
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -764,27 +763,6 @@ def enable_tiling(
764763
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
765764
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
766765

767-
def disable_tiling(self) -> None:
768-
r"""
769-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
770-
decoding in one step.
771-
"""
772-
self.use_tiling = False
773-
774-
def enable_slicing(self) -> None:
775-
r"""
776-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
777-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
778-
"""
779-
self.use_slicing = True
780-
781-
def disable_slicing(self) -> None:
782-
r"""
783-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
784-
decoding in one step.
785-
"""
786-
self.use_slicing = False
787-
788766
def _encode(self, x: torch.Tensor) -> torch.Tensor:
789767
batch_size, num_channels, num_frames, height, width = x.shape
790768

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..modeling_outputs import AutoencoderKLOutput
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import RMSNorm
29-
from .vae import DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3030

3131

3232
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No
10341034
return hidden_states
10351035

10361036

1037-
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1037+
class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
10381038
r"""
10391039
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
10401040
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1219,27 +1219,6 @@ def enable_tiling(
12191219
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
12201220
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
12211221

1222-
def disable_tiling(self) -> None:
1223-
r"""
1224-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1225-
decoding in one step.
1226-
"""
1227-
self.use_tiling = False
1228-
1229-
def enable_slicing(self) -> None:
1230-
r"""
1231-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1232-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1233-
"""
1234-
self.use_slicing = True
1235-
1236-
def disable_slicing(self) -> None:
1237-
r"""
1238-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1239-
decoding in one step.
1240-
"""
1241-
self.use_slicing = False
1242-
12431222
def _encode(self, x: torch.Tensor) -> torch.Tensor:
12441223
batch_size, num_channels, num_frames, height, width = x.shape
12451224

src/diffusers/models/autoencoders/autoencoder_kl_magvit.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..activations import get_activation
2727
from ..modeling_outputs import AutoencoderKLOutput
2828
from ..modeling_utils import ModelMixin
29-
from .vae import DecoderOutput, DiagonalGaussianDistribution
29+
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
3030

3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
663663
return hidden_states
664664

665665

666-
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
666+
class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
667667
r"""
668668
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
669669
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
@@ -805,27 +805,6 @@ def enable_tiling(
805805
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
806806
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
807807

808-
def disable_tiling(self) -> None:
809-
r"""
810-
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
811-
decoding in one step.
812-
"""
813-
self.use_tiling = False
814-
815-
def enable_slicing(self) -> None:
816-
r"""
817-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
818-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
819-
"""
820-
self.use_slicing = True
821-
822-
def disable_slicing(self) -> None:
823-
r"""
824-
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
825-
decoding in one step.
826-
"""
827-
self.use_slicing = False
828-
829808
@apply_forward_hook
830809
def _encode(
831810
self, x: torch.Tensor, return_dict: bool = True

0 commit comments

Comments
 (0)