Skip to content

Commit

Permalink
Add support for the timbrooks/instruct-pix2pix model
Browse files Browse the repository at this point in the history
  • Loading branch information
stronk-dev committed Apr 11, 2024
1 parent 74e31fc commit 77c1c14
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
6 changes: 6 additions & 0 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from diffusers import (
AutoPipelineForImage2Image,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
Expand All @@ -22,6 +23,7 @@
logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"
PIX2PIX_MODEL_ID = "timbrooks/instruct-pix2pix"


class ImageToImagePipeline(Pipeline):
Expand Down Expand Up @@ -87,6 +89,10 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif PIX2PIX_MODEL_ID in model_id:
self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
else:
self.ldm = AutoPipelineForImage2Image.from_pretrained(
model_id, **kwargs
Expand Down
4 changes: 4 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ if [ "$MODE" = "alpha" ]; then

# Download text-to-image and image-to-image models.
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models
huggingface-cli download timbrooks/instruct-pix2pix --include "*fp16.safetensors" --exclude "*lora.safetensors*" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
Expand All @@ -80,6 +81,9 @@ else
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download text-to-video models.
huggingface-cli download ali-vilab/text-to-video-ms-1.7b --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand Down

0 comments on commit 77c1c14

Please sign in to comment.