Skip to content

Commit

Permalink
Remove refiner, fix watermark
Browse files Browse the repository at this point in the history
  • Loading branch information
lucataco committed Nov 8, 2023
1 parent d10411c commit 65fcacb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 164 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
__pycache__
.cog
model-cache
trained-model
trained_model.tar
212 changes: 49 additions & 163 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import shutil
import subprocess
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from cog import BasePredictor, Input, Path
Expand All @@ -22,24 +21,15 @@
StableDiffusionXLInpaintPipeline,
)
from diffusers.models.attention_processor import LoRAAttnProcessor2_0
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.utils import load_image
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPImageProcessor

import json
import requests
from io import BytesIO
import tarfile
import math
import torch
from PIL import Image
import shutil
import math


import tarfile
import requests
from PIL import Image
from io import BytesIO
from dataset_and_utils import TokenEmbeddingsHandler

MODEL_NAME = "segmind/SSD-1B"
Expand Down Expand Up @@ -100,7 +90,6 @@ def load_trained_weights(self, weights_url, pipe):
print("Loading Unet LoRA")

unet = pipe.unet

tensors = load_file(os.path.join(local_weights_cache, "lora.safetensors"))

unet = pipe.unet
Expand Down Expand Up @@ -150,7 +139,6 @@ def load_trained_weights(self, weights_url, pipe):
with open(os.path.join(local_weights_cache, "special_params.json"), "r") as f:
params = json.load(f)
self.token_map = params

self.tuned_model = True


Expand Down Expand Up @@ -186,7 +174,7 @@ def predict(
),
prompt: str = Input(
description="Input prompt",
default="An TOK riding a rainbow unicorn",
default="A photo of TOK",
),
negative_prompt: str = Input(
description="Input Negative Prompt",
Expand Down Expand Up @@ -234,21 +222,6 @@ def predict(
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
refine: str = Input(
description="Which refine style to use",
choices=["no_refiner", "expert_ensemble_refiner", "base_image_refiner"],
default="no_refiner",
),
high_noise_frac: float = Input(
description="For expert_ensemble_refiner, the fraction of noise to use",
default=0.8,
le=1.0,
ge=0.0,
),
refine_steps: int = Input(
description="For base_image_refiner, the number of steps to refine, defaults to num_inference_steps",
default=None,
),
apply_watermark: bool = Input(
description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.",
default=True,
Expand All @@ -262,112 +235,48 @@ def predict(
) -> List[Path]:
# Check if there is a lora_url
if lora_url == None:
raise Exception(
f"Missing Lora_url parameter"
)

lora = True
if lora == True :
self.is_lora = True
print("LORA")
print("Loading ssd txt2img pipeline...")
self.txt2img_pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
cache_dir=MODEL_CACHE,
)
print("Loading ssd lora weights...")
self.load_trained_weights(lora_url, self.txt2img_pipe)
self.txt2img_pipe.to("cuda")
self.is_lora = True

# print("Loading SDXL img2img pipeline...")
# self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
# vae=self.txt2img_pipe.vae,
# text_encoder=self.txt2img_pipe.text_encoder,
# text_encoder_2=self.txt2img_pipe.text_encoder_2,
# tokenizer=self.txt2img_pipe.tokenizer,
# tokenizer_2=self.txt2img_pipe.tokenizer_2,
# unet=self.txt2img_pipe.unet,
# scheduler=self.txt2img_pipe.scheduler,
# )
# self.img2img_pipe.to("cuda")

# print("Loading SDXL inpaint pipeline...")
# self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
# vae=self.txt2img_pipe.vae,
# text_encoder=self.txt2img_pipe.text_encoder,
# text_encoder_2=self.txt2img_pipe.text_encoder_2,
# tokenizer=self.txt2img_pipe.tokenizer,
# tokenizer_2=self.txt2img_pipe.tokenizer_2,
# unet=self.txt2img_pipe.unet,
# scheduler=self.txt2img_pipe.scheduler,
# )
# self.inpaint_pipe.to("cuda")

# print("Loading SDXL refiner pipeline...")

# print("Loading refiner pipeline...")
# self.refiner = DiffusionPipeline.from_pretrained(
# "refiner-cache",
# text_encoder_2=self.txt2img_pipe.text_encoder_2,
# vae=self.txt2img_pipe.vae,
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16",
# )
# self.refiner.to("cuda")


else :
print("Loading sdxl txt2img pipeline...")
self.txt2img_pipe = DiffusionPipeline.from_pretrained(
MODEL_CACHE,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.is_lora = False

self.txt2img_pipe.to("cuda")

print("Loading SDXL img2img pipeline...")
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
text_encoder_2=self.txt2img_pipe.text_encoder_2,
tokenizer=self.txt2img_pipe.tokenizer,
tokenizer_2=self.txt2img_pipe.tokenizer_2,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
)
self.img2img_pipe.to("cuda")

print("Loading SDXL inpaint pipeline...")
self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
text_encoder_2=self.txt2img_pipe.text_encoder_2,
tokenizer=self.txt2img_pipe.tokenizer,
tokenizer_2=self.txt2img_pipe.tokenizer_2,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
)
self.inpaint_pipe.to("cuda")
print("Loading refiner pipeline...")
self.refiner = DiffusionPipeline.from_pretrained(
"refiner-cache",
text_encoder_2=self.txt2img_pipe.text_encoder_2,
vae=self.txt2img_pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
raise Exception(f"Missing Lora_url parameter")

self.is_lora = True
print("LORA")
print("Loading ssd txt2img pipeline...")
self.txt2img_pipe = StableDiffusionXLPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
cache_dir=MODEL_CACHE,
)
print("Loading ssd lora weights...")
self.load_trained_weights(lora_url, self.txt2img_pipe)
self.txt2img_pipe.to("cuda")
self.is_lora = True

print("Loading SDXL img2img pipeline...")
self.img2img_pipe = StableDiffusionXLImg2ImgPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
text_encoder_2=self.txt2img_pipe.text_encoder_2,
tokenizer=self.txt2img_pipe.tokenizer,
tokenizer_2=self.txt2img_pipe.tokenizer_2,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
)
self.img2img_pipe.to("cuda")

