-
Notifications
You must be signed in to change notification settings - Fork 28
/
en_transformer.py
294 lines (229 loc) · 9.1 KB
/
en_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
def fourier_encode_dist(x, num_encodings = 4, include_self = True):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype)
x = x / scales
x = torch.cat([x.sin(), x.cos()], dim=-1)
x = torch.cat((x, orig_x), dim = -1) if include_self else x
return x
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, feats, coors, **kwargs):
feats_out, coors_delta = self.fn(feats, coors, **kwargs)
return feats + feats_out, coors + coors_delta
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps)
return x / n * self.g
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = ScaleNorm(dim)
def forward(self, feats, coors, **kwargs):
feats = self.norm(feats)
feats, coors = self.fn(feats, coors, **kwargs)
return feats, coors
class FeedForward(nn.Module):
def __init__(
self,
*,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, feats, coors):
return self.net(feats), 0
class EquivariantAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 4,
edge_dim = 0,
m_dim = 16,
fourier_features = 0,
norm_rel_coors = False,
norm_coor_weights = False,
num_nearest_neighbors = 0,
init_eps = 1e-3
):
super().__init__()
self.fourier_features = fourier_features
self.num_nearest_neighbors = num_nearest_neighbors
attn_inner_dim = heads * dim_head
self.heads = heads
self.to_qkv = nn.Linear(dim, attn_inner_dim * 3, bias = False)
self.to_out = nn.Linear(attn_inner_dim, dim)
edge_input_dim = (fourier_features * 2) + (dim_head * 2) + edge_dim + 1
self.edge_mlp = nn.Sequential(
nn.Linear(edge_input_dim, edge_input_dim * 2),
nn.ReLU(),
nn.Linear(edge_input_dim * 2, m_dim),
nn.ReLU()
)
self.to_attn_mlp = nn.Sequential(
nn.Linear(m_dim, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1),
Rearrange('... () -> ...')
)
self.coors_mlp = nn.Sequential(
nn.Linear(m_dim * heads, m_dim * 4),
nn.ReLU(),
nn.Linear(m_dim * 4, 1),
Rearrange('... () -> ...'),
nn.TanH() if norm_coor_weights else nn.Identity()
)
self.norm_rel_coors = norm_rel_coors
self.init_eps = init_eps
self.apply(self.init_)
def init_(self, module):
if type(module) in {nn.Linear}:
nn.init.normal_(module.weight, std = self.init_eps)
def forward(
self,
feats,
coors,
edges = None,
mask = None
):
b, n, d, h, fourier_features, num_nn, device = *feats.shape, self.heads, self.fourier_features, self.num_nearest_neighbors, feats.device
rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d')
rel_dist = rel_coors.norm(dim = -1, p = 2)
nbhd_indices = None
if num_nn > 0:
rel_dist = rel_coors.norm(dim = -1, p = 2)
nbhd_indices = rel_dist.topk(num_nn, dim = -1, largest = False).indices
rel_dist = rearrange(rel_dist, 'b i j -> b i j ()')
if fourier_features > 0:
rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features)
rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d')
rel_dist = repeat(rel_dist, 'b i j d -> b h i j d', h = h)
# derive queries keys and values
q, k, v = self.to_qkv(feats).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# calculate nearest neighbors
i = j = n
if exists(nbhd_indices):
i, j = nbhd_indices.shape[-2:]
nbhd_indices_with_heads = repeat(nbhd_indices, 'b n d -> b h n d', h = h)
k = batched_index_select(k, nbhd_indices_with_heads, dim = 2)
v = batched_index_select(v, nbhd_indices_with_heads, dim = 2)
rel_dist = batched_index_select(rel_dist, nbhd_indices_with_heads, dim = 3)
rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2)
else:
k = repeat(k, 'b h j d -> b h n j d', n = n)
# prepare mask
if exists(mask):
q_mask = rearrange(mask, 'b i -> b () i ()')
k_mask = repeat(mask, 'b j -> b i j', i = n)
if exists(nbhd_indices):
k_mask = batched_index_select(k_mask, nbhd_indices, dim = 2)
k_mask = rearrange(k_mask, 'b i j -> b () i j')
mask = q_mask * k_mask
# expand queries and keys for concatting
q = repeat(q, 'b h i d -> b h i n d', n = j)
edge_input = torch.cat((q, k, rel_dist), dim = -1)
if exists(edges):
if exists(nbhd_indices):
edges = batched_index_select(edges, nbhd_indices, dim = 2)
edges = repeat(edges, 'b i j d -> b h i j d', h = h)
edge_input = torch.cat((edge_input, edges), dim = -1)
m_ij = self.edge_mlp(edge_input)
coor_mlp_input = rearrange(m_ij, 'b h i j d -> b i j (h d)')
coor_weights = self.coors_mlp(coor_mlp_input)
if exists(mask):
coor_mask = rearrange(mask, 'b () i j -> b i j')
coor_weights.masked_fill_(~coor_mask, 0.)
if self.norm_rel_coors:
rel_coors = F.normalize(rel_coors, dim = -1, p = 2)
coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors)
# derive attention
sim = self.to_attn_mlp(m_ij)
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value)
attn = sim.softmax(dim = -1)
# weighted sum of values and combine heads
aggregate_einsum_note = 'b h i j, b h j d -> b h i d' if not exists(nbhd_indices) else 'b h i j, b h i j d -> b h i d'
out = einsum(aggregate_einsum_note, attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out, coors_out
# transformer
class EnTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
num_tokens = None,
dim_head = 64,
heads = 8,
edge_dim = 0,
m_dim = 16,
fourier_features = 4,
num_nearest_neighbors = 0,
norm_rel_coors = False,
norm_coor_weights = False,
init_eps = 1e-3
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, EquivariantAttention(dim = dim, dim_head = dim_head, heads = heads, m_dim = m_dim, edge_dim = edge_dim, fourier_features = fourier_features, norm_rel_coors = norm_rel_coors, norm_coor_weights = norm_coor_weights, num_nearest_neighbors = num_nearest_neighbors, init_eps = init_eps))),
Residual(PreNorm(dim, FeedForward(dim = dim)))
]))
self.num_nearest_neighbors = num_nearest_neighbors
def forward(
self,
feats,
coors,
edges = None,
mask = None
):
if exists(self.token_emb):
feats = self.token_emb(feats)
# main network
for attn, ff in self.layers:
feats, coors = attn(feats, coors, edges = edges, mask = mask)
feats, coors = ff(feats, coors)
return feats, coors