/
dense_attention.py
2376 lines (2153 loc) · 90.4 KB
/
dense_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dense attention classes and mask/weighting functions."""
# pylint: disable=attribute-defined-outside-init,g-bare-generic
import abc
import functools
from typing import Callable, Optional, Tuple, Union
from aqt.jax_legacy.jax import flax_layers as aqt_flax_layers
from aqt.jax_legacy.jax import quant_config as aqt_config
from aqt.jax_legacy.jax import quantization as aqt
import chex
from flax import linen as nn
from flax.core import variables
from flax.linen import initializers
from flax.linen import partitioning as flax_partitioning
from flax.linen.linear import default_kernel_init
from flax.training import common_utils
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
from flaxformer import activation_partitioning
from flaxformer.components import dense
from flaxformer.components import embedding
from flaxformer.types import Array
from flaxformer.types import DType
from flaxformer.types import Initializer
from flaxformer.types import PRNGKey
RulesFallback = flax_partitioning.RulesFallback
def _softmax_with_extra_logit(
x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
) -> Array:
"""Softmax function with an additional virtual logit equal to zero.
For compatibility with some previously trained models.
This is equivalent to adding one to the denominator.
In the context of attention, it allows you to attend to nothing.
Args:
x: input to softmax
axis: the axis or axes along which the softmax should be computed. Either an
integer or a tuple of integers.
Returns:
A tensor with the same shape as x.
"""
m = jnp.maximum(lax.stop_gradient(x.max(axis, keepdims=True)), 0)
unnormalized = jnp.exp(x - m)
# After shift, extra logit is -m. Add exp(-m) to denominator
denom = unnormalized.sum(axis, keepdims=True) + jnp.exp(-m)
return unnormalized / denom
# ------------------------------------------------------------------------------
# Fast attention layers.
# ------------------------------------------------------------------------------
def dot_product_attention_weights(
query: Array,
key: Array,
bias: Optional[Array] = None,
broadcast_dropout: bool = True,
rescale_logits: bool = False,
rescale_weights: bool = False,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
enable_dropout: bool = True,
dtype: DType = jnp.float32,
precision: Optional[lax.Precision] = None,
use_extra_logit: bool = False,
float32_logits: bool = False,
) -> Array:
"""Computes dot-product attention weights given query and key.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key.
Note: query and key needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
rescale_logits: bool. Whether to rescale `query` logits by 1/sqrt(depth_kq).
rescale_weights: bool. Whether to rescale attention weights by
1/sqrt(depth_kq). This is applied before bias and softmax.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
enable_dropout: bool, whether to apply dropout
dtype: the dtype of the computation (default: float32)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
use_extra_logit: whether to include a virtual extra logit equal to zero.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
Returns:
Attention weights of shape `[batch..., num_heads, q_length, kv_length]`.
"""
assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
if rescale_logits and rescale_weights:
raise ValueError(
'Only one of rescale_logits or rescale_weights may be True.'
)
depth = query.shape[-1]
# Calculate attention matrix.
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
if rescale_logits:
query = query / jnp.sqrt(depth).astype(dtype)
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights` shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...khd->...hqk', query, key, precision=precision
)
if rescale_weights:
attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights.
attn_weights = (
_softmax_with_extra_logit if use_extra_logit else jax.nn.softmax
)(attn_weights).astype(dtype)
# Apply attention dropout.
if enable_dropout and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
keep_prob, dtype=dtype
)
attn_weights = attn_weights * multiplier
return attn_weights
def apply_dot_product_attention_weights_to_values(
attention_weights: Array,
value: Array,
precision: Optional[lax.Precision] = None,
) -> Array:
"""Applies the attention weights to the values.
Args:
attention_weights: The attention weights, e.g., computed by
dot_product_attention_weights.
value: The values to apply the attention to.
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Returns:
The weighted sum over values for each query position.
"""
return jnp.einsum(
'...hqk,...khd->...qhd', attention_weights, value, precision=precision
)
def dot_product_attention(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
broadcast_dropout: bool = True,
rescale_logits: bool = False,
rescale_weights: bool = False,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
enable_dropout: bool = True,
dtype: DType = jnp.float32,
precision: Optional[lax.Precision] = None,
use_extra_logit: bool = False,
float32_logits: bool = False,
):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
num_heads, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
rescale_logits: bool. Whether to rescale `query` logits by 1/sqrt(depth_kq).
rescale_weights: bool. Whether to rescale attention weights by
1/sqrt(depth_kq). This is applied before bias and softmax.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
enable_dropout: bool, whether to apply dropout
dtype: the dtype of the computation (default: float32)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
use_extra_logit: whether to include a virtual extra logit equal to zero.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
Returns:
Output of shape `[batch..., length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert (
query.shape[-2] == key.shape[-2] == value.shape[-2]
), 'q, k, v num_heads must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
attn_weights = dot_product_attention_weights(
query,
key,
bias=bias,
broadcast_dropout=broadcast_dropout,
rescale_logits=rescale_logits,
rescale_weights=rescale_weights,
dropout_rng=dropout_rng,
dropout_rate=dropout_rate,
enable_dropout=enable_dropout,
dtype=dtype,
precision=precision,
use_extra_logit=use_extra_logit,
float32_logits=float32_logits,
)
return apply_dot_product_attention_weights_to_values(
attn_weights, value, precision=precision
)
def dot_product_attention_multiquery(
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
broadcast_dropout: bool = True,
rescale_logits: bool = False,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.0,
enable_dropout: bool = True,
dtype: DType = jnp.float32,
precision: Optional[lax.Precision] = None,
use_extra_logit: bool = False,
float32_logits: bool = False,
) -> Array:
"""Computes dot-product multiquery-attention given query, key, and value.
This is a variant of the multi-head dot product attention introduced in
https://arxiv.org/abs/1706.03762 and implemented in `dot_product_attention`.
In this function, the key and the value have 1 head whereas query has 1 or
more heads. This variant is called "multi-query" attention.
It calculates the attention weights given query and key and combines the
values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
rescale_logits: bool. Whether to rescale `query` logits by 1/sqrt(depth_kq).
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
enable_dropout: bool, whether to apply dropout
dtype: the dtype of the computation (default: float32)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
use_extra_logit: whether to include a virtual extra logit equal to zero.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
Returns:
Output of shape `[batch..., length, num_heads, v_depth_per_head]`.
"""
assert (
key.ndim == value.ndim
), f'k, v must have same rank. key: {key.shape}, value: {value.shape}'
assert (
query.shape[:-3] == key.shape[:-2] == value.shape[:-2]
), f'q, k, v batch dims must match. query: {query.shape}'
assert key.shape[-2] == value.shape[-2], 'k, v lengths must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# calculate attention matrix
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
if rescale_logits:
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# attn weight shape is (batch..., num_heads, q_length, kv_length)
attn_weights = jnp.einsum(
'...qhd,...kd->...hqk', query, key, precision=precision
)
# apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# normalize the attention weights
attn_weights = (
_softmax_with_extra_logit if use_extra_logit else jax.nn.softmax
)(attn_weights).astype(dtype)
# apply attention dropout
if enable_dropout and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_shape[-2] = 1
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
keep = jnp.broadcast_to(keep, attn_weights.shape)
else:
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
keep_prob, dtype=dtype
)
attn_weights = attn_weights * multiplier
# return weighted sum over values for each query position
return jnp.einsum(
'...hqk,...kd->...qhd', attn_weights, value, precision=precision
)
class DenseAttention(metaclass=abc.ABCMeta):
"""API for attention classes that compute a full key-query attention matrix.
This allows for 2D matrices masking or re-weighting the attention between
specific key/query pairs.
"""
@abc.abstractmethod
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
precomputed_qkv: Optional[Array] = None,
decode: bool = False,
enable_dropout: bool = True,
) -> Array:
"""Applies attention on the input data.
Args:
inputs_q: input queries of shape `[batch_sizes..., q_length, q_features]`.
inputs_kv: key/values of shape `[batch_sizes..., kv_length, kv_features]`.
mask: attention mask of shape `[batch_sizes..., num_heads, q_length,
kv_length]`.
bias: attention bias of shape `[batch_sizes..., num_heads, q_length,
kv_length]`.
precomputed_qkv: when using fused implementations QKVO are defined outside
this module and we only use the module to run computations.
decode: Whether to prepare and use an autoregressive cache.
enable_dropout: Enables dropout if set to True.
Returns:
output of shape `[batch_sizes..., length, features]`.
"""
class MultiHeadDotProductAttention(nn.Module, DenseAttention):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
use_bias: bool: whether pointwise QKVO dense transforms use bias.
dtype: the dtype of the computation (default: float32)
qkv_features: dimension of the key, query, and value.
head_dim: dimension of each head. If unspecified, it defaults to
qkv_features // num_heads.
out_features: dimension of the last projection
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rate: dropout rate
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the kernel of the Dense layers.
qkv_kernel_init: optional initializer for the fused qkv kernel. If None,
kernel_init will be used instead.
kv_kernel_init: optional initializer for the fused kv kernel. If None,
kernel_init will be used instead.
q_kernel_init: optional initializer for the query (q) kernel. If None,
kernel_init will be used instead.
bias_init: initializer for the bias of the Dense layers.
attention_fn: dot_product_attention or compatible function. Accepts query,
key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,,
num_heads, value_channels]``
use_extra_logit: whether to include a virtual extra logit equal to zero.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
output_projection: Project the output of `attention_fn` to `out_features`.
If False, returns the output of `attention_fn` without a projection.
sow_intermediates: whether to track intermediates using Module.sow.
split_head_kernel: whether to store QKVO variables with a split head
dimension.
kernels_to_fuse: Which kernels to fuse, if any.
use_rotary_embedding: whether to use rotary embeddings.
"""
num_heads: int
use_bias: bool
dtype: DType = jnp.float32
qkv_features: Optional[int] = None
head_dim: Optional[int] = None
out_features: Optional[int] = None
broadcast_dropout: bool = True
dropout_rate: float = 0.0
precision: Optional[lax.Precision] = None
kernel_init: Initializer = (
default_kernel_init # pytype: disable=annotation-type-mismatch # jax-types
)
qkv_kernel_init: Optional[Initializer] = None
kv_kernel_init: Optional[Initializer] = None
q_kernel_init: Optional[Initializer] = None
bias_init: Initializer = initializers.zeros
rescale_logits: bool = False
attention_fn: Callable[[Array, Array, Array], Array] = staticmethod(
dot_product_attention
)
use_extra_logit: bool = False
float32_logits: bool = False
output_projection: bool = True
# TODO: Remove out_features and output_projection.
sow_intermediates: bool = False
split_head_kernel: bool = False
kernels_to_fuse: Optional[str] = None
use_rotary_embedding: bool = False
rotary_embedding_max_timescale: float = 1e4
# Whether to shard over the head dimension, setting this to False when the
# number of heads is not divisible your activation num_partitions
sharding_over_head_dimension: bool = True
q_conv: Optional[nn.Module] = None
k_conv: Optional[nn.Module] = None
v_conv: Optional[nn.Module] = None
dense_general_factory: Callable[..., nn.Module] = dense.DenseGeneral
def update_cache_prefill(
self,
key: Array,
value: Array,
cached_key: variables.Variable,
cached_value: variables.Variable,
cache_index: variables.Variable,
prefill_lengths: Array,
) -> Tuple[Array, Array, Array, Array, Array, Array]:
"""Update the autoregressive cache for multiple timesteps at once.
This is useful for things like a prefix-lm where the encoder section of the
input is visible bidirectionally. The key and value for this section need to
be computed in a single shot, as a step by step approach would result in
causal attention.
Args:
key: The calculated key used in attention. [batch..., length, num_heads,
features_per_head]
value: The calculated value used in attention. [batch..., length,
num_heads, features_per_head]
cached_key: The cache of previous keys. [batch..., num_heads,
features_per_head, length]
cached_value: The cache of previous values. [batch..., num_heads,
features_per_head, length]
cache_index: The timestep that we are currently calculating the key and
value for. [batch]
prefill_lengths: The number of timesteps we should fill in the cache.
[batch]
Returns:
The key, value, and the last timestep we just filled in the cache.
We also return the new cache values for now because assigning to a
variable inside of a method doesn't work. These returns will be removed
eventually.
"""
# Make a reference to the data underlaying the variable for ease of
# use.
cache_index.value = prefill_lengths
# Note, the cache index is now a vector
# of batch size so that each example can start just after it's
# prefix which can be different lengths for different examples.
cur_index = cache_index.value
# Move the sequence dimension to the end to match the cache shapes.
key_cached = jnp.moveaxis(key, -3, -1)
value_cached = jnp.moveaxis(value, -3, -1)
# Reshape the index so the batch is at the beginning, default
# broadcasting behavior is to add singleton dims to the front but
# we need them at the end.
batch_first_index = jnp.reshape(
cur_index, (-1,) + tuple(1 for _ in range(cached_key.value.ndim - 1))
)
# Calculate a mask that will set any position past the prefix to zero
# when applied to the key.
key_mask = (
lax.broadcasted_iota(
jnp.int32, cached_key.value.shape, cached_key.value.ndim - 1
)
< batch_first_index
)
value_mask = (
lax.broadcasted_iota(
jnp.int32, cached_value.value.shape, cached_value.value.ndim - 1
)
< batch_first_index
)
# Set the caches with the calculated key and values but hide anything
# past the prefix.
cached_key_value = key_cached * key_mask
cached_value_value = value_cached * value_mask
return (
key,
value,
cur_index,
cached_key_value,
cached_value_value,
prefill_lengths,
)
def update_cache_decode(
self,
key: Array,
value: Array,
cached_key: variables.Variable,
cached_value: variables.Variable,
cache_index: variables.Variable,
) -> Tuple[Array, Array, Array, Array, Array, Array]:
"""Update the next timestep in the autoregressive cache.
This is used during step by step decoding where each key and value we get
are a single (the next) timestep.
Args:
key: The calculated key used in attention. [batch..., 1, num_heads,
features_per_head]
value: The calculated value used in attention. [batch..., 1, num_heads,
features_per_head]
cached_key: The cache of previous keys. [batch..., num_heads,
features_per_head, length]
cached_value: The cache of previous values. [batch..., num_heads,
features_per_head, length]
cache_index: The timestep that we are currently calculating the key and
value for. [batch] if we are decoding after doing a prefill or [1] if we
are starting with step-by-step decoding.
Returns:
The key, value, and the last timestep we just filled in the cache. Note:
this index is the last timestep we just fill, the actual value of the
`cache_index` is already increased to point to the next timestep to fill.
We also return the new cache values for now because assigning to a
variable inside of a method doesn't work. These returns will be removed
eventually.
"""
cache_length = cached_key.value.shape[-1]
# Create a OHE of the current index. NOTE: the index is increased
# below.
# Note: We reshape the index into a column vector so that it will work
# if the index is a scalar or a vector with different cache positions
# from different elements in a batch.
cur_index = jnp.reshape(cache_index.value, (-1,))
one_hot_indices = jax.nn.one_hot(cur_index, cache_length, dtype=key.dtype)
# In order to update the key, value caches with the current key and
# value, we move the length axis to the back, similar to what we did
# for the cached ones above.
# Note these are currently the key and value of a single position,
# since we feed one position at a time.
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
# The one hot indices are now either [1, length] for a scalar index or
# [batch size, length] for examples where there are different lengths
# of prefixes. We need to add dims for num_heads and num_features as
# broadcasting doesn't work for the batched version.
one_hot_indices = jnp.expand_dims(
jnp.expand_dims(one_hot_indices, axis=1), axis=1
)
# Update key, value caches with our new 1d spatial slices.
# We implement an efficient scatter into the cache via one-hot
# broadcast and addition.
# Key/Value have seq lengths of 1 while one_hot has a seq_length
# of length. key/value will broadcast their value to each timestep
# and the onehot will mask all but the correct timesteps.
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key_value = key
cached_value_value = value
cache_index_value = cache_index.value + 1
# Move the keys and values back to their original shapes.
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
return (
key,
value,
cur_index,
cached_key_value,
cached_value_value,
cache_index_value,
)
@nn.compact
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
precomputed_qkv: Optional[Array] = None,
decode: bool = False,
enable_dropout: bool = True,
prefill: bool = False,
prefill_lengths: Optional[Array] = None,
) -> Array:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode`.
During decoding mode, this method is called twice, by `init` and
`apply`. In the former, inputs_q: [batch..., length, qkv_features] and
inputs_kv: [batch..., length, qkv_features]
During apply, query, key and value all have the shape: [batch * beam, 1,
qkv_features] where the batch dimension is added to include multiple beams.
Note that the batch dimension is different during the init and apply calls.
This is because the cached variables are directly passed-in during `apply`
method. In other words, the cache variables such as `cached_key` are
initialized with `batch` dim, expanded by tiling in the beam search function
to `batch * beam` dimension, and passed to the `apply` method as part of a
variable dict.
Args:
inputs_q: input queries of shape `[batch_sizes..., q_length, q_features]`.
inputs_kv: key/values of shape `[batch_sizes..., kv_length, kv_features]`.
mask: attention mask of shape `[batch_sizes..., num_heads, q_length,
kv_length]`.
bias: attention bias of shape `[batch_sizes..., num_heads, q_length,
kv_length]`.
precomputed_qkv: when using fused implementations QKVO are defined outside
this module and we only use the module to run computations.
decode: Whether to prepare and use an autoregressive cache.
enable_dropout: Enables dropout if set to True.
prefill: Whether to run a partial sequence to prefill the cache.
prefill_lengths: The length of each partial sequence we are filling in the
cache, lengths are inferred from the mask if not provided.
Returns:
If output_projection is True, then output of shape
`[batch_sizes..., length, out_features]`, where out_features is set to
features if not provided. If output_projection is False, then output of
shape `[batch_sizes..., length, num_heads, head_dim]`.
"""
validate_dense_attention_call_parameter_shapes(
inputs_q, inputs_kv, mask, bias, self.num_heads
)
qkv_kernel_init = (
self.qkv_kernel_init
if self.qkv_kernel_init is not None
else self.kernel_init
)
kv_kernel_init = (
self.kv_kernel_init
if self.kv_kernel_init is not None
else self.kernel_init
)
q_kernel_init = (
self.q_kernel_init
if self.q_kernel_init is not None
else self.kernel_init
)
if precomputed_qkv is not None:
raise ValueError('Support for precomputed QKVO not implemented.')
rotary_index = None
features = self.out_features or inputs_q.shape[-1]
qkv_features = self.qkv_features or inputs_q.shape[-1]
if self.head_dim is None:
head_dim = qkv_features // self.num_heads
else:
head_dim = self.head_dim
if self.kernels_to_fuse and not self.split_head_kernel:
raise ValueError(
'Un-reshaped kernels are required when using QKV fused '
'kernel optimization.'
)
# Is attention logit rescaling explicit or folded into initializer?
if self.rescale_logits:
query_init = q_kernel_init
else:
if self.kernels_to_fuse:
raise ValueError(
'Cannot fold in logit normalization to query '
'initializer when using fused kernels.'
)
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
query_init = lambda *args: q_kernel_init(*args) / depth_scaling
make_dense = functools.partial(
self.dense_general_factory,
axis=-1,
bias_init=self.bias_init,
use_bias=self.use_bias,
dtype=self.dtype,
precision=self.precision,
reshape_kernel=not self.split_head_kernel,
)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch..., length, num_heads, features_per_head]
if self.kernels_to_fuse is None:
query = make_dense(
kernel_init=query_init,
features=(self.num_heads, head_dim),
kernel_axis_names=['embed', 'heads', 'kv'],
name='query',
)(inputs_q)
key = make_dense(
kernel_init=self.kernel_init,
features=(self.num_heads, head_dim),
kernel_axis_names=['embed', 'heads', 'kv'],
name='key',
)(inputs_kv)
value = make_dense(
kernel_init=self.kernel_init,
features=(self.num_heads, head_dim),
kernel_axis_names=['embed', 'heads', 'kv'],
name='value',
)(inputs_kv)
# TODO: should we fuse/slice along depth or head dim?
elif self.kernels_to_fuse == 'qkv':
if inputs_q is not inputs_kv:
raise ValueError(
'qkv fusion is only supported in self-attention mode '
'(when inputs_q is inputs_kv).'
)
# 'qkv' fusion mode implies self-attention
qkv = make_dense(
kernel_init=qkv_kernel_init,
features=(3, self.num_heads, head_dim),
kernel_axis_names=['embed', 'stack', 'heads', 'kv'],
name='qkv_fused',
)(inputs_q)
query = jnp.squeeze(lax.dynamic_slice_in_dim(qkv, 0, 1, -3), -3)
key = jnp.squeeze(lax.dynamic_slice_in_dim(qkv, 1, 1, -3), -3)
value = jnp.squeeze(lax.dynamic_slice_in_dim(qkv, 2, 1, -3), -3)
elif self.kernels_to_fuse == 'kv':
query = make_dense(
kernel_init=query_init,
features=(self.num_heads, head_dim),
kernel_axis_names=['embed', 'heads', 'kv'],
name='query',
)(inputs_q)
kv = make_dense(
kernel_init=kv_kernel_init,
features=(2, self.num_heads, head_dim),
kernel_axis_names=['embed', 'stack', 'heads', 'kv'],
name='kv_fused',
)(inputs_kv)
key = jnp.squeeze(lax.dynamic_slice_in_dim(kv, 0, 1, -3), -3)
value = jnp.squeeze(lax.dynamic_slice_in_dim(kv, 1, 1, -3), -3)
else:
raise ValueError('Incorrect kernel fusion mode specified.')
# Multi Dconv Head Attention options:
if self.q_conv is not None:
query = self.q_conv( # pylint: disable=not-callable
query, decode=decode, prefill=prefill, prefill_lengths=prefill_lengths
)
if self.k_conv is not None:
key = self.k_conv( # pylint: disable=not-callable
key, decode=decode, prefill=prefill, prefill_lengths=prefill_lengths
)
if self.v_conv is not None:
value = self.v_conv( # pylint: disable=not-callable
value, decode=decode, prefill=prefill, prefill_lengths=prefill_lengths
)
if self.sharding_over_head_dimension:
# Note: We don't use `activation_partitioning.with_sharding_migration`
# here because we do often want this 2D sharded. However, if rules are
# valid, they should result in 2D sharding. We don't need to raise errors
# if both result in 2D sharding (which with_sharding_migration does).
if flax_partitioning.get_axis_rules():
query = flax_partitioning.with_sharding_constraint(
query, ('batch', 'length', 'heads', 'kv')
)
key = flax_partitioning.with_sharding_constraint(
key, ('batch', 'length', 'heads', 'kv')
)
value = flax_partitioning.with_sharding_constraint(
value, ('batch', 'length', 'heads', 'kv')
)
else:
query = activation_partitioning.with_sharding(query, 2)
key = activation_partitioning.with_sharding(key, 2)
value = activation_partitioning.with_sharding(value, 2)
query: Array = query # hint to quiet pytype.
key: Array = key
value: Array = value
if prefill and decode:
raise ValueError(
'prefill and decode cannot both be true at the same'
'time. If you are using a prefix LM with bidirectional '
'attention on the inputs, please make a call with '
'prefill=True that includes an attention mask that '
'covers your inputs first and then make your decoding '
'calls.'
)
if prefill or decode:
# Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
# The key and value have dimension
# [batch..., length, num_heads, features_per_head], but we cache them as
# [batch..., num_heads, features_per_head, length] as a TPU fusion
# optimization. This also enable the "scatter via one-hot broadcast"
# trick, which means we do a one-hot broadcast instead of a scatter/gather
# operations, which gives a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable(
'cache', 'cached_key', jnp.zeros, swap_dims(key.shape), key.dtype
)
cached_value = self.variable(
'cache',
'cached_value',
jnp.zeros,
swap_dims(value.shape),
value.dtype,
)
cache_index = self.variable(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
rotary_index = cache_index.value
if is_initialized:
# Here we are in "apply()".
*batch_dims, num_heads, features_per_head, length = (
cached_key.value.shape
)
if prefill:
if prefill_lengths is None:
# Figure out how far each element in the batch fills the cache based
# on the mask. We index each element in the batch, the first head
# dim (because this is always set to one), and the first query
# vector. If there is any prefix at all, the first element in the
# prefix would be part of it.
prefill_lengths = jnp.sum(mask[:, 0, 0, :], axis=-1).astype(
cache_index.value.dtype
)
(
key,
value,
cur_index,
cached_key_value,
cached_value_value,
cache_index_value,
) = self.update_cache_prefill(
key, value, cached_key, cached_value, cache_index, prefill_lengths
)
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
elif decode:
# Check the shape of the cached key against the input query.
expected_shape = tuple(batch_dims) + (1, num_heads, features_per_head)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
'expected query shape %s instead got %s.'
% (expected_shape, query.shape)
)
(
key,
value,
cur_index,
cached_key_value,
cached_value_value,
cache_index_value,
) = self.update_cache_decode(
key, value, cached_key, cached_value, cache_index
)
# Enforcing the Causal mask over previous positions and selecting only
# the bias value for the current index is only needed during decode
# mode where a single example is feed at a time. In prefill mode we
# uses these as provided, that same way it is done in a normal forward
# pass, like when computing logits during training.
# Causal mask for cached decoder self-attention: our single query
# position should only attend to those key positions that have already
# been generated and cached, not the remaining zero elements.
# (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one
# index.
# The same mask is applied to all batch elements and heads.
#
# Add trailing dims to the current index so it can either
# broadcast over the batch dim or it can just be batch size.
mask = combine_masks(
mask,
jnp.broadcast_to(
jnp.arange(length), tuple(batch_dims) + (1, 1, length)
)
<= jnp.reshape(cur_index, (-1, 1, 1, 1)),
)
# Grab the correct relative attention bias during decoding. This is
# only required during single step decoding.
if bias is not None:
# The bias is a full attention matrix, but during decoding we only
# have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :].
# If we are doing prefix decoding where cur index is a vector the
# result will be [batch, heads, 1, :]. If cur_index is a scalar
# like in encdec decoding, the result will be [1, heads, 1, :].
# We use a one-hot einsum rather than a slice to avoid introducing
# a Gather op that is currently lowered poorly by SPMD passes,
# adding expensive all-reduce and all-gather operations.