diff --git a/README.md b/README.md index bace478..e922b12 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,11 @@ # UPDATES & Notes -### 2022/02/01 +### 2023/02/01 - LoRA Joining is now available with `--mode=ljl` flag. Only three parameters are required : `path_to_lora1`, `path_to_lora2`, and `path_to_save`. -### 2022/01/29 +### 2023/01/29 - Dataset pipelines - LoRA Applied to Resnet as well, use `--use_extended_lora` to use it. @@ -62,7 +62,7 @@ - Compvis format Conversion script now works with safetensors, and will for PTI it will return Textual inversion format as well, so you can use it in embeddings folder. - 🥳🥳, LoRA is now officially integrated into the amazing Huggingface 🤗 `diffusers` library! Check out the [Blog](https://huggingface.co/blog/lora) and [examples](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora)! (NOTE : It is CURRENTLY DIFFERENT FILE FORMAT) -### 2022/01/09 +### 2023/01/09 - Pivotal Tuning Inversion with extended latent - Better textual inversion with Norm prior diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 641eade..c1d1b39 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -423,6 +423,12 @@ def perform_tuning( lora_unet_target_modules, lora_clip_target_modules, mask_temperature, + out_name: str, + tokenizer, + test_image_path: str, + log_wandb: bool = False, + wandb_log_prompt_cnt: int = 10, + class_token: str = "person", ): progress_bar = tqdm(range(num_steps)) @@ -434,6 +440,11 @@ def perform_tuning( unet.train() text_encoder.train() + if log_wandb: + preped_clip = prepare_clip_model_sets() + + loss_sum = 0.0 + for epoch in range(math.ceil(num_steps / len(dataloader))): for batch in dataloader: lr_scheduler_lora.step() @@ -450,6 +461,8 @@ def perform_tuning( mixed_precision=True, mask_temperature=mask_temperature, ) + loss_sum += loss.detach().item() + loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 @@ -493,15 +506,59 @@ def perform_tuning( print("LORA CLIP Moved", moved) + if log_wandb: + with torch.no_grad(): + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + + # open all images in test_image_path + images = [] + for file in os.listdir(test_image_path): + if file.endswith(".png") or file.endswith(".jpg"): + images.append( + Image.open(os.path.join(test_image_path, file)) + ) + + wandb.log({"loss": loss_sum / save_steps}) + loss_sum = 0.0 + wandb.log( + evaluate_pipe( + pipe, + target_images=images, + class_token=class_token, + learnt_token="".join(placeholder_tokens), + n_test=wandb_log_prompt_cnt, + n_step=50, + clip_model_sets=preped_clip, + ) + ) + if global_step >= num_steps: - return + break + + save_all( + unet, + text_encoder, + placeholder_token_ids=placeholder_token_ids, + placeholder_tokens=placeholder_tokens, + save_path=os.path.join(save_path, f"{out_name}.safetensors"), + target_replace_module_text=lora_clip_target_modules, + target_replace_module_unet=lora_unet_target_modules, + ) def train( instance_data_dir: str, pretrained_model_name_or_path: str, output_dir: str, - train_text_encoder: bool = False, + train_text_encoder: bool = True, pretrained_vae_name_or_path: str = None, revision: Optional[str] = None, class_data_dir: Optional[str] = None, @@ -555,7 +612,9 @@ def train( wandb_log_prompt_cnt: int = 10, wandb_project_name: str = "new_pti_project", wandb_entity: str = "new_pti_entity", + proxy_token: str = "person", enable_xformers_memory_efficient_attention: bool = False, + out_name: str = "final_lora", ): torch.manual_seed(seed) @@ -566,7 +625,6 @@ def train( name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}", reinit=True, config={ - "lr": learning_rate_ti, **(extra_args if extra_args is not None else {}), }, ) @@ -594,6 +652,8 @@ def train( placeholder_tokens ), "Unequal Initializer token for Placeholder tokens." + if proxy_token is not None: + class_token = proxy_token class_token = "".join(initializer_tokens) if placeholder_token_at_data is not None: @@ -817,6 +877,12 @@ def train( lora_unet_target_modules=lora_unet_target_modules, lora_clip_target_modules=lora_clip_target_modules, mask_temperature=mask_temperature, + tokenizer=tokenizer, + out_name=out_name, + test_image_path=instance_data_dir, + log_wandb=log_wandb, + wandb_log_prompt_cnt=wandb_log_prompt_cnt, + class_token=class_token, ) diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 99a2217..a8e81b0 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -223,6 +223,7 @@ def __init__( transforms.ColorJitter(0.1, 0.1) if color_jitter else transforms.Lambda(lambda x: x), + transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] diff --git a/setup.py b/setup.py index 550a309..b27850c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.1.3", + version="0.1.4", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),