Skip to content

Commit e35de4b

Browse files
committed
add group norm type to attention processor cross attention norm
This lets the cross attention norm use both a group norm block and a layer norm block. The group norm operates along the channels dimension and requires input shape (batch size, channels, *) where as the layer norm with a single `normalized_shape` dimension only operates over the least significant dimension i.e. (*, channels). The channels we want to normalize are the hidden dimension of the encoder hidden states. By convention, the encoder hidden states are always passed as (batch size, sequence length, hidden states). This means the layer norm can operate on the tensor without modification, but the group norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length). All existing attention processors will have the same logic and we can consolidate it in a helper function `prepare_encoder_hidden_states` prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten move norm_cross defined check to outside norm_encoder_hidden_states add missing attn.norm_cross check
1 parent c6180a3 commit e35de4b

File tree

6 files changed

+96
-21
lines changed

6 files changed

+96
-21
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
bias=False,
5757
upcast_attention: bool = False,
5858
upcast_softmax: bool = False,
59-
cross_attention_norm: bool = False,
59+
cross_attention_norm: Optional[str] = None,
60+
cross_attention_norm_num_groups: int = 32,
6061
added_kv_proj_dim: Optional[int] = None,
6162
norm_num_groups: Optional[int] = None,
6263
out_bias: bool = True,
@@ -69,7 +70,6 @@ def __init__(
6970
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
7071
self.upcast_attention = upcast_attention
7172
self.upcast_softmax = upcast_softmax
72-
self.cross_attention_norm = cross_attention_norm
7373

7474
self.scale = dim_head**-0.5 if scale_qk else 1.0
7575

@@ -92,8 +92,28 @@ def __init__(
9292
else:
9393
self.group_norm = None
9494

95-
if cross_attention_norm:
95+
if cross_attention_norm is None:
96+
self.norm_cross = None
97+
elif cross_attention_norm == "layer_norm":
9698
self.norm_cross = nn.LayerNorm(cross_attention_dim)
99+
elif cross_attention_norm == "group_norm":
100+
if self.added_kv_proj_dim is not None:
101+
# The given `encoder_hidden_states` are initially of shape
102+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
103+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
104+
# before the projection, so we need to use `added_kv_proj_dim` as
105+
# the number of channels for the group norm.
106+
norm_cross_num_channels = added_kv_proj_dim
107+
else:
108+
norm_cross_num_channels = cross_attention_dim
109+
110+
self.norm_cross = nn.GroupNorm(
111+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
112+
)
113+
else:
114+
raise ValueError(
115+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
116+
)
97117

98118
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
99119

@@ -304,6 +324,25 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None)
304324
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
305325
return attention_mask
306326

327+
def norm_encoder_hidden_states(self, encoder_hidden_states):
328+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
329+
330+
if isinstance(self.norm_cross, nn.LayerNorm):
331+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
332+
elif isinstance(self.norm_cross, nn.GroupNorm):
333+
# Group norm norms along the channels dimension and expects
334+
# input to be in the shape of (N, C, *). In this case, we want
335+
# to norm along the hidden dimension, so we need to move
336+
# (batch_size, sequence_length, hidden_size) ->
337+
# (batch_size, hidden_size, sequence_length)
338+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
339+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
340+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
341+
else:
342+
assert False
343+
344+
return encoder_hidden_states
345+
307346

308347
class AttnProcessor:
309348
def __call__(
@@ -321,8 +360,8 @@ def __call__(
321360

322361
if encoder_hidden_states is None:
323362
encoder_hidden_states = hidden_states
324-
elif attn.cross_attention_norm:
325-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
363+
elif attn.norm_cross:
364+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
326365

327366
key = attn.to_k(encoder_hidden_states)
328367
value = attn.to_v(encoder_hidden_states)
@@ -388,7 +427,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
388427
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
389428
query = attn.head_to_batch_dim(query)
390429

391-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
430+
if encoder_hidden_states is None:
431+
encoder_hidden_states = hidden_states
432+
elif attn.norm_cross:
433+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
392434

393435
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
394436
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -416,6 +458,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
416458

417459
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
418460

461+
if encoder_hidden_states is None:
462+
encoder_hidden_states = hidden_states
463+
elif attn.norm_cross:
464+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
465+
419466
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
420467

421468
query = attn.to_q(hidden_states)
@@ -467,8 +514,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
467514

468515
if encoder_hidden_states is None:
469516
encoder_hidden_states = hidden_states
470-
elif attn.cross_attention_norm:
471-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
517+
elif attn.norm_cross:
518+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
472519

473520
key = attn.to_k(encoder_hidden_states)
474521
value = attn.to_v(encoder_hidden_states)
@@ -511,8 +558,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
511558

512559
if encoder_hidden_states is None:
513560
encoder_hidden_states = hidden_states
514-
elif attn.cross_attention_norm:
515-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
561+
elif attn.norm_cross:
562+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
516563

517564
key = attn.to_k(encoder_hidden_states)
518565
value = attn.to_v(encoder_hidden_states)
@@ -561,7 +608,10 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
561608
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
562609
query = attn.head_to_batch_dim(query).contiguous()
563610

564-
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
611+
if encoder_hidden_states is None:
612+
encoder_hidden_states = hidden_states
613+
elif attn.norm_cross:
614+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
565615

566616
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
567617
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
@@ -598,8 +648,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
598648

599649
if encoder_hidden_states is None:
600650
encoder_hidden_states = hidden_states
601-
elif attn.cross_attention_norm:
602-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
651+
elif attn.norm_cross:
652+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
603653

604654
key = attn.to_k(encoder_hidden_states)
605655
value = attn.to_v(encoder_hidden_states)
@@ -647,6 +697,11 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
647697

648698
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
649699

700+
if encoder_hidden_states is None:
701+
encoder_hidden_states = hidden_states
702+
elif attn.norm_cross:
703+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
704+
650705
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
651706

652707
query = attn.to_q(hidden_states)

src/diffusers/models/unet_2d_blocks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_down_block(
4444
resnet_time_scale_shift="default",
4545
resnet_skip_time_act=False,
4646
resnet_out_scale_factor=1.0,
47+
cross_attention_norm=None,
4748
):
4849
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
4950
if down_block_type == "DownBlock2D":
@@ -126,6 +127,7 @@ def get_down_block(
126127
skip_time_act=resnet_skip_time_act,
127128
output_scale_factor=resnet_out_scale_factor,
128129
only_cross_attention=only_cross_attention,
130+
cross_attention_norm=cross_attention_norm,
129131
)
130132
elif down_block_type == "SkipDownBlock2D":
131133
return SkipDownBlock2D(
@@ -223,6 +225,7 @@ def get_up_block(
223225
resnet_time_scale_shift="default",
224226
resnet_skip_time_act=False,
225227
resnet_out_scale_factor=1.0,
228+
cross_attention_norm=None,
226229
):
227230
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
228231
if up_block_type == "UpBlock2D":
@@ -293,6 +296,7 @@ def get_up_block(
293296
skip_time_act=resnet_skip_time_act,
294297
output_scale_factor=resnet_out_scale_factor,
295298
only_cross_attention=only_cross_attention,
299+
cross_attention_norm=cross_attention_norm,
296300
)
297301
elif up_block_type == "AttnUpBlock2D":
298302
return AttnUpBlock2D(
@@ -578,6 +582,7 @@ def __init__(
578582
cross_attention_dim=1280,
579583
skip_time_act=False,
580584
only_cross_attention=False,
585+
cross_attention_norm=None,
581586
):
582587
super().__init__()
583588

@@ -618,6 +623,7 @@ def __init__(
618623
bias=True,
619624
upcast_softmax=True,
620625
only_cross_attention=only_cross_attention,
626+
cross_attention_norm=cross_attention_norm,
621627
processor=AttnAddedKVProcessor(),
622628
)
623629
)
@@ -1361,6 +1367,7 @@ def __init__(
13611367
add_downsample=True,
13621368
skip_time_act=False,
13631369
only_cross_attention=False,
1370+
cross_attention_norm=None,
13641371
):
13651372
super().__init__()
13661373

@@ -1400,6 +1407,7 @@ def __init__(
14001407
bias=True,
14011408
upcast_softmax=True,
14021409
only_cross_attention=only_cross_attention,
1410+
cross_attention_norm=cross_attention_norm,
14031411
processor=AttnAddedKVProcessor(),
14041412
)
14051413
)
@@ -1580,7 +1588,7 @@ def __init__(
15801588
temb_channels=temb_channels,
15811589
attention_bias=True,
15821590
add_self_attention=add_self_attention,
1583-
cross_attention_norm=True,
1591+
cross_attention_norm="layer_norm",
15841592
group_size=resnet_group_size,
15851593
)
15861594
)
@@ -2361,6 +2369,7 @@ def __init__(
23612369
add_upsample=True,
23622370
skip_time_act=False,
23632371
only_cross_attention=False,
2372+
cross_attention_norm=None,
23642373
):
23652374
super().__init__()
23662375
resnets = []
@@ -2401,6 +2410,7 @@ def __init__(
24012410
bias=True,
24022411
upcast_softmax=True,
24032412
only_cross_attention=only_cross_attention,
2413+
cross_attention_norm=cross_attention_norm,
24042414
processor=AttnAddedKVProcessor(),
24052415
)
24062416
)
@@ -2608,7 +2618,7 @@ def __init__(
26082618
temb_channels=temb_channels,
26092619
attention_bias=True,
26102620
add_self_attention=add_self_attention,
2611-
cross_attention_norm=True,
2621+
cross_attention_norm="layer_norm",
26122622
upcast_attention=upcast_attention,
26132623
)
26142624
)
@@ -2703,7 +2713,7 @@ def __init__(
27032713
upcast_attention: bool = False,
27042714
temb_channels: int = 768, # for ada_group_norm
27052715
add_self_attention: bool = False,
2706-
cross_attention_norm: bool = False,
2716+
cross_attention_norm: Optional[str] = None,
27072717
group_size: int = 32,
27082718
):
27092719
super().__init__()
@@ -2719,7 +2729,7 @@ def __init__(
27192729
dropout=dropout,
27202730
bias=attention_bias,
27212731
cross_attention_dim=None,
2722-
cross_attention_norm=False,
2732+
cross_attention_norm=None,
27232733
)
27242734

