Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tests/single_file/single_file_testing_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import tempfile
from io import BytesIO

Expand All @@ -9,7 +10,10 @@
from diffusers.models.attention_processor import AttnProcessor

from ..testing_utils import (
backend_empty_cache,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
torch_device,
)

Expand Down Expand Up @@ -47,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
return path


@nightly
@require_torch_accelerator
class SingleFileModelTesterMixin:
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)

def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}

if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder

if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype

model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading"
)

def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}

if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder

if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype

model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)

state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()

assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading"
)

for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]

assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)

assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
f"Parameter values differ for {key}: "
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
)

def test_checkpoint_altered_keys_loading(self):
# Test loading with checkpoints that have altered keys
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
return

for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)

single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype

model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)

del model
gc.collect()
backend_empty_cache(torch_device)


class SDSingleFileTesterMixin:
single_file_kwargs = {}

Expand Down
41 changes: 3 additions & 38 deletions tests/single_file/test_lumina2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

from diffusers import (
Lumina2Transformer2DModel,
)

from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin


enable_full_determinism()


@require_torch_accelerator
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [
"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
]

repo_id = "Alpha-VLLM/Lumina-Image-2.0"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)

def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)

del model
gc.collect()
backend_empty_cache(torch_device)
subfolder = "transformer"
32 changes: 2 additions & 30 deletions tests/single_file/test_model_autoencoder_dc_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

Expand All @@ -23,38 +21,24 @@
)

from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin


enable_full_determinism()


@slow
@require_torch_accelerator
class AutoencoderDCSingleFileTests(unittest.TestCase):
class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample"
base_precision = 1e-2

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

Expand All @@ -80,18 +64,6 @@ def test_single_file_inference_same_as_pretrained(self):

assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading"
)

def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly.
Expand Down
33 changes: 2 additions & 31 deletions tests/single_file/test_model_controlnet_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

Expand All @@ -23,46 +21,19 @@
)

from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin


enable_full_determinism()


@slow
@require_torch_accelerator
class ControlNetModelSingleFileTests(unittest.TestCase):
class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)

def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)

Expand Down
38 changes: 3 additions & 35 deletions tests/single_file/test_model_flux_transformer_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import gc
import unittest

from diffusers import (
FluxTransformer2DModel,
Expand All @@ -23,52 +22,21 @@
from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin


enable_full_determinism()


@require_torch_accelerator
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = FluxTransformer2DModel
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]

repo_id = "black-forest-labs/FLUX.1-dev"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between single file loading and pretrained loading"
)

def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)

del model
gc.collect()
backend_empty_cache(torch_device)
subfolder = "transformer"

def test_device_map_cuda(self):
backend_empty_cache(torch_device)
Expand Down
3 changes: 1 addition & 2 deletions tests/single_file/test_model_motion_adapter_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from diffusers import (
MotionAdapter,
Expand All @@ -27,7 +26,7 @@
enable_full_determinism()


class MotionAdapterSingleFileTests(unittest.TestCase):
class MotionAdapterSingleFileTests:
model_class = MotionAdapter

def test_single_file_components_version_v1_5(self):
Expand Down
Loading