Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin):
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
_is_onnx = False

def register_modules(self, **kwargs):
# import it here to avoid circular import
Expand Down Expand Up @@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
Expand Down Expand Up @@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.

Returns:
`os.PathLike`:
Expand All @@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if use_safetensors and not is_safetensors_available():
Expand Down Expand Up @@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)

model_folder_names = {os.path.split(f)[0] for f in model_filenames}
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
Expand Down Expand Up @@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
Copy link
Copy Markdown
Contributor

@echarlaix echarlaix Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set by default for all the OnnxStableDiffusionXxxPipeline ?

ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
Expand All @@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor

_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True

vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def preprocess(image):


class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
_is_onnx = True

def __init__(
self,
vae: OnnxRuntimeModel,
Expand Down
43 changes: 43 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ def test_download_bin_index(self):
assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files)

def test_download_no_openvino_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-open-vino",
cache_dir=tmpdirname,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# make sure that by default no openvino weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any("openvino_" in f for f in files)

def test_download_no_onnx_by_default(self):
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# make sure that by default no onnx weights are downloaded
assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files)

with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline",
cache_dir=tmpdirname,
use_onnx=True,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# if `use_onnx` is specified make sure weights are downloaded
assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files)
assert any((f.endswith(".onnx")) for f in files)
assert any((f.endswith(".pb")) for f in files)

def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
Expand Down