print("Loading SDXL inpaint pipeline...")
self.inpaint_pipe = StableDiffusionXLInpaintPipeline(
vae=self.txt2img_pipe.vae,
text_encoder=self.txt2img_pipe.text_encoder,
text_encoder_2=self.txt2img_pipe.text_encoder_2,
tokenizer=self.txt2img_pipe.tokenizer,
tokenizer_2=self.txt2img_pipe.tokenizer_2,
unet=self.txt2img_pipe.unet,
scheduler=self.txt2img_pipe.scheduler,
)
self.inpaint_pipe.to("cuda")

print("Loading SDXL refiner pipeline...")

"""Run a single prediction on the model"""
if seed is None:
Expand Down Expand Up @@ -405,17 +314,9 @@ def predict(
sdxl_kwargs["height"] = height
pipe = self.txt2img_pipe

if refine == "expert_ensemble_refiner":
sdxl_kwargs["output_type"] = "latent"
sdxl_kwargs["denoising_end"] = high_noise_frac
elif refine == "base_image_refiner":
sdxl_kwargs["output_type"] = "latent"

# toggles watermark for this prediction
if not apply_watermark:
# toggles watermark for this prediction
watermark_cache = pipe.watermark
pipe.watermark = None
self.refiner.watermark = None

pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)
generator = torch.Generator("cuda").manual_seed(seed)
Expand All @@ -433,23 +334,8 @@ def predict(

output = pipe(**common_args, **sdxl_kwargs)

if refine in ["expert_ensemble_refiner", "base_image_refiner"]:
refiner_kwargs = {
"image": output.images,
}

if refine == "expert_ensemble_refiner":
refiner_kwargs["denoising_start"] = high_noise_frac
if refine == "base_image_refiner" and refine_steps:
common_args["num_inference_steps"] = refine_steps

output = self.refiner(**common_args, **refiner_kwargs)

output_paths = []
for i, _ in enumerate(output.images):
# if nsfw:
# print(f"NSFW content detected in image {i}")
# continue
output_path = f"/tmp/out-{i}.png"
output.images[i].save(output_path)
output_paths.append(Path(output_path))
Expand Down

0 comments on commit 65fcacb

Please sign in to comment.