From 91eecd76451b0b7513a8ac49c966ac063759f212 Mon Sep 17 00:00:00 2001 From: Yondon Fu Date: Mon, 19 Feb 2024 21:42:29 +0000 Subject: [PATCH] runner: Use HF token to access gated models (i.e. SVD1.1) --- runner/app/pipelines/image_to_image.py | 3 ++- runner/app/pipelines/image_to_video.py | 3 ++- runner/app/pipelines/text_to_image.py | 3 ++- runner/dl_checkpoints.sh | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index a60e07ea..7f1450fb 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -21,7 +21,8 @@ def __init__(self, model_id: str): kwargs = {"cache_dir": get_model_dir()} torch_device = get_torch_device() - model_data = model_info(model_id) + # TODO: Move check offline so token is unnecessary when model is cached + model_data = model_info(model_id, token=os.environ.get("HF_TOKEN")) has_fp16_variant = any( ".fp16.safetensors" in file.rfilename for file in model_data.siblings ) diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 33d7f579..2e4e5122 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -21,7 +21,8 @@ def __init__(self, model_id: str): kwargs = {"cache_dir": get_model_dir()} torch_device = get_torch_device() - model_data = model_info(model_id) + # TODO: Move check offline so token is unnecessary when model is cached + model_data = model_info(model_id, token=os.environ.get("HF_TOKEN")) has_fp16_variant = any( ".fp16.safetensors" in file.rfilename for file in model_data.siblings ) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 427eec84..09de3431 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -17,7 +17,8 @@ def __init__(self, model_id: str): kwargs = {"cache_dir": get_model_dir()} torch_device = get_torch_device() - model_data = model_info(model_id) + # TODO: Move check offline so token is unnecessary when model is cached + model_data = model_info(model_id, token=os.environ.get("HF_TOKEN")) has_fp16_variant = any( ".fp16.safetensors" in file.rfilename for file in model_data.siblings ) diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index d0672de6..3fafcca7 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -9,4 +9,5 @@ huggingface-cli download runwayml/stable-diffusion-v1-5 --include "*.fp16.safete huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models # image-to-video -huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models \ No newline at end of file +huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models +huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt-1-1 --include "*.fp16.safetensors" "*.json" --token=$HF_TOKEN --cache-dir models \ No newline at end of file