Skip to content

Commit f20ff26

Browse files
committed
update
1 parent 03f101e commit f20ff26

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

scripts/convert_animatediff_motion_lora_to_diffusers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ def get_args():
2727
parser = argparse.ArgumentParser()
2828
parser.add_argument("--ckpt_path", type=str, required=True)
2929
parser.add_argument("--output_path", type=str, required=True)
30-
parser.add_argument("--use_safetensors", type=str, action="store_true")
3130

3231
return parser.parse_args()
3332

3433

3534
if __name__ == "__main__":
3635
args = get_args()
3736

38-
if args.use_safetensors:
37+
if args.ckpt_path.endswith(".safetensors"):
3938
state_dict = load_file(args.ckpt_path)
4039
else:
4140
state_dict = torch.load(args.ckpt_path, map_location="cpu")

scripts/convert_animatediff_motion_module_to_diffusers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@ def get_args():
3232
parser.add_argument("--use_motion_mid_block", action="store_true")
3333
parser.add_argument("--motion_max_seq_length", type=int, default=32)
3434
parser.add_argument("--save_fp16", action="store_true")
35-
parser.add_argument("--use_safetensors", action="store_true")
3635

3736
return parser.parse_args()
3837

3938

4039
if __name__ == "__main__":
4140
args = get_args()
4241

43-
if args.use_safetensors:
42+
if args.ckpt_path.endswith(".safetensors"):
4443
state_dict = load_file(args.ckpt_path)
4544
else:
4645
state_dict = torch.load(args.ckpt_path, map_location="cpu")

0 commit comments

Comments
 (0)