Skip to content

Commit

Permalink
Merge pull request #1 from Pau-Lozano/main
Browse files Browse the repository at this point in the history
Adds batched prompting option
  • Loading branch information
lucataco committed Nov 8, 2023
2 parents 418193e + 132a967 commit 51c0018
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pycache__
.cog
safety-cache
sdxl-cache
.idea
16 changes: 14 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def predict(
description="Negative Input prompt",
default="scary, cartoon, painting"
),
batched_prompt: bool = Input(
description="When active, your prompt will be split by newlines and images will be generated for each individual line",
default=False
),
image: Path = Input(
description="Input image for img2img or inpaint mode",
default=None,
Expand Down Expand Up @@ -230,13 +234,21 @@ def predict(
generator = torch.Generator("cuda").manual_seed(seed)

common_args = {
"prompt": [prompt] * num_outputs,
"negative_prompt": [negative_prompt] * num_outputs,
"guidance_scale": guidance_scale,
"generator": generator,
"num_inference_steps": num_inference_steps,
}

if batched_prompt:
print("Batch of prompts mode")
sdxl_kwargs["prompt"] = prompt.strip().splitlines() * num_outputs
sdxl_kwargs["negative_prompt"] = negative_prompt.strip().splitlines() * num_outputs
while (len(sdxl_kwargs["prompt"]) > len(sdxl_kwargs["negative_prompt"])) :
sdxl_kwargs["negative_prompt"].append("")
else:
sdxl_kwargs["prompt"] = [prompt] * num_outputs
sdxl_kwargs["negative_prompt"] = [negative_prompt] * num_outputs

if self.is_lora:
sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale}

Expand Down

0 comments on commit 51c0018

Please sign in to comment.