From eaf725ac577c587f28a5edcfc35080bba1680a69 Mon Sep 17 00:00:00 2001 From: Hamish Friedlander Date: Fri, 30 Dec 2022 21:09:30 +1300 Subject: [PATCH] Make safetensors properly optional, and support storing textural inversion embeddings (#101) --- lora_diffusion/cli_pt_to_safetensors.py | 85 ++++++++++++++ lora_diffusion/lora.py | 146 +++++++++++++++++++----- lora_diffusion/safe_open.py | 68 +++++++++++ train_lora_dreambooth.py | 54 ++++++--- train_lora_w_ti.py | 52 +++++++-- 5 files changed, 352 insertions(+), 53 deletions(-) create mode 100644 lora_diffusion/cli_pt_to_safetensors.py create mode 100644 lora_diffusion/safe_open.py diff --git a/lora_diffusion/cli_pt_to_safetensors.py b/lora_diffusion/cli_pt_to_safetensors.py new file mode 100644 index 0000000..e9047b5 --- /dev/null +++ b/lora_diffusion/cli_pt_to_safetensors.py @@ -0,0 +1,85 @@ +import os + +import fire +import torch +from lora_diffusion import ( + DEFAULT_TARGET_REPLACE, + TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + UNET_DEFAULT_TARGET_REPLACE, + convert_loras_to_safeloras_with_embeds, + safetensors_available, +) + +_target_by_name = { + "unet": UNET_DEFAULT_TARGET_REPLACE, + "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +} + + +def convert(*paths, outpath, overwrite=False, **settings): + """ + Converts one or more pytorch Lora and/or Textual Embedding pytorch files + into a safetensor file. + + Pass all the input paths as arguments. Whether they are Textual Embedding + or Lora models will be auto-detected. + + For Lora models, their name will be taken from the path, i.e. + "lora_weight.pt" => unet + "lora_weight.text_encoder.pt" => text_encoder + + You can also set target_modules and/or rank by providing an argument prefixed + by the name. + + So a complete example might be something like: + + ``` + python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 + ``` + """ + modelmap = {} + embeds = {} + + if os.path.exists(outpath) and not overwrite: + raise ValueError( + f"Output path {outpath} already exists, and overwrite is not True" + ) + + for path in paths: + data = torch.load(path) + + if isinstance(data, dict): + print(f"Loading textual inversion embeds {data.keys()} from {path}") + embeds.update(data) + + else: + name_parts = os.path.split(path)[1].split(".") + name = name_parts[-2] if len(name_parts) > 2 else "unet" + + model_settings = { + "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), + "rank": 4, + } + + prefix = f"{name}." + model_settings |= { + k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) + } + + print(f"Loading Lora for {name} from {path} with settings {model_settings}") + + modelmap[name] = ( + path, + model_settings["target_modules"], + model_settings["rank"], + ) + + convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) + + +def main(): + fire.Fire(convert) + + +if __name__ == "__main__": + main() diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index dc5a944..8bbd970 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -1,3 +1,4 @@ +import json import math from itertools import groupby from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -8,6 +9,25 @@ import torch.nn as nn import torch.nn.functional as F +try: + from safetensors.torch import safe_open + from safetensors.torch import save_file as safe_save + + safetensors_available = True +except ImportError: + from .safe_open import safe_open + + def safe_save( + tensors: Dict[str, torch.Tensor], + filename: str, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + raise EnvironmentError( + "Saving safetensors requires the safetensors library. Please install with pip or similar." + ) + + safetensors_available = False + class LoraInjectedLinear(nn.Module): def __init__(self, in_features, out_features, bias=False, r=4): @@ -35,6 +55,8 @@ def forward(self, input): DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE +EMBED_FLAG = "" + def _find_children( model, @@ -203,8 +225,9 @@ def save_lora_as_json(model, path="./lora.json"): json.dump(weights, f) -def save_safeloras( +def save_safeloras_with_embeds( modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, outpath="./lora.safetensors", ): """ @@ -217,10 +240,6 @@ def save_safeloras( weights = {} metadata = {} - import json - - from safetensors.torch import save_file - for name, (model, target_replace_module) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) @@ -231,12 +250,24 @@ def save_safeloras( weights[f"{name}:{i}:up"] = _up.weight weights[f"{name}:{i}:down"] = _down.weight - print(f"Saving weights to {outpath} with metadata", metadata) - save_file(weights, outpath, metadata) + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) -def convert_loras_to_safeloras( + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, outpath="./lora.safetensors", ): """ @@ -250,10 +281,6 @@ def convert_loras_to_safeloras( weights = {} metadata = {} - import json - - from safetensors.torch import save_file - for name, (path, target_replace_module, r) in modelmap.items(): metadata[name] = json.dumps(list(target_replace_module)) @@ -268,8 +295,19 @@ def convert_loras_to_safeloras( else: weights[f"{name}:{i}:down"] = weight - print(f"Saving weights to {outpath} with metadata", metadata) - save_file(weights, outpath, metadata) + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) def parse_safeloras( @@ -288,9 +326,6 @@ def parse_safeloras( } """ loras = {} - - import json - metadata = safeloras.metadata() get_name = lambda k: k.split(":")[0] @@ -299,12 +334,24 @@ def parse_safeloras( keys.sort(key=get_name) for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras # Extract the targets - target = json.loads(metadata[name]) + target = json.loads(info) # Build the result lists - Python needs us to preallocate lists to insert into them module_keys = list(module_keys) - ranks = [None] * (len(module_keys) // 2) + ranks = [4] * (len(module_keys) // 2) weights = [None] * len(module_keys) for key in module_keys: @@ -313,7 +360,7 @@ def parse_safeloras( idx = int(idx) # Add the rank - ranks[idx] = json.loads(metadata[f"{name}:{idx}:rank"]) + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) # Insert the weight into the list idx = idx * 2 + (1 if direction == "down" else 0) @@ -324,14 +371,42 @@ def parse_safeloras( return loras -def load_safeloras(path, device="cpu"): +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() - from safetensors.torch import safe_open + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + +def load_safeloras(path, device="cpu"): safeloras = safe_open(path, framework="pt", device=device) return parse_safeloras(safeloras) +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + def weight_apply_lora( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, alpha=1.0 ): @@ -535,28 +610,26 @@ def _ti_lora_path(path: str) -> str: return ".".join(path.split(".")[:-1] + ["ti", "pt"]) -def load_learned_embed_in_clip( - learned_embeds_path, +def apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, - token: Union[str, List[str]] = None, + token: Optional[Union[str, List[str]]] = None, idempotent=False, ): - loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") - if isinstance(token, str): trained_tokens = [token] elif isinstance(token, list): - assert len(loaded_learned_embeds.keys()) == len( + assert len(learned_embeds.keys()) == len( token ), "The number of tokens and the number of embeds should be the same" trained_tokens = token else: - trained_tokens = list(loaded_learned_embeds.keys()) + trained_tokens = list(learned_embeds.keys()) for token in trained_tokens: print(token) - embeds = loaded_learned_embeds[token] + embeds = learned_embeds[token] # cast to dtype of text_encoder dtype = text_encoder.get_input_embeddings().weight.dtype @@ -583,6 +656,19 @@ def load_learned_embed_in_clip( return token +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + def patch_pipe( pipe, unet_path, diff --git a/lora_diffusion/safe_open.py b/lora_diffusion/safe_open.py new file mode 100644 index 0000000..77ada82 --- /dev/null +++ b/lora_diffusion/safe_open.py @@ -0,0 +1,68 @@ +""" +Pure python version of Safetensors safe_open +From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 +""" + +import json +import mmap +import os + +import torch + + +class SafetensorsWrapper: + def __init__(self, metadata, tensors): + self._metadata = metadata + self._tensors = tensors + + def metadata(self): + return self._metadata + + def keys(self): + return self._tensors.keys() + + def get_tensor(self, k): + return self._tensors[k] + + +DTYPES = { + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, +} + + +def create_tensor(storage, info, offset): + dtype = DTYPES[info["dtype"]] + shape = info["shape"] + start, stop = info["data_offsets"] + return ( + torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) + .view(dtype=dtype) + .reshape(shape) + ) + + +def safe_open(filename, framework="pt", device="cpu"): + if framework != "pt": + raise ValueError("`framework` must be 'pt'") + + with open(filename, mode="r", encoding="utf8") as file_obj: + with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: + header = m.read(8) + n = int.from_bytes(header, "little") + metadata_bytes = m.read(n) + metadata = json.loads(metadata_bytes) + + size = os.stat(filename).st_size + storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() + offset = n + 8 + + return SafetensorsWrapper( + metadata=metadata.get("__metadata__", {}), + tensors={ + name: create_tensor(storage, info, offset).to(device) + for name, info in metadata.items() + if name != "__metadata__" + }, + ) diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index 6a29a5c..aadfdaf 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -33,6 +33,7 @@ from lora_diffusion import ( extract_lora_ups_down, inject_trainable_lora, + safetensors_available, save_lora_weight, save_safeloras, ) @@ -263,6 +264,13 @@ def parse_args(input_args=None): default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--output_format", + type=str, + choices=["pt", "safe", "both"], + default="both", + help="The output format of the model predicitions and checkpoints.", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -478,6 +486,17 @@ def parse_args(input_args=None): "You need not use --class_prompt without --with_prior_preservation." ) + if not safetensors_available: + if args.output_format == "both": + print( + "Safetensors is not available - changing output format to just output PyTorch files" + ) + args.output_format = "pt" + elif args.output_format == "safe": + raise ValueError( + "Safetensors is not available - either install it, or change output_format." + ) + return args @@ -974,27 +993,30 @@ def collate_fn(examples): print("\n\nLora TRAINING DONE!\n\n") - save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") - if args.train_text_encoder: - save_lora_weight( - pipeline.text_encoder, - args.output_dir + "/lora_weight.text_encoder.pt", - target_replace_module=["CLIPAttention"], - ) + if args.output_format == "pt" or args.output_format == "both": + save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") + if args.train_text_encoder: + save_lora_weight( + pipeline.text_encoder, + args.output_dir + "/lora_weight.text_encoder.pt", + target_replace_module=["CLIPAttention"], + ) + + if args.output_format == "safe" or args.output_format == "both": + loras = {} + loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"}) + if args.train_text_encoder: + loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"}) + + save_safeloras(loras, args.output_dir + "/lora_weight.safetensors") if args.push_to_hub: repo.push_to_hub( - commit_message="End of training", blocking=False, auto_lfs_prune=True + commit_message="End of training", + blocking=False, + auto_lfs_prune=True, ) - save_safeloras( - { - "unet": (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"}), - "text_encoder": (pipeline.text_encoder, {"CLIPAttention"}), - }, - args.output_dir + "/lora_weight.safetensors", - ) - accelerator.end_training() diff --git a/train_lora_w_ti.py b/train_lora_w_ti.py index 7e9eb81..8f63280 100644 --- a/train_lora_w_ti.py +++ b/train_lora_w_ti.py @@ -32,9 +32,11 @@ from transformers import CLIPTextModel, CLIPTokenizer from lora_diffusion import ( + extract_lora_ups_down, inject_trainable_lora, + safetensors_available, save_lora_weight, - extract_lora_ups_down, + save_safeloras_with_embeds, ) from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers from PIL import Image @@ -369,6 +371,13 @@ def parse_args(input_args=None): default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument( + "--output_format", + type=str, + choices=["pt", "safe", "both"], + default="both", + help="The output format of the model predicitions and checkpoints.", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -603,6 +612,17 @@ def parse_args(input_args=None): "You need not use --class_prompt without --with_prior_preservation." ) + if not safetensors_available: + if args.output_format == "both": + print( + "Safetensors is not available - changing output format to just output PyTorch files" + ) + args.output_format = "pt" + elif args.output_format == "safe": + raise ValueError( + "Safetensors is not available - either install it, or change output_format." + ) + return args @@ -1155,13 +1175,31 @@ def collate_fn(examples): print("\n\nLora TRAINING DONE!\n\n") - save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") + if args.output_format == "pt" or args.output_format == "both": + save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt") - save_lora_weight( - text_encoder, - args.output_dir + "/lora_weight.text_encoder.pt", - target_replace_module=["CLIPAttention"], - ) + save_lora_weight( + text_encoder, + args.output_dir + "/lora_weight.text_encoder.pt", + target_replace_module=["CLIPAttention"], + ) + + if args.output_format == "safe" or args.output_format == "both": + loras = {} + loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"}) + loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"}) + + learned_embeds = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[placeholder_token_id] + ) + + embeds = {args.placeholder_token: learned_embeds.detach().cpu()} + + save_safeloras_with_embeds( + loras, embeds, args.output_dir + "/lora_weight.safetensors" + ) accelerator.end_training()