-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
ann.py
366 lines (321 loc) · 15.2 KB
/
ann.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ANN (Approximate Nearest Neighbor) computes top-k with a configurable recall rate.
This package only optimizes the TPU backend. For other device types it fallbacks
to sort and slice.
Usage::
import functools
import jax
# MIPS := maximal inner product search
# Inputs:
# qy: f32[qy_size, feature_dim]
# db: f32[db_size, feature_dim]
#
# Returns:
# (f32[qy_size, k], i32[qy_size, k])
@functools.partial(jax.jit, static_argnames=["k", "recall_target"])
def mips(qy, db, k=10, recall_target=0.95):
dists = jax.lax.dot(qy, db.transpose())
# Computes max_k along the last dimension
# returns (f32[qy_size, k], i32[qy_size, k])
return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
# Multi-core example
# Inputs:
# qy: f32[num_devices, qy_size, feature_dim]
# db: f32[num_devices, per_device_db_size, feature_dim]
# db_offset: i32[num_devices]
# db_size = num_devices * per_device_db_size
#
# Returns:
# (f32[qy_size, num_devices, k], i32[qy_size, num_devices, k])
@functools.partial(
jax.pmap,
# static args: db_size, k, recall_target
static_broadcasted_argnums=[3, 4, 5],
out_axes=(1, 1))
def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
dists = jax.lax.dot(qy, db.transpose())
dists, neighbors = jax.lax.approx_max_k(
dists, k=k, recall_target=recall_target,
reduction_input_size_override=db_size)
return (dists, neighbors + db_offset)
# i32[qy_size, num_devices, k]
pmap_neighbors = pmap_mips(qy, db, db_offset, db_size, 10, 0.95)[1]
# i32[qy_size, num_devices * k]
neighbors = jax.lax.collapse(pmap_neighbors, start_dimension=1, stop_dimension=3)
Todos::
* On host top-k aggregation
* Inaccurate but fast differentiation
"""
from functools import partial
from typing import (Any, Tuple)
import numpy as np
from jax import core
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src import ad_util, dtypes
from jax.interpreters import ad, xla, batching
Array = Any
def approx_max_k(operand: Array,
k: int,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for max-k. Must be a floating number type.
k : Specifies the number of max-k.
reduction_dimension : Integer dimension along which to search. Default: -1.
recall_target : Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given ``operand`` is only a subset
of the overall computation in SPMD or distributed pipelines, where the
true input size cannot be deferred by the operand shape.
aggregate_to_topk : When true, aggregates approximate results to the top-k
in sorted order. When false, returns the approximate results unsorted. In
this case, the number of the approximate results is implementation defined
and is greater or equal to the specified ``k``.
Returns:
Tuple of two arrays. The arrays are the max ``k`` values and the
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap ``approx_max_k`` with jit. See the following
example for maximal inner production search (MIPS):
>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def mips(qy, db, k=10, recall_target=0.95):
... dists = jax.lax.dot(qy, db.transpose())
... # returns (f32[qy_size, k], i32[qy_size, k])
... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)
"""
return approx_top_k_p.bind(
operand,
k=k,
reduction_dimension=reduction_dimension,
recall_target=recall_target,
is_max_k=True,
reduction_input_size_override=reduction_input_size_override,
aggregate_to_topk=aggregate_to_topk)
def approx_min_k(operand: Array,
k: int,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True) -> Tuple[Array, Array]:
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for min-k. Must be a floating number type.
k : Specifies the number of min-k.
reduction_dimension: Integer dimension along which to search. Default: -1.
recall_target: Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given operand is only a subset of
the overall computation in SPMD or distributed pipelines, where the true
input size cannot be deferred by the ``operand`` shape.
aggregate_to_topk : When true, aggregates approximate results to the top-k
in sorted order. When false, returns the approximate results unsorted. In
this case, the number of the approximate results is implementation defined
and is greater or equal to the specified ``k``.
Returns:
Tuple of two arrays. The arrays are the least ``k`` values and the
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap ``approx_min_k`` with jit. See the following example
for nearest neighbor search over the squared l2 distance:
>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
... dists = half_db_norms - jax.lax.dot(qy, db.transpose())
... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> half_db_norms = jax.numpy.linalg.norm(db, axis=1) / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norms, k=10)
In the example above, we compute ``db_norms/2 - dot(qy, db^T)`` instead of
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
"""
return approx_top_k_p.bind(
operand,
k=k,
reduction_dimension=reduction_dimension,
recall_target=recall_target,
is_max_k=False,
reduction_input_size_override=reduction_input_size_override,
aggregate_to_topk=aggregate_to_topk)
def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
recall_target, is_max_k,
reduction_input_size_override,
aggregate_to_topk):
if k <= 0:
raise ValueError(f'k must be positive, got {k}')
if len(operand.shape) == 0:
raise TypeError('approx_top_k operand must have >= 1 dimension, got {}'.format(
operand.shape))
dims = list(operand.shape)
if dims[reduction_dimension] < k:
raise ValueError(
'k must be smaller than the size of reduction_dim {}, got {}'.format(
dims[reduction_dimension], k))
if not dtypes.issubdtype(operand.dtype, np.floating):
raise ValueError('operand must be a floating type')
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
return (operand.update(
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
operand.update(shape=dims, dtype=np.dtype(np.int32)))
def _comparator_builder(op_type, is_max_k):
c = xc.XlaBuilder(
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
if is_max_k:
cmp_result = xc.ops.Gt(p0, p1)
else:
cmp_result = xc.ops.Lt(p0, p1)
return c.build(cmp_result)
def _get_init_val_literal(op_type, is_max_k):
return np.array(np.NINF if is_max_k else np.Inf, dtype=op_type)
def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override,
aggregate_to_topk):
c = ctx.builder
op_shape = c.get_shape(operand)
if not op_shape.is_array():
raise ValueError(f'operand must be an array, but was {op_shape}')
op_dims = op_shape.dimensions()
op_type = op_shape.element_type()
if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder(op_type, is_max_k)
init_val_literal = _get_init_val_literal(op_type, is_max_k)
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
reduction_dimension)
init_val = xc.ops.Constant(c, init_val_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator, recall_target,
aggregate_to_topk, reduction_input_size_override)
return xla.xla_destructure(c, out)
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target,
is_max_k, reduction_input_size_override,
aggregate_to_topk):
c = ctx.builder
op_shape = c.get_shape(operand)
if not op_shape.is_array():
raise ValueError(f'operand must be an array, but was {op_shape}')
op_dims = op_shape.dimensions()
op_type = op_shape.element_type()
if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder(op_type, is_max_k)
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
reduction_dimension)
init_val_literal = _get_init_val_literal(op_type, is_max_k)
init_val = xc.ops.Constant(c, init_val_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator,
recall_target, aggregate_to_topk,
reduction_input_size_override)
return xla.xla_destructure(c, out)
def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override, aggregate_to_topk):
assert len(batch_operands) == 1
assert len(batch_axes) == 1
operand, = batch_operands
batch_axis, = batch_axes
dim_map = [d for d in range(operand.ndim) if d is not batch_axis]
reduction_dimension = dim_map[reduction_dimension]
return approx_top_k_p.bind(
operand,
k=k,
reduction_dimension=reduction_dimension,
recall_target=recall_target,
is_max_k=is_max_k,
reduction_input_size_override=reduction_input_size_override,
aggregate_to_topk=aggregate_to_topk), (batch_axis, batch_axis)
# Slow jvp implementation using gather.
#
# TODO(fchern): Some optimization ideas
# 1. ApproxTopK is internally a variadic reduce, so we can simply call
# ApproxTopK(operand, tangent, iota) for jvp.
# 2. vjp cannot benefit from the algorithm above. We must run scatter to
# distribute the output cotangent to input cotangent. A reasonable way to do
# this is to run it on CPU.
def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
recall_target, is_max_k, reduction_input_size_override,
aggregate_to_topk):
operand, = primals
tangent, = tangents
if is_max_k:
val_out, arg_out = approx_max_k(operand, k, reduction_dimension,
recall_target,
reduction_input_size_override,
aggregate_to_topk)
else:
val_out, arg_out = approx_min_k(operand, k, reduction_dimension,
recall_target,
reduction_input_size_override,
aggregate_to_topk)
if type(tangent) is ad_util.Zero:
tangent_out = ad_util.Zero.from_value(val_out)
else:
arg_shape = arg_out.shape
rank = len(arg_shape)
if reduction_dimension < 0:
reduction_dimension += rank
iotas = [
lax.broadcasted_iota(arg_out.dtype, arg_shape, i) for i in range(rank)
]
idx = tuple(
arg_out if i == reduction_dimension else iotas[i] for i in range(rank))
tangent_out = tangent[idx]
return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out))
approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp