Skip to content

Commit

Permalink
update Replicate demo
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxwh committed Mar 11, 2024
1 parent 9b94e01 commit 3701a18
Showing 1 changed file with 84 additions and 70 deletions.
154 changes: 84 additions & 70 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@
import time
from cog import BasePredictor, Input, Path, BaseModel
import torch
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusionXLPipeline
from diffusers import (
AutoPipelineForText2Image,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
UNet2DConditionModel,
StableDiffusionXLPipeline,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/Lykon/dreamshaper-xl-1-0.tar"
MODEL_WEIGHTS = "pretrained/Lykon/dreamshaper-xl-1-0"
SDXL_MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/Lykon/dreamshaper-xl-1-0.tar"
SDXL_MODEL_WEIGHTS = "pretrained/Lykon/dreamshaper-xl-1-0"
SD15_MODEL_URL = "https://weights.replicate.delivery/default/res-adapter/dreamlike-art/dreamlike-diffusion-1.0.tar"
SD15_MODEL_WEIGHTS = "pretrained/dreamlike-art/dreamlike-diffusion-1.0"

# For SDXL, SDXL-Lightning, dreamshaper-xl-1-0,
# For SDv1.5, dreamlike-diffusion-1.0

class ModelOutput(BaseModel):
without_res_adapter: Optional[Path]
Expand All @@ -39,30 +45,60 @@ def download_weights(url, dest, extract=True):
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_WEIGHTS):
download_weights(MODEL_URL, MODEL_WEIGHTS)
self.default_pipe = AutoPipelineForText2Image.from_pretrained(
MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16"
if not os.path.exists(SDXL_MODEL_WEIGHTS):
download_weights(SDXL_MODEL_URL, SDXL_MODEL_WEIGHTS)
if not os.path.exists(SD15_MODEL_WEIGHTS):
download_weights(SD15_MODEL_URL, SD15_MODEL_WEIGHTS)

# load "Lykon/dreamshaper-xl-1-0"
self.sdxl_pipe = AutoPipelineForText2Image.from_pretrained(
SDXL_MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16"
)
self.default_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.default_pipe.scheduler.config,
self.sdxl_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.sdxl_pipe.scheduler.config,
use_karras_sigmas=True,
algorithm_type="sde-dpmsolver++",
)
self.default_pipe = self.default_pipe.to("cuda")
self.sdxl_pipe = self.sdxl_pipe.to("cuda")

# load "ByteDance/SDXL-Lightning"
self.sdxl_lightning_pipe = AutoPipelineForText2Image.from_pretrained(
SDXL_MODEL_WEIGHTS, torch_dtype=torch.float16, variant="fp16"
)
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"
# Load SDXL-Lightning to UNet
unet = self.sdxl_lightning_pipe.unet
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
# Change UNet to pipeline
self.sdxl_lightning_pipe.unet = unet
self.sdxl_lightning_pipe.scheduler = EulerDiscreteScheduler.from_config(
self.sdxl_lightning_pipe.scheduler.config, timestep_spacing="trailing"
)
self.sdxl_lightning_pipe = self.sdxl_lightning_pipe.to("cuda")

# load "dreamlike-art/dreamlike-diffusion-1.0"
self.sd15_pipe = AutoPipelineForText2Image.from_pretrained(
SD15_MODEL_WEIGHTS
) # fp16 not available for "dreamlike-art/dreamlike-diffusion-1.0"
self.sd15_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.sd15_pipe.scheduler.config,
use_karras_sigmas=True,
algorithm_type="sde-dpmsolver++",
)
self.sd15_pipe = self.sd15_pipe.to("cuda")

