-
Notifications
You must be signed in to change notification settings - Fork 599
/
attention.py
840 lines (760 loc) · 31.9 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention core modules for Flax."""
from __future__ import annotations
import functools
import warnings
from typing import Any, Callable, Optional, Union, overload
import jax
import jax.numpy as jnp
from jax import lax, random
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.linear import (
DenseGeneral,
default_kernel_init,
)
from flax.linen.module import Module, compact, merge_param
from flax.linen.normalization import LayerNorm
from flax.typing import (
Array,
PRNGKey,
Dtype,
Shape as Shape,
Initializer,
PrecisionLike,
DotGeneralT,
)
def dot_product_attention_weights(
query: Array,
key: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
module: Optional[Module] = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
"""Computes dot-product attention weights given query and key.
Used by :func:`dot_product_attention`, which is what you'll most likely use.
But if you want access to the attention weights for introspection, then
you can directly call this function and call einsum yourself.
Args:
query: queries for calculating attention with shape of ``[batch...,
q_length, num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``False``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
module: the Module that will sow the attention weights into the
'intermediates' collection. Remember to mark 'intermediates' as mutable
via ``mutable=['intermediates']`` in order to have that collection
returned. If ``module`` is None, the attention weights will not be sowed.
force_fp32_for_softmax: bool, whether to force the softmax to be computed in
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.
Returns:
Output of shape ``[batch..., num_heads, q_length, kv_length]``.
"""
query, key = promote_dtype(query, key, dtype=dtype)
dtype = query.dtype
assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk',
query,
key,
precision=precision,
_dot_general=einsum_dot_general,
)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(dtype).min
attn_weights = jnp.where(mask, attn_weights, big_neg)
# normalize the attention weights
if force_fp32_for_softmax and dtype != jnp.float32:
attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32))
else:
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
if module:
module.sow('intermediates', 'attention_weights', attn_weights)
# apply attention dropout
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch + head dimensions
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
attn_weights = attn_weights * multiplier
return attn_weights
def dot_product_attention(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
mask: Optional[Array] = None,
broadcast_dropout: bool = True,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
deterministic: bool = False,
dtype: Optional[Dtype] = None,
precision: PrecisionLike = None,
module: Optional[Module] = None,
force_fp32_for_softmax: bool = False,
einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
.. note::
``query``, ``key``, ``value`` needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of ``[batch...,
q_length, num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
value: values to be used in attention with shape of ``[batch..., kv_length,
num_heads, v_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
mask: mask for the attention weights. This should be broadcastable to the
shape ``[batch..., num_heads, q_length, kv_length]``. This can be used for
incorporating causal masks. Attention weights are masked out if their
corresponding mask value is ``False``.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see ``jax.lax.Precision`
for details.
module: the Module that will sow the attention weights into the
'intermediates' collection. Remember to mark 'intermediates' as mutable
via ``mutable=['intermediates']`` in order to have that collection
returned. If ``module`` is None, the attention weights will not be sowed.
force_fp32_for_softmax: bool, whether to force the softmax to be computed in
fp32. This is useful for mixed-precision training where higher precision
is desired for numerical stability.
einsum_dot_general: the dot_general to use in einsum.
Returns:
Output of shape ``[batch..., q_length, num_heads, v_depth_per_head]``.
"""
query, key, value = promote_dtype(query, key, value, dtype=dtype)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
attn_weights = dot_product_attention_weights(
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
deterministic,
dtype,
precision,
module,
force_fp32_for_softmax,
einsum_dot_general=einsum_dot_general,
)
# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...khd->...qhd',
attn_weights,
value,
precision=precision,
_dot_general=einsum_dot_general,
)
class MultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.
Example usage::
>>> import flax.linen as nn
>>> import jax
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)
>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
Attributes:
num_heads: Number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: The dtype of the computation (default: infer from inputs and params)
param_dtype: The dtype passed to parameter initializers (default: float32)
qkv_features: Dimension of the key, query, and value.
out_features: Dimension of the last projection
broadcast_dropout: Use a broadcasted dropout along batch dims.
dropout_rate: Dropout rate.
deterministic: If False, the attention weight is masked randomly using
dropout, whereas if True, the attention weights are deterministic.
precision: Numerical precision of the computation see ``jax.lax.Precision``
for details.
kernel_init: Initializer for the kernel of the Dense layers.
out_kernel_init: Optional Initializer for the kernel of the output Dense layer,
if None, ``kernel_init`` will be used.
bias_init: Initializer for the bias of the Dense layers.
out_bias_init: Optional Initializer for the bias of the output Dense layer,
if None, ``bias_init`` will be used.
use_bias: Whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: Whether to prepare and use an autoregressive cache.
normalize_qk: Should QK normalization be applied (arxiv.org/abs/2302.05442).
"""
num_heads: int
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
qkv_features: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.0
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Initializer = default_kernel_init
out_kernel_init: Initializer | None = None
bias_init: Initializer = initializers.zeros_init()
out_bias_init: Initializer | None = None
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
normalize_qk: bool = False
force_fp32_for_softmax: bool = False
# Deprecated, will be removed.
qkv_dot_general: Optional[DotGeneralT] = None
out_dot_general: Optional[DotGeneralT] = None
qkv_dot_general_cls: Any = None
out_dot_general_cls: Any = None
@overload
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
sow_weights: bool = False,
):
...
@overload
def __call__(
self,
inputs_q: Array,
*,
inputs_kv: Optional[Array] = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
sow_weights: bool = False,
):
...
@compact
def __call__(
self,
inputs_q: Array,
inputs_k: Optional[Array] = None,
inputs_v: Optional[Array] = None,
*,
inputs_kv: Optional[Array] = None,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
sow_weights: bool = False,
):
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of
inputs_q (self attention).
If only inputs_v is None, it will copy the value of inputs_k.
Args:
inputs_q: input queries of shape ``[batch_sizes..., length, features]``.
inputs_k: key of shape ``[batch_sizes..., length, features]``. If None,
inputs_k will copy the value of inputs_q.
inputs_v: values of shape ``[batch_sizes..., length, features]``. If None,
inputs_v will copy the value of inputs_k.
inputs_kv: key/values of shape ``[batch_sizes..., length, features]``. If
None, inputs_kv will copy the value of inputs_q. This arg will be
deprecated soon. Use inputs_k and inputs_v instead.
mask: attention mask of shape ``[batch_sizes..., num_heads, query_length,
key/value_length]``. Attention weights are masked out if their
corresponding mask value is ``False``.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
dropout_rng: optional rng key to pass to the attention layer's dropout
mask. Otherwise, self.make_rng('dropout') is used instead.
sow_weights: if ``True``, the attention weights are sowed into the
'intermediates' collection. Remember to mark 'intermediates' as
mutable via ``mutable=['intermediates']`` in order to have that
collection returned.
Returns:
output of shape ``[batch_sizes..., length, features]``.
"""
if inputs_kv is not None:
if inputs_k is not None or inputs_v is not None:
raise ValueError(
'If either `inputs_k` or `inputs_v` is not None, '
'`inputs_kv` must be None. If `inputs_kv` is not None, both `inputs_k` '
'and `inputs_v` must be None. We recommend using `inputs_k` and '
'`inputs_v` args, since `inputs_kv` will be deprecated soon. See '
'https://github.com/google/flax/discussions/3389 for more '
'information.'
)
inputs_k = inputs_v = inputs_kv
warnings.warn(
'The inputs_kv arg will be deprecated soon. '
'Use inputs_k and inputs_v instead. See '
'https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning,
)
else:
if inputs_k is None:
if inputs_v is not None:
raise ValueError(
'`inputs_k` cannot be None if `inputs_v` is not None. '
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
'value to `inputs_k` and leave `inputs_v` as None.'
)
inputs_k = inputs_q
if inputs_v is None:
inputs_v = inputs_k
elif inputs_v.shape[-1] == inputs_v.shape[-2]:
warnings.warn(
f'You are passing an array of shape {inputs_v.shape} '
'to the `inputs_v` arg, when you may have intended '
'to pass it to the `mask` arg. As of Flax version '
'0.7.4, the function signature of '
"MultiHeadDotProductAttention's `__call__` method "
'has changed to `__call__(inputs_q, inputs_k=None, '
'inputs_v=None, *, inputs_kv=None, mask=None, '
'deterministic=None)`. Use the kwarg `mask` instead. '
'See https://github.com/google/flax/discussions/3389 '
'and read the docstring for more information.',
DeprecationWarning,
)
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
assert qkv_features % self.num_heads == 0, (
f'Memory dimension ({qkv_features}) must be divisible by number of'
f' heads ({self.num_heads}).'
)
head_dim = qkv_features // self.num_heads
dense = functools.partial(
DenseGeneral,
axis=-1,
dtype=self.dtype,
param_dtype=self.param_dtype,
features=(self.num_heads, head_dim),
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
precision=self.precision,
dot_general=self.qkv_dot_general,
dot_general_cls=self.qkv_dot_general_cls,
)
# project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, n_heads, n_features_per_head]
query, key, value = (
dense(name='query')(inputs_q),
dense(name='key')(inputs_k),
dense(name='value')(inputs_v),
)
if self.normalize_qk:
# Normalizing query and key projections stabilizes training with higher
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
query = LayerNorm(
name='query_ln',
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(query) # type: ignore[call-arg]
key = LayerNorm(
name='key_ln',
use_bias=False,
dtype=self.dtype,
param_dtype=self.param_dtype,
)(key) # type: ignore[call-arg]
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.decode:
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
cached_key = self.variable(
'cache', 'cached_key', jnp.zeros, key.shape, key.dtype
)
cached_value = self.variable(
'cache', 'cached_value', jnp.zeros, value.shape, value.dtype
)
cache_index = self.variable(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
'expected query shape %s instead got %s.'
% (expected_shape, query.shape)
)
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
indices: tuple[Union[int, jax.Array], ...] = (zero,) * len(
batch_dims
) + (
cur_index,
zero,
zero,
)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
# causal mask for cached decoder self-attention:
# our single query position should only attend to those key
# positions that have already been generated and cached,
# not the remaining zero elements.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(max_length) <= cur_index,
tuple(batch_dims) + (1, 1, max_length),
),
)
if (
self.dropout_rate > 0.0
): # Require `deterministic` only if using dropout.
m_deterministic = merge_param(
'deterministic', self.deterministic, deterministic
)
if not m_deterministic and dropout_rng is None:
dropout_rng = self.make_rng('dropout')
else:
m_deterministic = True
# apply attention
if sow_weights:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
module=self,
) # pytype: disable=wrong-keyword-args
else:
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=m_deterministic,
dtype=self.dtype,
precision=self.precision,
)
# back to the original inputs dimensions
out = DenseGeneral(
features=features,
axis=(-2, -1),
kernel_init=self.out_kernel_init or self.kernel_init,
bias_init=self.out_bias_init or self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
dot_general=self.out_dot_general,
dot_general_cls=self.out_dot_general_cls,
name='out', # type: ignore[call-arg]
)(x)
return out
class MultiHeadAttention(MultiHeadDotProductAttention):
"""Multi-head dot-product attention.
Alias for ``MultiHeadDotProductAttention``.
**NOTE**: ``MultiHeadAttention`` is a wrapper of ``MultiHeadDotProductAttention``,
and so their implementations are identical. However ``MultiHeadAttention`` layers
will, by default, be named ``MultiHeadAttention_{index}``, whereas ``MultiHeadDotProductAttention``
will be named ``MultiHeadDotProductAttention_{index}``. Therefore, this could affect
checkpointing, param collection names and RNG threading (since the layer name is
used when generating new RNG's) within the module.
Example usage::
>>> import flax.linen as nn
>>> import jax
>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)
>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
dtype: the dtype of the computation (default: infer from inputs and params)
param_dtype: the dtype passed to parameter initializers (default: float32)
qkv_features: dimension of the key, query, and value.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
precision: numerical precision of the computation see ``jax.lax.Precision``
for details.
kernel_init: initializer for the kernel of the Dense layers.
bias_init: initializer for the bias of the Dense layers.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape ``[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
decode: whether to prepare and use an autoregressive cache.
normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442).
"""
class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention.
This layer is deprecated in favor of ``MultiHeadDotProductAttention``.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
"""
@compact
def __call__( # type: ignore
self,
inputs_q: Array,
mask: Optional[Array] = None,
deterministic: Optional[bool] = None,
dropout_rng: Optional[PRNGKey] = None,
sow_weights: bool = False,
):
"""Applies multi-head dot product self-attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Args:
inputs_q: input queries of shape ``[batch_sizes..., length, features]``.
mask: attention mask of shape ``[batch_sizes..., num_heads, query_length,
key/value_length]``. Attention weights are masked out if their
corresponding mask value is ``False``.
deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
Returns:
output of shape ``[batch_sizes..., length, features]``.
"""
warnings.warn(
'SelfAttention will be deprecated soon. Use '
'`MultiHeadDotProductAttention.__call__(inputs_q)` instead. '
'See https://github.com/google/flax/discussions/3389 '
'for more information.',
DeprecationWarning,
)
return super().__call__(
inputs_q,
mask=mask,
deterministic=deterministic,
dropout_rng=dropout_rng,
sow_weights=sow_weights,
)
# mask-making utility functions
def make_attention_mask(
query_input: Array,
key_input: Array,
pairwise_fn: Callable[..., Any] = jnp.multiply,
extra_batch_dims: int = 0,
dtype: Dtype = jnp.float32,
):
"""Mask-making helper for attention weights.
In case of 1d inputs (i.e., ``[batch..., len_q]``, ``[batch..., len_kv]``, the
attention weights will be ``[batch..., heads, len_q, len_kv]`` and this
function will produce ``[batch..., 1, len_q, len_kv]``.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
extra_batch_dims: number of extra batch dims to add singleton axes for, none
by default
dtype: mask return dtype
Returns:
A ``[batch..., 1, len_q, len_kv]`` shaped mask for 1d attention.
"""
mask = pairwise_fn(
jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
)
mask = jnp.expand_dims(mask, axis=-3)
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
return mask.astype(dtype)
def make_causal_mask(
x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
) -> Array:
"""Make a causal mask for self-attention.
In case of 1d inputs (i.e., ``[batch..., len]``, the self-attention weights
will be ``[batch..., heads, len, len]`` and this function will produce a
causal mask of shape ``[batch..., 1, len, len]``.
Args:
x: input array of shape ``[batch..., len]``
extra_batch_dims: number of batch dims to add singleton axes for, none by
default
dtype: mask return dtype
Returns:
A ``[batch..., 1, len, len]`` shaped causal mask for 1d attention.
"""
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
return make_attention_mask(
idxs,
idxs,
jnp.greater_equal,
extra_batch_dims=extra_batch_dims,
dtype=dtype,
)
def combine_masks(
*masks: Optional[Array], dtype: Dtype = jnp.float32
) -> Optional[Array]:
"""Combine attention masks.
Args:
*masks: set of attention mask arguments to combine, some can be None.
dtype: dtype for the returned mask.
Returns:
Combined mask, reduced by logical and, returns None if no masks given.
"""
masks_list = [m for m in masks if m is not None]
if not masks_list:
return None
assert all(
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
mask, *other_masks = masks_list
for other_mask in other_masks:
mask = jnp.logical_and(mask, other_mask)
return mask.astype(dtype)