Skip to content

Commit

Permalink
runner: Use HF token to access gated models (i.e. SVD1.1)
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 19, 2024
1 parent 3d35574 commit 91eecd7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
model_data = model_info(model_id)
# TODO: Move check offline so token is unnecessary when model is cached
model_data = model_info(model_id, token=os.environ.get("HF_TOKEN"))
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
Expand Down
3 changes: 2 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
model_data = model_info(model_id)
# TODO: Move check offline so token is unnecessary when model is cached
model_data = model_info(model_id, token=os.environ.get("HF_TOKEN"))
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
Expand Down
3 changes: 2 additions & 1 deletion runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
model_data = model_info(model_id)
# TODO: Move check offline so token is unnecessary when model is cached
model_data = model_info(model_id, token=os.environ.get("HF_TOKEN"))
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
Expand Down
3 changes: 2 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ huggingface-cli download runwayml/stable-diffusion-v1-5 --include "*.fp16.safete
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
# image-to-video
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt-1-1 --include "*.fp16.safetensors" "*.json" --token=$HF_TOKEN --cache-dir models

0 comments on commit 91eecd7

Please sign in to comment.