Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
natsunoyuki authored Dec 6, 2023
1 parent 22f0ee8 commit c756a2e
Showing 1 changed file with 74 additions and 28 deletions.
102 changes: 74 additions & 28 deletions src/diffuser_tools/text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,75 @@
import diffusers
import torch

# TODO

# %%
class Text2ImagePipe(object):
# %%
def __init__(
self,
model_dir,
prompt = None,
self,
model_dir,
prompt = None,
negative_prompt= None,
lora_dir = None,
scheduler = None,
lora_dir = None,
lora_dirs = [],
lora_scales = [],
clip_skip = 1,
safety_checker = None,
use_prompt_embeddings = True,
split_character = ",",
torch_dtype = torch.float32,
device = torch.device("cpu"),
):
"""
Text2Image stable diffusion pipeline capable of handling:
1. Prompt and negative prompt embeddings.
2. Loading LoRAs.
3. CLIP skips.
4. Safety checker.
Inputs:
model_dir: str
Path to the model checkpoint safetensors file.
prompt: str
Prompt.
negative_prompt: str
Negative prompt.
scheduler: str
Scheduler to use. Choose from: EADS, EDS or DPMSMS.
lora_dir: str
Path to a single LoRA safetensors file.
lora_dirs: list of str
Paths to multiple LoRA safetensors files.
lora_scales: list of floats
Corresponding scaling factors for the LoRAs in lora_dirs.
clip_skip: int
Number of CLIP layers to skip. 0 means no CLIP skipping.
safety_checker: None
Set to None to remove turn safety checker off.
Can also use customized safety checkers.
use_prompt_embeddings: bool
If True, prompt embeddings and negative prompt embeddings will be
used instead. Overcomes CLIP's 77 token limit.
split_character: str
Character used to split the prompt and negative prompt into tokens.
"," by default.
torch_dtype: torch.float32 or torch.float16.
Use torch.float32 if using torch.device("cpu"), and
use torch.float16 if using torch.device("cuda").
device: torch.device("cpu") or torch.device("cuda")
Use CUDA if you have access to a GPU! Makes life easier.
"""
# Hardware related parameters.
# These will be used directly internally.
self.torch_dtype = torch_dtype
self.device = device

# Diffusers pipeline.
self.pipe = None

# Load model weights.

# Load CivitAI model weights in the form of safetensors.
# TODO add support for other file types or for downloading models from HuggingFace.
self.pipe = diffusers.StableDiffusionPipeline.from_single_file(
model_dir,
torch_dtype = torch_dtype,
Expand All @@ -40,10 +82,14 @@ def __init__(
self.pipe.safety_checker = safety_checker

# Load LoRA weights.
# TODO add lora fusing.
# https://huggingface.co/docs/diffusers/using-diffusers/loading_adapters#lora
if lora_dir is not None:
self.pipe.load_lora_weights(lora_dir)
# TODO clean this up!
if len(lora_dirs) == 0:
if lora_dir is not None:
self.pipe.load_lora_weights(lora_dir)
else:
for ldir, lsc in zip(lora_dirs, lora_scales):
self.pipe.load_lora_weights(ldir)
self.pipe.fuse_lora(lora_scale = lsc)

# CLIP skip.
clip_layers = self.pipe.text_encoder.text_model.encoder.layers
Expand Down Expand Up @@ -73,16 +119,16 @@ def __init__(
self.prompt_embeddings = None
self.negative_prompt_embeddings = None
self.set_prompts(
prompt,
negative_prompt,
use_prompt_embeddings,
prompt,
negative_prompt,
use_prompt_embeddings,
split_character
)

# %%
def set_prompts(
self,
prompt = None,
self,
prompt = None,
negative_prompt = None,
use_prompt_embeddings = True,
split_character = ",",
Expand All @@ -103,16 +149,16 @@ def set_prompts(

# %%
def get_prompt_embeddings(
self,
split_character = ",",
self,
split_character = ",",
return_embeddings = False,
):
"""Prompt embeddings to overcome CLIP 77 token limit.
https://github.com/huggingface/diffusers/issues/2136
"""

max_length = self.pipe.tokenizer.model_max_length

# Simple method of checking if the prompt is longer than the negative
# prompt - split the input strings using `split_character`.
count_prompt = len(self.prompt.split(split_character))
Expand All @@ -125,9 +171,9 @@ def get_prompt_embeddings(
).input_ids.to(self.device)
shape_max_length = input_ids.shape[-1]
negative_ids = self.pipe.tokenizer(
self.negative_prompt,
truncation = False,
padding = "max_length",
self.negative_prompt,
truncation = False,
padding = "max_length",
max_length = shape_max_length,
return_tensors = "pt",
).input_ids.to(self.device)
Expand Down Expand Up @@ -155,7 +201,7 @@ def get_prompt_embeddings(
neg_embeds.append(
self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0]
)

self.prompt_embeddings = torch.cat(concat_embeds, dim = 1)
self.negative_prompt_embeddings = torch.cat(neg_embeds, dim = 1)

Expand All @@ -166,8 +212,8 @@ def get_prompt_embeddings(
def run_pipe(
self,
steps = 50,
width = 512,
height = 768,
width = 512,
height = 768,
scale = 7.0,
seed = 0,
use_prompt_embeddings = False,
Expand All @@ -176,11 +222,11 @@ def run_pipe(
"""Runs the loaded model.
"""
if self.prompt is None and self.prompt_embeddings is None:
return
return

if self.pipe is None:
return

start_time = time.time()

if use_prompt_embeddings is True:
Expand Down

0 comments on commit c756a2e

Please sign in to comment.