From 8ba623dd49914f1c9706c1a143abbee03479bdec Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 28 Feb 2024 06:22:19 +0000 Subject: [PATCH 1/2] update --- scripts/convert_animatediff_motion_module_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index 9c5d236fd713..60344c4417f0 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -48,4 +48,4 @@ def get_args(): # skip loading position embeddings adapter.load_state_dict(conv_state_dict, strict=False) adapter.save_pretrained(args.output_path) - adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16) + adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16") From 2edf45a01e07178a24ad727807c9a193fe226ea0 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 28 Feb 2024 08:02:06 +0000 Subject: [PATCH 2/2] update --- scripts/convert_animatediff_motion_module_to_diffusers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index 60344c4417f0..ceb967acd3d6 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -30,6 +30,7 @@ def get_args(): parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--use_motion_mid_block", action="store_true") parser.add_argument("--motion_max_seq_length", type=int, default=32) + parser.add_argument("--save_fp16", action="store_true") return parser.parse_args() @@ -48,4 +49,6 @@ def get_args(): # skip loading position embeddings adapter.load_state_dict(conv_state_dict, strict=False) adapter.save_pretrained(args.output_path) - adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16") + + if args.save_fp16: + adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")