Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/diffusers/models/unets/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
Expand Down Expand Up @@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
Expand Down Expand Up @@ -252,16 +255,21 @@ def setup(self) -> None:
self.down_blocks = down_blocks

# mid
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
elif self.config.mid_block_type is None:
self.mid_block = None
else:
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")

# up
up_blocks = []
Expand Down Expand Up @@ -412,7 +420,8 @@ def __call__(
down_block_res_samples = new_down_block_res_samples

# 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
if self.mid_block is not None:
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)

if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
Expand Down