Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,16 +1894,24 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")

# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
has_valid_url_prefix = False
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
for prefix in valid_url_prefixes:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
has_valid_url_prefix = True

# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
if not has_valid_url_prefix:
raise ValueError(
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
)

# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = os.path.join(*ckpt_path.parts[:2])
file_path = os.path.join(*ckpt_path.parts[2:])
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])

if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
Expand Down