diff --git a/.dockerignore b/.dockerignore index 2e6c19b..65a3fe0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -18,6 +18,7 @@ __pycache__ # Exclude output files /outputs +output*.png # Exclude models cache /models diff --git a/predict.py b/predict.py index b73f3ca..51acb4c 100644 --- a/predict.py +++ b/predict.py @@ -9,7 +9,6 @@ import os import shutil import subprocess -from typing import List, Union import time os.environ["HF_HUB_CACHE"] = "models" @@ -22,9 +21,7 @@ ) from huggingface_hub import hf_hub_download -# import spaces -# import gradio as gr from transformers import CLIPImageProcessor from photomaker import PhotoMakerStableDiffusionXLPipeline @@ -63,7 +60,6 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str class Predictor(BasePredictor): def setup(self) -> None: """Load the model into memory to make running multiple predictions efficient""" - # self.model = torch.load("./weights.pth") self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -101,15 +97,25 @@ def setup(self) -> None: self.pipe.scheduler = EulerDiscreteScheduler.from_config( self.pipe.scheduler.config ) - # pipe.set_adapters(["photomaker"], adapter_weights=[1.0]) self.pipe.fuse_lora() @torch.inference_mode() def predict( self, - # from ChatGPT input_image: Path = Input( - description="The input image, a photo of your face" + description="The input image, for example a photo of your face." + ), + input_image2: Path = Input( + description="Additional input image (optional)", + default=None + ), + input_image3: Path = Input( + description="Additional input image (optional)", + default=None + ), + input_image4: Path = Input( + description="Additional input image (optional)", + default=None ), prompt: str = Input( description="Prompt. Example: 'a photo of a man/woman img'. The phrase 'img' is the trigger word.", @@ -130,9 +136,9 @@ def predict( style_strength_ratio: float = Input( description="Style strength (%)", default=20, ge=15, le=50 ), - #num_outputs: int = Input( - # description="Number of output images", default=1, ge=1, le=4 - #), + num_outputs: int = Input( + description="Number of output images", default=1, ge=1, le=4 + ), guidance_scale: float = Input( description="Guidance scale", default=5, ge=0.1, le=10.0 ), @@ -141,7 +147,7 @@ def predict( description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)", default=False ) - ) -> List[Path]: + ) -> list[Path]: """Run a single prediction on the model""" # remove old outputs output_folder = Path('outputs') @@ -169,7 +175,11 @@ def predict( # apply the style template prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt) - input_id_images = load_image(str(input_image)) + # load the input images + input_id_images = [] + for maybe_image in [input_image, input_image2, input_image3, input_image4]: + if maybe_image: + input_id_images.append(load_image(str(maybe_image))) generator = torch.Generator(device=self.device).manual_seed(seed) @@ -184,7 +194,7 @@ def predict( prompt=prompt, input_id_images=input_id_images, negative_prompt=negative_prompt, - num_images_per_prompt=1, # used to be: num_outputs but currently we accept only one input + num_images_per_prompt=num_outputs, num_inference_steps=num_steps, start_merge_step=start_merge_step, generator=generator,