Skip to content

Allow passing a CACHE_MAPPING of previously downloaded components and their sha256 to avoid deduplicated downloading #1984

@patrickvonplaten

Description

@patrickvonplaten

A common problem of diffusers models/pipelines is that many components of different pipelines share the exact same underlying weights, but it's hard to avoid not downloading them twice. We could solve this problem by providing a cache_mapping: Dict[str, path] to DiffusionPipeline.from_pretrained(...) that would check if the file has previously been downloaded and if yes, it won't be downloaded again. If not, it will be downloaded and added to cache_mapping.

It's quite trivial to look up the sha256 hashes of files before downloading them, e.g.:

from huggingface_hub import model_info

info = model_info("runwayml/stable-diffusion-v1-5", files_metadata=True)
files = info.siblings
shas = {f.rfilename: f.lfs["sha256"] for f in files if f.lfs is not None}
shas

gives

{'safety_checker/pytorch_model.bin': '193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974',
 'text_encoder/pytorch_model.bin': '770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67',
 'unet/diffusion_pytorch_model.bin': 'c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4',
 'v1-5-pruned-emaonly.ckpt': 'cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516',
 'v1-5-pruned.ckpt': 'e1441589a6f3c5a53f5f54d0975a18a7feb7cdf0b0dee276dfc3331ae376a053',
 'vae/diffusion_pytorch_model.bin': '1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5'}

See colab here: https://colab.research.google.com/drive/1WGLdcgnzbIf_dn9QF51TRO_6ogEqVsea?usp=sharing

Now we could integrate this code quite easily into from_pretrained(...) since we're making a call to the Hub anyways already:

From the user API it could look as follows:

cache_mapping = {}

from diffusers import DiffusionPipeline

pipeline, cache_mapping = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", cache_mapping=cache_mapping}

# then cache mapping would look as follows:
# {"193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974": "./cache/.... <path/to/file>, ...}

pipeline, cache_mapping = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", cache_mapping=cache_mapping}  # now the safety checker won't be downloaded again.

cc @pcuenca @keturn @patil-suraj @anton-l what do you think?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions