Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,12 @@ def _get_pipeline_class(
return class_obj

diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
pipeline_cls = getattr(diffusers_module, config["_class_name"])
class_name = config["_class_name"]

if class_name.startswith("Flax"):
class_name = class_name[4:]

pipeline_cls = getattr(diffusers_module, class_name)

if load_connected_pipeline:
from .auto_pipeline import _get_connected_pipeline
Expand Down
10 changes: 9 additions & 1 deletion tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
UniPCMultistepScheduler,
logging,
)
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -805,6 +805,14 @@ def test_download_ignore_files(self):
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
assert len(files) == 14

def test_get_pipeline_class_from_flax(self):
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
config = {"_class_name": "StableDiffusionPipeline"}

# when loading a PyTorch Pipeline from a FlaxPipeline `model_index.json`, e.g.: https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-lms-pipe/blob/7a9063578b325779f0f1967874a6771caa973cad/model_index.json#L2
# we need to make sure that we don't load the Flax Pipeline class, but instead the PyTorch pipeline class
assert _get_pipeline_class(DiffusionPipeline, flax_config) == _get_pipeline_class(DiffusionPipeline, config)


class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
Expand Down