From 7b6c0b6c6bf1e8968db99edcf34dba96620aec80 Mon Sep 17 00:00:00 2001 From: Ernie Chu <51432514+ernestchu@users.noreply.github.com> Date: Thu, 6 Jun 2024 00:26:34 +0800 Subject: [PATCH 1/2] add conv_in_channels in def save_motion_modules current implementation doesn't pass conv_in_channels to the MotionAdapter constructor. This would incorrectly result into "conv_in_channels": null in the saved config.json, disregarding the actual conv_in_channels. --- src/diffusers/models/unets/unet_motion_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index b224d9d73317..7c1aedd23ebf 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -613,6 +613,7 @@ def save_motion_modules( motion_num_attention_heads=self.config["motion_num_attention_heads"], motion_max_seq_length=self.config["motion_max_seq_length"], use_motion_mid_block=self.config["use_motion_mid_block"], + conv_in_channels=self.config["conv_in_channels"], ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained( From af4b3921fcf95cfde8d00ef297d21de9ccb6040f Mon Sep 17 00:00:00 2001 From: Ernie Chu <51432514+ernestchu@users.noreply.github.com> Date: Thu, 6 Jun 2024 00:47:27 +0800 Subject: [PATCH 2/2] fix typo --- src/diffusers/models/unets/unet_motion_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 7c1aedd23ebf..c8631dcafe97 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -613,7 +613,7 @@ def save_motion_modules( motion_num_attention_heads=self.config["motion_num_attention_heads"], motion_max_seq_length=self.config["motion_max_seq_length"], use_motion_mid_block=self.config["use_motion_mid_block"], - conv_in_channels=self.config["conv_in_channels"], + conv_in_channels=self.config["in_channels"], ) adapter.load_state_dict(motion_state_dict) adapter.save_pretrained(