diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index fb120ebc7d3b..b357fe6a8b29 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -342,7 +342,12 @@ def _get_pipeline_class( return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) - pipeline_cls = getattr(diffusers_module, config["_class_name"]) + class_name = config["_class_name"] + + if class_name.startswith("Flax"): + class_name = class_name[4:] + + pipeline_cls = getattr(diffusers_module, class_name) if load_connected_pipeline: from .auto_pipeline import _get_connected_pipeline diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 927c5ec28518..5a0c300c60c4 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -57,7 +57,7 @@ UniPCMultistepScheduler, logging, ) -from diffusers.pipelines.pipeline_utils import variant_compatible_siblings +from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import ( CONFIG_NAME, @@ -805,6 +805,14 @@ def test_download_ignore_files(self): assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) assert len(files) == 14 + def test_get_pipeline_class_from_flax(self): + flax_config = {"_class_name": "FlaxStableDiffusionPipeline"} + config = {"_class_name": "StableDiffusionPipeline"} + + # when loading a PyTorch Pipeline from a FlaxPipeline `model_index.json`, e.g.: https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-lms-pipe/blob/7a9063578b325779f0f1967874a6771caa973cad/model_index.json#L2 + # we need to make sure that we don't load the Flax Pipeline class, but instead the PyTorch pipeline class + assert _get_pipeline_class(DiffusionPipeline, flax_config) == _get_pipeline_class(DiffusionPipeline, config) + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self):