@torch.inference_mode()
def predict(
self,
base_model: str = Input(
description="Choose a stable diffusion architecture, supporint sd1.5 and sdxl.",
default="sdxl",
choices=["sd1.5", "sdxl"],
),
model_name: str = Input(
description="Name of a stable diffusion model, should have either sd1.5 or sdxl architecture.",
description="Choose a stable diffusion model.",
default="ByteDance/SDXL-Lightning",
choice=["Lykon/dreamshaper-xl-1-0", "ByteDance/SDXL-Lightning", "dreamlike-art/dreamlike-diffusion-1.0"]
choices=[
"Lykon/dreamshaper-xl-1-0",
"ByteDance/SDXL-Lightning",
"dreamlike-art/dreamlike-diffusion-1.0",
],
),
prompt: str = Input(
description="Input prompt",
Expand All @@ -72,14 +108,8 @@ def predict(
description="Specify things to not see in the output",
default="ugly, deformed, noisy, blurry, nsfw, low contrast, text, BadDream, 3d, cgi, render, fake, anime, open mouth, big forehead, long neck",
),
width: int = Input(
description="Width of output image",
default=512,
),
height: int = Input(
description="Height of output image",
default=512,
),
width: int = Input(description="Width of output image", default=512),
height: int = Input(description="Height of output image", default=512),
num_inference_steps: int = Input(
description="Number of denoising steps", default=4
),
Expand All @@ -101,44 +131,25 @@ def predict(

generator = torch.Generator("cuda").manual_seed(seed)

base_model = (
"sd1.5" if model_name == "dreamlike-art/dreamlike-diffusion-1.0" else "sdxl"
)

if model_name == "Lykon/dreamshaper-xl-1-0":
self.pipe = self.default_pipe
pipe = self.sdxl_pipe

elif model_name == "ByteDance/SDXL-Lightning":
self.pipe = self.default_pipe
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!

# Load SDXL-Lightning to UNet
unet = self.default_pipe.unet
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))

# Change UNet to pipeline
self.pipe.unet = unet
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing")
pipe = self.sdxl_lightning_pipe
else:
try:
self.pipe = AutoPipelineForText2Image.from_pretrained(
model_name, torch_dtype=torch.float16, variant="fp16"
)
except:
print("fp16 not available.")
self.pipe = AutoPipelineForText2Image.from_pretrained(model_name)

self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
use_karras_sigmas=True,
algorithm_type="sde-dpmsolver++",
)
self.pipe = self.pipe.to("cuda")
pipe = self.sd15_pipe

if show_baseline:
if len(self.pipe.get_active_adapters()) > 0:
if len(pipe.get_active_adapters()) > 0:
print("Unloading LoRA weights...")
self.pipe.unload_lora_weights()
pipe.unload_lora_weights()

print("Generating images without res_adapter...")
baseline_image = self.pipe(
baseline_image = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
Expand All @@ -150,40 +161,43 @@ def predict(
baseline_path = "/tmp/baseline.png"
baseline_image.save(baseline_path)

if len(self.pipe.get_active_adapters()) == 0:
if len(pipe.get_active_adapters()) == 0:
if base_model == "sd1.5":
print("Loading Resolution LoRA weights...")
self.pipe.load_lora_weights(
pipe.load_lora_weights(
hf_hub_download(
repo_id="jiaxiangc/res-adapter",
subfolder=f"sd1.5",
subfolder="sd1.5",
filename="resolution_lora.safetensors",
),
adapter_name="res_adapter",
)
self.pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
print("Load Resolution Norm weights")
self.pipe.unet.load_state_dict(load_file(
hf_hub_download(
repo_id="jiaxiangc/res-adapter",
subfolder="sd1.5",
filename="resolution_normalization.safetensors"
pipe.unet.load_state_dict(
load_file(
hf_hub_download(
repo_id="jiaxiangc/res-adapter",
subfolder="sd1.5",
filename="resolution_normalization.safetensors",
)
),
), strict=False)
strict=False,
)
elif base_model == "sdxl":
print("Loading Resolution LoRA weights...")
self.pipe.load_lora_weights(
pipe.load_lora_weights(
hf_hub_download(
repo_id="jiaxiangc/res-adapter",
subfolder=f"sdxl-i",
subfolder="sdxl-i",
filename="resolution_lora.safetensors",
),
adapter_name="res_adapter",
)
self.pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])
pipe.set_adapters(["res_adapter"], adapter_weights=[1.0])

print("Generating images with res_adapter...")
image = self.pipe(
image = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
Expand Down

0 comments on commit 3701a18

Please sign in to comment.