From 7501b280fbfcea7ad70baca3a56c24290d4b3761 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 27 Nov 2023 15:39:09 +0000 Subject: [PATCH 1/3] fix --- .../stable_diffusion/convert_from_ckpt.py | 187 ++++++++++++++---- 1 file changed, 151 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 35466f008f54..4382f030d6ae 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -166,7 +166,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits @@ -381,7 +386,12 @@ def create_ldm_bert_config(original_config): def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False + checkpoint, + config, + path=None, + extract_ema=False, + controlnet=False, + skip_extract_state_dict=False, ): """ Takes a state dict and a config, and returns a converted checkpoint. @@ -446,7 +456,7 @@ def convert_ldm_unet_checkpoint( new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] # Relevant to StableDiffusionUpscalePipeline - if "num_class_embeds" in config: + if config["num_class_embeds"] is not None: new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] @@ -497,17 +507,31 @@ def convert_ldm_unet_checkpoint( ) paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) resnet_0 = middle_blocks[0] @@ -523,7 +547,11 @@ def convert_ldm_unet_checkpoint( attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) for i in range(num_output_blocks): @@ -546,9 +574,16 @@ def convert_ldm_unet_checkpoint( resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} @@ -572,13 +607,25 @@ def convert_ldm_unet_checkpoint( "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) new_checkpoint[new_path] = unet_state_dict[old_path] @@ -682,7 +729,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 @@ -691,12 +744,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): @@ -715,7 +780,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 @@ -724,12 +795,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) return new_checkpoint @@ -804,7 +887,10 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder text_model_dict = {} - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + remove_prefixes = [ + "cond_stage_model.transformer", + "conditioner.embedders.0.transformer", + ] for key in keys: for prefix in remove_prefixes: @@ -841,8 +927,14 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), + ( + "token_embedding.weight", + "transformer.text_model.embeddings.token_embedding.weight", + ), + ( + "positional_embedding", + "transformer.text_model.embeddings.position_embedding.weight", + ), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) @@ -1486,7 +1578,9 @@ def download_from_original_stable_diffusion_ckpt( try: tokenizer = CLIPTokenizer.from_pretrained( - "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only + "stabilityai/stable-diffusion-2", + subfolder="tokenizer", + local_files_only=local_files_only, ) except Exception: raise ValueError( @@ -1513,7 +1607,8 @@ def download_from_original_stable_diffusion_ckpt( "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" ) low_res_scheduler = DDPMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + "stabilityai/stable-diffusion-x4-upscaler", + subfolder="low_res_scheduler", ) pipe = pipeline_class( @@ -1541,9 +1636,10 @@ def download_from_original_stable_diffusion_ckpt( pipe.requires_safety_checker = False else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) + ( + image_normalizer, + image_noising_scheduler, + ) = stable_unclip_image_noising_components(original_config, clip_stats_path=clip_stats_path, device=device) if stable_unclip == "img2img": feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) @@ -1567,23 +1663,29 @@ def download_from_original_stable_diffusion_ckpt( if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained( - karlo_model, subfolder="prior", local_files_only=local_files_only + karlo_model, + subfolder="prior", + local_files_only=local_files_only, ) try: prior_tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only + "openai/clip-vit-large-patch14", + local_files_only=local_files_only, ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) prior_text_model = CLIPTextModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", local_files_only=local_files_only + "openai/clip-vit-large-patch14", + local_files_only=local_files_only, ) prior_scheduler = UnCLIPScheduler.from_pretrained( - karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only + karlo_model, + subfolder="prior_scheduler", + local_files_only=local_files_only, ) prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) else: @@ -1620,7 +1722,8 @@ def download_from_original_stable_diffusion_ckpt( ) try: feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + "CompVis/stable-diffusion-safety-checker", + local_files_only=local_files_only, ) except Exception: raise ValueError( @@ -1651,10 +1754,12 @@ def download_from_original_stable_diffusion_ckpt( if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + "CompVis/stable-diffusion-safety-checker", + local_files_only=local_files_only, ) feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + "CompVis/stable-diffusion-safety-checker", + local_files_only=local_files_only, ) else: safety_checker = None @@ -1694,7 +1799,9 @@ def download_from_original_stable_diffusion_ckpt( text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) try: tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + pad_token="!", + local_files_only=local_files_only, ) except Exception: raise ValueError( @@ -1757,7 +1864,9 @@ def download_from_original_stable_diffusion_ckpt( text_encoder = None try: tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + pad_token="!", + local_files_only=local_files_only, ) except Exception: raise ValueError( @@ -1794,7 +1903,13 @@ def download_from_original_stable_diffusion_ckpt( text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + pipe = LDMTextToImagePipeline( + vqvae=vae, + bert=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) return pipe From ce19a090511df95c507c43750ccbb8df570bb552 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 4 Dec 2023 06:52:29 +0000 Subject: [PATCH 2/3] fix ldm conversion --- src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index f07fed2b6202..f8d1134c699b 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -456,7 +456,7 @@ def convert_ldm_unet_checkpoint( new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] # Relevant to StableDiffusionUpscalePipeline - if config["num_class_embeds"] is not None: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] From 97fc161b13ba0280ed49bea1f4630733960369ba Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 4 Dec 2023 15:44:55 +0000 Subject: [PATCH 3/3] fix linting --- .../stable_diffusion/convert_from_ckpt.py | 185 ++++-------------- 1 file changed, 35 insertions(+), 150 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index f8d1134c699b..6960ba6c4516 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -166,12 +166,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, - checkpoint, - old_checkpoint, - attention_paths_to_split=None, - additional_replacements=None, - config=None, + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits @@ -386,12 +381,7 @@ def create_ldm_bert_config(original_config): def convert_ldm_unet_checkpoint( - checkpoint, - config, - path=None, - extract_ema=False, - controlnet=False, - skip_extract_state_dict=False, + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False ): """ Takes a state dict and a config, and returns a converted checkpoint. @@ -507,31 +497,17 @@ def convert_ldm_unet_checkpoint( ) paths = renew_resnet_paths(resnets) - meta_path = { - "old": f"input_blocks.{i}.0", - "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", - } + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = { - "old": f"input_blocks.{i}.1", - "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", - } + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] @@ -547,11 +523,7 @@ def convert_ldm_unet_checkpoint( attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( - attentions_paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): @@ -574,16 +546,9 @@ def convert_ldm_unet_checkpoint( resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = { - "old": f"output_blocks.{i}.0", - "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", - } + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} @@ -607,25 +572,13 @@ def convert_ldm_unet_checkpoint( "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join( - [ - "up_blocks", - str(block_id), - "resnets", - str(layer_in_block_id), - path["new"], - ] - ) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] @@ -729,13 +682,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 @@ -744,24 +691,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): @@ -780,13 +715,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 @@ -795,24 +724,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint @@ -887,10 +804,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder text_model_dict = {} - remove_prefixes = [ - "cond_stage_model.transformer", - "conditioner.embedders.0.transformer", - ] + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] for key in keys: for prefix in remove_prefixes: @@ -927,14 +841,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), - ( - "token_embedding.weight", - "transformer.text_model.embeddings.token_embedding.weight", - ), - ( - "positional_embedding", - "transformer.text_model.embeddings.position_embedding.weight", - ), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) @@ -1581,9 +1489,7 @@ def download_from_original_stable_diffusion_ckpt( try: tokenizer = CLIPTokenizer.from_pretrained( - "stabilityai/stable-diffusion-2", - subfolder="tokenizer", - local_files_only=local_files_only, + "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only ) except Exception: raise ValueError( @@ -1610,8 +1516,7 @@ def download_from_original_stable_diffusion_ckpt( "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" ) low_res_scheduler = DDPMScheduler.from_pretrained( - "stabilityai/stable-diffusion-x4-upscaler", - subfolder="low_res_scheduler", + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" ) pipe = pipeline_class( @@ -1639,10 +1544,9 @@ def download_from_original_stable_diffusion_ckpt( pipe.requires_safety_checker = False else: - ( - image_normalizer, - image_noising_scheduler, - ) = stable_unclip_image_noising_components(original_config, clip_stats_path=clip_stats_path, device=device) + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) if stable_unclip == "img2img": feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) @@ -1666,29 +1570,23 @@ def download_from_original_stable_diffusion_ckpt( if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained( - karlo_model, - subfolder="prior", - local_files_only=local_files_only, + karlo_model, subfolder="prior", local_files_only=local_files_only ) try: prior_tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-large-patch14", - local_files_only=local_files_only, + "openai/clip-vit-large-patch14", local_files_only=local_files_only ) except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) prior_text_model = CLIPTextModelWithProjection.from_pretrained( - "openai/clip-vit-large-patch14", - local_files_only=local_files_only, + "openai/clip-vit-large-patch14", local_files_only=local_files_only ) prior_scheduler = UnCLIPScheduler.from_pretrained( - karlo_model, - subfolder="prior_scheduler", - local_files_only=local_files_only, + karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only ) prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) else: @@ -1725,8 +1623,7 @@ def download_from_original_stable_diffusion_ckpt( ) try: feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", - local_files_only=local_files_only, + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) except Exception: raise ValueError( @@ -1757,12 +1654,10 @@ def download_from_original_stable_diffusion_ckpt( if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", - local_files_only=local_files_only, + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", - local_files_only=local_files_only, + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) else: safety_checker = None @@ -1802,9 +1697,7 @@ def download_from_original_stable_diffusion_ckpt( text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) try: tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", - pad_token="!", - local_files_only=local_files_only, + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only ) except Exception: raise ValueError( @@ -1867,9 +1760,7 @@ def download_from_original_stable_diffusion_ckpt( text_encoder = None try: tokenizer_2 = CLIPTokenizer.from_pretrained( - "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", - pad_token="!", - local_files_only=local_files_only, + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only ) except Exception: raise ValueError( @@ -1906,13 +1797,7 @@ def download_from_original_stable_diffusion_ckpt( text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) - pipe = LDMTextToImagePipeline( - vqvae=vae, - bert=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - ) + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe