-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
/
relative_mha.py
152 lines (120 loc) ยท 6.23 KB
/
relative_mha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
---
title: Relative Multi-Headed Attention
summary: >
Documented implementation with explanations of
Relative Multi-Headed Attention from paper Transformer-XL.
---
# Relative Multi-Headed Attention
This is an implementation of relative multi-headed attention from paper
[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)
in [PyTorch](https://pytorch.org).
"""
import torch
from torch import nn
from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention
def shift_right(x: torch.Tensor):
"""
This method shifts $i^{th}$ row of a matrix by $i$ columns.
If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
result would be `[[1, 2 ,3], [0, 4, 5], [6, 0, 7]]`.
*Ideally we should mask out the lower triangle but it's ok for our purpose*.
"""
# Concatenate a column of zeros
zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
x_padded = torch.cat([x, zero_pad], dim=1)
# Reshape and remove excess elements from the end
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[:-1].view_as(x)
#
return x
class RelativeMultiHeadAttention(MultiHeadAttention):
"""
## Relative Multi-Head Attention Module
We override [Multi-Head Attention](mha.html) module so we only need to
write the `get_scores` method.
"""
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
# The linear transformations do not need a bias since we
# explicitly include it when calculating scores.
# However having a bias for `value` might make sense.
super().__init__(heads, d_model, dropout_prob, bias=False)
# Number of relative positions
self.P = 2 ** 12
# Relative positional embeddings for key relative to the query.
# We need $2P$ embeddings because the keys can be before or after the query.
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
# Relative positional embedding bias for key relative to the query.
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
# Positional embeddings for the query is independent of the position of the query
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
r"""
### Get relative attention scores
With absolute attention
\begin{align}
A^{abs}_{j} &= lin_q(X^q_i + P_i)^\top lin_k(X^k_j + P_j) \\
&= \underset{\textcolor{lightgreen}{A}}{Q_i^\top K_j} +
\underset{\textcolor{lightgreen}{B}}{Q_i^\top U^K_j} +
\underset{\textcolor{lightgreen}{C}}{{U^Q_i}^\top K_j} +
\underset{\textcolor{lightgreen}{D}}{{U^Q_i}^\top U^K_j}
\end{align}
where $Q_i, K_j$, are linear transformations of
original embeddings $X^q_i, X^k_j$
and $U^Q_i, U^K_j$ are linear transformations of
absolute positional encodings $P_i, P_j$.
They reason out that the attention to a given key should be the same regardless of
the position of query.
Hence replace $\underset{\textcolor{lightgreen}{C}}{{U^Q_i}^\top K_j}$
with a constant $\underset{\textcolor{lightgreen}{C}}{\textcolor{orange}{v^\top} K_j}$.
For the second and third terms relative positional encodings are introduced.
So $\underset{\textcolor{lightgreen}{B}}{Q_i^\top U^K_j}$ is
replaced with $\underset{\textcolor{lightgreen}{B}}{Q_i^\top \textcolor{orange}{R_{i - j}}}$
and $\underset{\textcolor{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
with $\underset{\textcolor{lightgreen}{D}}{\textcolor{orange}{S_{i-j}}}$.
\begin{align}
A^{rel}_{i,j} &= \underset{\mathbf{\textcolor{lightgreen}{A}}}{Q_i^\top K_j} +
\underset{\mathbf{\textcolor{lightgreen}{B}}}{Q_i^\top \textcolor{orange}{R_{i - j}}} +
\underset{\mathbf{\textcolor{lightgreen}{C}}}{\textcolor{orange}{v^\top} K_j} +
\underset{\mathbf{\textcolor{lightgreen}{D}}}{\textcolor{orange}{S_{i-j}}}
\end{align}
"""
# $\textcolor{orange}{R_k}$
key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]
# $\textcolor{orange}{S_k}$
key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]
# $\textcolor{orange}{v^\top}$
query_pos_bias = self.query_pos_bias[None, None, :, :]
# ${(\textcolor{lightgreen}{\mathbf{A + C}})}_{i,j} =
# Q_i^\top K_j +
# \textcolor{orange}{v^\top} K_j$
ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
# $\textcolor{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \textcolor{orange}{R_k}$
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
# $\textcolor{lightgreen}{\mathbf{D'}_{i,k}} = \textcolor{orange}{S_k}$
d = key_pos_bias[None, :, None, :]
# Shift the rows of $\textcolor{lightgreen}{\mathbf{(B' + D')}_{i,k}}$
# to get $$\textcolor{lightgreen}{\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}}$$
bd = shift_right(b + d)
# Remove extra positions
bd = bd[:, -key.shape[0]:]
# Return the sum $$
# \underset{\mathbf{\textcolor{lightgreen}{A}}}{Q_i^\top K_j} +
# \underset{\mathbf{\textcolor{lightgreen}{B}}}{Q_i^\top \textcolor{orange}{R_{i - j}}} +
# \underset{\mathbf{\textcolor{lightgreen}{C}}}{\textcolor{orange}{v^\top} K_j} +
# \underset{\mathbf{\textcolor{lightgreen}{D}}}{\textcolor{orange}{S_{i-j}}}
# $$
return ac + bd
def _test_shift_right():
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
inspect(x)
inspect(shift_right(x))
x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(shift_right(x)[:, :, 0, 0])
x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(shift_right(x)[:, :, 0, 0])
if __name__ == '__main__':
_test_shift_right()