-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
ufunc_api.py
330 lines (288 loc) · 13.4 KB
/
ufunc_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools to create numpy-style ufuncs."""
_AT_INPLACE_WARNING = """\
Because JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like
np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
>>> arr = jnp.add.at(arr, ind, val, inplace=False)
"""
from functools import partial
import operator
from typing import Any, Callable, Optional
import jax
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis
import numpy as np
def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == jax._src.pjit.pjit_p and
all(jax._src.pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None
_primitive_reducers = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}
_primitive_accumulators = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}
class ufunc:
"""Functions that operate element-by-element on whole arrays.
This is a class for LAX-backed implementations of numpy ufuncs.
"""
def __init__(self, func, /, nin, nout, *, name=None, nargs=None, identity=None):
# We want ufunc instances to work properly when marked as static,
# and for this reason it's important that their properties not be
# mutated. We prevent this by storing them in a dunder attribute,
# and accessing them via read-only properties.
self.__name__ = name or func.__name__
self.__static_props = {
'func': func,
'call': vectorize(func),
'nin': operator.index(nin),
'nout': operator.index(nout),
'nargs': operator.index(nargs or nin),
'identity': identity
}
_func = property(lambda self: self.__static_props['func'])
_call = property(lambda self: self.__static_props['call'])
nin = property(lambda self: self.__static_props['nin'])
nout = property(lambda self: self.__static_props['nout'])
nargs = property(lambda self: self.__static_props['nargs'])
identity = property(lambda self: self.__static_props['identity'])
def __hash__(self):
# Do not include _call, because it is computed from _func.
return hash((self._func, self.__name__, self.identity,
self.nin, self.nout, self.nargs))
def __eq__(self, other):
# Do not include _call, because it is computed from _func.
return isinstance(other, ufunc) and (
(self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) ==
(other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs))
def __repr__(self):
return f"<jnp.ufunc '{self.__name__}'>"
def __call__(self, *args, out=None, where=None, **kwargs):
if out is not None:
raise NotImplementedError(f"out argument of {self}")
if where is not None:
raise NotImplementedError(f"where argument of {self}")
return self._call(*args, **kwargs)
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None):
check_arraylike(f"{self.__name__}.reduce", a)
if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("reduce only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.reduce()")
if initial is not None:
check_arraylike(f"{self.__name__}.reduce", initial)
if where is not None:
check_arraylike(f"{self.__name__}.reduce", where)
if self.identity is None and initial is None:
raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, "
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
assert self.nin == 2 and self.nout == 1
arr = lax_internal.asarray(arr)
if initial is None:
initial = self.identity
if dtype is None:
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
if where is not None:
where = _broadcast_to(where, arr.shape)
if isinstance(axis, tuple):
axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis)
raise NotImplementedError("tuple of axes")
elif axis is None:
if keepdims:
final_shape = (1,) * arr.ndim
else:
final_shape = ()
arr = arr.ravel()
if where is not None:
where = where.ravel()
axis = 0
else:
axis = canonicalize_axis(axis, arr.ndim)
if keepdims:
final_shape = (*arr.shape[:axis], 1, *arr.shape[axis + 1:])
else:
final_shape = (*arr.shape[:axis], *arr.shape[axis + 1:])
# TODO: handle without transpose?
if axis != 0:
arr = _moveaxis(arr, axis, 0)
if where is not None:
where = _moveaxis(where, axis, 0)
if initial is None and arr.shape[0] == 0:
raise ValueError("zero-size array to reduction operation {self.__name__} which has no ideneity")
def body_fun(i, val):
if where is None:
return self._call(val, arr[i].astype(dtype))
else:
return _where(where[i], self._call(val, arr[i].astype(dtype)), val)
if initial is None:
start_index = 1
start_value = arr[0]
else:
start_index = 0
start_value = initial
start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:])
result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value)
if keepdims:
result = result.reshape(final_shape)
return result
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def accumulate(self, a, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("accumulate only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("accumulate only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
return accumulator(a, axis=axis, dtype=dtype)
def _accumulate_via_scan(self, arr, axis=0, dtype=None):
assert self.nin == 2 and self.nout == 1
check_arraylike(f"{self.__name__}.accumulate", arr)
arr = lax_internal.asarray(arr)
if dtype is None:
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype
if axis is None or isinstance(axis, tuple):
raise ValueError("accumulate does not allow multiple axes")
axis = canonicalize_axis(axis, np.ndim(arr))
arr = _moveaxis(arr, axis, 0)
def scan_fun(carry, _):
i, x = carry
y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype)))
return (i + 1, y), y
_, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0])
return _moveaxis(result, 0, axis)
@_wraps(np.ufunc.accumulate, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0], static_argnames=['inplace'])
def at(self, a, indices, b=None, /, *, inplace=True):
if inplace:
raise NotImplementedError(_AT_INPLACE_WARNING)
if b is None:
return self._at_via_scan(a, indices)
else:
return self._at_via_scan(a, indices, b)
def _at_via_scan(self, a, indices, *args):
check_arraylike(f"{self.__name__}.at", a, *args)
dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype
a = lax_internal.asarray(a).astype(dtype)
args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args)
indices = _eliminate_deprecated_list_indexing(indices)
if not indices:
return a
shapes = [np.shape(i) for i in indices if not isinstance(i, slice)]
shape = shapes and jax.lax.broadcast_shapes(*shapes)
if not shape:
return a.at[indices].set(self._call(a.at[indices].get(), *args))
args = tuple(_broadcast_to(arg, shape).ravel() for arg in args)
indices = [idx if isinstance(idx, slice) else _broadcast_to(idx, shape).ravel() for idx in indices]
def scan_fun(carry, x):
i, a = carry
idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices)
a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args)))
return (i + 1, a), x
carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0]))
return carry[1]
@_wraps(np.ufunc.reduceat, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype'])
def reduceat(self, a, indices, axis=0, dtype=None, out=None):
if self.nin != 2:
raise ValueError("reduceat only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("reduceat only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.reduceat()")
return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype)
def _reduceat_via_scan(self, a, indices, axis=0, dtype=None):
check_arraylike(f"{self.__name__}.reduceat", a, indices)
a = lax_internal.asarray(a)
idx_tuple = _eliminate_deprecated_list_indexing(indices)
assert len(idx_tuple) == 1
indices = idx_tuple[0]
if a.ndim == 0:
raise ValueError(f"reduceat: a must have 1 or more dimension, got {a.shape=}")
if indices.ndim != 1:
raise ValueError(f"reduceat: indices must be one-dimensional, got {indices.shape=}")
if dtype is None:
dtype = a.dtype
if axis is None or isinstance(axis, (tuple, list)):
raise ValueError("reduceat requires a single integer axis.")
axis = canonicalize_axis(axis, a.ndim)
out = take(a, indices, axis=axis)
ind = jax.lax.expand_dims(append(indices, a.shape[axis]),
np.delete(np.arange(out.ndim), axis))
ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis)
ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis)
def loop_body(i, out):
return _where((i > ind_start) & (i < ind_end),
self._call(out, take(a, i.reshape(1), axis=axis)),
out)
return jax.lax.fori_loop(0, a.shape[axis], loop_body, out)
@_wraps(np.ufunc.outer, module="numpy.ufunc")
@partial(jax.jit, static_argnums=[0])
def outer(self, A, B, /, **kwargs):
if self.nin != 2:
raise ValueError("outer only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("outer only supported for functions returning a single value")
check_arraylike(f"{self.__name__}.outer", A, B)
_ravel = lambda A: jax.lax.reshape(A, (np.size(A),))
result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B))
return result.reshape(*np.shape(A), *np.shape(B))
def frompyfunc(func, /, nin, nout, *, identity=None):
"""Create a JAX ufunc from an arbitrary JAX-compatible scalar function.
Args:
func : a callable that takes `nin` scalar arguments and return `nout` outputs.
nin: integer specifying the number of scalar inputs
nout: integer specifying the number of scalar outputs
identity: (optional) a scalar specifying the identity of the operation, if any.
Returns:
wrapped : jax.numpy.ufunc wrapper of func.
"""
# TODO(jakevdp): use functools.wraps or similar to wrap the docstring?
return ufunc(func, nin, nout, identity=identity)