From 49332b72ace7212cde2ba25766781492b5a968ac Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 21:51:42 +0530 Subject: [PATCH 01/10] fix device_map check behaviour. --- src/diffusers/pipelines/pipeline_loading_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 7d42ed5bcba8..3d4004db984f 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -951,6 +951,15 @@ def _get_ignore_patterns( def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + if not (is_accelerate_available() and not is_accelerate_version("<", "0.14.0")): return False - return getattr(model, "hf_device_map", None) is not None + + # Check if the model has a device map that is not exclusively CPU + # `device_map` can only contain CPU when a model has sharded checkpoints. + # See here: https://github.com/huggingface/diffusers/blob/41e4779d988ead99e7acd78dc8e752de88777d0f/src/diffusers/models/modeling_utils.py#L883 + device_map = getattr(model, "hf_device_map", None) + if device_map is not None: + unique_devices = set(device_map.values()) + return len(unique_devices) > 1 or unique_devices != {"cpu"} + + return False From 014bdfc414c15e7ef177d6e0f4eb39766dda15fb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:13:20 +0530 Subject: [PATCH 02/10] add tests --- src/diffusers/loaders/lora_base.py | 7 +------ src/diffusers/loaders/unet.py | 7 +------ tests/pipelines/test_pipelines_common.py | 25 +++++++++++++++++++++++- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a13f8c20112a..10e04883f5ec 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,13 +25,13 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict +from ..pipelines.pipeline_loading_utils import model_has_device_map from ..utils import ( USE_PEFT_BACKEND, _get_model_file, delete_adapter_layers, deprecate, is_accelerate_available, - is_accelerate_version, is_peft_available, is_transformers_available, logging, @@ -215,11 +215,6 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - return getattr(model, "hf_device_map", None) is not None - if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if ( diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 55b1a24e60db..16b2579a4039 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -32,6 +32,7 @@ MultiIPAdapterImageProjection, ) from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict +from ..pipelines.pipeline_loading_utils import model_has_device_map from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -39,7 +40,6 @@ get_adapter_name, get_peft_kwargs, is_accelerate_available, - is_accelerate_version, is_peft_version, is_torch_version, logging, @@ -399,11 +399,6 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - return getattr(model, "hf_device_map", None) is not None - if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): if ( diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f5ceda8f2703..efcfe80a42ca 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -32,12 +32,13 @@ from diffusers.loaders import IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor from diffusers.models.controlnet_xs import UNetControlNetXSModel +from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import logging +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, @@ -1987,6 +1988,28 @@ def test_calling_sco_raises_error_device_mapped_components(self, safe_serializat and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception) ) + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From 55dc9368619d7d649c5dba65bb6396504d9f78df Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:14:35 +0530 Subject: [PATCH 03/10] updates --- src/diffusers/pipelines/pipeline_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 3d4004db984f..769bc84216f4 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -951,7 +951,7 @@ def _get_ignore_patterns( def model_has_device_map(model): - if not (is_accelerate_available() and not is_accelerate_version("<", "0.14.0")): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False # Check if the model has a device map that is not exclusively CPU From 222e70c4a95b5fa9524839513aed7cf762cbc7a1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:19:08 +0530 Subject: [PATCH 04/10] more robust condition. --- tests/pipelines/test_pipelines_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index efcfe80a42ca..9bfe18a8f48e 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1993,7 +1993,9 @@ def test_sharded_components_can_be_device_placed(self): component_selected = None for component_name in components: - if isinstance(components[component_name], ModelMixin): + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): component_to_be_sharded = components[component_name] component_cls = component_to_be_sharded.__class__ component_selected = component_name From 94069ffabf11592184b944cbb397d97eb504c707 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:19:49 +0530 Subject: [PATCH 05/10] assertion. --- tests/pipelines/test_pipelines_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9bfe18a8f48e..bdd502aee2b6 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2000,6 +2000,9 @@ def test_sharded_components_can_be_device_placed(self): component_cls = component_to_be_sharded.__class__ component_selected = component_name break + + assert component_selected, "No component selected that can be sharded." + model_size = compute_module_sizes(component_to_be_sharded)[""] max_shard_size = int((model_size * 0.75) / (2**10)) From 36722c321b25e396f8c5dae8d48cb897e1381c94 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:21:14 +0530 Subject: [PATCH 06/10] quality --- tests/pipelines/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index bdd502aee2b6..359799a4cba9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2000,7 +2000,7 @@ def test_sharded_components_can_be_device_placed(self): component_cls = component_to_be_sharded.__class__ component_selected = component_name break - + assert component_selected, "No component selected that can be sharded." model_size = compute_module_sizes(component_to_be_sharded)[""] From 81efbc06a009de3a0b4be8277c1ef93a84922e82 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:32:02 +0530 Subject: [PATCH 07/10] circular imports. --- src/diffusers/loaders/lora_base.py | 3 ++- src/diffusers/loaders/unet.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 10e04883f5ec..d44d6b22ece2 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -25,7 +25,6 @@ from huggingface_hub.constants import HF_HUB_OFFLINE from ..models.modeling_utils import ModelMixin, load_state_dict -from ..pipelines.pipeline_loading_utils import model_has_device_map from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -212,6 +211,8 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ + from ..pipelines.pipeline_loading_utils import model_has_device_map + is_model_cpu_offload = False is_sequential_cpu_offload = False diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 16b2579a4039..03f6b6de7c7e 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -32,7 +32,6 @@ MultiIPAdapterImageProjection, ) from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict -from ..pipelines.pipeline_loading_utils import model_has_device_map from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -396,6 +395,8 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ + from ..pipelines.pipeline_loading_utils import model_has_device_map + is_model_cpu_offload = False is_sequential_cpu_offload = False From 7113453974ab3e1addab24be4d242ae3decb1278 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Oct 2024 22:47:09 +0530 Subject: [PATCH 08/10] fix --- .../kandinsky/test_kandinsky_prior.py | 38 ++++++++++++++- .../kandinsky2_2/test_kandinsky_prior.py | 38 ++++++++++++++- .../test_kandinsky_prior_emb2emb.py | 37 +++++++++++++++ .../stable_unclip/test_stable_unclip.py | 46 ++++++++++++++++++- tests/pipelines/unclip/test_unclip.py | 38 +++++++++++++++ 5 files changed, 194 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py index 5f42447bd9d5..628014e23ec8 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import numpy as np @@ -28,11 +30,16 @@ ) from diffusers import KandinskyPriorPipeline, PriorTransformer, UnCLIPScheduler -from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME +from diffusers.utils.testing_utils import enable_full_determinism, is_accelerate_available, skip_mps, torch_device from ..test_pipelines_common import PipelineTesterMixin +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + enable_full_determinism() @@ -236,3 +243,32 @@ def test_attention_slicing_forward_pass(self): test_max_difference=test_max_difference, test_mean_pixel_difference=test_mean_pixel_difference, ) + + # It needs a different sharding ratio than the standard 0.75. So, we override it. + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + + assert component_selected, "No component selected that can be sharded." + + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.45) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + print(f"{os.listdir(tmp_dir)}") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py index be0bc238d4da..349538d5f5de 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py @@ -14,6 +14,8 @@ # limitations under the License. import inspect +import os +import tempfile import unittest import numpy as np @@ -29,11 +31,16 @@ ) from diffusers import KandinskyV22PriorPipeline, PriorTransformer, UnCLIPScheduler -from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME +from diffusers.utils.testing_utils import enable_full_determinism, is_accelerate_available, skip_mps, torch_device from ..test_pipelines_common import PipelineTesterMixin +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + enable_full_determinism() @@ -277,3 +284,32 @@ def callback_inputs_test(pipe, i, t, callback_kwargs): output = pipe(**inputs)[0] assert output.abs().sum() == 0 + + # It needs a different sharding ratio than the standard 0.75. So, we override it. + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + + assert component_selected, "No component selected that can be sharded." + + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.45) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + print(f"{os.listdir(tmp_dir)}") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py index e898824e2d17..2412f652884a 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random +import tempfile import unittest import numpy as np @@ -30,9 +32,12 @@ ) from diffusers import KandinskyV22PriorEmb2EmbPipeline, PriorTransformer, UnCLIPScheduler +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + is_accelerate_available, skip_mps, torch_device, ) @@ -40,6 +45,9 @@ from ..test_pipelines_common import PipelineTesterMixin +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + enable_full_determinism() @@ -240,3 +248,32 @@ def test_attention_slicing_forward_pass(self): test_max_difference=test_max_difference, test_mean_pixel_difference=test_mean_pixel_difference, ) + + # It needs a different sharding ratio than the standard 0.75. So, we override it. + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + + assert component_selected, "No component selected that can be sharded." + + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.45) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + print(f"{os.listdir(tmp_dir)}") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index be5e3783ff5c..91cc03a1adac 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -1,4 +1,6 @@ import gc +import os +import tempfile import unittest import torch @@ -12,8 +14,17 @@ StableUnCLIPPipeline, UNet2DConditionModel, ) +from diffusers.models.modeling_utils import ModelMixin from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME +from diffusers.utils.testing_utils import ( + enable_full_determinism, + is_accelerate_available, + load_numpy, + nightly, + require_torch_gpu, + torch_device, +) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( @@ -24,6 +35,10 @@ ) +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + + enable_full_determinism() @@ -196,6 +211,35 @@ def test_calling_to_raises_error_device_mapped_components(self): def test_calling_sco_raises_error_device_mapped_components(self): pass + # It needs a different sharding ratio than the standard 0.75. So, we override it. + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + + assert component_selected, "No component selected that can be sharded." + + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.45) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + print(f"{os.listdir(tmp_dir)}") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) + @nightly @require_torch_gpu diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 07590c9db458..f7e32a01a595 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -14,6 +14,8 @@ # limitations under the License. import gc +import os +import tempfile import unittest import numpy as np @@ -21,9 +23,12 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel +from diffusers.models.modeling_utils import ModelMixin from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME from diffusers.utils.testing_utils import ( enable_full_determinism, + is_accelerate_available, load_numpy, nightly, require_torch_gpu, @@ -35,6 +40,10 @@ from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + + enable_full_determinism() @@ -418,6 +427,35 @@ def test_save_load_optional_components(self): def test_float16_inference(self): super().test_float16_inference(expected_max_diff=1.0) + # It needs a different sharding ratio than the standard 0.75. So, we override it. + def test_sharded_components_can_be_device_placed(self): + components = self.get_dummy_components() + + component_selected = None + for component_name in components: + if isinstance(components[component_name], ModelMixin) and hasattr( + components[component_name], "load_config" + ): + component_to_be_sharded = components[component_name] + component_cls = component_to_be_sharded.__class__ + component_selected = component_name + break + + assert component_selected, "No component selected that can be sharded." + + model_size = compute_module_sizes(component_to_be_sharded)[""] + max_shard_size = int((model_size * 0.45) / (2**10)) + + with tempfile.TemporaryDirectory() as tmp_dir: + component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") + print(f"{os.listdir(tmp_dir)}") + self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + loaded_sharded_component = component_cls.from_pretrained(tmp_dir) + _ = components.pop(component_selected) + components.update({component_selected: loaded_sharded_component}) + _ = self.pipeline_class(**components).to(torch_device) + @nightly class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase): From 08b9533d6560a9bf036ec5b48b6a2866b3f792e2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Nov 2024 07:37:30 +0530 Subject: [PATCH 09/10] remove gaps. --- src/diffusers/loaders/lora_base.py | 1 - src/diffusers/loaders/unet.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 715196eec705..e124b6eeacf3 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -211,7 +211,6 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ - is_model_cpu_offload = False is_sequential_cpu_offload = False diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 29c2a74f345a..2fa7732a6a3b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -395,7 +395,6 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ - is_model_cpu_offload = False is_sequential_cpu_offload = False From aeb969fc685ac15f1cea3098b4247dd2ec30c541 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Nov 2024 07:38:54 +0530 Subject: [PATCH 10/10] remove prints. --- tests/pipelines/kandinsky/test_kandinsky_prior.py | 1 - tests/pipelines/kandinsky2_2/test_kandinsky_prior.py | 1 - tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py | 1 - tests/pipelines/stable_unclip/test_stable_unclip.py | 1 - tests/pipelines/unclip/test_unclip.py | 1 - 5 files changed, 5 deletions(-) diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py index 628014e23ec8..7545ec5bb5d3 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -265,7 +265,6 @@ def test_sharded_components_can_be_device_placed(self): with tempfile.TemporaryDirectory() as tmp_dir: component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - print(f"{os.listdir(tmp_dir)}") self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) loaded_sharded_component = component_cls.from_pretrained(tmp_dir) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py index 349538d5f5de..55dbb302b274 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py @@ -306,7 +306,6 @@ def test_sharded_components_can_be_device_placed(self): with tempfile.TemporaryDirectory() as tmp_dir: component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - print(f"{os.listdir(tmp_dir)}") self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) loaded_sharded_component = component_cls.from_pretrained(tmp_dir) diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py index 2412f652884a..751a667e19f9 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py @@ -270,7 +270,6 @@ def test_sharded_components_can_be_device_placed(self): with tempfile.TemporaryDirectory() as tmp_dir: component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - print(f"{os.listdir(tmp_dir)}") self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) loaded_sharded_component = component_cls.from_pretrained(tmp_dir) diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 91cc03a1adac..9740d28b0b14 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -232,7 +232,6 @@ def test_sharded_components_can_be_device_placed(self): with tempfile.TemporaryDirectory() as tmp_dir: component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - print(f"{os.listdir(tmp_dir)}") self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) loaded_sharded_component = component_cls.from_pretrained(tmp_dir) diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index f7e32a01a595..a5fe670105a1 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -448,7 +448,6 @@ def test_sharded_components_can_be_device_placed(self): with tempfile.TemporaryDirectory() as tmp_dir: component_to_be_sharded.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB") - print(f"{os.listdir(tmp_dir)}") self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) loaded_sharded_component = component_cls.from_pretrained(tmp_dir)