|
74 | 74 | "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
75 | 75 | "stable_cascade_stage_c": "clip_txt_mapper.weight",
|
76 | 76 | "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
| 77 | + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.norm.weight", |
| 78 | + "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", |
| 79 | + "animatediff_sdxl_beta": "down_blocks.3.motion_modules.0.temporal_transformer.norm.bias", |
77 | 80 | }
|
78 | 81 |
|
79 | 82 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
103 | 106 | "sd3": {
|
104 | 107 | "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
105 | 108 | },
|
| 109 | + "animatediff_v2": "guoyww/animatediff-motion-adapter-v1-5-2", |
| 110 | + "animatediff_v3": "guoyww/animatediff-motion-adapter-v1-5-3", |
| 111 | + "animatediff_sdxl": "guoyww/animatediff-motion-adapter-sdxl-beta", |
106 | 112 | }
|
107 | 113 |
|
108 | 114 | # Use to configure model sample size when original config is provided
|
@@ -485,6 +491,13 @@ def infer_diffusers_model_type(checkpoint):
|
485 | 491 | elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
486 | 492 | model_type = "sd3"
|
487 | 493 |
|
| 494 | + elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: |
| 495 | + if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: |
| 496 | + model_type = "animatediff_v2" |
| 497 | + elif CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"] in checkpoint: |
| 498 | + model_type = "animatediff_sdxl_beta" |
| 499 | + else: |
| 500 | + model_type = "animatediff_v3" |
488 | 501 | else:
|
489 | 502 | model_type = "v1"
|
490 | 503 |
|
@@ -1822,3 +1835,22 @@ def create_diffusers_t5_model_from_checkpoint(
|
1822 | 1835 | param.data = param.data.to(torch.float32)
|
1823 | 1836 |
|
1824 | 1837 | return model
|
| 1838 | + |
| 1839 | + |
| 1840 | +def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 1841 | + converted_state_dict = {} |
| 1842 | + for k, v in checkpoint.items(): |
| 1843 | + if "pos_encoder" in k: |
| 1844 | + continue |
| 1845 | + |
| 1846 | + else: |
| 1847 | + converted_state_dict[ |
| 1848 | + k.replace(".norms.0", ".norm1") |
| 1849 | + .replace(".norms.1", ".norm2") |
| 1850 | + .replace(".ff_norm", ".norm3") |
| 1851 | + .replace(".attention_blocks.0", ".attn1") |
| 1852 | + .replace(".attention_blocks.1", ".attn2") |
| 1853 | + .replace(".temporal_transformer", "") |
| 1854 | + ] = v |
| 1855 | + |
| 1856 | + return converted_state_dict |
0 commit comments