diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 39b483b61c00..a5ec2875ca0e 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): The tuple of upsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): @@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn" only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 @@ -252,16 +255,21 @@ def setup(self) -> None: self.down_blocks = down_blocks # mid - self.mid_block = FlaxUNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - dropout=self.dropout, - num_attention_heads=num_attention_heads[-1], - transformer_layers_per_block=transformer_layers_per_block[-1], - use_linear_projection=self.use_linear_projection, - use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - dtype=self.dtype, - ) + if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + elif self.config.mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") # up up_blocks = [] @@ -412,7 +420,8 @@ def __call__( down_block_res_samples = new_down_block_res_samples # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + if self.mid_block is not None: + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if mid_block_additional_residual is not None: sample += mid_block_additional_residual