Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if dreambooth lora #3360

Merged
merged 15 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
191 changes: 159 additions & 32 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and

import argparse
import gc
import hashlib
import itertools
import logging
Expand All @@ -30,7 +31,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub import create_repo, model_info, upload_folder
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
Expand All @@ -48,7 +49,13 @@
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -108,6 +115,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

return RobertaSeriesModelWithTransformation
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel

return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")

Expand Down Expand Up @@ -387,6 +398,11 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--pre_compute_text_embeddings",
williamberman marked this conversation as resolved.
Show resolved Hide resolved
action="store_true",
help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -428,10 +444,12 @@ def __init__(
class_num=None,
size=512,
center_crop=False,
encoder_hidden_states=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
Expand Down Expand Up @@ -473,13 +491,17 @@ def __getitem__(self, index):
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids

if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
else:
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
williamberman marked this conversation as resolved.
Show resolved Hide resolved
return_tensors="pt",
).input_ids

if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
Expand Down Expand Up @@ -536,6 +558,16 @@ def __getitem__(self, index):
return example


def model_has_vae(args):
williamberman marked this conversation as resolved.
Show resolved Hide resolved
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
if os.path.isdir(args.pretrained_model_name_or_path):
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
return os.path.isfile(config_file_name)
else:
files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
return any(file.rfilename == config_file_name for file in files_in_repo)


def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)

Expand Down Expand Up @@ -656,13 +688,20 @@ def main(args):
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
else:
vae = None
williamberman marked this conversation as resolved.
Show resolved Hide resolved

unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# We only train the additional adapter LoRA layers
vae.requires_grad_(False)
if vae is not None:
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

Expand All @@ -676,7 +715,8 @@ def main(args):

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if vae is not None:
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

if args.enable_xformers_memory_efficient_attention:
Expand Down Expand Up @@ -707,7 +747,7 @@ def main(args):

# Set correct lora layers
unet_lora_attn_procs = {}
for name in unet.attn_processors.keys():
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
Expand All @@ -718,7 +758,12 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

unet_lora_attn_procs[name] = LoRAAttnProcessor(
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_processor_class = LoRAAttnProcessor

unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)

Expand Down Expand Up @@ -783,6 +828,44 @@ def main(args):
eps=args.adam_epsilon,
)

if args.pre_compute_text_embeddings:
williamberman marked this conversation as resolved.
Show resolved Hide resolved

def compute_text_embeddings(prompt):
with torch.no_grad():
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)

text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.to(text_encoder.device)

prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]

return prompt_embeds

pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
validation_prompt_negative_prompt_embeds = compute_text_embeddings("")

text_encoder = None
tokenizer = None

gc.collect()
torch.cuda.empty_cache()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
validation_prompt_negative_prompt_embeds = None

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
Expand All @@ -793,6 +876,7 @@ def main(args):
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
williamberman marked this conversation as resolved.
Show resolved Hide resolved
)

train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -896,32 +980,48 @@ def main(args):
continue

with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

if vae is not None:
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
williamberman marked this conversation as resolved.
Show resolved Hide resolved
else:
model_input = pixel_values

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
if args.pre_compute_text_embeddings:
encoder_hidden_states = batch["input_ids"]
else:
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample

# if model predicts variance, throw away the prediction. we will only train on the
williamberman marked this conversation as resolved.
Show resolved Hide resolved
# simplified training objective. This means that all schedulers using the fine tuned
# model must be configured to use one of the fixed variance variance types.
if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

Expand Down Expand Up @@ -988,19 +1088,35 @@ def main(args):
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, variance_type=variance_type
)

pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
if args.pre_compute_text_embeddings:
pipeline_args = {
"prompt_embeds": validation_prompt_encoder_hidden_states,
"negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
}
else:
pipeline_args = {"prompt": args.validation_prompt}
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
]

for tracker in accelerator.trackers:
Expand All @@ -1024,7 +1140,8 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
text_encoder = text_encoder.to(torch.float32)
if text_encoder is not None:
text_encoder = text_encoder.to(torch.float32)
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
Expand All @@ -1036,7 +1153,17 @@ def main(args):
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, variance_type=variance_type
)

pipeline = pipeline.to(accelerator.device)

# load attention processors
Expand Down
20 changes: 18 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from huggingface_hub import hf_hub_download

from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
SlicedAttnAddedKVProcessor,
)
from .utils import (
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -250,10 +254,22 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRAAttnProcessor(
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)

if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnProcessor

attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
Expand Down