-
Notifications
You must be signed in to change notification settings - Fork 45
/
attention.py
287 lines (244 loc) · 11.4 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base attention dataflow combinators.
This module contains basic primitives for attention operations in Transformer
neural networks. These primitives are intentionally as simple as possible,
and do not include the actual initialization logic or attention weight
computation. Instead, they abstract away the core dataflow patterns across
training and kv-cache inference modes.
"""
from __future__ import annotations
import dataclasses
from typing import Any
import jax
import jax.numpy as jnp
from penzai.core import layer as layer_base
from penzai.core import named_axes
from penzai.core import struct
from penzai.data_effects import local_state
from penzai.data_effects import side_input
@struct.pytree_dataclass
class ApplyAttentionMask(layer_base.Layer):
"""Applies an attention mask to its input logit array.
This layer retrieves a causal attention mask from its side input, and uses it
to mask its argument. Masked out values are replaced with the
``masked_out_value`` attribute, which is usually a large (but finite) negative
value.
Attributes:
mask: A side input that provides the attention mask to apply to the input
attention scores. This side input should be provided as a boolean array
that is broadcastable with the input.
"""
mask: side_input.SideInputEffect[named_axes.NamedArray]
masked_out_value: jax.typing.ArrayLike
@classmethod
def from_config(
cls,
mask_tag: Any,
masked_out_value: jax.typing.ArrayLike = -2.3819763e38,
) -> ApplyAttentionMask:
"""Creates an ``ApplyAttentionMask`` layer from a tag and a mask value.
Args:
mask_tag: Side input tag for the mask side input. This should be used to
identify the sdie inputs that correspond to the same attention mask
throughout the model.
masked_out_value: The value to replace masked out values with. This is
usually a large (but finite) negative value, so that it maps to a
negligible attention weight in a numerically stable way.
Returns:
A new ``ApplyAttentionMask`` layer with the given configuration.
"""
return cls(
mask=side_input.SideInputRequest(tag=mask_tag),
masked_out_value=masked_out_value,
)
def __call__(self, x: named_axes.NamedArray) -> named_axes.NamedArray:
"""Applies the attention mask to the input array.
Args:
x: The input array to mask. Usually the matrix of query-key dot products.
Returns:
An adjusted matrix of logits, where any value where the mask is False has
been replaced with the `masked_out_value` argument.
"""
return named_axes.nmap(jnp.where)(self.mask.ask(), x, self.masked_out_value)
@struct.pytree_dataclass
class Attention(layer_base.Layer):
"""A basic attention combinator.
An attention layer contains five subcomputations, for computing queries, keys,
and values, combining queries and keys into attention weights, and combining
attention weights and values into an output. This class abstracts away the
dataflow patterns common to all attention layers, and leaves the details of
the actual computations to the sublayers.
Attributes:
input_to_query: A layer that maps the input to an array of queries.
input_to_key: A layer that maps the input to an array of keys.
input_to_value: A layer that maps the input to an array of values.
query_key_to_attn: A layer that maps a tuple of (queries, keys) to attention
weights.
attn_value_to_output: A layer that maps a a tuple of (attention weights,
values) to a final output.
"""
input_to_query: layer_base.LayerLike
input_to_key: layer_base.LayerLike
input_to_value: layer_base.LayerLike
query_key_to_attn: layer_base.LayerLike
attn_value_to_output: layer_base.LayerLike
def __call__(self, x: named_axes.NamedArray) -> named_axes.NamedArray:
"""Runs the attention computation.
Args:
x: The input to the computation, which will be mapped to queries, keys,
and values by the sublayers.
Returns:
The final output of the ``attn_value_to_output`` sublayer.
"""
query = self.input_to_query(x)
key = self.input_to_key(x)
value = self.input_to_value(x)
attn = self.query_key_to_attn((query, key))
output = self.attn_value_to_output((attn, value))
return output
@struct.pytree_dataclass
class KVCachingAttention(layer_base.Layer):
"""Key/value caching variant of `Attention`.
``KVCachingAttention`` is a drop-in replacement for `Attention`, but adds
key/value caching logic using Penzai's effect system. This means that a model
initially configured for training can be quickly adapted to do inference
without making the training logic more complicated.
Attributes:
input_to_query: A layer that maps the input to an array of queries, usually
taken from the original `Attention` layer.
input_to_key: A layer that maps the input to an array of keys, usually taken
from the original `Attention` layer. The output of this layer will
additionally be stored in the stateful key/value cache.
input_to_value: A layer that maps the input to an array of values, usually
taken from the original `Attention` layer. The output of this layer will
additionally be stored in the stateful key/value cache.
query_key_to_attn: A layer that maps a tuple of ``(queries, keys)`` to
attention weights, usually taken from the original `Attention` layer. The
key input will contain the full key cache, rather than the slice produced
for the current token.
attn_value_to_output: A layer that maps a a tuple of ``(attention weights,
values)`` to a final output, usually taken from the original `Attention`
layer. The value input will contain the full value cache, rather than the
slice produced for the current token.
sequence_axis: The axis along which to do key/value caching. Should be an
axis name that appears in the output of the ``input_to_key`` and
``input_to_value`` sublayers.
kv_cache_end_index: A side input that identifies the current dynamic size of
the key/value caches, i.e. the number of elements that have been populated
with entries. Should be populated by a scalar integer array.
kv_cache: A state effect variable that stores a tuple of key and value
caches. This will be initialized when this layer is constructed, and will
be updated as it runs.
"""
input_to_query: layer_base.LayerLike
input_to_key: layer_base.LayerLike
input_to_value: layer_base.LayerLike
query_key_to_attn: layer_base.LayerLike
attn_value_to_output: layer_base.LayerLike
sequence_axis: str = dataclasses.field(metadata={"pytree_node": False})
kv_cache_end_index: side_input.SideInputEffect[jax.Array]
kv_cache: local_state.LocalStateEffect[
tuple[named_axes.NamedArray, named_axes.NamedArray]
]
def __call__(self, x: named_axes.NamedArray) -> named_axes.NamedArray:
"""Runs the caching attention computation and update the K/V cache state.
When called, ``self.kv_cache_end_index`` should be filled with
a scalar integer identifying the current size of the cache (before inserting
this token), and ``self.kv_cache`` should be a `LocalState` that contains
the current state.
Args:
x: The input to the computation, which will be mapped to queries, keys,
and values by the sublayers.
Returns:
The final output of the ``attn_value_to_output`` sublayer.
"""
# Retrieve effectful inputs.
kvc_end_index = self.kv_cache_end_index.ask()
key_cache, value_cache = self.kv_cache.get()
# Compute queries, keys, and values as normal.
query = self.input_to_query(x)
key = self.input_to_key(x)
value = self.input_to_value(x)
# Update the KV caches.
new_key_cache = named_axes.nmap(jax.lax.dynamic_update_slice)(
key_cache.untag(self.sequence_axis),
key.untag(self.sequence_axis),
(kvc_end_index,),
).tag(self.sequence_axis)
new_value_cache = named_axes.nmap(jax.lax.dynamic_update_slice)(
value_cache.untag(self.sequence_axis),
value.untag(self.sequence_axis),
(kvc_end_index,),
).tag(self.sequence_axis)
self.kv_cache.set((new_key_cache, new_value_cache))
# Run the rest on the updated KV caches.
attn = self.query_key_to_attn((query, new_key_cache))
output = self.attn_value_to_output((attn, new_value_cache))
return output
@classmethod
def from_uncached(
cls,
original: Attention,
sequence_axis: str,
cache_len: int,
cached_axes: dict[str, int],
cache_end_index_tag: side_input.Tag,
state_category: local_state.Category,
cache_dtype: jax.typing.DTypeLike = jnp.float32,
) -> KVCachingAttention:
"""Builds a caching attention from an uncached attention.
Args:
original: The original attention layer that this block should replace.
sequence_axis: The axis along which keys and values should be cached.
Should be present in the output of the ``input_to_key`` and
``input_to_value`` sublayers.
cache_len: Length of the cache; used to populate the initial state.
cached_axes: Axis names and sizes for all other axes of the key and value
arrays (e.g. for batch, heads, and the projected embeddings). These are
used to initialize the cache.
cache_end_index_tag: Side input tag for the cache position side input.
This should be used to identify the side inputs that should receive the
cache position information, and should (usually) be provided to the
`pz.de.WithSideInputsFromInputTuple` handler that actually provides this
side input.
state_category: Category for the local state. This should be used to
identify the state variables that correspond to key-value caches in the
model, and should (usually) be provided to the
`pz.de.handle_local_states` call that functionalizes the state effect.
cache_dtype: Dtype for the data to store in the cache. Should match the
dtype of the key and value arrays.
Returns:
A ``KVCachingAttention`` instance that behaves like the original
`Attention` layer, but updates key-value caches iteratively, using new
side input and state effect requests.
"""
def kv_cache_initializer():
empty_cache = named_axes.zeros(
{**cached_axes, sequence_axis: cache_len},
dtype=cache_dtype,
)
return (empty_cache, empty_cache)
return cls(
input_to_query=original.input_to_query,
input_to_key=original.input_to_key,
input_to_value=original.input_to_value,
query_key_to_attn=original.query_key_to_attn,
attn_value_to_output=original.attn_value_to_output,
sequence_axis=sequence_axis,
kv_cache_end_index=side_input.SideInputRequest(cache_end_index_tag),
kv_cache=local_state.InitialLocalStateRequest(
kv_cache_initializer, category=state_category
),
)