Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 47 additions & 36 deletions tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLLTX2Audio
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import is_flaky, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Audio
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLLTX2AudioTesterConfig(BaseModelTesterConfig):
@property
def main_input_name(self):
return "sample"

@property
def model_class(self):
return AutoencoderKLLTX2Audio

def get_autoencoder_kl_ltx_video_config(self):
@property
def output_shape(self):
return (2, 5, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self):
return {
"in_channels": 2, # stereo,
"output_channels": 2,
Expand All @@ -50,39 +61,39 @@ def get_autoencoder_kl_ltx_video_config(self):
"double_z": True,
}

@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 2
num_channels = 2
num_frames = 8
num_mel_bins = 16
spectrogram = randn_tensor(
(batch_size, num_channels, num_frames, num_mel_bins),
generator=self.generator,
device=torch_device,
)
return {"sample": spectrogram}

spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)

input_dict = {"sample": spectrogram}
return input_dict
class TestAutoencoderKLLTX2Audio(AutoencoderKLLTX2AudioTesterConfig, ModelTesterMixin):
base_precision = 1e-2

@property
def input_shape(self):
return (2, 5, 16)
def test_outputs_equivalence(self):
pytest.skip("Unsupported test.")

@property
def output_shape(self):
return (2, 5, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLLTX2AudioTraining(AutoencoderKLLTX2AudioTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLLTX2Audio."""

# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
def test_output(self):
super().test_output(expected_output_shape=(2, 2, 5, 16))

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
class TestAutoencoderKLLTX2AudioMemory(AutoencoderKLLTX2AudioTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLLTX2Audio."""

@is_flaky()
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
super().test_group_offloading_with_disk(tmp_path, record_stream, offload_type, atol=atol, rtol=rtol)


@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
class TestAutoencoderKLLTX2AudioSlicingTiling(AutoencoderKLLTX2AudioTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLLTX2Audio."""
74 changes: 40 additions & 34 deletions tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLLTX2Video
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Video
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLLTX2VideoTesterConfig(BaseModelTesterConfig):
@property
def main_input_name(self):
return "sample"

@property
def model_class(self):
return AutoencoderKLLTX2Video

def get_autoencoder_kl_ltx_video_config(self):
@property
def output_shape(self):
return (3, 9, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self):
return {
"in_channels": 3,
"out_channels": 3,
Expand All @@ -59,30 +69,26 @@ def get_autoencoder_kl_ltx_video_config(self):
"decoder_spatial_padding_mode": "zeros",
}

@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": image}

image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

input_dict = {"sample": image}
return input_dict
class TestAutoencoderKLLTX2Video(AutoencoderKLLTX2VideoTesterConfig, ModelTesterMixin):
base_precision = 1e-2

@property
def input_shape(self):
return (3, 9, 16, 16)
def test_outputs_equivalence(self):
pytest.skip("Unsupported test.")

@property
def output_shape(self):
return (3, 9, 16, 16)

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLLTX2VideoTraining(AutoencoderKLLTX2VideoTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLLTX2Video."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {
Expand All @@ -94,10 +100,10 @@ def test_gradient_checkpointing_is_applied(self):
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass

@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
class TestAutoencoderKLLTX2VideoMemory(AutoencoderKLLTX2VideoTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLLTX2Video."""


class TestAutoencoderKLLTX2VideoSlicingTiling(AutoencoderKLLTX2VideoTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLLTX2Video."""
Loading