diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 57dd424aa4c6..2fdf70ce76b9 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -200,7 +200,7 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continous projections if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) + self.proj_out = nn.Linear(inner_dim, in_channels) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: