diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 9bc25155a0b6..a67fa9d41ca5 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -485,10 +485,18 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove attributes from orig class that cannot be expected orig_cls_name = config_dict.pop("_class_name", cls.__name__) - if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): + if ( + isinstance(orig_cls_name, str) + and orig_cls_name != cls.__name__ + and hasattr(diffusers_library, orig_cls_name) + ): orig_cls = getattr(diffusers_library, orig_cls_name) unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} + elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)): + raise ValueError( + "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)." + ) # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index bad23a60293f..512cf8d56718 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -33,8 +33,6 @@ from requests.exceptions import HTTPError from tqdm.auto import tqdm -import diffusers - from .. import __version__ from ..configuration_utils import ConfigMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT @@ -305,13 +303,23 @@ def maybe_raise_or_warn( ) -def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): +def get_class_obj_and_candidates( + library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None +): """Simple helper method to retrieve class object of module as well as potential parent class objects""" + component_folder = os.path.join(cache_dir, component_name) + if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = {c: class_obj for c in importable_classes.keys()} + elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): + # load custom component + class_obj = get_class_from_dynamic_module( + component_folder, module_file=library_name + ".py", class_name=class_name + ) + class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) @@ -323,7 +331,15 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p def _get_pipeline_class( - class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None + class_obj, + config, + load_connected_pipeline=False, + custom_pipeline=None, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, ): if custom_pipeline is not None: if custom_pipeline.endswith(".py"): @@ -331,11 +347,19 @@ def _get_pipeline_class( # decompose into folder & file file_name = path.name custom_pipeline = path.parent.absolute() + elif repo_id is not None: + file_name = f"{custom_pipeline}.py" + custom_pipeline = repo_id else: file_name = CUSTOM_PIPELINE_FILE_NAME return get_class_from_dynamic_module( - custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision + custom_pipeline, + module_file=file_name, + class_name=class_name, + repo_id=repo_id, + cache_dir=cache_dir, + revision=revision if hub_revision is None else hub_revision, ) if class_obj != DiffusionPipeline: @@ -383,11 +407,18 @@ def load_sub_model( variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], + revision: str = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( - library_name, class_name, importable_classes, pipelines, is_pipeline_module + library_name, + class_name, + importable_classes, + pipelines, + is_pipeline_module, + component_name=name, + cache_dir=cached_folder, ) load_method_name = None @@ -414,14 +445,15 @@ def load_sub_model( load_method = getattr(class_obj, load_method_name) # add kwargs to loading method + diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options - is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) + is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) @@ -501,7 +533,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): def register_modules(self, **kwargs): # import it here to avoid circular import - from diffusers import pipelines + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") for name, module in kwargs.items(): # retrieve library @@ -1080,11 +1113,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it + custom_class_name = None + if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")): + custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py") + elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile( + os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") + ): + custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py") + custom_class_name = config_dict["_class_name"][1] + pipeline_class = _get_pipeline_class( cls, config_dict, load_connected_pipeline=load_connected_pipeline, custom_pipeline=custom_pipeline, + class_name=custom_class_name, cache_dir=cache_dir, revision=custom_revision, ) @@ -1223,6 +1266,7 @@ def load_module(name, value): variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, + revision=revision, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." @@ -1542,6 +1586,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending with `.onnx` and `.pb`. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This + option should only be set to `True` for repositories you trust and in which you have read the code, as + it will execute code present on the Hub on your local machine. Returns: `os.PathLike`: @@ -1569,6 +1617,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + trust_remote_code = kwargs.pop("trust_remote_code", False) allow_pickle = False if use_safetensors is None: @@ -1604,15 +1653,34 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) config_dict = cls._dict_from_json_file(config_file) - ignore_filenames = config_dict.pop("_ignore_files", []) # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None: + continue + + candidate_file = os.path.join(component, module_candidate + ".py") + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + if len(variant_filenames) == 0 and variant is not None: deprecation_message = ( f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." @@ -1636,12 +1704,21 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} + custom_class_name = None + if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): + custom_pipeline = config_dict["_class_name"][0] + custom_class_name = config_dict["_class_name"][1] + # all filenames compatible with variant will be added allow_patterns = list(model_filenames) # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # add custom component files + allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] + # add custom pipeline file + allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] @@ -1652,12 +1729,32 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: CUSTOM_PIPELINE_FILE_NAME, ] + load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames + load_components_from_hub = len(custom_components) > 0 + + if load_pipe_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + if load_components_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " + f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( cls, config_dict, load_connected_pipeline=load_connected_pipeline, custom_pipeline=custom_pipeline, + repo_id=pretrained_model_name if load_pipe_from_hub else None, + hub_revision=revision, + class_name=custom_class_name, cache_dir=cache_dir, revision=custom_revision, ) @@ -1754,9 +1851,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # retrieve pipeline class from local file cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) - cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name + cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name - pipeline_class = getattr(diffusers, cls_name, None) + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None if pipeline_class is not None and pipeline_class._load_connected_pipes: modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 13861b581c9b..875fd787c8b0 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -862,6 +862,58 @@ def test_run_custom_pipeline(self): # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 assert output_str == "This is a test" + def test_remote_components(self): + # make sure that trust remote code has to be passed + with self.assertRaises(ValueError): + pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components") + + # Check that only loading custom componets "my_unet", "my_scheduler" works + pipeline = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True + ) + + assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel") + assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler") + assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline" + + pipeline = pipeline.to(torch_device) + images = pipeline("test", num_inference_steps=2, output_type="np")[0] + + assert images.shape == (1, 64, 64, 3) + + # Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works + pipeline = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-sdxl-custom-components", custom_pipeline="my_pipeline", trust_remote_code=True + ) + + assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel") + assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler") + assert pipeline.__class__.__name__ == "MyPipeline" + + pipeline = pipeline.to(torch_device) + images = pipeline("test", num_inference_steps=2, output_type="np")[0] + + assert images.shape == (1, 64, 64, 3) + + def test_remote_auto_custom_pipe(self): + # make sure that trust remote code has to be passed + with self.assertRaises(ValueError): + pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all") + + # Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works + pipeline = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True + ) + + assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel") + assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler") + assert pipeline.__class__.__name__ == "MyPipeline" + + pipeline = pipeline.to(torch_device) + images = pipeline("test", num_inference_steps=2, output_type="np")[0] + + assert images.shape == (1, 64, 64, 3) + def test_local_custom_pipeline_repo(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") pipeline = DiffusionPipeline.from_pretrained(