Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v 0.0.6 #54

Merged
merged 16 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras = None # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
Expand All @@ -42,6 +43,9 @@ def inject_trainable_lora(
require_grad_params = []
names = []

if loras != None:
loras = torch.load(loras)

for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:

Expand All @@ -62,18 +66,21 @@ def inject_trainable_lora(

# switch the module
_module._modules[name] = _tmp

require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)

if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)

_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)

return require_grad_params, names


Expand Down Expand Up @@ -245,3 +252,96 @@ def tune_lora_scale(model, alpha: float = 1.0):
for _module in model.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
_module.scale = alpha


def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def _ti_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["ti", "pt"])


def load_learned_embed_in_clip(
learned_embeds_path, text_encoder, tokenizer, token=None
):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")

# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]

# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype

# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
while num_added_tokens == 0:
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i += 1

# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))

# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token


def patch_pipe(
pipe,
unet_path,
token,
alpha: float = 1.0,
r: int = 4,
patch_text=False,
patch_ti=False,
):

ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)

unet_has_lora = False
text_encoder_has_lora = False

for _module in pipe.unet.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
unet_has_lora = True

for _module in pipe.text_encoder.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
text_encoder_has_lora = True

if not unet_has_lora:
monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
else:
monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)

if patch_text:
if not text_encoder_has_lora:
monkeypatch_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=["CLIPAttention"],
r=r,
)
else:

monkeypatch_replace_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=["CLIPAttention"],
r=r,
)
if patch_ti:
token = load_learned_embed_in_clip(
ti_path, pipe.text_encoder, pipe.tokenizer, token
)
25 changes: 25 additions & 0 deletions run_lorpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="./data_example_text"
export OUTPUT_DIR="./output_example_lorpt"

accelerate launch train_lora_w_ti.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-5 \
--learning_rate_text=1e-5 \
--learning_rate_ti=5e-4 \
--color_jitter \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=5000 \
--placeholder_token="<krk>" \
--learnable_property="object"\
--initializer_token="woman" \
--save_steps=500 \
--unfreeze_lora_step=1500 \
246 changes: 246 additions & 0 deletions scripts/run_lorpt.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.0.5",
version="0.0.6",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down
Loading