From ec4d3d6f72a2c37d26f4f1b38016ca30113a8f26 Mon Sep 17 00:00:00 2001 From: Martin Muller Date: Sat, 24 Feb 2024 22:17:54 -0500 Subject: [PATCH 1/2] make mid block optional for flax UNet --- .../models/unets/unet_2d_condition_flax.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 39b483b61c00..4382228c57a3 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): The tuple of upsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): @@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn" only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 @@ -252,16 +255,21 @@ def setup(self) -> None: self.down_blocks = down_blocks # mid - self.mid_block = FlaxUNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - dropout=self.dropout, - num_attention_heads=num_attention_heads[-1], - transformer_layers_per_block=transformer_layers_per_block[-1], - use_linear_projection=self.use_linear_projection, - use_memory_efficient_attention=self.use_memory_efficient_attention, - split_head_dim=self.split_head_dim, - dtype=self.dtype, - ) + if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + elif self.config.mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f'Unexpected mid_block_type {self.config.mid_block_type}') # up up_blocks = [] @@ -412,7 +420,8 @@ def __call__( down_block_res_samples = new_down_block_res_samples # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + if self.mid_block is not None: + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if mid_block_additional_residual is not None: sample += mid_block_additional_residual From d8b904ab859aff76b229d8414053889f7c143d1a Mon Sep 17 00:00:00 2001 From: Martin Muller Date: Thu, 7 Mar 2024 09:31:49 -0500 Subject: [PATCH 2/2] make style --- src/diffusers/models/unets/unet_2d_condition_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py index 4382228c57a3..a5ec2875ca0e 100644 --- a/src/diffusers/models/unets/unet_2d_condition_flax.py +++ b/src/diffusers/models/unets/unet_2d_condition_flax.py @@ -269,7 +269,7 @@ def setup(self) -> None: elif self.config.mid_block_type is None: self.mid_block = None else: - raise ValueError(f'Unexpected mid_block_type {self.config.mid_block_type}') + raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") # up up_blocks = []