27252735
# 2. Cross-Attn

src/diffusers/models/unet_2d_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def __init__(
164164
projection_class_embeddings_input_dim: Optional[int] = None,
165165
class_embeddings_concat: bool = False,
166166
mid_block_only_cross_attention: Optional[bool] = None,
167+
cross_attention_norm: Optional[str] = None,
167168
):
168169
super().__init__()
169170

@@ -323,6 +324,7 @@ def __init__(
323324
resnet_time_scale_shift=resnet_time_scale_shift,
324325
resnet_skip_time_act=resnet_skip_time_act,
325326
resnet_out_scale_factor=resnet_out_scale_factor,
327+
cross_attention_norm=cross_attention_norm,
326328
)
327329
self.down_blocks.append(down_block)
328330

@@ -355,6 +357,7 @@ def __init__(
355357
resnet_time_scale_shift=resnet_time_scale_shift,
356358
skip_time_act=resnet_skip_time_act,
357359
only_cross_attention=mid_block_only_cross_attention,
360+
cross_attention_norm=cross_attention_norm,
358361
)
359362
elif mid_block_type is None:
360363
self.mid_block = None
@@ -406,6 +409,7 @@ def __init__(
406409
resnet_time_scale_shift=resnet_time_scale_shift,
407410
resnet_skip_time_act=resnet_skip_time_act,
408411
resnet_out_scale_factor=resnet_out_scale_factor,
412+
cross_attention_norm=cross_attention_norm,
409413
)
410414
self.up_blocks.append(up_block)
411415
prev_output_channel = output_channel

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def __call__(
243243

244244
if encoder_hidden_states is None:
245245
encoder_hidden_states = hidden_states
246-
elif attn.cross_attention_norm:
247-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
246+
elif attn.norm_cross:
247+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
248248

249249
key = attn.to_k(encoder_hidden_states)
250250
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __call__(
6565

6666
if encoder_hidden_states is None:
6767
encoder_hidden_states = hidden_states
68-
elif attn.cross_attention_norm:
69-
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
68+
elif attn.norm_cross:
69+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
7070

7171
key = attn.to_k(encoder_hidden_states)
7272
value = attn.to_v(encoder_hidden_states)

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def __init__(
250250
projection_class_embeddings_input_dim: Optional[int] = None,
251251
class_embeddings_concat: bool = False,
252252
mid_block_only_cross_attention: Optional[bool] = None,
253+
cross_attention_norm: Optional[str] = None,
253254
):
254255
super().__init__()
255256

@@ -415,6 +416,7 @@ def __init__(
415416
resnet_time_scale_shift=resnet_time_scale_shift,
416417
resnet_skip_time_act=resnet_skip_time_act,
417418
resnet_out_scale_factor=resnet_out_scale_factor,
419+
cross_attention_norm=cross_attention_norm,
418420
)
419421
self.down_blocks.append(down_block)
420422

@@ -447,6 +449,7 @@ def __init__(
447449
resnet_time_scale_shift=resnet_time_scale_shift,
448450
skip_time_act=resnet_skip_time_act,
449451
only_cross_attention=mid_block_only_cross_attention,
452+
cross_attention_norm=cross_attention_norm,
450453
)
451454
elif mid_block_type is None:
452455
self.mid_block = None
@@ -498,6 +501,7 @@ def __init__(
498501
resnet_time_scale_shift=resnet_time_scale_shift,
499502
resnet_skip_time_act=resnet_skip_time_act,
500503
resnet_out_scale_factor=resnet_out_scale_factor,
504+
cross_attention_norm=cross_attention_norm,
501505
)
502506
self.up_blocks.append(up_block)
503507
prev_output_channel = output_channel
@@ -1490,6 +1494,7 @@ def __init__(
14901494
cross_attention_dim=1280,
14911495
skip_time_act=False,
14921496
only_cross_attention=False,
1497+
cross_attention_norm=None,
14931498
):
14941499
super().__init__()
14951500

@@ -1530,6 +1535,7 @@ def __init__(
15301535
bias=True,
15311536
upcast_softmax=True,
15321537
only_cross_attention=only_cross_attention,
1538+
cross_attention_norm=cross_attention_norm,
15331539
processor=AttnAddedKVProcessor(),
15341540
)
15351541
)

0 commit comments

Comments
 (0)