Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,18 @@ def extract_init_dict(cls, config_dict, **kwargs):

# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
if (
isinstance(orig_cls_name, str)
and orig_cls_name != cls.__name__
and hasattr(diffusers_library, orig_cls_name)
):
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
raise ValueError(
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
)

# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
Expand Down
124 changes: 111 additions & 13 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from requests.exceptions import HTTPError
from tqdm.auto import tqdm

import diffusers

from .. import __version__
from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
Expand Down Expand Up @@ -305,13 +303,23 @@ def maybe_raise_or_warn(
)


def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)

if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)

class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
Expand All @@ -323,19 +331,35 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p


def _get_pipeline_class(
class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
class_obj,
config,
load_connected_pipeline=False,
custom_pipeline=None,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

return get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
custom_pipeline,
module_file=file_name,
class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir,
revision=revision if hub_revision is None else hub_revision,
)

if class_obj != DiffusionPipeline:
Expand Down Expand Up @@ -383,11 +407,18 @@ def load_sub_model(
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
revision: str = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
)

load_method_name = None
Expand All @@ -414,14 +445,15 @@ def load_sub_model(
load_method = getattr(class_obj, load_method_name)

# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options

is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)

if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
Expand Down Expand Up @@ -501,7 +533,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):

def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

for name, module in kwargs.items():
# retrieve library
Expand Down Expand Up @@ -1080,11 +1113,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
custom_class_name = None
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
):
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
custom_class_name = config_dict["_class_name"][1]

pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
Expand Down Expand Up @@ -1223,6 +1266,7 @@ def load_module(name, value):
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
revision=revision,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down Expand Up @@ -1542,6 +1586,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.

Returns:
`os.PathLike`:
Expand Down Expand Up @@ -1569,6 +1617,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)

allow_pickle = False
if use_safetensors is None:
Expand Down Expand Up @@ -1604,15 +1653,34 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)

config_dict = cls._dict_from_json_file(config_file)

ignore_filenames = config_dict.pop("_ignore_files", [])

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None:
continue

candidate_file = os.path.join(component, module_candidate + ".py")

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
Expand All @@ -1636,12 +1704,21 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]

# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)

# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]

Expand All @@ -1652,12 +1729,32 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
CUSTOM_PIPELINE_FILE_NAME,
]

load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0

if load_pipe_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
repo_id=pretrained_model_name if load_pipe_from_hub else None,
hub_revision=revision,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
Expand Down Expand Up @@ -1754,9 +1851,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

# retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name

pipeline_class = getattr(diffusers, cls_name, None)
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None

if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
Expand Down
52 changes: 52 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,58 @@ def test_run_custom_pipeline(self):
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"

def test_remote_components(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")

# Check that only loading custom componets "my_unet", "my_scheduler" works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
)

assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

assert images.shape == (1, 64, 64, 3)

# Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-components", custom_pipeline="my_pipeline", trust_remote_code=True
)

assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

assert images.shape == (1, 64, 64, 3)

def test_remote_auto_custom_pipe(self):
# make sure that trust remote code has to be passed
with self.assertRaises(ValueError):
pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")

# Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
pipeline = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
)

assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
assert pipeline.__class__.__name__ == "MyPipeline"

pipeline = pipeline.to(torch_device)
images = pipeline("test", num_inference_steps=2, output_type="np")[0]

assert images.shape == (1, 64, 64, 3)

def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(
Expand Down