-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
175 lines (136 loc) · 5.03 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
r"""Attention layers"""
__all__ = [
'MultiheadAttention',
]
import jax
import math
from einops import rearrange
from jax import Array
from typing import *
from .linear import Linear
from .module import Module
def attention(
q: Array,
k: Array,
v: Array,
mask: Array = None,
) -> Array:
r"""Computes the scaled dot-product attention.
Arguments:
q: The query tensor :math:`Q`, with shape :math:`(*, S, C)`.
k: The key tensor :math:`K`, with shape :math:`(*, T, C)`.
v: The value tensor :math:`V`, with shape :math:`(*, T, C')`.
mask: A boolean attention mask, with shape :math:`(*, S, T)`.
A :py:`False` value indicates that the corresponding attention weight
is set to :math:`-\infty`.
Returns:
The output vector :math:`y`, with shape :math:`(*, S, C')`.
"""
C = q.shape[-1]
weight = jax.numpy.einsum('...ik,...jk->...ij', q, k)
weight = weight / math.sqrt(C)
if mask is not None:
weight = jax.numpy.where(mask, weight, -1e9)
attn = jax.nn.softmax(weight, axis=-1)
return jax.numpy.einsum('...ij,...jk->...ik', attn, v)
class MultiheadAttention(Module):
r"""Creates a multihead attention layer.
.. math:: Y = \sum_i
\mathrm{attention}(X_q W_q^i + b_q^i, X_k W_k^i + b_k^i, X_v W_v^i) W_y^i
where
.. math:: \mathrm{attention}(Q, K, V) =
\mathrm{softmax}\left( \frac{Q K^T}{\sqrt{H}} \right) V
denotes the scaled dot-product attention.
References:
| Attention Is All You Need (Vaswani et al., 2023)
| https://arxiv.org/abs/1706.03762
Arguments:
key: A PRNG key for initialization.
heads: The number of attention heads.
in_features: The number of input features :math:`C`.
hid_features: The number of hidden features :math:`H` per head.
out_features: The number of output features :math:`C'`.
If :py:`None`, :math:`C' = C`.
bias: Whether the layer learns additive biases :math:`(b_q, b_k)` or not.
causal: Whether the attention mask is causal or not. If :py:`True`, the
:math:`i`-th query is only allowed to attend the :math:`j`-th key if
:math:`j - i \leq T - S`.
dropout: The dropout rate on attention weights.
"""
def __init__(
self,
key: Array,
heads: int,
in_features: int,
hid_features: int,
out_features: int = None,
bias: bool = True,
causal: bool = False,
dropout: float = 0.0,
):
keys = jax.random.split(key, 4)
if out_features is None:
out_features = in_features
self.lin_q = Linear(keys[0], in_features, hid_features * heads, bias=bias)
self.lin_k = Linear(keys[1], in_features, hid_features * heads, bias=bias)
self.lin_v = Linear(keys[2], in_features, hid_features * heads, bias=False)
self.lin_y = Linear(keys[3], hid_features * heads, out_features, bias=False)
self.heads = heads
self.causal = causal
self.dropout = dropout
def __call__(
self,
xq: Array,
xk: Array = None,
xv: Array = None,
mask: Array = None,
key: Array = None,
) -> Array:
r"""
Arguments:
xq: The query tensor :math:`X_q`, with shape :math:`(*, S, C)`.
xk: The key tensor :math:`X_k`, with shape :math:`(*, T, C)`.
If :py:`None`, :math:`X_k = X_q`.
xv: The value tensor :math:`X_v`, with shape :math:`(*, T, C)`.
If :py:`None`, :math:`X_v = X_k`.
mask: A boolean attention mask, with shape :math:`(*, S, T)`.
A :py:`False` value indicates that the corresponding attention weight
is set to :math:`-\infty`.
key: A PRNG key. If :py:`None`, dropout is not applied.
Returns:
The output tensor :math:`Y`, with shape :math:`(*, S, C')`.
"""
if xk is None:
xk = xq
if xv is None:
xv = xk
S, T = xq.shape[-2], xk.shape[-2]
# Project
q = self.lin_q(xq)
k = self.lin_k(xk)
v = self.lin_v(xv)
q, k, v = [
rearrange(x, '... L (N H) -> ... N L H', H=self.heads)
for x in (q, k, v)
]
# Mask
if self.causal:
if mask is None:
mask = jax.numpy.ones((S, T), dtype=bool)
mask = jax.numpy.tril(mask, T - S)
if key is not None:
shape = jax.numpy.broadcast_shapes(
(*q.shape[:-2], S, 1),
(*k.shape[:-2], 1, T),
(S, T) if mask is None else mask.shape,
)
keep = jax.random.bernoulli(key, p=1 - self.dropout, shape=shape)
if mask is None:
mask = keep
else:
mask = jax.numpy.logical_and(mask, keep)
# Attention
y = attention(q, k, v, mask)
y = rearrange(y, '... N L H -> ... L (N H)')
y = self.lin_y(y)
return y