From 03f101e906aa5eaf1e32aeb9a56d70d6cc0ce41a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 24 Apr 2024 07:42:12 +0000 Subject: [PATCH 1/2] update --- scripts/convert_animatediff_motion_lora_to_diffusers.py | 8 ++++++-- scripts/convert_animatediff_motion_module_to_diffusers.py | 8 +++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/scripts/convert_animatediff_motion_lora_to_diffusers.py b/scripts/convert_animatediff_motion_lora_to_diffusers.py index 509a7345793c..90d62126c76c 100644 --- a/scripts/convert_animatediff_motion_lora_to_diffusers.py +++ b/scripts/convert_animatediff_motion_lora_to_diffusers.py @@ -1,7 +1,7 @@ import argparse import torch -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file def convert_motion_module(original_state_dict): @@ -27,6 +27,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--use_safetensors", type=str, action="store_true") return parser.parse_args() @@ -34,7 +35,10 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.use_safetensors: + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index ceb967acd3d6..9a589db17aa6 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -1,6 +1,7 @@ import argparse import torch +from safetensors.torch import load_file from diffusers import MotionAdapter @@ -31,6 +32,7 @@ def get_args(): parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--motion_max_seq_length", type=int, default=32) parser.add_argument("--save_fp16", action="store_true") + parser.add_argument("--use_safetensors", action="store_true") return parser.parse_args() @@ -38,7 +40,11 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.use_safetensors: + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] From f20ff26950d5965a78bb4258bc7ce35b88e4108a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 29 Apr 2024 11:50:54 +0000 Subject: [PATCH 2/2] update --- scripts/convert_animatediff_motion_lora_to_diffusers.py | 3 +-- scripts/convert_animatediff_motion_module_to_diffusers.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/convert_animatediff_motion_lora_to_diffusers.py b/scripts/convert_animatediff_motion_lora_to_diffusers.py index 90d62126c76c..c680fdc68462 100644 --- a/scripts/convert_animatediff_motion_lora_to_diffusers.py +++ b/scripts/convert_animatediff_motion_lora_to_diffusers.py @@ -27,7 +27,6 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) - parser.add_argument("--use_safetensors", type=str, action="store_true") return parser.parse_args() @@ -35,7 +34,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if args.use_safetensors: + if args.ckpt_path.endswith(".safetensors"): state_dict = load_file(args.ckpt_path) else: state_dict = torch.load(args.ckpt_path, map_location="cpu") diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index 9a589db17aa6..e8fb007243fd 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -32,7 +32,6 @@ def get_args(): parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--motion_max_seq_length", type=int, default=32) parser.add_argument("--save_fp16", action="store_true") - parser.add_argument("--use_safetensors", action="store_true") return parser.parse_args() @@ -40,7 +39,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if args.use_safetensors: + if args.ckpt_path.endswith(".safetensors"): state_dict = load_file(args.ckpt_path) else: state_dict = torch.load(args.ckpt_path, map_location="cpu")