diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 9cb0f42c85ef..246a4b8124d8 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -217,6 +217,7 @@ def __init__( use_motion_mid_block: int = True, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, ): super().__init__() @@ -252,9 +253,7 @@ def __init__( timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim ) if encoder_hid_dim_type is None: @@ -306,6 +305,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, temporal_num_attention_heads=motion_num_attention_heads, temporal_max_seq_length=motion_max_seq_length, ) @@ -321,6 +321,7 @@ def __init__( num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images