From 269ccf8803bc16814299cfb51d64c9ce6e49b3e9 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 15 Sep 2023 12:24:16 +0200 Subject: [PATCH 01/63] initial script --- examples/wuerstchen/text_to_image/README.md | 28 + .../wuerstchen/text_to_image/requirements.txt | 5 + .../text_to_image/train_text_to_image.py | 509 ++++++++++++++++++ 3 files changed, 542 insertions(+) create mode 100644 examples/wuerstchen/text_to_image/README.md create mode 100644 examples/wuerstchen/text_to_image/requirements.txt create mode 100644 examples/wuerstchen/text_to_image/train_text_to_image.py diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md new file mode 100644 index 000000000000..177eea4e7662 --- /dev/null +++ b/examples/wuerstchen/text_to_image/README.md @@ -0,0 +1,28 @@ +# Würstchen text-to-image fine-tuning + +## Running locally with PyTorch + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +cd examples/wuerstchen/text_to_image +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag. + diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/wuerstchen/text_to_image/requirements.txt new file mode 100644 index 000000000000..55f8f9f25e21 --- /dev/null +++ b/examples/wuerstchen/text_to_image/requirements.txt @@ -0,0 +1,5 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +wandb +huggingface-cli diff --git a/examples/wuerstchen/text_to_image/train_text_to_image.py b/examples/wuerstchen/text_to_image/train_text_to_image.py new file mode 100644 index 000000000000..ed1957fbdea2 --- /dev/null +++ b/examples/wuerstchen/text_to_image/train_text_to_image.py @@ -0,0 +1,509 @@ +import os +import time +import torch +import torchvision +from torch import nn, optim +from torch.utils.data import DataLoader +from warmup_scheduler import GradualWarmupScheduler +from tqdm import tqdm +import numpy as np +import wandb +import shutil +from transformers import AutoTokenizer, CLIPTextModel +import webdataset as wds +from webdataset.handlers import warn_and_continue +from torch.distributed import init_process_group, destroy_process_group +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.multiprocessing as mp +from torchtools.utils import Diffuzz +from diffnext_v2 import Prior +from diffnext_v2_ldm import DiffNeXt, EfficientNetEncoder +from vqgan import VQModel +from utils import WebdatasetFilter +import transformers + +transformers.utils.logging.set_verbosity_error() + +# PARAMETERS +updates = 1000000 +warmup_updates = 10000 +ema_start = 5000 +ema_every = 100 +ema_beta = 0.9 +# batch_size = 20 * 8 * 8 # 2048 20 * 64 +batch_size = 20 * 8 * 8 # 2048 20 * 64 +grad_accum_steps = 1 +max_iters = updates * grad_accum_steps +print_every = 2000 * grad_accum_steps +extra_ckpt_every = 50000 * grad_accum_steps +lr = 1e-4 # 1e-4 +generate_new_wandb_id = False +consistency_weight = 0.1 + +dataset_path = "pipe:aws s3 cp s3://stability-west/laion-a-native-high-res/{part-0/{00000..18000}.tar,part-1/{00000..13500}.tar,part-2/{00000..13500}.tar,part-3/{00000..13500}.tar,part-4/{00000..14100}.tar} -" +run_name = "Würstchen-Prior-LDM-Consistency-Scale-EffNet-CLIP-G" +dist_file = "dist_file8" +output_path = f"results/{run_name}" +os.makedirs(output_path, exist_ok=True) +checkpoint_dir = "models" +checkpoint_path = os.path.join(checkpoint_dir, run_name, "model.pt") +os.makedirs(os.path.join(checkpoint_dir, run_name), exist_ok=True) + +wandv_project = "Paella DiffNeXt" +wandv_entity = "babbleberns" +wandb_run_name = run_name + +transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(512), + torchvision.transforms.RandomCrop(512), + ] +) + +effnet_preprocess = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize( + 384, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + antialias=True, + ), + torchvision.transforms.CenterCrop(384), + torchvision.transforms.Normalize( + mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) + ), + ] +) + + +def identity(x): + return x + + +def ddp_setup(rank, world_size, n_node, node_id): # <--- DDP + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "33751" + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + rank=rank + node_id * world_size, + world_size=world_size * n_node, + init_method="dist_file", + ) + print(f"[GPU {rank+node_id*world_size}] READY") + + +def train(gpu_id, world_size, n_nodes): + node_id = int(os.environ["SLURM_PROCID"]) + main_node = gpu_id == 0 and node_id == 0 + ddp_setup(gpu_id, world_size, n_nodes, node_id) # <--- DDP + device = torch.device(gpu_id) + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # --- PREPARE DATASET --- + dataset = ( + wds.WebDataset(dataset_path, resampled=True, handler=warn_and_continue) + .select( + WebdatasetFilter( + min_size=512, + max_pwatermark=0.5, + aesthetic_threshold=5.1, + unsafe_threshold=0.99, + ) + ) + .shuffle(44, handler=warn_and_continue) + .decode("pilrgb", handler=warn_and_continue) + .to_tuple("jpg", "txt", handler=warn_and_continue) + .map_tuple(transforms, identity, handler=warn_and_continue) + ) + + real_batch_size = batch_size // (world_size * n_nodes * grad_accum_steps) + dataloader = DataLoader( + dataset, batch_size=real_batch_size, num_workers=8, pin_memory=False + ) + + if main_node: + print("REAL BATCH SIZE / DEVICE:", real_batch_size) + + # - EfficientNet - + pretrained_checkpoint = torch.load( + "models/text2img_wurstchen_b_v1_457k.pt", map_location=device + ) + + effnet = EfficientNetEncoder().to(device) + effnet.load_state_dict(pretrained_checkpoint["effnet_state_dict"]) + effnet.eval().requires_grad_(False) + + # # - vqmodel - + # if main_node: cp /fsx/home-pablo/models/risotto/text2img_wurstchen_b_v1_457k.pt ../models + # vqmodel = VQModel().to(device) + # vqmodel.load_state_dict(torch.load(f"models/vqgan_f4_v1_500k.pt", map_location=device)['state_dict']) + # vqmodel.eval().requires_grad_(False) + + # # - LDM Model as generator - + # generator = DiffNeXt().to(device) + # generator.load_state_dict(pretrained_checkpoint['state_dict']) + # generator.eval().requires_grad_(False).to(torch.bfloat16) + + del pretrained_checkpoint + torch.cuda.empty_cache() + + # --- PREPARE MODELS --- + try: + checkpoint = ( + torch.load(checkpoint_path, map_location=device) + if os.path.exists(checkpoint_path) + else None + ) + except RuntimeError as e: + if os.path.exists(f"{checkpoint_path}.bak"): + os.remove(checkpoint_path) + shutil.copyfile(f"{checkpoint_path}.bak", checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + raise e + + diffuzz = Diffuzz(device=device) + + # - CLIP text encoder + clip_model = ( + CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + .to(device) + .eval() + .requires_grad_(False) + ) + clip_tokenizer = AutoTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + ) + + # - Diffusive Imagination Combinatrainer, a.k.a. Risotto - + model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) + # model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=42, nhead=24).to(device) + if checkpoint is not None: + model.load_state_dict(checkpoint["state_dict"]) + + if main_node: # <--- DDP + model_ema = ( + Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24) + .eval() + .requires_grad_(False) + ) + # model_ema = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=42, nhead=24).eval().requires_grad_(False) + + # load checkpoints & prepare ddp + if checkpoint is not None: + if main_node: # <--- DDP + if "ema_state_dict" in checkpoint: + model_ema.load_state_dict(checkpoint["ema_state_dict"]) + else: + model_ema.load_state_dict(model.state_dict()) + + # - SETUP WANDB - + if main_node: # <--- DDP + if checkpoint is not None and not generate_new_wandb_id: + run_id = checkpoint["wandb_run_id"] + else: + run_id = wandb.util.generate_id() + wandb.init( + project=wandv_project, + name=wandb_run_name, + entity=wandv_entity, + id=run_id, + resume="allow", + ) + + model = DDP(model, device_ids=[gpu_id], output_device=device) # <--- DDP + + if main_node: # <--- DDP + print( + "Num trainable params:", + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) + + # SETUP OPTIMIZER, SCHEDULER & CRITERION + optimizer = optim.AdamW(model.parameters(), lr=lr) # eps=1e-4 + # optimizer = StableAdamW(model.parameters(), lr=lr) # eps=1e-4 + # optimizer = Lion(model.parameters(), lr=lr / 3) # eps=1e-4 + scheduler = GradualWarmupScheduler( + optimizer, multiplier=1, total_epoch=warmup_updates + ) + if checkpoint is not None: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + scheduler.last_epoch = checkpoint["scheduler_last_step"] + + start_iter = 1 + grad_norm = torch.tensor(0, device=device) + if checkpoint is not None: + start_iter = checkpoint["scheduler_last_step"] * grad_accum_steps + 1 + if main_node: # <--- DDP + print("RESUMING TRAINING FROM ITER ", start_iter) + + loss_adjusted = 0.0 + ema_loss = None + if checkpoint is not None: + ema_loss = checkpoint["metrics"]["ema_loss"] + + if checkpoint is not None: + del checkpoint # cleanup memory + torch.cuda.empty_cache() + + # -------------- START TRAINING -------------- + if main_node: + print("Everything prepared, starting training now....") + dataloader_iterator = iter(dataloader) + pbar = ( + tqdm(range(start_iter, max_iters + 1)) + if (main_node) + else range(start_iter, max_iters + 1) + ) # <--- DDP + model.train() + for it in pbar: + bls = time.time() + images, captions = next(dataloader_iterator) + ble = time.time() - bls + images = images.to(device) + + with torch.no_grad(): + effnet_features = effnet(effnet_preprocess(images)) + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if ( + np.random.rand() < 0.05 + ): # 90% of the time, drop the CLIP text embeddings (indepentently) + clip_captions = [""] * len( + captions + ) # 5% of the time drop all the captions + else: + clip_captions = captions + clip_tokens = clip_tokenizer( + clip_captions, + truncation=True, + padding="max_length", + max_length=clip_tokenizer.model_max_length, + return_tensors="pt", + ).to(device) + clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state + + t = ( + (1 - torch.rand(images.size(0), device=device)) + .mul(1.08) + .add(0.001) + .clamp(0.001, 1.0) + ) + noised_embeddings, noise = diffuzz.diffuse(effnet_features, t) + + t_consistency = t - t * (1 - torch.rand(images.size(0), device=device)).mul( + 1.08 + ).add(0.001).clamp(0.001, 1.0) + noised_embeddings_consistency, _ = diffuzz.diffuse( + effnet_features, t_consistency, noise=noise + ) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + pred_noise = model(noised_embeddings, t, clip_text_embeddings) + + model.eval().requires_grad_(False) + with torch.no_grad(): + with model.no_sync(): + pred_noise_consistency = model.module( + noised_embeddings_consistency, + t_consistency, + clip_text_embeddings, + ) + model.train().requires_grad_(True) + + loss = nn.functional.mse_loss(pred_noise, noise, reduction="none").mean( + dim=[1, 2, 3] + ) + consistency_loss = nn.functional.mse_loss( + pred_noise, pred_noise_consistency, reduction="none" + ).mean(dim=[1, 2, 3]) + loss_adjusted = ( + (loss + consistency_loss * consistency_weight) * diffuzz.p2_weight(t) + ).mean() / grad_accum_steps + + if it % grad_accum_steps == 0 or it == max_iters: + loss_adjusted.backward() + grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad(set_to_none=True) + if main_node and (it % ema_every == 0 or it == max_iters): + if it < ema_start: + model_ema.load_state_dict(model.module.state_dict()) + else: + model_ema.update_weights_ema(model.module, beta=ema_beta) + else: + with model.no_sync(): + loss_adjusted.backward() + + ema_loss = ( + loss.mean().item() + if ema_loss is None + else ema_loss * 0.99 + loss.mean().item() * 0.01 + ) + + if main_node: + pbar.set_postfix( + { + "bs": images.size(0), + "loss": loss.mean().item(), + "c_loss": consistency_loss.mean().item(), + "loss_adjusted": loss_adjusted.item(), + "ema_loss": ema_loss, + "grad_norm": grad_norm.item(), + "lr": optimizer.param_groups[0]["lr"], + "total_steps": scheduler.last_epoch, + } + ) + + # ble = torch.Tensor([ble]).to(device) + # gathered_values = [torch.zeros(1, dtype=torch.float32).to(device) for _ in range(8)] + # dist.all_gather(gathered_values, ble) + if main_node: + # ble_dict = {f'batch_loading_{i}': b[0].item() for i, b in enumerate(gathered_values)} + wandb.log( + { + "loss": loss.mean().item(), + "c_loss": consistency_loss.mean().item(), + "loss_adjusted": loss_adjusted.item(), + "ema_loss": ema_loss, + "grad_norm": grad_norm.item(), + "lr": optimizer.param_groups[0]["lr"], + "total_steps": scheduler.last_epoch, + } + ) + + if main_node and ( + it == 1 or it % print_every == 0 or it == max_iters + ): # <--- DDP + tqdm.write(f"ITER {it}/{max_iters} - loss {ema_loss}") + + try: + os.remove(f"{checkpoint_path}.bak") + except OSError: + pass + + try: + os.rename(checkpoint_path, f"{checkpoint_path}.bak") + except OSError: + pass + + if it % extra_ckpt_every == 0: + torch.save( + { + "state_dict": model.module.state_dict(), + "ema_state_dict": model_ema.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_last_step": scheduler.last_epoch, + "iter": it, + "metrics": { + "ema_loss": ema_loss, + }, + "wandb_run_id": run_id, + }, + os.path.join(checkpoint_dir, run_name, f"model_{it}.pt"), + ) + + torch.save( + { + "state_dict": model.module.state_dict(), + "ema_state_dict": model_ema.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_last_step": scheduler.last_epoch, + "iter": it, + "metrics": { + "ema_loss": ema_loss, + }, + "wandb_run_id": run_id, + }, + checkpoint_path, + ) + + # model.eval() + # images, captions = next(dataloader_iterator) + # # while images.size(0) < 8: + # # _images, _captions = next(dataloader_iterator) + # # images = torch.cat([images, _images], dim=0) + # # captions += _captions + # images, captions = images.to(device), captions + # images = images[:8] + # captions = captions[:8] + + # prior_steps = 60 + # prior_cfg = 6 + # prior_sampler = "ddpm" + + # generator_steps = 12 + # generator_cfg = 2.0 + # generator_sampler = "ddpm" + # generator_latent_shape = (batch_size, 4, 128, 128) + + # with torch.cuda.amp.autocast(dtype=torch.bfloat16), torch.no_grad(): + # clip_tokens = clip_tokenizer(captions, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) + # clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state + + # clip_tokens_uncond = clip_tokenizer([''] * len(captions), truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) + # clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state + + # t = (1-torch.rand(images.size(0), device=device)).add(0.001).clamp(0.001, 1.0) + # effnet_features = effnet(effnet_preprocess(images)) + # effnet_embeddings_uncond = torch.zeros_like(effnet_features) + # noised_embeddings, noise = diffuzz.diffuse(effnet_features, t) + + # pred_noise = model(noised_embeddings, t, clip_text_embeddings) + # pred = diffuzz.undiffuse(noised_embeddings, t, torch.zeros_like(t), pred_noise) + # sampled = diffuzz.sample(model.module, {'c': clip_text_embeddings}, unconditional_inputs={"c": clip_text_embeddings_uncond}, + # shape=effnet_features.shape, timesteps=prior_steps, cfg=prior_cfg, sampler=prior_sampler)[-1] + # # sampled_ema = diffuzz.sample(model_ema, {'c': clip_text_embeddings}, unconditional_inputs={"c": clip_text_embeddings_uncond}, + # # shape=effnet_features.shape, timesteps=prior_steps, cfg=prior_cfg, sampler=prior_sampler)[-1] + + # sampled_images = diffuzz.sample(generator, {'effnet': sampled, 'clip': clip_text_embeddings}, + # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, + # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] + # # sampled_images_ema = diffuzz.sample(generator, {'effnet': sampled_ema, 'clip': clip_text_embeddings}, + # # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, + # # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] + # sampled_images_original = diffuzz.sample(generator, {'effnet': effnet_features, 'clip': clip_text_embeddings}, + # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, + # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] + # sampled_pred = diffuzz.sample(generator, {'effnet': pred, 'clip': clip_text_embeddings}, + # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, + # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] + # sampled_noised = diffuzz.sample(generator, {'effnet': noised_embeddings, 'clip': clip_text_embeddings}, + # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, + # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] + + # noised_images = vqmodel.decode(sampled_noised) + # pred_images = vqmodel.decode(sampled_pred) + # sampled_images_original = vqmodel.decode(sampled_images_original) + # sampled_images = vqmodel.decode(sampled_images) + # # sampled_images_ema = vqmodel.decode(sampled_images_ema) + # model.train() + + # torchvision.utils.save_image(torch.cat([ + # torch.cat([i for i in images.cpu()], dim=-1), + # torch.cat([i for i in noised_images.cpu()], dim=-1), + # torch.cat([i for i in pred_images.cpu()], dim=-1), + # torch.cat([i for i in sampled_images.cpu()], dim=-1), + # # torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), + # torch.cat([i for i in sampled_images_original.cpu()], dim=-1), + # ], dim=-2), f'{output_path}/{it:06d}.jpg') + + # # log_data = [ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [wandb.Image(sampled_images_original[i])] + [wandb.Image(images[i])] for i in range(len(images))] + # log_data = [ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_original[i])] + [wandb.Image(images[i])] for i in range(len(images))] + # log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Sampled Original", "Orig"]) + # wandb.log({"Log": log_table}) + # # torch.cuda.empty_cache() + # del clip_tokens, clip_text_embeddings, clip_tokens_uncond, clip_text_embeddings_uncond, t, effnet_features, effnet_embeddings_uncond + # del noised_embeddings, noise, pred_noise, pred, sampled, sampled_ema, sampled_images, sampled_images_ema, sampled_images_original + # del sampled_pred, sampled_noised, noised_images, pred_images, log_data, log_table + + destroy_process_group() # <--- DDP + + +if __name__ == "__main__": + world_size = torch.cuda.device_count() + n_node = 8 # [3,8,11,14-15,17-18,49] + mp.spawn(train, args=(world_size, n_node), nprocs=world_size) # <--- DDP ;) From 67d734d80711086d2e3e885d7dd950642597ef0d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 15 Sep 2023 12:27:42 +0200 Subject: [PATCH 02/63] formatting --- .../text_to_image/train_text_to_image.py | 107 ++++++------------ 1 file changed, 34 insertions(+), 73 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image.py b/examples/wuerstchen/text_to_image/train_text_to_image.py index ed1957fbdea2..0ba1abda08e6 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image.py @@ -1,26 +1,28 @@ import os +import shutil import time + +import numpy as np import torch +import torch.multiprocessing as mp import torchvision +import transformers +import wandb +import webdataset as wds +from diffnext_v2 import Prior +from diffnext_v2_ldm import EfficientNetEncoder from torch import nn, optim +from torch.distributed import destroy_process_group, init_process_group +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader -from warmup_scheduler import GradualWarmupScheduler +from torchtools.utils import Diffuzz from tqdm import tqdm -import numpy as np -import wandb -import shutil from transformers import AutoTokenizer, CLIPTextModel -import webdataset as wds +from warmup_scheduler import GradualWarmupScheduler from webdataset.handlers import warn_and_continue -from torch.distributed import init_process_group, destroy_process_group -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.multiprocessing as mp -from torchtools.utils import Diffuzz -from diffnext_v2 import Prior -from diffnext_v2_ldm import DiffNeXt, EfficientNetEncoder -from vqgan import VQModel + from utils import WebdatasetFilter -import transformers + transformers.utils.logging.set_verbosity_error() @@ -69,9 +71,7 @@ antialias=True, ), torchvision.transforms.CenterCrop(384), - torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ), + torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) @@ -120,17 +120,13 @@ def train(gpu_id, world_size, n_nodes): ) real_batch_size = batch_size // (world_size * n_nodes * grad_accum_steps) - dataloader = DataLoader( - dataset, batch_size=real_batch_size, num_workers=8, pin_memory=False - ) + dataloader = DataLoader(dataset, batch_size=real_batch_size, num_workers=8, pin_memory=False) if main_node: print("REAL BATCH SIZE / DEVICE:", real_batch_size) # - EfficientNet - - pretrained_checkpoint = torch.load( - "models/text2img_wurstchen_b_v1_457k.pt", map_location=device - ) + pretrained_checkpoint = torch.load("models/text2img_wurstchen_b_v1_457k.pt", map_location=device) effnet = EfficientNetEncoder().to(device) effnet.load_state_dict(pretrained_checkpoint["effnet_state_dict"]) @@ -152,11 +148,7 @@ def train(gpu_id, world_size, n_nodes): # --- PREPARE MODELS --- try: - checkpoint = ( - torch.load(checkpoint_path, map_location=device) - if os.path.exists(checkpoint_path) - else None - ) + checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None except RuntimeError as e: if os.path.exists(f"{checkpoint_path}.bak"): os.remove(checkpoint_path) @@ -174,9 +166,7 @@ def train(gpu_id, world_size, n_nodes): .eval() .requires_grad_(False) ) - clip_tokenizer = AutoTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - ) + clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # - Diffusive Imagination Combinatrainer, a.k.a. Risotto - model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) @@ -185,11 +175,7 @@ def train(gpu_id, world_size, n_nodes): model.load_state_dict(checkpoint["state_dict"]) if main_node: # <--- DDP - model_ema = ( - Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24) - .eval() - .requires_grad_(False) - ) + model_ema = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).eval().requires_grad_(False) # model_ema = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=42, nhead=24).eval().requires_grad_(False) # load checkpoints & prepare ddp @@ -226,9 +212,7 @@ def train(gpu_id, world_size, n_nodes): optimizer = optim.AdamW(model.parameters(), lr=lr) # eps=1e-4 # optimizer = StableAdamW(model.parameters(), lr=lr) # eps=1e-4 # optimizer = Lion(model.parameters(), lr=lr / 3) # eps=1e-4 - scheduler = GradualWarmupScheduler( - optimizer, multiplier=1, total_epoch=warmup_updates - ) + scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates) if checkpoint is not None: optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler.last_epoch = checkpoint["scheduler_last_step"] @@ -253,27 +237,19 @@ def train(gpu_id, world_size, n_nodes): if main_node: print("Everything prepared, starting training now....") dataloader_iterator = iter(dataloader) - pbar = ( - tqdm(range(start_iter, max_iters + 1)) - if (main_node) - else range(start_iter, max_iters + 1) - ) # <--- DDP + pbar = tqdm(range(start_iter, max_iters + 1)) if (main_node) else range(start_iter, max_iters + 1) # <--- DDP model.train() for it in pbar: bls = time.time() images, captions = next(dataloader_iterator) - ble = time.time() - bls + time.time() - bls images = images.to(device) with torch.no_grad(): effnet_features = effnet(effnet_preprocess(images)) with torch.cuda.amp.autocast(dtype=torch.bfloat16): - if ( - np.random.rand() < 0.05 - ): # 90% of the time, drop the CLIP text embeddings (indepentently) - clip_captions = [""] * len( - captions - ) # 5% of the time drop all the captions + if np.random.rand() < 0.05: # 90% of the time, drop the CLIP text embeddings (indepentently) + clip_captions = [""] * len(captions) # 5% of the time drop all the captions else: clip_captions = captions clip_tokens = clip_tokenizer( @@ -285,20 +261,13 @@ def train(gpu_id, world_size, n_nodes): ).to(device) clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - t = ( - (1 - torch.rand(images.size(0), device=device)) - .mul(1.08) - .add(0.001) - .clamp(0.001, 1.0) - ) + t = (1 - torch.rand(images.size(0), device=device)).mul(1.08).add(0.001).clamp(0.001, 1.0) noised_embeddings, noise = diffuzz.diffuse(effnet_features, t) - t_consistency = t - t * (1 - torch.rand(images.size(0), device=device)).mul( - 1.08 - ).add(0.001).clamp(0.001, 1.0) - noised_embeddings_consistency, _ = diffuzz.diffuse( - effnet_features, t_consistency, noise=noise + t_consistency = t - t * (1 - torch.rand(images.size(0), device=device)).mul(1.08).add(0.001).clamp( + 0.001, 1.0 ) + noised_embeddings_consistency, _ = diffuzz.diffuse(effnet_features, t_consistency, noise=noise) with torch.cuda.amp.autocast(dtype=torch.bfloat16): pred_noise = model(noised_embeddings, t, clip_text_embeddings) @@ -313,12 +282,10 @@ def train(gpu_id, world_size, n_nodes): ) model.train().requires_grad_(True) - loss = nn.functional.mse_loss(pred_noise, noise, reduction="none").mean( + loss = nn.functional.mse_loss(pred_noise, noise, reduction="none").mean(dim=[1, 2, 3]) + consistency_loss = nn.functional.mse_loss(pred_noise, pred_noise_consistency, reduction="none").mean( dim=[1, 2, 3] ) - consistency_loss = nn.functional.mse_loss( - pred_noise, pred_noise_consistency, reduction="none" - ).mean(dim=[1, 2, 3]) loss_adjusted = ( (loss + consistency_loss * consistency_weight) * diffuzz.p2_weight(t) ).mean() / grad_accum_steps @@ -338,11 +305,7 @@ def train(gpu_id, world_size, n_nodes): with model.no_sync(): loss_adjusted.backward() - ema_loss = ( - loss.mean().item() - if ema_loss is None - else ema_loss * 0.99 + loss.mean().item() * 0.01 - ) + ema_loss = loss.mean().item() if ema_loss is None else ema_loss * 0.99 + loss.mean().item() * 0.01 if main_node: pbar.set_postfix( @@ -375,9 +338,7 @@ def train(gpu_id, world_size, n_nodes): } ) - if main_node and ( - it == 1 or it % print_every == 0 or it == max_iters - ): # <--- DDP + if main_node and (it == 1 or it % print_every == 0 or it == max_iters): # <--- DDP tqdm.write(f"ITER {it}/{max_iters} - loss {ema_loss}") try: From 3c7ac6ffe85b5f2172fc139764bf0bbcd0fa3894 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 18 Sep 2023 11:40:05 +0200 Subject: [PATCH 03/63] prior trainer wip --- .../train_text_to_image_prior.py | 343 ++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 examples/wuerstchen/text_to_image/train_text_to_image_prior.py diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py new file mode 100644 index 000000000000..c28dce395550 --- /dev/null +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -0,0 +1,343 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import os +from pathlib import Path + +import datasets +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo + +from diffusers import DDPMWuerstchenScheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.logging import set_verbosity_info, set_verbosity_error + + +if is_wandb_available(): + pass + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.21.0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.") + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="warp-ai/wuerstchen", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="warp-ai/wuerstchen-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="kandi_2_2-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="learning rate", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, effnet, tokenizer and models. + noise_scheduler = DDPMWuerstchenScheduler() + tokenizer = + +if __name__ == "__main__": + main() From b41282801dfb831fc771a380c96d68824a3db0c2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 18 Sep 2023 15:55:29 +0200 Subject: [PATCH 04/63] add efficient_net_encoder --- examples/wuerstchen/text_to_image/__init__.py | 0 .../modeling_efficient_net_encoder.py | 38 +++++++++++++++++++ .../train_text_to_image_prior.py | 28 +++++++++++++- 3 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 examples/wuerstchen/text_to_image/__init__.py create mode 100644 examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py diff --git a/examples/wuerstchen/text_to_image/__init__.py b/examples/wuerstchen/text_to_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py new file mode 100644 index 000000000000..f3b9669288bf --- /dev/null +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -0,0 +1,38 @@ +import os + +from torchvision.models import efficientnet_v2_l, efficientnet_v2_s +from torchvision.transforms import Compose, Normalize, Resize, CenterCrop, InterpolationMode +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + +EFFNET_PREPROCESS = Compose( + [ + Resize( + 384, + interpolation=InterpolationMode.BILINEAR, + antialias=True, + ), + CenterCrop(384), + Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] +) + + +class EfficientNetEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"): + super().__init__() + + if effnet == "efficientnet_v2_s": + self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() + else: + self.backbone = efficientnet_v2_l(weights="DEFAULT").features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + ) + + def forward(self, x): + return self.mapper(self.backbone(x)) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index c28dce395550..2474eb688983 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -16,9 +16,12 @@ import os from pathlib import Path +import torch import datasets -import transformers +from transformers import CLIPTokenizer +from transformers.utils import ContextManagers from accelerate import Accelerator +from accelerate.state import AcceleratorState, is_initialized from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo @@ -27,6 +30,8 @@ from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.logging import set_verbosity_info, set_verbosity_error +from .modeling_efficient_net_encoder import EfficientNetEncoder + if is_wandb_available(): pass @@ -337,7 +342,26 @@ def main(): # Load scheduler, effnet, tokenizer and models. noise_scheduler = DDPMWuerstchenScheduler() - tokenizer = + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="tokenizer") + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + image_encoder = EfficientNetEncoder.from_pretrained("warp-ai/EfficientNetEncoder").eval() + + if __name__ == "__main__": main() From a24131a7e7774f5bb2afe4a0219a8d0998644b3f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 18 Sep 2023 16:11:37 +0200 Subject: [PATCH 05/63] add CLIPTextModel --- .../train_text_to_image_prior.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 2474eb688983..a12b3f93d338 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -18,7 +18,7 @@ import torch import datasets -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel from transformers.utils import ContextManagers from accelerate import Accelerator from accelerate.state import AcceleratorState, is_initialized @@ -342,7 +342,7 @@ def main(): # Load scheduler, effnet, tokenizer and models. noise_scheduler = DDPMWuerstchenScheduler() - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): """ @@ -353,15 +353,28 @@ def deepspeed_zero_init_disabled_context_manager(): return [] return [deepspeed_plugin.zero3_init_context_manager(enable=False)] - + weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - image_encoder = EfficientNetEncoder.from_pretrained("warp-ai/EfficientNetEncoder").eval() - + image_encoder = EfficientNetEncoder.from_pretrained( + "warp-ai/EfficientNetEncoder", torch_dtype=weight_dtype + ).eval() + clip_model = CLIPTextModel.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ).eval() + + # Freeze text_encoder and image_encoder + clip_model.requires_grad_(False) + image_encoder.requires_grad_(False) + + # prior + + # create EMA for the prior + if __name__ == "__main__": main() From b4f2cdb600999d23fd2df6f18eb21afa1926332c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 10:28:37 +0200 Subject: [PATCH 06/63] add prior ema support --- .../modeling_efficient_net_encoder.py | 7 +- .../train_text_to_image_prior.py | 64 ++++++++++++++++--- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py index f3b9669288bf..93a5acfff793 100644 --- a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -1,12 +1,11 @@ -import os - -from torchvision.models import efficientnet_v2_l, efficientnet_v2_s -from torchvision.transforms import Compose, Normalize, Resize, CenterCrop, InterpolationMode import torch.nn as nn +from torchvision.models import efficientnet_v2_l, efficientnet_v2_s +from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Normalize, Resize from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin + EFFNET_PREPROCESS = Compose( [ Resize( diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index a12b3f93d338..fa5a4486b6d5 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -16,19 +16,24 @@ import os from pathlib import Path -import torch +import accelerate import datasets -from transformers import CLIPTokenizer, CLIPTextModel -from transformers.utils import ContextManagers +import torch +import transformers from accelerate import Accelerator -from accelerate.state import AcceleratorState, is_initialized from accelerate.logging import get_logger +from accelerate.state import AcceleratorState, is_initialized from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers from diffusers import DDPMWuerstchenScheduler +from diffusers.pipelines.wuerstchen import WuerstchenPrior +from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.logging import set_verbosity_info, set_verbosity_error +from diffusers.utils.logging import set_verbosity_error, set_verbosity_info from .modeling_efficient_net_encoder import EfficientNetEncoder @@ -340,9 +345,9 @@ def main(): repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id - # Load scheduler, effnet, tokenizer and models. - noise_scheduler = DDPMWuerstchenScheduler() - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") + # Load scheduler, effnet, tokenizer, clip_model + DDPMWuerstchenScheduler() + CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): """ @@ -371,9 +376,48 @@ def deepspeed_zero_init_disabled_context_manager(): clip_model.requires_grad_(False) image_encoder.requires_grad_(False) - # prior + # load prior model + WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + + # Create EMA for the prior + if args.use_ema: + ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config) + ema_prior.to(accelerator.device) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "prior")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior) + ema_prior.load_state_dict(load_model.state_dict()) + ema_prior.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model - # create EMA for the prior + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) if __name__ == "__main__": From 3c8f6ed833afe5213c6c7ec03626c81f9e4d2931 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 10:31:07 +0200 Subject: [PATCH 07/63] optimizer --- .../train_text_to_image_prior.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index fa5a4486b6d5..607faa3c843a 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -377,7 +377,7 @@ def deepspeed_zero_init_disabled_context_manager(): image_encoder.requires_grad_(False) # load prior model - WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") # Create EMA for the prior if args.use_ema: @@ -419,6 +419,27 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + prior.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) if __name__ == "__main__": main() From 34aab3ec6f914f0417e9c7ef811cc4eb2e142347 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 10:40:06 +0200 Subject: [PATCH 08/63] fix typo --- examples/wuerstchen/text_to_image/README.md | 30 +++++++++++++++++-- .../train_text_to_image_prior.py | 6 ++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 177eea4e7662..067ac812f2ed 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -13,7 +13,7 @@ cd diffusers pip install . ``` -Then cd in the example folder and run +Then cd into the example folder and run ```bash cd examples/wuerstchen/text_to_image pip install -r requirements.txt @@ -24,5 +24,31 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e ```bash accelerate config ``` -For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag. +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag. +## Prior training + +You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning. + +
+ + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \ + --dataset_name=$DATASET_NAME \ + --resolution=768 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --checkpoints_total_limit=3 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --validation_prompts="A robot pokemon, 4k photo" \ + --report_to="wandb" \ + --push_to_hub \ + --output_dir="wuerstchen-prior-pokemon-model" +``` + diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 607faa3c843a..ece9aa201d35 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -35,7 +35,7 @@ from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.logging import set_verbosity_error, set_verbosity_info -from .modeling_efficient_net_encoder import EfficientNetEncoder +from modeling_efficient_net_encoder import EfficientNetEncoder if is_wandb_available(): @@ -346,8 +346,8 @@ def main(): ).repo_id # Load scheduler, effnet, tokenizer, clip_model - DDPMWuerstchenScheduler() - CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") + noise_scheduler = DDPMWuerstchenScheduler() + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): """ From 9def4b5f760bc61e4913ed558f7c92643b7e0698 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 13:38:12 +0200 Subject: [PATCH 09/63] add dataloader --- .../train_text_to_image_prior.py | 130 +++++++++++++++++- 1 file changed, 125 insertions(+), 5 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index ece9aa201d35..ac15303db895 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -13,30 +13,34 @@ import argparse import logging +import math import os +import random from pathlib import Path import accelerate import datasets +import numpy as np import torch import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.state import AcceleratorState, is_initialized from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset from huggingface_hub import create_repo +from modeling_efficient_net_encoder import EFFNET_PREPROCESS, EfficientNetEncoder from packaging import version from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers from diffusers import DDPMWuerstchenScheduler +from diffusers.optimization import get_scheduler from diffusers.pipelines.wuerstchen import WuerstchenPrior from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.logging import set_verbosity_error, set_verbosity_info -from modeling_efficient_net_encoder import EfficientNetEncoder - if is_wandb_available(): pass @@ -346,7 +350,7 @@ def main(): ).repo_id # Load scheduler, effnet, tokenizer, clip_model - noise_scheduler = DDPMWuerstchenScheduler() + DDPMWuerstchenScheduler() tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): @@ -368,12 +372,12 @@ def deepspeed_zero_init_disabled_context_manager(): image_encoder = EfficientNetEncoder.from_pretrained( "warp-ai/EfficientNetEncoder", torch_dtype=weight_dtype ).eval() - clip_model = CLIPTextModel.from_pretrained( + text_encoder = CLIPTextModel.from_pretrained( args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype ).eval() # Freeze text_encoder and image_encoder - clip_model.requires_grad_(False) + text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) # load prior model @@ -441,5 +445,121 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.bool() + return text_input_ids, text_mask + + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["effnet_pixel_values"] = EFFNET_PREPROCESS(images) + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples]) + effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float() + text_input_ids = torch.stack([example["text_input_ids"] for example in examples]) + text_mask = torch.stack([example["text_mask"] for example in examples]) + return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + prior, optimizer, train_dataloader, lr_scheduler + ) + image_encoder.to(accelerator.device) + text_encoder.to(accelerator.device) + + if __name__ == "__main__": main() From d8fb19cf76ae448eba6e70223978aefaadbd48f7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 14:02:42 +0200 Subject: [PATCH 10/63] prompt_embeds and image_embeds --- .../modeling_efficient_net_encoder.py | 19 ++-- .../train_text_to_image_prior.py | 86 ++++++++++++++++++- 2 files changed, 97 insertions(+), 8 deletions(-) diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py index 93a5acfff793..f961ddd37a79 100644 --- a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -1,6 +1,15 @@ +import torch import torch.nn as nn from torchvision.models import efficientnet_v2_l, efficientnet_v2_s -from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Normalize, Resize +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + Resize, + PILToTensor, + ConvertImageDtype, +) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin @@ -8,12 +17,10 @@ EFFNET_PREPROCESS = Compose( [ - Resize( - 384, - interpolation=InterpolationMode.BILINEAR, - antialias=True, - ), + Resize(384, interpolation=InterpolationMode.BILINEAR, antialias=True), CenterCrop(384), + PILToTensor(), + ConvertImageDtype(torch.float), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ] ) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index ac15303db895..56afec850a5b 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -31,6 +31,7 @@ from huggingface_hub import create_repo from modeling_efficient_net_encoder import EFFNET_PREPROCESS, EfficientNetEncoder from packaging import version +from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers @@ -493,7 +494,7 @@ def load_model_hook(models, input_dir): ) # Preprocessing the datasets. - # We need to tokenize input captions and transform the images. + # We need to tokenize input captions and transform the images def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: @@ -516,7 +517,7 @@ def tokenize_captions(examples, is_train=True): def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] - examples["effnet_pixel_values"] = EFFNET_PREPROCESS(images) + examples["effnet_pixel_values"] = [EFFNET_PREPROCESS(image) for image in images] examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) return examples @@ -543,9 +544,11 @@ def collate_fn(examples): ) # Scheduler and math around the number of training steps. + overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, @@ -560,6 +563,85 @@ def collate_fn(examples): image_encoder.to(accelerator.device) text_encoder.to(accelerator.device) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + prior.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(prior): + # Convert images to latent space + text_input_ids, text_mask, effnet_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["effnet_pixel_values"].to(weight_dtype), + ) + + with torch.no_grad(): + text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask) + prompt_embeds = text_encoder_output.last_hidden_state + image_embeds = image_encoder(effnet_images) + if __name__ == "__main__": main() From 3fe9079cd42dce34c218525903f2c7bec329c5ff Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 14:31:23 +0200 Subject: [PATCH 11/63] intial training loop --- .../modeling_efficient_net_encoder.py | 4 +-- .../train_text_to_image_prior.py | 32 ++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py index f961ddd37a79..09905a65e43b 100644 --- a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -4,11 +4,11 @@ from torchvision.transforms import ( CenterCrop, Compose, + ConvertImageDtype, InterpolationMode, Normalize, - Resize, PILToTensor, - ConvertImageDtype, + Resize, ) from diffusers.configuration_utils import ConfigMixin, register_to_config diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 56afec850a5b..468ccdd8779d 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -22,6 +22,7 @@ import datasets import numpy as np import torch +import torch.nn.functional as F import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -351,7 +352,7 @@ def main(): ).repo_id # Load scheduler, effnet, tokenizer, clip_model - DDPMWuerstchenScheduler() + noise_scheduler = DDPMWuerstchenScheduler() tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): @@ -642,6 +643,35 @@ def collate_fn(examples): prompt_embeds = text_encoder_output.last_hidden_state image_embeds = image_encoder(effnet_images) + # Sample noise that we'll add to the image_embeds + noise = torch.randn_like(image_embeds) + bsz = image_embeds.shape[0] + + # Sample a random timestep for each image + timesteps = torch.rand((bsz,), device=image_embeds.device) + + # add noise to latent + noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) + + # Predict the noise residual and compute loss + pred_noise = prior(noisy_latents, timesteps, prompt_embeds) + + # TODO snr_gamma + loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if __name__ == "__main__": main() From 3a22be00241ebff6918c6741cecf00c2a6d75060 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 14:47:39 +0200 Subject: [PATCH 12/63] fix output_dir --- examples/wuerstchen/text_to_image/train_text_to_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 468ccdd8779d..ef18a3f78b37 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -127,7 +127,7 @@ def parse_args(): parser.add_argument( "--output_dir", type=str, - default="kandi_2_2-model-finetuned", + default="wuerstchen-model-finetuned", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( From 6b5d2e7c8f7e237f24fd26b0066256298b1c2d6a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 15:07:47 +0200 Subject: [PATCH 13/63] fix add_noise --- .../schedulers/scheduling_ddpm_wuerstchen.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 781efb12b18b..4605a8eda5df 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -211,23 +211,13 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + device = original_samples.device + alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( + timesteps.size(0), *[1 for _ in original_samples.shape[1:]] + ) + noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise return noisy_samples def __len__(self): From 8f9a6837bdcb9ae0a1c53cd7ee80ef4a6773a3f3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 19 Sep 2023 15:16:11 +0200 Subject: [PATCH 14/63] accelerator check --- .../train_text_to_image_prior.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index ef18a3f78b37..d94ecc90e4b1 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -16,6 +16,7 @@ import math import os import random +import shutil from pathlib import Path import accelerate @@ -670,6 +671,47 @@ def collate_fn(examples): optimizer.step() lr_scheduler.step() optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_prior.step(prior.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break From 8d93fe528be3b12cda5d4624e029ae761794c57f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 20 Sep 2023 09:28:39 +0200 Subject: [PATCH 15/63] make effnet_transforms dynamic --- .../modeling_efficient_net_encoder.py | 21 ------------------- .../train_text_to_image_prior.py | 18 +++++++++++----- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py index 09905a65e43b..a15c52d3fb26 100644 --- a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -1,31 +1,10 @@ -import torch import torch.nn as nn from torchvision.models import efficientnet_v2_l, efficientnet_v2_s -from torchvision.transforms import ( - CenterCrop, - Compose, - ConvertImageDtype, - InterpolationMode, - Normalize, - PILToTensor, - Resize, -) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin -EFFNET_PREPROCESS = Compose( - [ - Resize(384, interpolation=InterpolationMode.BILINEAR, antialias=True), - CenterCrop(384), - PILToTensor(), - ConvertImageDtype(torch.float), - Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ] -) - - class EfficientNetEncoder(ModelMixin, ConfigMixin): @register_to_config def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"): diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index d94ecc90e4b1..fe33380a125a 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -31,8 +31,9 @@ from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo -from modeling_efficient_net_encoder import EFFNET_PREPROCESS, EfficientNetEncoder +from modeling_efficient_net_encoder import EfficientNetEncoder from packaging import version +from torchvision import transforms from tqdm import tqdm from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers @@ -516,10 +517,18 @@ def tokenize_captions(examples, is_train=True): text_mask = inputs.attention_mask.bool() return text_input_ids, text_mask + effnet_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] - examples["effnet_pixel_values"] = [EFFNET_PREPROCESS(image) for image in images] + examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images] examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) return examples @@ -657,7 +666,7 @@ def collate_fn(examples): # Predict the noise residual and compute loss pred_noise = prior(noisy_latents, timesteps, prompt_embeds) - # TODO snr_gamma + # TODO snr_gamma and consistency loss loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -671,7 +680,7 @@ def collate_fn(examples): optimizer.step() lr_scheduler.step() optimizer.zero_grad() - + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: @@ -714,6 +723,5 @@ def collate_fn(examples): break - if __name__ == "__main__": main() From 7a46b1e603634714bcfc4f8a0b2a81aa48fcec9c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 20 Sep 2023 22:24:32 +0200 Subject: [PATCH 16/63] fix training loop --- .../wuerstchen/text_to_image/requirements.txt | 1 + .../train_text_to_image_prior.py | 112 +++++++++--------- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/wuerstchen/text_to_image/requirements.txt index 55f8f9f25e21..951225d4da53 100644 --- a/examples/wuerstchen/text_to_image/requirements.txt +++ b/examples/wuerstchen/text_to_image/requirements.txt @@ -3,3 +3,4 @@ torchvision transformers>=4.25.1 wandb huggingface-cli +bitsandbytes diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index fe33380a125a..5ba0fefd8340 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -571,8 +571,8 @@ def collate_fn(examples): prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( prior, optimizer, train_dataloader, lr_scheduler ) - image_encoder.to(accelerator.device) - text_encoder.to(accelerator.device) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -663,64 +663,64 @@ def collate_fn(examples): # add noise to latent noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) - # Predict the noise residual and compute loss - pred_noise = prior(noisy_latents, timesteps, prompt_embeds) + # Predict the noise residual and compute losscd + pred_noise = prior(noisy_latents, timesteps, prompt_embeds) - # TODO snr_gamma and consistency loss - loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") + # TODO snr_gamma and consistency loss + loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() - train_loss += avg_loss.item() / args.gradient_accumulation_steps + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps - # Backpropagate - accelerator.backward(loss) - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_prior.step(prior.parameters()) - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_prior.step(prior.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break if __name__ == "__main__": From 61c845cbee3272d6dcd07fafd255fdbd5657abad Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 21 Sep 2023 09:29:32 +0200 Subject: [PATCH 17/63] add validation logging --- .../train_text_to_image_prior.py | 102 +++++++++++++++++- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 5ba0fefd8340..bc64423ce0f3 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -38,16 +38,17 @@ from transformers import CLIPTextModel, CLIPTokenizer from transformers.utils import ContextManagers -from diffusers import DDPMWuerstchenScheduler +from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler from diffusers.optimization import get_scheduler -from diffusers.pipelines.wuerstchen import WuerstchenPrior +from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.logging import set_verbosity_error, set_verbosity_info if is_wandb_available(): - pass + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.21.0") @@ -59,8 +60,55 @@ } +def log_validation(prior, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=accelerator.unwrap_model(prior), + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, generator=generator + ).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + return images + + def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of finetuning Kandinsky 2.2.") + parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.") parser.add_argument( "--pretrained_decoder_model_name_or_path", type=str, @@ -722,6 +770,52 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_prior.store(prior.parameters()) + ema_prior.copy_to(prior.parameters()) + log_validation(prior, args, accelerator, weight_dtype, global_step) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_prior.restore(prior.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + prior = accelerator.unwrap_model(prior) + if args.use_ema: + ema_prior.copy_to(prior.parameters()) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, prior_prior=prior + ) + pipeline.prior_pipe.save_pretrained(args.output_dir) + + # Run a final round of inference. + images = [] + if args.validation_prompts is not None: + logger.info("Running inference for collecting generated images...") + pipeline = pipeline.to(accelerator.device) + pipeline.torch_dtype = weight_dtype + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + images.append(image) + + if args.push_to_hub: + pass + + accelerator.end_training() + if __name__ == "__main__": main() From fdc2c92c74c633034c2eb7d420da490f921febb8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 21 Sep 2023 10:40:48 +0200 Subject: [PATCH 18/63] use loaded text_encoder --- examples/wuerstchen/text_to_image/README.md | 7 ++++--- .../wuerstchen/text_to_image/requirements.txt | 1 + .../train_text_to_image_prior.py | 20 ++++++++++++------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 067ac812f2ed..c848c8a2bc1f 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -36,11 +36,12 @@ You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` ```bash export DATASET_NAME="lambdalabs/pokemon-blip-captions" -accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \ +accelerate launch train_text_to_image_prior.py \ + --mixed_precision="fp16" --dataset_name=$DATASET_NAME \ --resolution=768 \ - --train_batch_size=1 \ - --gradient_accumulation_steps=4 \ + --train_batch_size=4 \ + --dataloader_num_workers=4 \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ diff --git a/examples/wuerstchen/text_to_image/requirements.txt b/examples/wuerstchen/text_to_image/requirements.txt index 951225d4da53..a58ad09eca55 100644 --- a/examples/wuerstchen/text_to_image/requirements.txt +++ b/examples/wuerstchen/text_to_image/requirements.txt @@ -4,3 +4,4 @@ transformers>=4.25.1 wandb huggingface-cli bitsandbytes +deepspeed diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index bc64423ce0f3..5404e3ec81d8 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -60,12 +60,14 @@ } -def log_validation(prior, args, accelerator, weight_dtype, epoch): +def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch): logger.info("Running validation... ") pipeline = AutoPipelineForText2Image.from_pretrained( args.pretrained_decoder_model_name_or_path, prior_prior=accelerator.unwrap_model(prior), + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -693,7 +695,7 @@ def collate_fn(examples): text_input_ids, text_mask, effnet_images = ( batch["text_input_ids"], batch["text_mask"], - batch["effnet_pixel_values"].to(weight_dtype), + batch["effnet_pixel_values"].to(weight_dtype) ) with torch.no_grad(): @@ -776,7 +778,7 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_prior.store(prior.parameters()) ema_prior.copy_to(prior.parameters()) - log_validation(prior, args, accelerator, weight_dtype, global_step) + log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step) if args.use_ema: # Switch back to the original UNet parameters. ema_prior.restore(prior.parameters()) @@ -789,7 +791,10 @@ def collate_fn(examples): ema_prior.copy_to(prior.parameters()) pipeline = AutoPipelineForText2Image.from_pretrained( - args.pretrained_decoder_model_name_or_path, prior_prior=prior + args.pretrained_decoder_model_name_or_path, + prior_prior=prior, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, ) pipeline.prior_pipe.save_pretrained(args.output_dir) @@ -797,8 +802,7 @@ def collate_fn(examples): images = [] if args.validation_prompts is not None: logger.info("Running inference for collecting generated images...") - pipeline = pipeline.to(accelerator.device) - pipeline.torch_dtype = weight_dtype + pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) pipeline.set_progress_bar_config(disable=True) if args.seed is None: @@ -808,7 +812,9 @@ def collate_fn(examples): for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, generator=generator + ).images[0] images.append(image) if args.push_to_hub: From 749f977a0ef83feea6b74eef83c34ba9d25b58e7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 21 Sep 2023 10:58:58 +0200 Subject: [PATCH 19/63] use PreTrainedTokenizerFast --- .../wuerstchen/text_to_image/train_text_to_image_prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 5404e3ec81d8..cedf18dba48f 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -35,7 +35,7 @@ from packaging import version from torchvision import transforms from tqdm import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, PreTrainedTokenizerFast from transformers.utils import ContextManagers from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler @@ -405,7 +405,7 @@ def main(): # Load scheduler, effnet, tokenizer, clip_model noise_scheduler = DDPMWuerstchenScheduler() - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") + tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") def deepspeed_zero_init_disabled_context_manager(): """ @@ -695,7 +695,7 @@ def collate_fn(examples): text_input_ids, text_mask, effnet_images = ( batch["text_input_ids"], batch["text_mask"], - batch["effnet_pixel_values"].to(weight_dtype) + batch["effnet_pixel_values"].to(weight_dtype), ) with torch.no_grad(): From a2a9b97870944e3e5aff7621c0e2667d1b092e79 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 17:57:54 +0200 Subject: [PATCH 20/63] load weigth from pickle --- .../text_to_image/modeling_efficient_net_encoder.py | 4 ++-- .../text_to_image/train_text_to_image_prior.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py index a15c52d3fb26..bd551ebf1623 100644 --- a/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py +++ b/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py @@ -11,9 +11,9 @@ def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"): super().__init__() if effnet == "efficientnet_v2_s": - self.backbone = efficientnet_v2_s(weights="DEFAULT").features.eval() + self.backbone = efficientnet_v2_s(weights="DEFAULT").features else: - self.backbone = efficientnet_v2_l(weights="DEFAULT").features.eval() + self.backbone = efficientnet_v2_l(weights="DEFAULT").features self.mapper = nn.Sequential( nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index cedf18dba48f..c50a75138e67 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -30,7 +30,7 @@ from accelerate.state import AcceleratorState, is_initialized from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import create_repo +from huggingface_hub import create_repo, hf_hub_download from modeling_efficient_net_encoder import EfficientNetEncoder from packaging import version from torchvision import transforms @@ -423,9 +423,12 @@ def deepspeed_zero_init_disabled_context_manager(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - image_encoder = EfficientNetEncoder.from_pretrained( - "warp-ai/EfficientNetEncoder", torch_dtype=weight_dtype - ).eval() + pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt") + state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu") + image_encoder = EfficientNetEncoder() + image_encoder.load_state_dict(state_dict["effnet_state_dict"]) + image_encoder.eval() + text_encoder = CLIPTextModel.from_pretrained( args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype ).eval() From 81384fbd392e40995de1948fbedcae5f641331b0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 18:00:58 +0200 Subject: [PATCH 21/63] save_model_card --- .../train_text_to_image_prior.py | 90 ++++++++++++++++++- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index c50a75138e67..8197af67d010 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -43,7 +43,7 @@ from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.logging import set_verbosity_error, set_verbosity_info +from diffusers.utils.logging import make_image_grid, set_verbosity_error, set_verbosity_info if is_wandb_available(): @@ -60,6 +60,82 @@ } +def save_model_card( + args, + repo_id: str, + images=None, + repo_folder=None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + yaml = f""" +--- +license: mit +base_model: {args.pretrained_prior_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- wuerstchen +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# Finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) +pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16) +prompt = "{args.validation_prompts[0]}" +image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple() +image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch): logger.info("Running validation... ") @@ -405,7 +481,9 @@ def main(): # Load scheduler, effnet, tokenizer, clip_model noise_scheduler = DDPMWuerstchenScheduler() - tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer") + tokenizer = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="tokenizer" + ) def deepspeed_zero_init_disabled_context_manager(): """ @@ -821,7 +899,13 @@ def collate_fn(examples): images.append(image) if args.push_to_hub: - pass + save_model_card(args, repo_id, images, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) accelerator.end_training() From 64b3d300b0190b855cac8ad29627abcd56cc0434 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 18:01:28 +0200 Subject: [PATCH 22/63] remove unused file --- .../text_to_image/train_text_to_image.py | 470 ------------------ 1 file changed, 470 deletions(-) delete mode 100644 examples/wuerstchen/text_to_image/train_text_to_image.py diff --git a/examples/wuerstchen/text_to_image/train_text_to_image.py b/examples/wuerstchen/text_to_image/train_text_to_image.py deleted file mode 100644 index 0ba1abda08e6..000000000000 --- a/examples/wuerstchen/text_to_image/train_text_to_image.py +++ /dev/null @@ -1,470 +0,0 @@ -import os -import shutil -import time - -import numpy as np -import torch -import torch.multiprocessing as mp -import torchvision -import transformers -import wandb -import webdataset as wds -from diffnext_v2 import Prior -from diffnext_v2_ldm import EfficientNetEncoder -from torch import nn, optim -from torch.distributed import destroy_process_group, init_process_group -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data import DataLoader -from torchtools.utils import Diffuzz -from tqdm import tqdm -from transformers import AutoTokenizer, CLIPTextModel -from warmup_scheduler import GradualWarmupScheduler -from webdataset.handlers import warn_and_continue - -from utils import WebdatasetFilter - - -transformers.utils.logging.set_verbosity_error() - -# PARAMETERS -updates = 1000000 -warmup_updates = 10000 -ema_start = 5000 -ema_every = 100 -ema_beta = 0.9 -# batch_size = 20 * 8 * 8 # 2048 20 * 64 -batch_size = 20 * 8 * 8 # 2048 20 * 64 -grad_accum_steps = 1 -max_iters = updates * grad_accum_steps -print_every = 2000 * grad_accum_steps -extra_ckpt_every = 50000 * grad_accum_steps -lr = 1e-4 # 1e-4 -generate_new_wandb_id = False -consistency_weight = 0.1 - -dataset_path = "pipe:aws s3 cp s3://stability-west/laion-a-native-high-res/{part-0/{00000..18000}.tar,part-1/{00000..13500}.tar,part-2/{00000..13500}.tar,part-3/{00000..13500}.tar,part-4/{00000..14100}.tar} -" -run_name = "Würstchen-Prior-LDM-Consistency-Scale-EffNet-CLIP-G" -dist_file = "dist_file8" -output_path = f"results/{run_name}" -os.makedirs(output_path, exist_ok=True) -checkpoint_dir = "models" -checkpoint_path = os.path.join(checkpoint_dir, run_name, "model.pt") -os.makedirs(os.path.join(checkpoint_dir, run_name), exist_ok=True) - -wandv_project = "Paella DiffNeXt" -wandv_entity = "babbleberns" -wandb_run_name = run_name - -transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Resize(512), - torchvision.transforms.RandomCrop(512), - ] -) - -effnet_preprocess = torchvision.transforms.Compose( - [ - torchvision.transforms.Resize( - 384, - interpolation=torchvision.transforms.InterpolationMode.BILINEAR, - antialias=True, - ), - torchvision.transforms.CenterCrop(384), - torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ] -) - - -def identity(x): - return x - - -def ddp_setup(rank, world_size, n_node, node_id): # <--- DDP - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "33751" - torch.cuda.set_device(rank) - init_process_group( - backend="nccl", - rank=rank + node_id * world_size, - world_size=world_size * n_node, - init_method="dist_file", - ) - print(f"[GPU {rank+node_id*world_size}] READY") - - -def train(gpu_id, world_size, n_nodes): - node_id = int(os.environ["SLURM_PROCID"]) - main_node = gpu_id == 0 and node_id == 0 - ddp_setup(gpu_id, world_size, n_nodes, node_id) # <--- DDP - device = torch.device(gpu_id) - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - - # --- PREPARE DATASET --- - dataset = ( - wds.WebDataset(dataset_path, resampled=True, handler=warn_and_continue) - .select( - WebdatasetFilter( - min_size=512, - max_pwatermark=0.5, - aesthetic_threshold=5.1, - unsafe_threshold=0.99, - ) - ) - .shuffle(44, handler=warn_and_continue) - .decode("pilrgb", handler=warn_and_continue) - .to_tuple("jpg", "txt", handler=warn_and_continue) - .map_tuple(transforms, identity, handler=warn_and_continue) - ) - - real_batch_size = batch_size // (world_size * n_nodes * grad_accum_steps) - dataloader = DataLoader(dataset, batch_size=real_batch_size, num_workers=8, pin_memory=False) - - if main_node: - print("REAL BATCH SIZE / DEVICE:", real_batch_size) - - # - EfficientNet - - pretrained_checkpoint = torch.load("models/text2img_wurstchen_b_v1_457k.pt", map_location=device) - - effnet = EfficientNetEncoder().to(device) - effnet.load_state_dict(pretrained_checkpoint["effnet_state_dict"]) - effnet.eval().requires_grad_(False) - - # # - vqmodel - - # if main_node: cp /fsx/home-pablo/models/risotto/text2img_wurstchen_b_v1_457k.pt ../models - # vqmodel = VQModel().to(device) - # vqmodel.load_state_dict(torch.load(f"models/vqgan_f4_v1_500k.pt", map_location=device)['state_dict']) - # vqmodel.eval().requires_grad_(False) - - # # - LDM Model as generator - - # generator = DiffNeXt().to(device) - # generator.load_state_dict(pretrained_checkpoint['state_dict']) - # generator.eval().requires_grad_(False).to(torch.bfloat16) - - del pretrained_checkpoint - torch.cuda.empty_cache() - - # --- PREPARE MODELS --- - try: - checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None - except RuntimeError as e: - if os.path.exists(f"{checkpoint_path}.bak"): - os.remove(checkpoint_path) - shutil.copyfile(f"{checkpoint_path}.bak", checkpoint_path) - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - raise e - - diffuzz = Diffuzz(device=device) - - # - CLIP text encoder - clip_model = ( - CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - .to(device) - .eval() - .requires_grad_(False) - ) - clip_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") - - # - Diffusive Imagination Combinatrainer, a.k.a. Risotto - - model = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) - # model = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=42, nhead=24).to(device) - if checkpoint is not None: - model.load_state_dict(checkpoint["state_dict"]) - - if main_node: # <--- DDP - model_ema = Prior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).eval().requires_grad_(False) - # model_ema = Prior(c_in=16, c=1536, c_cond=1024, c_r=64, depth=42, nhead=24).eval().requires_grad_(False) - - # load checkpoints & prepare ddp - if checkpoint is not None: - if main_node: # <--- DDP - if "ema_state_dict" in checkpoint: - model_ema.load_state_dict(checkpoint["ema_state_dict"]) - else: - model_ema.load_state_dict(model.state_dict()) - - # - SETUP WANDB - - if main_node: # <--- DDP - if checkpoint is not None and not generate_new_wandb_id: - run_id = checkpoint["wandb_run_id"] - else: - run_id = wandb.util.generate_id() - wandb.init( - project=wandv_project, - name=wandb_run_name, - entity=wandv_entity, - id=run_id, - resume="allow", - ) - - model = DDP(model, device_ids=[gpu_id], output_device=device) # <--- DDP - - if main_node: # <--- DDP - print( - "Num trainable params:", - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) - - # SETUP OPTIMIZER, SCHEDULER & CRITERION - optimizer = optim.AdamW(model.parameters(), lr=lr) # eps=1e-4 - # optimizer = StableAdamW(model.parameters(), lr=lr) # eps=1e-4 - # optimizer = Lion(model.parameters(), lr=lr / 3) # eps=1e-4 - scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates) - if checkpoint is not None: - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - scheduler.last_epoch = checkpoint["scheduler_last_step"] - - start_iter = 1 - grad_norm = torch.tensor(0, device=device) - if checkpoint is not None: - start_iter = checkpoint["scheduler_last_step"] * grad_accum_steps + 1 - if main_node: # <--- DDP - print("RESUMING TRAINING FROM ITER ", start_iter) - - loss_adjusted = 0.0 - ema_loss = None - if checkpoint is not None: - ema_loss = checkpoint["metrics"]["ema_loss"] - - if checkpoint is not None: - del checkpoint # cleanup memory - torch.cuda.empty_cache() - - # -------------- START TRAINING -------------- - if main_node: - print("Everything prepared, starting training now....") - dataloader_iterator = iter(dataloader) - pbar = tqdm(range(start_iter, max_iters + 1)) if (main_node) else range(start_iter, max_iters + 1) # <--- DDP - model.train() - for it in pbar: - bls = time.time() - images, captions = next(dataloader_iterator) - time.time() - bls - images = images.to(device) - - with torch.no_grad(): - effnet_features = effnet(effnet_preprocess(images)) - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - if np.random.rand() < 0.05: # 90% of the time, drop the CLIP text embeddings (indepentently) - clip_captions = [""] * len(captions) # 5% of the time drop all the captions - else: - clip_captions = captions - clip_tokens = clip_tokenizer( - clip_captions, - truncation=True, - padding="max_length", - max_length=clip_tokenizer.model_max_length, - return_tensors="pt", - ).to(device) - clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - - t = (1 - torch.rand(images.size(0), device=device)).mul(1.08).add(0.001).clamp(0.001, 1.0) - noised_embeddings, noise = diffuzz.diffuse(effnet_features, t) - - t_consistency = t - t * (1 - torch.rand(images.size(0), device=device)).mul(1.08).add(0.001).clamp( - 0.001, 1.0 - ) - noised_embeddings_consistency, _ = diffuzz.diffuse(effnet_features, t_consistency, noise=noise) - - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - pred_noise = model(noised_embeddings, t, clip_text_embeddings) - - model.eval().requires_grad_(False) - with torch.no_grad(): - with model.no_sync(): - pred_noise_consistency = model.module( - noised_embeddings_consistency, - t_consistency, - clip_text_embeddings, - ) - model.train().requires_grad_(True) - - loss = nn.functional.mse_loss(pred_noise, noise, reduction="none").mean(dim=[1, 2, 3]) - consistency_loss = nn.functional.mse_loss(pred_noise, pred_noise_consistency, reduction="none").mean( - dim=[1, 2, 3] - ) - loss_adjusted = ( - (loss + consistency_loss * consistency_weight) * diffuzz.p2_weight(t) - ).mean() / grad_accum_steps - - if it % grad_accum_steps == 0 or it == max_iters: - loss_adjusted.backward() - grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step() - scheduler.step() - optimizer.zero_grad(set_to_none=True) - if main_node and (it % ema_every == 0 or it == max_iters): - if it < ema_start: - model_ema.load_state_dict(model.module.state_dict()) - else: - model_ema.update_weights_ema(model.module, beta=ema_beta) - else: - with model.no_sync(): - loss_adjusted.backward() - - ema_loss = loss.mean().item() if ema_loss is None else ema_loss * 0.99 + loss.mean().item() * 0.01 - - if main_node: - pbar.set_postfix( - { - "bs": images.size(0), - "loss": loss.mean().item(), - "c_loss": consistency_loss.mean().item(), - "loss_adjusted": loss_adjusted.item(), - "ema_loss": ema_loss, - "grad_norm": grad_norm.item(), - "lr": optimizer.param_groups[0]["lr"], - "total_steps": scheduler.last_epoch, - } - ) - - # ble = torch.Tensor([ble]).to(device) - # gathered_values = [torch.zeros(1, dtype=torch.float32).to(device) for _ in range(8)] - # dist.all_gather(gathered_values, ble) - if main_node: - # ble_dict = {f'batch_loading_{i}': b[0].item() for i, b in enumerate(gathered_values)} - wandb.log( - { - "loss": loss.mean().item(), - "c_loss": consistency_loss.mean().item(), - "loss_adjusted": loss_adjusted.item(), - "ema_loss": ema_loss, - "grad_norm": grad_norm.item(), - "lr": optimizer.param_groups[0]["lr"], - "total_steps": scheduler.last_epoch, - } - ) - - if main_node and (it == 1 or it % print_every == 0 or it == max_iters): # <--- DDP - tqdm.write(f"ITER {it}/{max_iters} - loss {ema_loss}") - - try: - os.remove(f"{checkpoint_path}.bak") - except OSError: - pass - - try: - os.rename(checkpoint_path, f"{checkpoint_path}.bak") - except OSError: - pass - - if it % extra_ckpt_every == 0: - torch.save( - { - "state_dict": model.module.state_dict(), - "ema_state_dict": model_ema.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_last_step": scheduler.last_epoch, - "iter": it, - "metrics": { - "ema_loss": ema_loss, - }, - "wandb_run_id": run_id, - }, - os.path.join(checkpoint_dir, run_name, f"model_{it}.pt"), - ) - - torch.save( - { - "state_dict": model.module.state_dict(), - "ema_state_dict": model_ema.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_last_step": scheduler.last_epoch, - "iter": it, - "metrics": { - "ema_loss": ema_loss, - }, - "wandb_run_id": run_id, - }, - checkpoint_path, - ) - - # model.eval() - # images, captions = next(dataloader_iterator) - # # while images.size(0) < 8: - # # _images, _captions = next(dataloader_iterator) - # # images = torch.cat([images, _images], dim=0) - # # captions += _captions - # images, captions = images.to(device), captions - # images = images[:8] - # captions = captions[:8] - - # prior_steps = 60 - # prior_cfg = 6 - # prior_sampler = "ddpm" - - # generator_steps = 12 - # generator_cfg = 2.0 - # generator_sampler = "ddpm" - # generator_latent_shape = (batch_size, 4, 128, 128) - - # with torch.cuda.amp.autocast(dtype=torch.bfloat16), torch.no_grad(): - # clip_tokens = clip_tokenizer(captions, truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - # clip_text_embeddings = clip_model(**clip_tokens).last_hidden_state - - # clip_tokens_uncond = clip_tokenizer([''] * len(captions), truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt").to(device) - # clip_text_embeddings_uncond = clip_model(**clip_tokens_uncond).last_hidden_state - - # t = (1-torch.rand(images.size(0), device=device)).add(0.001).clamp(0.001, 1.0) - # effnet_features = effnet(effnet_preprocess(images)) - # effnet_embeddings_uncond = torch.zeros_like(effnet_features) - # noised_embeddings, noise = diffuzz.diffuse(effnet_features, t) - - # pred_noise = model(noised_embeddings, t, clip_text_embeddings) - # pred = diffuzz.undiffuse(noised_embeddings, t, torch.zeros_like(t), pred_noise) - # sampled = diffuzz.sample(model.module, {'c': clip_text_embeddings}, unconditional_inputs={"c": clip_text_embeddings_uncond}, - # shape=effnet_features.shape, timesteps=prior_steps, cfg=prior_cfg, sampler=prior_sampler)[-1] - # # sampled_ema = diffuzz.sample(model_ema, {'c': clip_text_embeddings}, unconditional_inputs={"c": clip_text_embeddings_uncond}, - # # shape=effnet_features.shape, timesteps=prior_steps, cfg=prior_cfg, sampler=prior_sampler)[-1] - - # sampled_images = diffuzz.sample(generator, {'effnet': sampled, 'clip': clip_text_embeddings}, - # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, - # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] - # # sampled_images_ema = diffuzz.sample(generator, {'effnet': sampled_ema, 'clip': clip_text_embeddings}, - # # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, - # # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] - # sampled_images_original = diffuzz.sample(generator, {'effnet': effnet_features, 'clip': clip_text_embeddings}, - # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, - # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] - # sampled_pred = diffuzz.sample(generator, {'effnet': pred, 'clip': clip_text_embeddings}, - # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, - # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] - # sampled_noised = diffuzz.sample(generator, {'effnet': noised_embeddings, 'clip': clip_text_embeddings}, - # generator_latent_shape, timesteps=generator_steps, cfg=generator_cfg, sampler=generator_sampler, - # unconditional_inputs = {'effnet': effnet_embeddings_uncond, 'clip': clip_text_embeddings_uncond})[-1] - - # noised_images = vqmodel.decode(sampled_noised) - # pred_images = vqmodel.decode(sampled_pred) - # sampled_images_original = vqmodel.decode(sampled_images_original) - # sampled_images = vqmodel.decode(sampled_images) - # # sampled_images_ema = vqmodel.decode(sampled_images_ema) - # model.train() - - # torchvision.utils.save_image(torch.cat([ - # torch.cat([i for i in images.cpu()], dim=-1), - # torch.cat([i for i in noised_images.cpu()], dim=-1), - # torch.cat([i for i in pred_images.cpu()], dim=-1), - # torch.cat([i for i in sampled_images.cpu()], dim=-1), - # # torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), - # torch.cat([i for i in sampled_images_original.cpu()], dim=-1), - # ], dim=-2), f'{output_path}/{it:06d}.jpg') - - # # log_data = [ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [wandb.Image(sampled_images_original[i])] + [wandb.Image(images[i])] for i in range(len(images))] - # log_data = [ [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_original[i])] + [wandb.Image(images[i])] for i in range(len(images))] - # log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Sampled Original", "Orig"]) - # wandb.log({"Log": log_table}) - # # torch.cuda.empty_cache() - # del clip_tokens, clip_text_embeddings, clip_tokens_uncond, clip_text_embeddings_uncond, t, effnet_features, effnet_embeddings_uncond - # del noised_embeddings, noise, pred_noise, pred, sampled, sampled_ema, sampled_images, sampled_images_ema, sampled_images_original - # del sampled_pred, sampled_noised, noised_images, pred_images, log_data, log_table - - destroy_process_group() # <--- DDP - - -if __name__ == "__main__": - world_size = torch.cuda.device_count() - n_node = 8 # [3,8,11,14-15,17-18,49] - mp.spawn(train, args=(world_size, n_node), nprocs=world_size) # <--- DDP ;) From f20a6fc71391eca130e826b48a1e2fe5158c32d7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 18:02:51 +0200 Subject: [PATCH 23/63] fix typos --- .../wuerstchen/text_to_image/train_text_to_image_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 8197af67d010..07cb9344b2f3 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -30,7 +30,7 @@ from accelerate.state import AcceleratorState, is_initialized from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import create_repo, hf_hub_download, upload_folder from modeling_efficient_net_encoder import EfficientNetEncoder from packaging import version from torchvision import transforms @@ -475,7 +475,7 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: - create_repo( + repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id From d9e1d47f2048a0ab31d0d1575c9888756d1b80c2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 18:06:32 +0200 Subject: [PATCH 24/63] save prior pipeilne in its own folder --- examples/wuerstchen/text_to_image/train_text_to_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 07cb9344b2f3..0f0ddaa8a2a4 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -877,7 +877,7 @@ def collate_fn(examples): prior_text_encoder=accelerator.unwrap_model(text_encoder), prior_tokenizer=tokenizer, ) - pipeline.prior_pipe.save_pretrained(args.output_dir) + pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline")) # Run a final round of inference. images = [] From 67c37e3002569f7a12b4bc28bcff50c6bf23451f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 23 Sep 2023 18:10:18 +0200 Subject: [PATCH 25/63] fix imports --- .../text_to_image/train_text_to_image_prior.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 0f0ddaa8a2a4..3fd923713d34 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -42,8 +42,8 @@ from diffusers.optimization import get_scheduler from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_wandb_available -from diffusers.utils.logging import make_image_grid, set_verbosity_error, set_verbosity_info +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.logging import set_verbosity_error, set_verbosity_info if is_wandb_available(): @@ -158,7 +158,11 @@ def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dty for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): image = pipeline( - args.validation_prompts[i], prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, generator=generator + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + height=args.resolution, + width=args.resolution, ).images[0] images.append(image) @@ -894,7 +898,11 @@ def collate_fn(examples): for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): image = pipeline( - args.validation_prompts[i], prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, generator=generator + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + width=args.resolution, + height=args.resolution, ).images[0] images.append(image) From 021b0a45e895a6fb56cd56d986136c59c0b28da6 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 24 Sep 2023 18:47:18 +0200 Subject: [PATCH 26/63] fix pipe_t2i --- .../wuerstchen/text_to_image/train_text_to_image_prior.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 3fd923713d34..74e9be2d9603 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -99,11 +99,11 @@ def save_model_card( from diffusers import DiffusionPipeline import torch -pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) -pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype=torch.float16) +pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype}) +pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}) prompt = "{args.validation_prompts[0]}" -image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple() -image = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0] +(image_embeds,) = pipe_prior(prompt).to_tuple() +image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0] image.save("my_image.png") ``` From c2faf11478155a10fe5ddb5142d93f87da2f6918 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 25 Sep 2023 10:03:30 +0200 Subject: [PATCH 27/63] scale image_embeds --- examples/wuerstchen/text_to_image/train_text_to_image_prior.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 74e9be2d9603..591f1c77e36d 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -787,6 +787,8 @@ def collate_fn(examples): text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask) prompt_embeds = text_encoder_output.last_hidden_state image_embeds = image_encoder(effnet_images) + # scale + image_embeds = image_embeds.add(1.).div(42.) # Sample noise that we'll add to the image_embeds noise = torch.randn_like(image_embeds) From 77924ea3b477d80d6fb2746fffb2345b032c5401 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 25 Sep 2023 10:33:38 +0200 Subject: [PATCH 28/63] remove snr_gamma --- examples/wuerstchen/text_to_image/README.md | 2 +- .../text_to_image/train_text_to_image_prior.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index c848c8a2bc1f..bd9954cb3d1e 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -50,6 +50,6 @@ accelerate launch train_text_to_image_prior.py \ --validation_prompts="A robot pokemon, 4k photo" \ --report_to="wandb" \ --push_to_hub \ - --output_dir="wuerstchen-prior-pokemon-model" + --output_dir="wuerstchen-prior-pokemon-model" ``` diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 591f1c77e36d..3613a81d3191 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -312,13 +312,6 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) - parser.add_argument( - "--snr_gamma", - type=float, - default=None, - help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " - "More details here: https://arxiv.org/abs/2303.09556.", - ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -803,7 +796,7 @@ def collate_fn(examples): # Predict the noise residual and compute losscd pred_noise = prior(noisy_latents, timesteps, prompt_embeds) - # TODO snr_gamma and consistency loss + # vanilla loss loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). From 85efacda1755d1c42ca227dda0a541b928a79709 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 25 Sep 2023 20:53:33 +0200 Subject: [PATCH 29/63] format --- examples/wuerstchen/text_to_image/train_text_to_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 3613a81d3191..f47a9e7a9774 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -781,7 +781,7 @@ def collate_fn(examples): prompt_embeds = text_encoder_output.last_hidden_state image_embeds = image_encoder(effnet_images) # scale - image_embeds = image_embeds.add(1.).div(42.) + image_embeds = image_embeds.add(1.0).div(42.0) # Sample noise that we'll add to the image_embeds noise = torch.randn_like(image_embeds) From 3433ebb820378a178d06ffdb6752c65cef4642ed Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 25 Sep 2023 20:54:13 +0200 Subject: [PATCH 30/63] initial lora prior training --- examples/wuerstchen/text_to_image/README.md | 13 + .../train_text_to_image_lora_prior.py | 905 ++++++++++++++++++ .../wuerstchen/modeling_wuerstchen_prior.py | 84 ++ 3 files changed, 1002 insertions(+) create mode 100644 examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index bd9954cb3d1e..780963042f15 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -53,3 +53,16 @@ accelerate launch train_text_to_image_prior.py \ --output_dir="wuerstchen-prior-pokemon-model" ``` + +## Training with LoRA + +Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. + +In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: + +- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. + +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py new file mode 100644 index 000000000000..a7149e1600b3 --- /dev/null +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -0,0 +1,905 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState, is_initialized +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, hf_hub_download, upload_folder +from modeling_efficient_net_encoder import EfficientNetEncoder +from packaging import version +from torchvision import transforms +from tqdm import tqdm +from transformers import CLIPTextModel, PreTrainedTokenizerFast +from transformers.utils import ContextManagers + +from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.optimization import get_scheduler +from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.logging import set_verbosity_error, set_verbosity_info + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.21.0") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def save_model_card( + args, + repo_id: str, + images=None, + repo_folder=None, +): + img_str = "" + if len(images) > 0: + image_grid = make_image_grid(images, 1, len(args.validation_prompts)) + image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) + img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" + + yaml = f""" +--- +license: mit +base_model: {args.pretrained_prior_model_name_or_path} +datasets: +- {args.dataset_name} +tags: +- wuerstchen +- text-to-image +- diffusers +- lora +inference: true +--- + """ + model_card = f""" +# LoRA Finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n +{img_str} + +## Pipeline usage + +You can use the pipeline like so: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype}) +pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}) +prompt = "{args.validation_prompts[0]}" +(image_embeds,) = pipe_prior(prompt).to_tuple() +image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0] +image.save("my_image.png") +``` + +## Training info + +These are the key hyperparameters used during training: + +* Epochs: {args.num_train_epochs} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Gradient accumulation steps: {args.gradient_accumulation_steps} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + wandb_info = "" + if is_wandb_available(): + wandb_run_url = None + if wandb.run is not None: + wandb_run_url = wandb.run.url + + if wandb_run_url is not None: + wandb_info = f""" +More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). +""" + + model_card += wandb_info + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=accelerator.unwrap_model(prior), + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + height=args.resolution, + width=args.resolution, + ).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + return images + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.") + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--pretrained_decoder_model_name_or_path", + type=str, + default="warp-ai/wuerstchen", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_prior_model_name_or_path", + type=str, + default="warp-ai/wuerstchen-prior", + required=False, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="wuerstchen-model-finetuned-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="learning rate", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0.0, + required=False, + help="weight decay_to_use", + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + return args + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, effnet, tokenizer, clip_model + noise_scheduler = DDPMWuerstchenScheduler() + tokenizer = PreTrainedTokenizerFast.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt") + state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu") + image_encoder = EfficientNetEncoder() + image_encoder.load_state_dict(state_dict["effnet_state_dict"]) + image_encoder.eval() + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ).eval() + + # Freeze text_encoder and image_encoder + text_encoder.requires_grad_(False) + image_encoder.requires_grad_(False) + + # load prior model + prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + # lora attn processor + lora_attn_procs = {} + for name in prior.attn_processors.keys(): + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank) + prior.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(prior.attn_processors) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "prior")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + text_input_ids = inputs.input_ids + text_mask = inputs.attention_mask.bool() + return text_input_ids, text_mask + + effnet_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images] + examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples]) + effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float() + text_input_ids = torch.stack([example["text_input_ids"] for example in examples]) + text_mask = torch.stack([example["text_mask"] for example in examples]) + return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + lora_layers, optimizer, train_dataloader, lr_scheduler + ) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + prior.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(prior): + # Convert images to latent space + text_input_ids, text_mask, effnet_images = ( + batch["text_input_ids"], + batch["text_mask"], + batch["effnet_pixel_values"].to(weight_dtype), + ) + + with torch.no_grad(): + text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask) + prompt_embeds = text_encoder_output.last_hidden_state + image_embeds = image_encoder(effnet_images) + # scale + image_embeds = image_embeds.add(1.0).div(42.0) + + # Sample noise that we'll add to the image_embeds + noise = torch.randn_like(image_embeds) + bsz = image_embeds.shape[0] + + # Sample a random timestep for each image + timesteps = torch.rand((bsz,), device=image_embeds.device) + + # add noise to latent + noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) + + # Predict the noise residual and compute losscd + pred_noise = prior(noisy_latents, timesteps, prompt_embeds) + + # vanilla loss + loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + prior = accelerator.unwrap_model(prior) + + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_prior=prior, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + ) + pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline")) + + # Run a final round of inference. + images = [] + if args.validation_prompts is not None: + logger.info("Running inference for collecting generated images...") + pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline( + args.validation_prompts[i], + prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, + generator=generator, + width=args.resolution, + height=args.resolution, + ).images[0] + images.append(image) + + if args.push_to_hub: + save_model_card(args, repo_id, images, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 9bd29b59b3af..514d2c997248 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -14,11 +14,19 @@ # limitations under the License. import math +from typing import Dict, Union import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) from ...models.modeling_utils import ModelMixin from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm @@ -45,6 +53,82 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro nn.Conv2d(c, c_in * 2, kernel_size=1), ) + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 From 10fb63537c3b08445db9ec9b1f3b87045fba2359 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 25 Sep 2023 21:27:11 +0200 Subject: [PATCH 31/63] log_validation and save --- .../train_text_to_image_lora_prior.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index a7149e1600b3..12031f07d21c 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -43,7 +43,6 @@ from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior -from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.logging import set_verbosity_error, set_verbosity_info @@ -139,17 +138,17 @@ def save_model_card( f.write(yaml + model_card) -def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch): +def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch): logger.info("Running validation... ") pipeline = AutoPipelineForText2Image.from_pretrained( args.pretrained_decoder_model_name_or_path, - prior_prior=accelerator.unwrap_model(prior), prior_text_encoder=accelerator.unwrap_model(text_encoder), prior_tokenizer=tokenizer, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) + pipeline.set_attn_processor(attn_processors) pipeline.set_progress_bar_config(disable=True) if args.seed is None: @@ -521,7 +520,12 @@ def deepspeed_zero_init_disabled_context_manager(): image_encoder.requires_grad_(False) # load prior model - prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + prior = WuerstchenPrior.from_pretrained( + args.pretrained_prior_model_name_or_path, + subfolder="prior", + torch_dtype=weight_dtype, + device=accelerator.device, + ) # lora attn processor lora_attn_procs = {} for name in prior.attn_processors.keys(): @@ -805,7 +809,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -851,26 +855,29 @@ def collate_fn(examples): if accelerator.is_main_process: if args.validation_prompts is not None and epoch % args.validation_epochs == 0: - log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step) + log_validation( + text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step + ) # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - prior = accelerator.unwrap_model(prior) - - pipeline = AutoPipelineForText2Image.from_pretrained( - args.pretrained_decoder_model_name_or_path, - prior_prior=prior, - prior_text_encoder=accelerator.unwrap_model(text_encoder), - prior_tokenizer=tokenizer, - ) - pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline")) + prior = prior.to(torch.float32) + prior.save_attn_procs(os.path.join(args.output_dir, "prior_lora")) # Run a final round of inference. images = [] if args.validation_prompts is not None: logger.info("Running inference for collecting generated images...") + pipeline = AutoPipelineForText2Image.from_pretrained( + args.pretrained_decoder_model_name_or_path, + prior_text_encoder=accelerator.unwrap_model(text_encoder), + prior_tokenizer=tokenizer, + ) pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) + # load attention processors + pipeline.prior_prior.load_attn_procs(os.path.join(args.output_dir, "prior_lora")) + pipeline.set_progress_bar_config(disable=True) if args.seed is None: From 0a7ffa94740488e3cb3583a1763d71819db3f279 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Sep 2023 10:30:10 +0200 Subject: [PATCH 32/63] initial gradient working --- .../text_to_image/train_text_to_image_lora_prior.py | 11 ++++------- .../text_to_image/train_text_to_image_prior.py | 2 +- .../pipelines/wuerstchen/modeling_wuerstchen_prior.py | 3 ++- .../schedulers/scheduling_ddpm_wuerstchen.py | 3 ++- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 12031f07d21c..b55d60c68c75 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -520,12 +520,9 @@ def deepspeed_zero_init_disabled_context_manager(): image_encoder.requires_grad_(False) # load prior model - prior = WuerstchenPrior.from_pretrained( - args.pretrained_prior_model_name_or_path, - subfolder="prior", - torch_dtype=weight_dtype, - device=accelerator.device, - ) + prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") + prior.to(accelerator.device, dtype=weight_dtype) + # lora attn processor lora_attn_procs = {} for name in prior.attn_processors.keys(): @@ -791,7 +788,7 @@ def collate_fn(examples): bsz = image_embeds.shape[0] # Sample a random timestep for each image - timesteps = torch.rand((bsz,), device=image_embeds.device) + timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype) # add noise to latent noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index f47a9e7a9774..05323ac54738 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -788,7 +788,7 @@ def collate_fn(examples): bsz = image_embeds.shape[0] # Sample a random timestep for each image - timesteps = torch.rand((bsz,), device=image_embeds.device) + timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype) # add noise to latent noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 514d2c997248..c1802c0fbc20 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -27,11 +27,12 @@ AttnAddedKVProcessor, AttnProcessor, ) +from ...loaders import UNet2DConditionLoadersMixin from ...models.modeling_utils import ModelMixin from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm -class WuerstchenPrior(ModelMixin, ConfigMixin): +class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py index 4605a8eda5df..bafa6d7f1b87 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py +++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py @@ -214,11 +214,12 @@ def add_noise( timesteps: torch.FloatTensor, ) -> torch.FloatTensor: device = original_samples.device + dtype = original_samples.dtype alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view( timesteps.size(0), *[1 for _ in original_samples.shape[1:]] ) noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise - return noisy_samples + return noisy_samples.to(dtype=dtype) def __len__(self): return self.config.num_train_timesteps From d9b6b48c2d9c3a3817903f8fb050895893fc6b6c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Sep 2023 10:36:30 +0200 Subject: [PATCH 33/63] remove save/load hooks --- .../train_text_to_image_lora_prior.py | 33 +++---------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index b55d60c68c75..4fe5aef5dcd6 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -515,11 +515,13 @@ def deepspeed_zero_init_disabled_context_manager(): args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype ).eval() - # Freeze text_encoder and image_encoder + # Freeze text_encoder, cast to weight_dtype and image_encoder and move to device text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) + image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) - # load prior model + # load prior model, cast to weight_dtype and move to device prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") prior.to(accelerator.device, dtype=weight_dtype) @@ -530,31 +532,6 @@ def deepspeed_zero_init_disabled_context_manager(): prior.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(prior.attn_processors) - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "prior")) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - for i in range(len(models)): - # pop models so that they are not loaded again - model = models.pop() - - # load diffusers style into model - load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior") - model.register_to_config(**load_model.config) - - model.load_state_dict(load_model.state_dict()) - del load_model - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True @@ -699,8 +676,6 @@ def collate_fn(examples): lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( lora_layers, optimizer, train_dataloader, lr_scheduler ) - image_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From dbc238b8d94f0cf78519a35b95092e946ee3fe75 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 26 Sep 2023 10:45:32 +0200 Subject: [PATCH 34/63] set set_attn_processor on prior_prior --- .../wuerstchen/text_to_image/train_text_to_image_lora_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 4fe5aef5dcd6..15a628179cfb 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -148,7 +148,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) - pipeline.set_attn_processor(attn_processors) + pipeline.prior_prior.set_attn_processor(attn_processors) pipeline.set_progress_bar_config(disable=True) if args.seed is None: From af4dcae1f31d754c76598486119ff9d068a6d46d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 11:32:15 +0200 Subject: [PATCH 35/63] add lora script --- examples/wuerstchen/text_to_image/README.md | 28 ++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 780963042f15..1b6d528ec53f 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -24,7 +24,10 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e ```bash accelerate config ``` -For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag. +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To do so, run: +```bash +huggingface-cli login +``` ## Prior training @@ -56,7 +59,7 @@ accelerate launch train_text_to_image_prior.py \ ## Training with LoRA -Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. +Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: @@ -64,5 +67,24 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. -[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +### Prior Training + +First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `DATASET_NAME` environment variables. Here, we will use the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). + +```bash +export DATASET_NAME="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \ + --dataset_name=$DATASET_NAME --caption_column="text" \ + --resolution=768 \ + --train_batch_size=8 \ + --num_train_epochs=100 --checkpointing_steps=5000 \ + --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ + --seed=42 \ + --rank=4 \ + --output_dir="wuerstchen-prior-pokemon-lora" \ + --validation_prompt="cute dragon creature" --report_to="wandb" \ + --push_to_hub \ +``` From bc776dcdd409f68599d04e470d3f9629b9f4c816 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 11:34:31 +0200 Subject: [PATCH 36/63] typos --- examples/wuerstchen/text_to_image/README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 1b6d528ec53f..40ddb1025d35 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -40,7 +40,7 @@ You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` export DATASET_NAME="lambdalabs/pokemon-blip-captions" accelerate launch train_text_to_image_prior.py \ - --mixed_precision="fp16" + --mixed_precision="fp16" \ --dataset_name=$DATASET_NAME \ --resolution=768 \ --train_batch_size=4 \ @@ -75,7 +75,8 @@ First, you need to set up your development environment as explained in the [inst ```bash export DATASET_NAME="lambdalabs/pokemon-blip-captions" -accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \ +accelerate launch train_text_to_image_prior_lora.py \ + --mixed_precision="fp16" \ --dataset_name=$DATASET_NAME --caption_column="text" \ --resolution=768 \ --train_batch_size=8 \ @@ -83,8 +84,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_prior_lora.py \ --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \ --seed=42 \ --rank=4 \ - --output_dir="wuerstchen-prior-pokemon-lora" \ - --validation_prompt="cute dragon creature" --report_to="wandb" \ + --validation_prompt="cute dragon creature" \ + --report_to="wandb" \ --push_to_hub \ + --output_dir="wuerstchen-prior-pokemon-lora" ``` - From 7989eae34d3fc26c6be4f8eef3e12f2f62aa8479 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 19:32:52 +0200 Subject: [PATCH 37/63] use LoraLoaderMixin for prior pipeline --- examples/wuerstchen/text_to_image/README.md | 2 +- .../train_text_to_image_lora_prior.py | 15 +++++++------- src/diffusers/loaders.py | 20 ++++++++++--------- .../wuerstchen/modeling_wuerstchen_prior.py | 4 +++- .../wuerstchen/pipeline_wuerstchen_prior.py | 5 ++++- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 40ddb1025d35..7ae90a1256f9 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -70,7 +70,7 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de ### Prior Training -First, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `DATASET_NAME` environment variables. Here, we will use the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). +First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variables. Here, we will use the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). ```bash export DATASET_NAME="lambdalabs/pokemon-blip-captions" diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index 15a628179cfb..a811c7392f99 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -19,7 +19,6 @@ import shutil from pathlib import Path -import accelerate import datasets import numpy as np import torch @@ -32,13 +31,12 @@ from datasets import load_dataset from huggingface_hub import create_repo, hf_hub_download, upload_folder from modeling_efficient_net_encoder import EfficientNetEncoder -from packaging import version from torchvision import transforms from tqdm import tqdm from transformers import CLIPTextModel, PreTrainedTokenizerFast from transformers.utils import ContextManagers -from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler +from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers.loaders import AttnProcsLayers from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.optimization import get_scheduler @@ -524,7 +522,7 @@ def deepspeed_zero_init_disabled_context_manager(): # load prior model, cast to weight_dtype and move to device prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") prior.to(accelerator.device, dtype=weight_dtype) - + # lora attn processor lora_attn_procs = {} for name in prior.attn_processors.keys(): @@ -835,7 +833,10 @@ def collate_fn(examples): accelerator.wait_for_everyone() if accelerator.is_main_process: prior = prior.to(torch.float32) - prior.save_attn_procs(os.path.join(args.output_dir, "prior_lora")) + WuerstchenPriorPipeline.save_lora_weights( + os.path.join(args.output_dir, "prior_lora"), + unet_lora_layers=lora_layers, + ) # Run a final round of inference. images = [] @@ -847,8 +848,8 @@ def collate_fn(examples): prior_tokenizer=tokenizer, ) pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype) - # load attention processors - pipeline.prior_prior.load_attn_procs(os.path.join(args.output_dir, "prior_lora")) + # load lora weights + pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora")) pipeline.set_progress_bar_config(disable=True) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f63532b84e7c..8f745c7b3a7c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1133,7 +1133,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=self.unet, + unet=getattr(self, self.UNET_NAME), low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) @@ -1464,7 +1464,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) @@ -1784,7 +1784,7 @@ def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameter @classmethod def save_lora_weights( - self, + cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, @@ -1824,7 +1824,7 @@ def save_lora_weights( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers ) - unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} + unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: @@ -1835,12 +1835,12 @@ def save_lora_weights( ) text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() + f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) # Save the model - self.write_lora_layers( + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, @@ -1849,7 +1849,9 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + @classmethod def write_lora_layers( + cls, state_dict: Dict[str, torch.Tensor], save_directory: str, is_main_process: bool, @@ -2147,7 +2149,7 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True self.unet.unfuse_lora() if self.use_peft_backend: - from peft.tuners.tuner_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer def unfuse_text_encoder_lora(text_encoder): for module in text_encoder.modules(): @@ -2818,7 +2820,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di @classmethod def save_lora_weights( - self, + cls, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, @@ -2869,7 +2871,7 @@ def pack_weights(layers, prefix): state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - self.write_lora_layers( + cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index c1802c0fbc20..43fa425a507b 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -20,6 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin from ...models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -27,12 +28,13 @@ AttnAddedKVProcessor, AttnProcessor, ) -from ...loaders import UNet2DConditionLoadersMixin from ...models.modeling_utils import ModelMixin from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + UNET_NAME = "prior" + @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 8e737a74bbfe..e4d83a2b4452 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -20,6 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...loaders import LoraLoaderMixin from ...schedulers import DDPMWuerstchenScheduler from ...utils import ( BaseOutput, @@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput): image_embeddings: Union[torch.FloatTensor, np.ndarray] -class WuerstchenPriorPipeline(DiffusionPipeline): +class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): """ Pipeline for generating image prior for Wuerstchen. @@ -84,6 +85,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline): A scheduler to be used in combination with `prior` to generate image embedding. """ + UNET_NAME = "prior" + TEXT_ENCODER_NAME = "text_encoder" model_cpu_offload_seq = "text_encoder->prior" def __init__( From 70cd979494b8ba506731ea8abd6af572634ec38d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 19:47:30 +0200 Subject: [PATCH 38/63] fix usage --- .../text_to_image/train_text_to_image_lora_prior.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index a811c7392f99..f2c0061a14de 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -99,11 +99,13 @@ def save_model_card( from diffusers import DiffusionPipeline import torch -pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype}) -pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}) -prompt = "{args.validation_prompts[0]}" -(image_embeds,) = pipe_prior(prompt).to_tuple() -image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0] +pipeline = AutoPipelineForText2Image.from_pretrained( + "{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype} + ) +# load lora weights from folder: +pipeline.prior_pipe.load_lora_weights("{os.path.join(args.output_dir, "prior_lora")}") + +image = pipeline(prompt=prompt).images[0] image.save("my_image.png") ``` @@ -111,6 +113,7 @@ def save_model_card( These are the key hyperparameters used during training: +* LoRA rank: {args.rank} * Epochs: {args.num_train_epochs} * Learning rate: {args.learning_rate} * Batch size: {args.train_batch_size} From 0454a878bd17a4c9c55a92d0ed710fcc6af3af28 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 20:02:19 +0200 Subject: [PATCH 39/63] make fix-copies --- .../pipelines/wuerstchen/modeling_wuerstchen_prior.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 43fa425a507b..055619b9f98a 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -82,7 +82,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): r""" Sets the attention processor to use to compute attention. @@ -106,9 +108,9 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): - module.set_processor(processor) + module.set_processor(processor, _remove_lora=_remove_lora) else: - module.set_processor(processor.pop(f"{name}.processor")) + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) @@ -130,7 +132,7 @@ def set_default_attn_processor(self): f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) - self.set_attn_processor(processor) + self.set_attn_processor(processor, _remove_lora=True) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions From 7435c7058b49291a6225937a1d317bde052b0268 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 20:13:05 +0200 Subject: [PATCH 40/63] yse repo_id --- .../wuerstchen/text_to_image/train_text_to_image_lora_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index f2c0061a14de..ce77a0b1d06d 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -103,7 +103,7 @@ def save_model_card( "{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype} ) # load lora weights from folder: -pipeline.prior_pipe.load_lora_weights("{os.path.join(args.output_dir, "prior_lora")}") +pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype}) image = pipeline(prompt=prompt).images[0] image.save("my_image.png") From 2eb5d9c8c8979b082e526cbe3c8398055015dc7a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 20:20:32 +0200 Subject: [PATCH 41/63] write_lora_layers is a staitcmethod --- src/diffusers/loaders.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 29aa630397c9..58b80b1a82c0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1867,9 +1867,8 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - @classmethod + @staticmethod def write_lora_layers( - cls, state_dict: Dict[str, torch.Tensor], save_directory: str, is_main_process: bool, From 234bebbc02c43648e2c6f6582a9c7226b05afd69 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 20:53:51 +0200 Subject: [PATCH 42/63] use defualts --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 58b80b1a82c0..80665643a0d5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1143,14 +1143,14 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, self.UNET_NAME), + unet=getattr(self, "unet", self.UNET_NAME), low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=self.text_encoder, + text_encoder=getattr(self, "text_encoder", self.TEXT_ENCODER_NAME), lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From afb001c9092332072991ccc0a75e0b8db2ede95f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 27 Sep 2023 21:17:36 +0200 Subject: [PATCH 43/63] fix defaults --- src/diffusers/loaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 80665643a0d5..c53a813bfe8a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1143,14 +1143,16 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, "unet", self.UNET_NAME), + unet=getattr(self, self.UNET_NAME if hasattr(self, "UNET_NAME") else "unet"), low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=getattr(self, "text_encoder", self.TEXT_ENCODER_NAME), + text_encoder=getattr( + self, self.TEXT_ENCODER_NAME if hasattr(self, "TEXT_ENCODER_NAME") else "text_encoder" + ), lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From 47a31ab13a90b0236d5778abd0e17326b903e363 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Sep 2023 14:07:02 +0200 Subject: [PATCH 44/63] undo --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c53a813bfe8a..0fcd830f2d39 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2168,7 +2168,7 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True self.unet.unfuse_lora() if self.use_peft_backend: - from peft.tuners.tuners_utils import BaseTunerLayer + from peft.tuners.tuner_utils import BaseTunerLayer def unfuse_text_encoder_lora(text_encoder): for module in text_encoder.modules(): From 682f30e5dfe01b85ed89e60d43ed0b3c2baf8f36 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Sep 2023 15:10:29 +0200 Subject: [PATCH 45/63] Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py Co-authored-by: Patrick von Platen --- .../pipelines/wuerstchen/pipeline_wuerstchen_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index e4d83a2b4452..11a73e6d6b5f 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -85,8 +85,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): A scheduler to be used in combination with `prior` to generate image embedding. """ - UNET_NAME = "prior" - TEXT_ENCODER_NAME = "text_encoder" + unet_name = "prior" + text_encoder_name = "text_encoder" model_cpu_offload_seq = "text_encoder->prior" def __init__( From f0638ffa19dfe8e944abed7a5f1cc1e8e8eef6a8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Sep 2023 15:11:34 +0200 Subject: [PATCH 46/63] Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c53a813bfe8a..5321dccc0cf9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1143,7 +1143,7 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, self.UNET_NAME if hasattr(self, "UNET_NAME") else "unet"), + unet=getattr(self, self.unet_name), low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) From 8957bf868d6653ad64c889634ab601f008948d17 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Sep 2023 15:11:43 +0200 Subject: [PATCH 47/63] Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5321dccc0cf9..00cd6a1cf685 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1150,9 +1150,7 @@ def load_lora_weights( self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=getattr( - self, self.TEXT_ENCODER_NAME if hasattr(self, "TEXT_ENCODER_NAME") else "text_encoder" - ), + text_encoder=getattr(self, self.text_encoder_name), lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From dddd55396597909589d599a3122ebce74d298a87 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 28 Sep 2023 15:28:33 +0200 Subject: [PATCH 48/63] Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py --- src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 055619b9f98a..e53bf08a3640 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -33,7 +33,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - UNET_NAME = "prior" + unet_name = "prior" @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): From 1ab236fa291048aa1c75d6f6d6dd419eb4013bb0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Oct 2023 11:12:36 +0200 Subject: [PATCH 49/63] Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 494d68e365b8..3d2b6bb392bc 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1151,7 +1151,7 @@ def load_lora_weights( self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name), + text_encoder=getattr(self, self.text_encoder_name, self.text_encoder_name), lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From 402f305849b48de34a6c96227816e5d62fd1cfe1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 4 Oct 2023 11:12:48 +0200 Subject: [PATCH 50/63] Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3d2b6bb392bc..f9dea7f7f2ca 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1144,7 +1144,7 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, self.unet_name), + unet=getattr(self, self.unet_name, self.unet), low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) From 72e755f185a749d0eee4b52d5795ae2cf2233a4b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 12:44:29 +0200 Subject: [PATCH 51/63] add graident checkpoint support to prior --- examples/wuerstchen/text_to_image/README.md | 4 +- .../wuerstchen/modeling_wuerstchen_prior.py | 52 ++++++++++++++++--- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 7ae90a1256f9..d449e87e794e 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -31,7 +31,7 @@ huggingface-cli login ## Prior training -You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning. +You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so one can utilize it for more GPU memory constrained setups.
@@ -44,6 +44,8 @@ accelerate launch train_text_to_image_prior.py \ --dataset_name=$DATASET_NAME \ --resolution=768 \ --train_batch_size=4 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ --dataloader_num_workers=4 \ --max_train_steps=15000 \ --learning_rate=1e-05 \ diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index e53bf08a3640..b89c314472c3 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from ...utils import is_torch_version from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...models.attention_processor import ( @@ -34,7 +35,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): unet_name = "prior" - + _supports_gradient_checkpointing = True + @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() @@ -56,6 +58,8 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro nn.Conv2d(c, c_in * 2, kernel_size=1), ) + self.gradient_checkpointing = False + @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -134,6 +138,9 @@ def set_default_attn_processor(self): self.set_attn_processor(processor, _remove_lora=True) + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 @@ -150,12 +157,43 @@ def forward(self, x, r, c): x = self.projection(x) c_embed = self.cond_mapper(c) r_embed = self.gen_r_embedding(r) - for block in self.blocks: - if isinstance(block, AttnBlock): - x = block(x, c_embed) - elif isinstance(block, TimestepBlock): - x = block(x, r_embed) + + if self.training and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, c_embed, use_reentrant=False) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, use_reentrant=False) else: - x = block(x) + for block in self.blocks: + if isinstance(block, AttnBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, c_embed) + elif isinstance(block, TimestepBlock): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x, r_embed) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), x) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) a, b = self.out(x).chunk(2, dim=1) return (x_in - a) / ((1 - b).abs() + 1e-5) From 43343c6b24770ae5e802893eb10d82b5140e3e5b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 12:44:57 +0200 Subject: [PATCH 52/63] gradient_checkpointing --- .../wuerstchen/text_to_image/train_text_to_image_prior.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index 05323ac54738..f1337efb0577 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -294,6 +294,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -555,6 +560,9 @@ def load_model_hook(models, input_dir): accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: + prior.enable_gradient_checkpointing() + if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True From 15b2d117594a059f01b390de4955d4eaae49d1ad Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:15:00 +0200 Subject: [PATCH 53/63] formatting --- .../wuerstchen/modeling_wuerstchen_prior.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index b89c314472c3..6c0c5d3ac4f9 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn -from ...utils import is_torch_version from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import UNet2DConditionLoadersMixin from ...models.attention_processor import ( @@ -30,13 +29,14 @@ AttnProcessor, ) from ...models.modeling_utils import ModelMixin +from ...utils import is_torch_version from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): unet_name = "prior" _supports_gradient_checkpointing = True - + @register_to_config def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): super().__init__() @@ -59,7 +59,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro ) self.gradient_checkpointing = False - + @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -159,34 +159,33 @@ def forward(self, x, r, c): r_embed = self.gen_r_embedding(r) if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + if is_torch_version(">=", "1.11.0"): for block in self.blocks: if isinstance(block, AttnBlock): x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, c_embed, use_reentrant=False) + create_custom_forward(block), x, c_embed, use_reentrant=False + ) elif isinstance(block, TimestepBlock): x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed, use_reentrant=False) + create_custom_forward(block), x, r_embed, use_reentrant=False + ) else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, use_reentrant=False) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False) else: for block in self.blocks: if isinstance(block, AttnBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, c_embed) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed) elif isinstance(block, TimestepBlock): - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x, r_embed) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed) else: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), x) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) else: for block in self.blocks: if isinstance(block, AttnBlock): From 4de3fbe185474d322faa8f6df971117a2d134ac0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:46:23 +0200 Subject: [PATCH 54/63] Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index d449e87e794e..696aac1fe3dc 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -24,7 +24,7 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e ```bash accelerate config ``` -For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To do so, run: +For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run: ```bash huggingface-cli login ``` From 162500e836830b83b91eb93ba9b4f3977850764d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:46:32 +0200 Subject: [PATCH 55/63] Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 696aac1fe3dc..e4c184001a45 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -31,7 +31,7 @@ huggingface-cli login ## Prior training -You can fine-tune the Würstchen prior model with `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so one can utilize it for more GPU memory constrained setups. +You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
From b3e54cbae39ecd1607b44cc9c9544631b19d0ff1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:46:43 +0200 Subject: [PATCH 56/63] Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index e4c184001a45..c45d0bdbdf59 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -65,7 +65,7 @@ Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Mic In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: -- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). - Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. - LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. From 12209efb019ad50532db5970bdfcf49246ead264 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:46:54 +0200 Subject: [PATCH 57/63] Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index c45d0bdbdf59..0c8e506cda55 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -72,7 +72,7 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de ### Prior Training -First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variables. Here, we will use the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). +First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions). ```bash export DATASET_NAME="lambdalabs/pokemon-blip-captions" From a1527b20882867d547ed62c20f03747fa4ab3a6f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:47:02 +0200 Subject: [PATCH 58/63] Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/README.md b/examples/wuerstchen/text_to_image/README.md index 0c8e506cda55..5378e3ef5253 100644 --- a/examples/wuerstchen/text_to_image/README.md +++ b/examples/wuerstchen/text_to_image/README.md @@ -6,7 +6,7 @@ Before running the scripts, make sure to install the library's training dependen **Important** -To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment: ```bash git clone https://github.com/huggingface/diffusers cd diffusers From a28f5c0e0b6216d2db90e57a019fa7ae9e250569 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:48:01 +0200 Subject: [PATCH 59/63] Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py Co-authored-by: Pedro Cuenca --- .../wuerstchen/text_to_image/train_text_to_image_lora_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index ce77a0b1d06d..5235fa99cfdd 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -50,7 +50,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.21.0") +check_min_version("0.22.0") logger = get_logger(__name__, log_level="INFO") From cda5de40c54857c62e6f09ceafe0c5ee3b71a289 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:48:25 +0200 Subject: [PATCH 60/63] Update src/diffusers/loaders.py Co-authored-by: Pedro Cuenca --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d16f9188f9de..f9cb811e66c5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1151,7 +1151,7 @@ def load_lora_weights( self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name, self.text_encoder_name), + text_encoder=getattr(self, self.text_encoder_name, self.text_encoder), lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From d9964e2a5ade71105f8c97a7c0b16a3f3042190c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 9 Oct 2023 13:48:48 +0200 Subject: [PATCH 61/63] Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py Co-authored-by: Pedro Cuenca --- examples/wuerstchen/text_to_image/train_text_to_image_prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index f1337efb0577..92f63c93fc1a 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.21.0") +check_min_version("0.22.0") logger = get_logger(__name__, log_level="INFO") From 89fa22fbe6bbbcc30e0e810df83b35e1985428ae Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 10 Oct 2023 09:48:00 +0200 Subject: [PATCH 62/63] use default unet and text_encoder --- src/diffusers/loaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index af7420e62058..1010344546b7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1153,14 +1153,16 @@ def load_lora_weights( self.load_lora_into_unet( state_dict, network_alphas=network_alphas, - unet=getattr(self, self.unet_name, self.unet), + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, ) self.load_lora_into_text_encoder( state_dict, network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name, self.text_encoder), + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, From cc3adb5c80806ac4cd30b073ad418e4612c98ca3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 14 Oct 2023 12:08:07 +0200 Subject: [PATCH 63/63] fix test --- src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index 6c0c5d3ac4f9..ca72ce581fcc 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -59,6 +59,7 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro ) self.gradient_checkpointing = False + self.set_default_attn_processor() @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors