-
Notifications
You must be signed in to change notification settings - Fork 29
/
attentions.py
489 lines (417 loc) · 17.4 KB
/
attentions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
import math
import torch
from jaxtyping import Float
from torch import Tensor, device as Device, dtype as DType
from torch.nn.functional import scaled_dot_product_attention as _scaled_dot_product_attention
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.basics import Identity
from refiners.fluxion.layers.chain import Chain, Distribute, Lambda, Parallel
from refiners.fluxion.layers.linear import Linear
from refiners.fluxion.layers.module import Module
def scaled_dot_product_attention(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
"""Scaled Dot Product Attention.
Note:
Optimization depends on which PyTorch backend is used.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
See also [torch.nn.functional.scaled_dot_product_attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""
return _scaled_dot_product_attention(
query=query,
key=key,
value=value,
is_causal=is_causal,
)
def scaled_dot_product_attention_non_optimized(
query: Float[Tensor, "batch source_sequence_length dim"],
key: Float[Tensor, "batch target_sequence_length dim"],
value: Float[Tensor, "batch target_sequence_length dim"],
is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
"""Non-optimized Scaled Dot Product Attention.
See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
"""
if is_causal:
# TODO: implement causal attention
raise NotImplementedError(
"Causal attention for `scaled_dot_product_attention_non_optimized` is not yet implemented"
)
dim = query.shape[-1]
attention = query @ key.permute(0, 1, 3, 2)
attention = attention / math.sqrt(dim)
attention = torch.softmax(input=attention, dim=-1)
return attention @ value
class ScaledDotProductAttention(Module):
"""Scaled Dot Product Attention.
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-19.png)
Note:
This layer simply wraps `scaled_dot_product_attention` inside an `fl.Module`.
Receives:
Query (Float[Tensor, "batch num_queries embedding_dim"]):
Key (Float[Tensor, "batch num_keys embedding_dim"]):
Value (Float[Tensor, "batch num_values embedding_dim"]):
Returns:
(Float[Tensor, "batch num_queries embedding_dim"]):
Example:
```py
attention = fl.ScaledDotProductAttention(num_heads=8)
query = torch.randn(2, 10, 128)
key = torch.randn(2, 10, 128)
value = torch.randn(2, 10, 128)
output = attention(query, key, value)
assert output.shape == (2, 10, 128)
```
"""
def __init__(
self,
num_heads: int = 1,
is_causal: bool = False,
is_optimized: bool = True,
slice_size: int | None = None,
) -> None:
"""Initialize the Scaled Dot Product Attention layer.
Args:
num_heads: The number of heads to use.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
slice_size: The slice size to use for the optimized attention.
"""
super().__init__()
self.num_heads = num_heads
self.is_causal = is_causal
self.is_optimized = is_optimized
self.slice_size = slice_size
self.dot_product = (
scaled_dot_product_attention if self.is_optimized else scaled_dot_product_attention_non_optimized
)
def forward(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
) -> Float[Tensor, "batch num_queries embedding_dim"]:
if self.slice_size:
return self._sliced_attention(
query=query,
key=key,
value=value,
slice_size=self.slice_size,
)
else:
return self._process_attention(
query=query,
key=key,
value=value,
)
def _sliced_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
slice_size: int,
) -> Float[Tensor, "batch num_queries embedding_dim"]:
"""Compute the scaled dot product attention in slices.
This is useful when the input tensors are too large to be processed in one go.
"""
_, num_queries, _ = query.shape
output = torch.zeros_like(query)
for start_idx in range(0, num_queries, slice_size):
end_idx = min(start_idx + slice_size, num_queries)
output[:, start_idx:end_idx, :] = self._process_attention(
query=query[:, start_idx:end_idx, :],
key=key,
value=value,
)
return output
def _process_attention(
self,
query: Float[Tensor, "batch num_queries embedding_dim"],
key: Float[Tensor, "batch num_keys embedding_dim"],
value: Float[Tensor, "batch num_values embedding_dim"],
) -> Float[Tensor, "batch num_queries embedding_dim"]:
"""Compute the scaled dot product attention.
Split the input tensors (query, key, value) into multiple heads along the embedding dimension,
then compute the scaled dot product attention for each head, and finally merge the heads back.
"""
return self._merge_multi_head(
x=self.dot_product(
query=self._split_to_multi_head(query),
key=self._split_to_multi_head(key),
value=self._split_to_multi_head(value),
is_causal=self.is_causal,
)
)
def _split_to_multi_head(
self,
x: Float[Tensor, "batch_size sequence_length embedding_dim"],
) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
"""Split the input tensor into multiple heads along the embedding dimension.
See also `merge_multi_head`, which is the inverse operation.
"""
assert (
x.ndim == 3
), f"Expected input tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
assert (
x.shape[-1] % self.num_heads == 0
), f"Expected embedding_dim (x.shape[-1]={x.shape[-1]}) to be divisible by num_heads ({self.num_heads})"
return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)
def _merge_multi_head(
self,
x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"],
) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
"""Merge the input tensor from multiple heads along the embedding dimension.
See also `split_to_multi_head`, which is the inverse operation.
"""
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])
class Attention(Chain):
"""Multi-Head Attention layer.
??? note "See [[arXiv:1706.03762] Attention Is All You Need (Figure 2)](https://arxiv.org/abs/1706.03762) for more details"
![](https://ar5iv.labs.arxiv.org/html/1706.03762/assets/Figures/ModalNet-20.png)
Note: This layer simply chains
- a [`Distribute`][refiners.fluxion.layers.chain.Distribute] layer,
containing 3 [`Linear`][refiners.fluxion.layers.linear.Linear] layers,
which transforms the 3 inputs into Query, Key and Value
- a [`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
- a [`Linear`][refiners.fluxion.layers.linear.Linear] layer,
which projects the output of the
[`ScaledDotProductAttention`][refiners.fluxion.layers.attentions.ScaledDotProductAttention] layer
Receives:
Query (Float[Tensor, "batch sequence_length embedding_dim"]):
Key (Float[Tensor, "batch sequence_length embedding_dim"]):
Value (Float[Tensor, "batch sequence_length embedding_dim"]):
Returns:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Example:
```py
attention = fl.Attention(num_heads=8, embedding_dim=128)
tensor = torch.randn(2, 10, 128)
output = attention(tensor, tensor, tensor)
assert output.shape == (2, 10, 128)
```
"""
def __init__(
self,
embedding_dim: int,
num_heads: int = 1,
key_embedding_dim: int | None = None,
value_embedding_dim: int | None = None,
inner_dim: int | None = None,
use_bias: bool = True,
is_causal: bool = False,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the Attention layer.
Args:
embedding_dim: The embedding dimension of the input and output tensors.
num_heads: The number of heads to use.
key_embedding_dim: The embedding dimension of the key tensor.
value_embedding_dim: The embedding dimension of the value tensor.
inner_dim: The inner dimension of the linear layers.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
assert (
embedding_dim % num_heads == 0
), f"embedding_dim {embedding_dim} must be divisible by num_heads {num_heads}"
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.heads_dim = embedding_dim // num_heads
self.key_embedding_dim = key_embedding_dim or embedding_dim
self.value_embedding_dim = value_embedding_dim or embedding_dim
self.inner_dim = inner_dim or embedding_dim
self.use_bias = use_bias
self.is_causal = is_causal
self.is_optimized = is_optimized
super().__init__(
Distribute(
Linear( # Query projection
in_features=self.embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear( # Key projection
in_features=self.key_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
Linear( # Value projection
in_features=self.value_embedding_dim,
out_features=self.inner_dim,
bias=self.use_bias,
device=device,
dtype=dtype,
),
),
ScaledDotProductAttention(
num_heads=num_heads,
is_causal=is_causal,
is_optimized=is_optimized,
),
Linear( # Output projection
in_features=self.inner_dim,
out_features=self.embedding_dim,
bias=True,
device=device,
dtype=dtype,
),
)
class SelfAttention(Attention):
"""Multi-Head Self-Attention layer.
Note: This layer simply chains
- a [`Parallel`][refiners.fluxion.layers.chain.Parallel] layer,
which duplicates the input Tensor
(for each Linear layer in the `Attention` layer)
- an [`Attention`][refiners.fluxion.layers.attentions.Attention] layer
Receives:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Returns:
(Float[Tensor, "batch sequence_length embedding_dim"]):
Example:
```py
self_attention = fl.SelfAttention(num_heads=8, embedding_dim=128)
tensor = torch.randn(2, 10, 128)
output = self_attention(tensor)
assert output.shape == (2, 10, 128)
```
"""
def __init__(
self,
embedding_dim: int,
inner_dim: int | None = None,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool = False,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the Self-Attention layer.
Args:
embedding_dim: The embedding dimension of the input and output tensors.
inner_dim: The inner dimension of the linear layers.
num_heads: The number of heads to use.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
super().__init__(
embedding_dim=embedding_dim,
inner_dim=inner_dim,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
is_optimized=is_optimized,
device=device,
dtype=dtype,
)
self.insert(
index=0,
module=Parallel(
Identity(), # Query projection's input
Identity(), # Key projection's input
Identity(), # Value projection's input
),
)
class SelfAttention2d(SelfAttention):
"""Multi-Head 2D Self-Attention layer.
Note: This Module simply chains
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
which transforms the input Tensor into a sequence
- a [`SelfAttention`][refiners.fluxion.layers.attentions.SelfAttention] layer
- a [`Lambda`][refiners.fluxion.layers.chain.Lambda] layer,
which transforms the output sequence into a 2D Tensor
Receives:
(Float[Tensor, "batch channels height width"]):
Returns:
(Float[Tensor, "batch channels height width"]):
Example:
```py
self_attention = fl.SelfAttention2d(channels=128, num_heads=8)
tensor = torch.randn(2, 128, 64, 64)
output = self_attention(tensor)
assert output.shape == (2, 128, 64, 64)
```
"""
def __init__(
self,
channels: int,
num_heads: int = 1,
use_bias: bool = True,
is_causal: bool = False,
is_optimized: bool = True,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the 2D Self-Attention layer.
Args:
channels: The number of channels of the input and output tensors.
num_heads: The number of heads to use.
use_bias: Whether to use bias in the linear layers.
is_causal: Whether to use causal attention.
is_optimized: Whether to use optimized attention.
device: The device to use.
dtype: The dtype to use.
"""
assert channels % num_heads == 0, f"channels {channels} must be divisible by num_heads {num_heads}"
self.channels = channels
super().__init__(
embedding_dim=channels,
num_heads=num_heads,
use_bias=use_bias,
is_causal=is_causal,
is_optimized=is_optimized,
device=device,
dtype=dtype,
)
self.insert(0, Lambda(self._tensor_2d_to_sequence))
self.append(Lambda(self._sequence_to_tensor_2d))
def init_context(self) -> Contexts:
return {
"reshape": {
"height": None,
"width": None,
}
}
def _tensor_2d_to_sequence(
self,
x: Float[Tensor, "batch channels height width"],
) -> Float[Tensor, "batch height*width channels"]:
"""Transform a 2D Tensor into a sequence.
The height and width of the input Tensor are stored in a `"reshape"` context,
so that the output Tensor can be transformed back into a 2D Tensor in the `sequence_to_tensor_2d` method.
"""
height, width = x.shape[-2:]
self.set_context(
context="reshape",
value={
"height": height,
"width": width,
},
)
return x.reshape(x.shape[0], x.shape[1], height * width).transpose(1, 2)
def _sequence_to_tensor_2d(
self,
x: Float[Tensor, "batch sequence_length channels"],
) -> Float[Tensor, "batch channels height width"]:
"""Transform a sequence into a 2D Tensor.
The height and width of the output Tensor are retrieved from the `"reshape"` context,
which was set in the `tensor_2d_to_sequence` method.
"""
height, width = self.use_context("reshape").values()
return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], height, width)