Skip to content

Commit

Permalink
Add up to four inputs and outputs in cog demo
Browse files Browse the repository at this point in the history
And clean up predict.py based on review feedback
  • Loading branch information
jd7h committed Jan 17, 2024
1 parent db00dff commit 36dadb3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Expand Up @@ -18,6 +18,7 @@ __pycache__

# Exclude output files
/outputs
output*.png

# Exclude models cache
/models
36 changes: 23 additions & 13 deletions predict.py
Expand Up @@ -9,7 +9,6 @@
import os
import shutil
import subprocess
from typing import List, Union
import time

os.environ["HF_HUB_CACHE"] = "models"
Expand All @@ -22,9 +21,7 @@
)

from huggingface_hub import hf_hub_download
# import spaces

# import gradio as gr
from transformers import CLIPImageProcessor

from photomaker import PhotoMakerStableDiffusionXLPipeline
Expand Down Expand Up @@ -63,7 +60,6 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
# self.model = torch.load("./weights.pth")

self.device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -101,15 +97,25 @@ def setup(self) -> None:
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config
)
# pipe.set_adapters(["photomaker"], adapter_weights=[1.0])
self.pipe.fuse_lora()

@torch.inference_mode()
def predict(
self,
# from ChatGPT
input_image: Path = Input(
description="The input image, a photo of your face"
description="The input image, for example a photo of your face."
),
input_image2: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image3: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image4: Path = Input(
description="Additional input image (optional)",
default=None
),
prompt: str = Input(
description="Prompt. Example: 'a photo of a man/woman img'. The phrase 'img' is the trigger word.",
Expand All @@ -130,9 +136,9 @@ def predict(
style_strength_ratio: float = Input(
description="Style strength (%)", default=20, ge=15, le=50
),
#num_outputs: int = Input(
# description="Number of output images", default=1, ge=1, le=4
#),
num_outputs: int = Input(
description="Number of output images", default=1, ge=1, le=4
),
guidance_scale: float = Input(
description="Guidance scale", default=5, ge=0.1, le=10.0
),
Expand All @@ -141,7 +147,7 @@ def predict(
description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)",
default=False
)
) -> List[Path]:
) -> list[Path]:
"""Run a single prediction on the model"""
# remove old outputs
output_folder = Path('outputs')
Expand Down Expand Up @@ -169,7 +175,11 @@ def predict(
# apply the style template
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)

input_id_images = load_image(str(input_image))
# load the input images
input_id_images = []
for maybe_image in [input_image, input_image2, input_image3, input_image4]:
if maybe_image:
input_id_images.append(load_image(str(maybe_image)))

generator = torch.Generator(device=self.device).manual_seed(seed)

Expand All @@ -184,7 +194,7 @@ def predict(
prompt=prompt,
input_id_images=input_id_images,
negative_prompt=negative_prompt,
num_images_per_prompt=1, # used to be: num_outputs but currently we accept only one input
num_images_per_prompt=num_outputs,
num_inference_steps=num_steps,
start_merge_step=start_merge_step,
generator=generator,
Expand Down

0 comments on commit 36dadb3

Please sign in to comment.