Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline download] Improve pipeline download for index and passed co… #2980

Merged
merged 4 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 96 additions & 35 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray


def is_safetensors_compatible(filenames, variant=None) -> bool:
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
Expand All @@ -150,9 +150,14 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:

sf_filenames = set()

passed_components = passed_components or []

for filename in filenames:
_, extension = os.path.splitext(filename)

if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
continue

if extension == ".bin":
pt_filenames.append(filename)
elif extension == ".safetensors":
Expand All @@ -163,10 +168,8 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)

if filename == "pytorch_model":
filename = "model"
elif filename == f"pytorch_model.{variant}":
filename = f"model.{variant}"
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename

Expand Down Expand Up @@ -196,24 +199,51 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = "\d{5}-of-\d{5}"

if variant is not None:
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
variant_file_re = re.compile(
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
# `text_encoder/pytorch_model.bin.index.fp16.json`
variant_index_re = re.compile(
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)

variant_file_regex = (
re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})")
if variant is not None
else None
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
non_variant_file_re = re.compile(
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}")
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")

if variant is not None:
variant_filenames = {f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None}
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()

non_variant_filenames = {f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None}
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes

# all variant filenames will be used by default
usable_filenames = set(variant_filenames)

def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename

for f in non_variant_filenames:
variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}"
variant_filename = convert_to_variant(f)
if variant_filename not in usable_filenames:
usable_filenames.add(f)

Expand Down Expand Up @@ -292,6 +322,27 @@ def get_class_obj_and_candidates(library_name, class_name, importable_classes, p
return class_obj, class_candidates


def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

return get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
)

if class_obj != DiffusionPipeline:
return class_obj

diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
return getattr(diffusers_module, config["_class_name"])


def load_sub_model(
library_name: str,
class_name: str,
Expand Down Expand Up @@ -779,7 +830,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)

# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
Expand All @@ -794,8 +845,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
revision=revision,
from_flax=from_flax,
use_safetensors=use_safetensors,
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
**kwargs,
)
else:
cached_folder = pretrained_model_name_or_path
Expand All @@ -810,29 +864,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path))
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME

pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
)
elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
pipeline_class = _get_pipeline_class(
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
)

# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
Expand Down Expand Up @@ -1095,6 +1137,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)

Expand Down Expand Up @@ -1153,7 +1196,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names]
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names]
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]

allow_patterns += [
SCHEDULER_CONFIG_NAME,
Expand All @@ -1162,17 +1205,28 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
CUSTOM_PIPELINE_FILE_NAME,
]

# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, variant=variant)
and not is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
)
):
raise EnvironmentError(
f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, variant=variant):
elif use_safetensors and is_safetensors_compatible(
model_filenames, variant=variant, passed_components=passed_components
):
ignore_patterns = ["*.bin", "*.msgpack"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
Expand All @@ -1194,6 +1248,13 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)

# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
]
# Don't download index files of forbidden patterns either
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]

re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]

Expand Down
128 changes: 125 additions & 3 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def test_one_request_upon_cached(self):

with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)
DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-pipe", cache_dir=tmpdirname)

download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 15, "15 calls to files"
Expand All @@ -101,6 +99,55 @@ def test_one_request_upon_cached(self):
len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"

def test_less_downloads_passed_object(self):
with tempfile.TemporaryDirectory() as tmpdirname:
cached_folder = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)

# make sure safety checker is not downloaded
assert "safety_checker" not in os.listdir(cached_folder)

# make sure rest is downloaded
assert "unet" in os.listdir(cached_folder)
assert "tokenizer" in os.listdir(cached_folder)
assert "vae" in os.listdir(cached_folder)
assert "model_index.json" in os.listdir(cached_folder)
assert "scheduler" in os.listdir(cached_folder)
assert "feature_extractor" in os.listdir(cached_folder)

def test_less_downloads_passed_object_calls(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
if torch_device == "mps":
return

with tempfile.TemporaryDirectory() as tmpdirname:
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)

download_requests = [r.method for r in m.request_history]
# 15 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("HEAD") == 13, "13 calls to files"
# 17 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
assert (
len(download_requests) == 28
), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"

with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)

cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
assert (
len(cache_requests) == 2
), "We should call only `model_info` to check for _commit hash and `send_telemetry`"

def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
Expand Down Expand Up @@ -165,6 +212,54 @@ def test_download_safetensors(self):
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert not any(f.endswith(".bin") for f in files)

def test_download_safetensors_index(self):
for variant in ["fp16", None]:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
cache_dir=tmpdirname,
use_safetensors=True,
variant=variant,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# None of the downloaded files should be a safetensors file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder
if variant is None:
assert not any("fp16" in f for f in files)
else:
model_files = [f for f in files if "safetensors" in f]
assert all("fp16" in f for f in model_files)

assert len([f for f in files if ".safetensors" in f]) == 8
assert not any(".bin" in f for f in files)

def test_download_bin_index(self):
for variant in ["fp16", None]:
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = DiffusionPipeline.download(
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
cache_dir=tmpdirname,
use_safetensors=False,
variant=variant,
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))]
files = [item for sublist in all_root_files for item in sublist]

# None of the downloaded files should be a safetensors file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-indexes/tree/main/text_encoder
if variant is None:
assert not any("fp16" in f for f in files)
else:
model_files = [f for f in files if "bin" in f]
assert all("fp16" in f for f in model_files)

assert len([f for f in files if ".bin" in f]) == 8
assert not any(".safetensors" in f for f in files)

def test_download_no_safety_checker(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -362,6 +457,33 @@ def test_download_broken_variant(self):

diffusers.utils.import_utils._safetensors_available = True

def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
for use_safe in [True, False]:
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe-indexes",
variant=variant,
use_safetensors=use_safe,
safety_checker=None,
)
pipe = pipe.to(torch_device)
generator = torch.manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(
tmpdirname, safe_serialization=use_safe, variant=variant
)
pipe_2 = pipe_2.to(torch_device)

generator = torch.manual_seed(0)

out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images

assert np.max(np.abs(out - out_2)) < 1e-3

def test_text_inversion_download(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
Expand Down