-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
lax_numpy.py
5390 lines (4637 loc) · 208 KB
/
lax_numpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
from __future__ import annotations
import builtins
import collections
from collections.abc import Sequence
from functools import partial
import math
import operator
import types
from typing import (overload, Any, Callable, Literal, NamedTuple, Protocol, TypeVar, Union)
from textwrap import dedent as _dedent
import warnings
import numpy as np
import opt_einsum
import jax
from jax import jit
from jax import errors
from jax import lax
from jax.tree_util import tree_leaves, tree_flatten, tree_map
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src.custom_derivatives import custom_jvp
from jax._src import dispatch
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax._src.array import ArrayImpl
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator, PrecisionLike)
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis,
NumpyComplexWarning)
newaxis = None
T = TypeVar('T')
# Like core.canonicalize_shape, but also accept int-like (non-sequence)
# arguments for `shape`.
def canonicalize_shape(shape: Any, context: str="") -> core.Shape:
if (not isinstance(shape, (tuple, list)) and
(getattr(shape, 'ndim', None) == 0 or ndim(shape) == 0)):
return core.canonicalize_shape((shape,), context) # type: ignore
else:
return core.canonicalize_shape(shape, context) # type: ignore
# Common docstring additions:
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument.
"""
# Some objects below rewrite their __module__ attribute to this name.
_PUBLIC_MODULE_NAME = "jax.numpy"
# NumPy constants
pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
nan = np.nan
# NumPy utility functions
get_printoptions = np.get_printoptions
printoptions = np.printoptions
set_printoptions = np.set_printoptions
@util._wraps(np.iscomplexobj)
def iscomplexobj(x: Any) -> bool:
if x is None:
return False
try:
typ = x.dtype.type
except AttributeError:
typ = asarray(x).dtype.type
return issubdtype(typ, complexfloating)
shape = _shape = np.shape
ndim = _ndim = np.ndim
size = np.size
def _dtype(x: Any) -> DType:
return dtypes.dtype(x, canonicalize=True)
# At present JAX doesn't have a reason to distinguish between scalars and arrays
# in its object system. Further, we want JAX scalars to have the same type
# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX
# scalar object with JAX promotion behaviors, instead we make the JAX scalar
# types return JAX arrays when instantiated.
class _ScalarMeta(type):
dtype: np.dtype
def __hash__(self) -> int:
return hash(self.dtype.type)
def __eq__(self, other: Any) -> bool:
return id(self) == id(other) or self.dtype.type == other
def __ne__(self, other: Any) -> bool:
return not (self == other)
def __call__(self, x: Any) -> Array:
return asarray(x, dtype=self.dtype)
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, self.dtype.type)
def _abstractify_scalar_meta(x):
raise TypeError(f"JAX scalar type {x} cannot be interpreted as a JAX array.")
api_util._shaped_abstractify_handlers[_ScalarMeta] = _abstractify_scalar_meta
def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
meta.__module__ = _PUBLIC_MODULE_NAME
return meta
bool_ = _make_scalar_type(np.bool_)
uint4 = _make_scalar_type(dtypes.uint4)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
int4 = _make_scalar_type(dtypes.int4)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)
float64 = double = _make_scalar_type(np.float64)
complex64 = csingle = _make_scalar_type(np.complex64)
complex128 = cdouble = _make_scalar_type(np.complex128)
int_ = int32 if dtypes.int_ == np.int32 else int64
uint = uint32 if dtypes.uint == np.uint32 else uint64
float_: Any = float32 if dtypes.float_ == np.float32 else float64
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
generic = np.generic
number = np.number
inexact = np.inexact
complexfloating = np.complexfloating
floating = np.floating
integer = np.integer
signedinteger = np.signedinteger
unsignedinteger = np.unsignedinteger
flexible = np.flexible
character = np.character
object_ = np.object_
iinfo = dtypes.iinfo
finfo = dtypes.finfo
dtype = np.dtype
can_cast = dtypes.can_cast
promote_types = dtypes.promote_types
ComplexWarning = NumpyComplexWarning
array_str = np.array_str
array_repr = np.array_repr
save = np.save
savez = np.savez
@util._wraps(np.dtype)
def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
copy: bool = False) -> DType:
"""Similar to np.dtype, but respects JAX dtype defaults."""
if dtypes.issubdtype(obj, dtypes.extended):
return obj # type: ignore[return-value]
if obj is None:
obj = dtypes.float_
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
obj = _DEFAULT_TYPEMAP[obj]
return np.dtype(obj, align=align, copy=copy)
### utility functions
_DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = {
bool: bool_,
int: int_,
float: float_,
complex: complex_,
}
_lax_const = lax_internal._const
def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
range rather than wrapping.
Args:
val: value to be converted
dtype: dtype of output
Returns:
equivalent of val in new dtype
Examples
--------
Normal integer type conversion will wrap:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
Array(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
Array(2147483647, dtype=int32)
"""
val = val if isinstance(val, Array) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
val_dtype = dtypes.canonicalize_dtype(val.dtype)
if val_dtype != val.dtype:
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
# This happens in X32 mode and can either come from a jax value created in another
# context, or a Python integer converted to int64.
pass
min_val = _lax_const(val, max(iinfo(dtype).min, iinfo(val_dtype).min))
max_val = _lax_const(val, min(iinfo(dtype).max, iinfo(val_dtype).max))
return clip(val, min_val, max_val).astype(dtype)
@util._wraps(np.load, update_doc=False)
def load(*args: Any, **kwargs: Any) -> Array:
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
try:
out = asarray(out)
except (TypeError, AssertionError): # Unsupported dtype
pass
return out
### implementations of numpy functions in terms of lax
@util._wraps(np.fmin, module='numpy')
@jit
def fmin(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2)
@util._wraps(np.fmax, module='numpy')
@jit
def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2)
@util._wraps(np.issubdtype)
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
return dtypes.issubdtype(arg1, arg2)
@util._wraps(np.isscalar)
def isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
iterable = np.iterable
@util._wraps(np.result_type)
def result_type(*args: Any) -> DType:
return dtypes.result_type(*args)
@util._wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, axis: int = -1) -> Array:
if x is None:
util.check_arraylike('trapz', y)
y_arr, = util.promote_dtypes_inexact(y)
else:
util.check_arraylike('trapz', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx = diff(x_arr)
else:
dx = moveaxis(diff(x_arr, axis=axis), axis, -1)
y_arr = moveaxis(y_arr, axis, -1)
return 0.5 * (dx * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
@util._wraps(np.trunc, module='numpy')
@jit
def trunc(x: ArrayLike) -> Array:
util.check_arraylike('trunc', x)
return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x))
_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """
preferred_element_type : dtype, optional
If specified, accumulate results and return a result of the given data type.
If not specified, the function instead follows the numpy convention of always
accumulating results and returning an inexact dtype.
"""
@partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type'])
def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike,
preferred_element_type: DTypeLike | None = None) -> Array:
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
if preferred_element_type is None:
# if unspecified, promote to inexact following NumPy's default for convolutions.
x, y = util.promote_dtypes_inexact(x, y)
else:
# otherwise cast to same type but otherwise preserve input dtypes
x, y = util.promote_dtypes(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
out_order = slice(None)
if op == 'correlate':
y = ufuncs.conj(y)
if len(x) < len(y):
x, y = y, x
out_order = slice(None, None, -1)
elif op == 'convolve':
if len(x) < len(y):
x, y = y, x
y = flip(y)
if mode == 'valid':
padding = [(0, 0)]
elif mode == 'same':
padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)]
elif mode == 'full':
padding = [(y.shape[0] - 1, y.shape[0] - 1)]
else:
raise ValueError("mode must be one of ['full', 'same', 'valid']")
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
padding, precision=precision,
preferred_element_type=preferred_element_type)
return result[0, 0, out_order]
@util._wraps(np.convolve, lax_description=_PRECISION_DOC,
extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def convolve(a: ArrayLike, v: ArrayLike, mode: str = 'full', *,
precision: PrecisionLike = None,
preferred_element_type: dtype | None = None) -> Array:
util.check_arraylike("convolve", a, v)
return _conv(asarray(a), asarray(v), mode=mode, op='convolve',
precision=precision, preferred_element_type=preferred_element_type)
@util._wraps(np.correlate, lax_description=_PRECISION_DOC,
extra_params=_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
@partial(jit, static_argnames=('mode', 'precision', 'preferred_element_type'))
def correlate(a: ArrayLike, v: ArrayLike, mode: str = 'valid', *,
precision: PrecisionLike = None,
preferred_element_type: dtype | None = None) -> Array:
util.check_arraylike("correlate", a, v)
return _conv(asarray(a), asarray(v), mode=mode, op='correlate',
precision=precision, preferred_element_type=preferred_element_type)
@util._wraps(np.histogram_bin_edges)
def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
range: None | Array | Sequence[ArrayLike] = None,
weights: ArrayLike | None = None) -> Array:
del weights # unused, because string bins is not supported.
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
util.check_arraylike("histogram_bin_edges", a, bins)
arr = asarray(a)
dtype = dtypes.to_inexact_dtype(arr.dtype)
if _ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins_int = core.concrete_or_error(operator.index, bins,
"bins argument of histogram_bin_edges")
if range is None:
range = [arr.min(), arr.max()]
range = asarray(range, dtype=dtype)
if shape(range) != (2,):
raise ValueError(f"`range` must be either None or a sequence of scalars, got {range}")
range = (where(reductions.ptp(range) == 0, range[0] - 0.5, range[0]),
where(reductions.ptp(range) == 0, range[1] + 0.5, range[1]))
assert range is not None
return linspace(range[0], range[1], bins_int + 1, dtype=dtype)
@util._wraps(np.histogram)
def histogram(a: ArrayLike, bins: ArrayLike = 10,
range: Sequence[ArrayLike] | None = None,
weights: ArrayLike | None = None,
density: bool | None = None) -> tuple[Array, Array]:
if weights is None:
util.check_arraylike("histogram", a, bins)
a, = util.promote_dtypes_inexact(a)
weights = ones_like(a)
else:
util.check_arraylike("histogram", a, bins, weights)
if shape(a) != shape(weights):
raise ValueError("weights should have the same shape as a.")
a, weights = util.promote_dtypes_inexact(a, weights)
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
counts = zeros(len(bin_edges), weights.dtype).at[bin_idx].add(weights)[1:]
if density:
bin_widths = diff(bin_edges)
counts = counts / bin_widths / counts.sum()
return counts, bin_edges
@util._wraps(np.histogram2d)
def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
weights: ArrayLike | None = None,
density: bool | None = None) -> tuple[Array, Array, Array]:
util.check_arraylike("histogram2d", x, y)
try:
N = len(bins) # type: ignore[arg-type]
except TypeError:
N = 1
if N != 1 and N != 2:
x_edges = y_edges = asarray(bins)
bins = [x_edges, y_edges]
sample = transpose(asarray([x, y]))
hist, edges = histogramdd(sample, bins, range, weights, density)
return hist, edges[0], edges[1]
@util._wraps(np.histogramdd)
def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
weights: ArrayLike | None = None,
density: bool | None = None) -> tuple[Array, list[Array]]:
if weights is None:
util.check_arraylike("histogramdd", sample)
sample, = util.promote_dtypes_inexact(sample)
else:
util.check_arraylike("histogramdd", sample, weights)
if shape(weights) != shape(sample)[:1]:
raise ValueError("should have one weight for each sample.")
sample, weights = util.promote_dtypes_inexact(sample, weights)
N, D = shape(sample)
if range is not None and (
len(range) != D or any(r is not None and shape(r)[0] != 2 for r in range)): # type: ignore[arg-type]
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
f"of {D} pairs or Nones; got {range=}")
try:
num_bins = len(bins) # type: ignore[arg-type]
except TypeError:
# when bin_size is integer, the same bin is used for each dimension
bins_per_dimension: list[ArrayLike] = D * [bins] # type: ignore[assignment]
else:
if num_bins != D:
raise ValueError("should be a bin for each dimension.")
bins_per_dimension = list(bins) # type: ignore[arg-type]
bin_idx_by_dim: list[Array] = []
bin_edges_by_dim: list[Array] = []
for i in builtins.range(D):
range_i = None if range is None else range[i]
bin_edges = histogram_bin_edges(sample[:, i], bins_per_dimension[i], range_i, weights)
bin_idx = searchsorted(bin_edges, sample[:, i], side='right')
bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx)
bin_idx_by_dim.append(bin_idx)
bin_edges_by_dim.append(bin_edges)
nbins = tuple(len(bin_edges) + 1 for bin_edges in bin_edges_by_dim)
dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim]
xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip')
hist = bincount(xy, weights, length=math.prod(nbins))
hist = reshape(hist, nbins)
core = D*(slice(1, -1),)
hist = hist[core]
if density:
hist = hist.astype(sample.dtype)
hist /= hist.sum()
for norm in ix_(*dedges):
hist /= norm
return hist, bin_edges_by_dim
_ARRAY_VIEW_DOC = """
The JAX version of this function may in some cases return a copy rather than a
view of the input.
"""
@util._wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
util.check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
return lax.transpose(a, axes_)
@util._wraps(getattr(np, 'matrix_transpose', None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes the last two dimensions of x.
Parameters
----------
x : array_like
Input array. Must have ``x.ndim >= 2``.
Returns
-------
xT : Array
Transposed array.
"""
util.check_arraylike("matrix_transpose", x)
ndim = np.ndim(x)
if ndim < 2:
raise ValueError(f"x must be at least two-dimensional for matrix_transpose; got {ndim=}")
axes = (*range(ndim - 2), ndim - 1, ndim - 2)
return lax.transpose(x, axes)
@util._wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
util.check_arraylike("rot90", m)
if np.ndim(m) < 2:
raise ValueError("rot90 requires its first argument to have ndim at least "
f"two, but got first argument of shape {np.shape(m)}, "
f"which has ndim {np.ndim(m)}")
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
ax2 = _canonicalize_axis(ax2, ndim(m))
if ax1 == ax2:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
if k == 0:
return asarray(m)
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(ndim(m)))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@util._wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
if axis is None:
return lax.rev(m, list(range(len(shape(m)))))
axis = _ensure_index_tuple(axis)
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
@util._wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC)
def fliplr(m: ArrayLike) -> Array:
util.check_arraylike("fliplr", m)
return _flip(asarray(m), 1)
@util._wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC)
def flipud(m: ArrayLike) -> Array:
util.check_arraylike("flipud", m)
return _flip(asarray(m), 0)
@util._wraps(np.iscomplex)
@jit
def iscomplex(x: ArrayLike) -> Array:
i = ufuncs.imag(x)
return lax.ne(i, _lax_const(i, 0))
@util._wraps(np.isreal)
@jit
def isreal(x: ArrayLike) -> Array:
i = ufuncs.imag(x)
return lax.eq(i, _lax_const(i, 0))
@util._wraps(np.angle)
@partial(jit, static_argnames=['deg'])
def angle(z: ArrayLike, deg: bool = False) -> Array:
re = ufuncs.real(z)
im = ufuncs.imag(z)
dtype = _dtype(re)
if not issubdtype(dtype, inexact) or (
issubdtype(_dtype(z), floating) and ndim(z) == 0):
dtype = dtypes.canonicalize_dtype(float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
result = lax.atan2(im, re)
return ufuncs.degrees(result) if deg else result
@util._wraps(np.diff)
@partial(jit, static_argnames=('n', 'axis'))
def diff(a: ArrayLike, n: int = 1, axis: int = -1,
prepend: ArrayLike | None = None,
append: ArrayLike | None = None) -> Array:
util.check_arraylike("diff", a)
arr = asarray(a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff")
if n == 0:
return arr
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")
if arr.ndim == 0:
raise ValueError(f"diff requires input that is at least one dimensional; got {a}")
nd = arr.ndim
axis = _canonicalize_axis(axis, nd)
combined: list[Array] = []
if prepend is not None:
util.check_arraylike("diff", prepend)
if isscalar(prepend):
shape = list(arr.shape)
shape[axis] = 1
prepend = broadcast_to(prepend, tuple(shape))
combined.append(asarray(prepend))
combined.append(arr)
if append is not None:
util.check_arraylike("diff", append)
if isscalar(append):
shape = list(arr.shape)
shape[axis] = 1
append = broadcast_to(append, tuple(shape))
combined.append(asarray(append))
if len(combined) > 1:
arr = concatenate(combined, axis)
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1_tuple = tuple(slice1)
slice2_tuple = tuple(slice2)
op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract
for _ in range(n):
arr = op(arr[slice1_tuple], arr[slice2_tuple])
return arr
_EDIFF1D_DOC = """\
Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
loses precision.
"""
@util._wraps(np.ediff1d, lax_description=_EDIFF1D_DOC)
@jit
def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None,
to_begin: ArrayLike | None = None) -> Array:
util.check_arraylike("ediff1d", ary)
arr = ravel(ary)
result = lax.sub(arr[1:], arr[:-1])
if to_begin is not None:
util.check_arraylike("ediff1d", to_begin)
result = concatenate((ravel(asarray(to_begin, dtype=arr.dtype)), result))
if to_end is not None:
util.check_arraylike("ediff1d", to_end)
result = concatenate((result, ravel(asarray(to_end, dtype=arr.dtype))))
return result
@util._wraps(np.gradient, skip_params=['edge_order'])
@partial(jit, static_argnames=('axis', 'edge_order'))
def gradient(f: ArrayLike, *varargs: ArrayLike,
axis: int | Sequence[int] | None = None,
edge_order: int | None = None) -> Array | list[Array]:
if edge_order is not None:
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
a, *spacing = util.promote_args_inexact("gradient", f, *varargs)
def gradient_along_axis(a, h, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
a_grad = concatenate((
(sliced(1, 2) - sliced(0, 1)), # upper edge
(sliced(2, None) - sliced(None, -2)) * 0.5, # inner
(sliced(-1, None) - sliced(-2, -1)), # lower edge
), axis)
return a_grad / h
if axis is None:
axis_tuple = tuple(range(a.ndim))
else:
axis_tuple = tuple(_canonicalize_axis(i, a.ndim) for i in _ensure_index_tuple(axis))
if len(axis_tuple) == 0:
return []
if min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
if len(spacing) == 0:
dx: Sequence[ArrayLike] = [1.0] * len(axis_tuple)
elif len(spacing) == 1:
dx = list(spacing) * len(axis_tuple)
elif len(spacing) == len(axis_tuple):
dx = list(spacing)
else:
TypeError(f"Invalid number of spacing arguments {len(spacing)} for {axis=}")
if ndim(dx[0]) != 0:
raise NotImplementedError("Non-constant spacing not implemented")
a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)]
return a_grad[0] if len(axis_tuple) == 1 else a_grad
@util._wraps(np.isrealobj)
def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)
@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
__tracebackhide__ = True
util.check_arraylike("reshape", a)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
pass
return asarray(a).reshape(newshape, order=order)
@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a: ArrayLike, order: str = "C") -> Array:
util.check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@util._wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
mode: str = 'raise', order: str = 'C') -> Array:
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
util.check_arraylike("ravel_multi_index", *multi_index)
multi_index_arr = [asarray(i) for i in multi_index]
for index in multi_index_arr:
if mode == 'raise':
core.concrete_or_error(array, index,
"The error occurred because ravel_multi_index was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if not issubdtype(_dtype(index), integer):
raise TypeError("only int indices permitted")
if mode == "raise":
if any(reductions.any((i < 0) | (i >= d)) for i, d in zip(multi_index_arr, dims)):
raise ValueError("invalid entry in coordinates array")
elif mode == "clip":
multi_index_arr = [clip(i, 0, d - 1) for i, d in zip(multi_index_arr, dims)]
elif mode == "wrap":
multi_index_arr = [i % d for i, d in zip(multi_index_arr, dims)]
else:
raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")
if order == "F":
strides = np.cumprod((1,) + dims[:-1])
elif order == "C":
strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
result = array(0, dtype=(multi_index_arr[0].dtype if multi_index_arr
else dtypes.canonicalize_dtype(int_)))
for i, s in zip(multi_index_arr, strides):
result = result + i * int(s)
return result
_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped into the valid range.
"""
@util._wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
util.check_arraylike("unravel_index", indices)
indices_arr = asarray(indices)
# Note: we do not convert shape to an array, because it may be passed as a
# tuple of weakly-typed values, and asarray() would strip these weak types.
try:
shape = list(shape)
except TypeError:
shape = [shape]
if any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = ufuncs.divmod(indices_arr, s)
oob_pos = indices_arr > 0
oob_neg = indices_arr < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in safe_zip(shape, out_indices))
@util._wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
util.check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)
if any(dim_length < 0 for dim_length in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")
arr = ravel(a)
new_size = math.prod(new_shape)
if arr.size == 0 or new_size == 0:
return zeros_like(arr, shape=new_shape)
repeats = ceil_of_ratio(new_size, arr.size)
arr = tile(arr, repeats)[:new_size]
return reshape(arr, new_shape)
@util._wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: tuple[int]) -> Array:
if axis is None:
a_shape = shape(a)
if not core.is_constant_shape(a_shape):
# We do not even know the rank of the output if the input shape is not known
raise ValueError("jnp.squeeze with axis=None is not supported with shape polymorphism")
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
return lax.squeeze(a, axis)
@util._wraps(np.expand_dims)
def expand_dims(a: ArrayLike, axis: int | Sequence[int]) -> Array:
util.check_arraylike("expand_dims", a)
axis = _ensure_index_tuple(axis)
return lax.expand_dims(a, axis)
@util._wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a: ArrayLike, axis1: int, axis2: int) -> Array:
util.check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, list(perm))
@util._wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC)
def moveaxis(a: ArrayLike, source: int | Sequence[int],
destination: int | Sequence[int]) -> Array:
util.check_arraylike("moveaxis", a)
return _moveaxis(asarray(a), _ensure_index_tuple(source),
_ensure_index_tuple(destination))
@partial(jit, static_argnames=('source', 'destination'), inline=True)
def _moveaxis(a: Array, source: tuple[int, ...], destination: tuple[int, ...]) -> Array:
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):
raise ValueError("Inconsistent number of elements: {} vs {}"
.format(len(source), len(destination)))
perm = [i for i in range(ndim(a)) if i not in source]
for dest, src in sorted(zip(destination, source)):
perm.insert(dest, src)
return lax.transpose(a, perm)
@util._wraps(np.isclose)
@partial(jit, static_argnames=('equal_nan',))
def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike = 1e-08,
equal_nan: bool = False) -> Array:
a, b = util.promote_args("isclose", a, b)
dtype = _dtype(a)
if issubdtype(dtype, inexact):
if issubdtype(dtype, complexfloating):
dtype = util._complex_elem_type(dtype)
rtol = lax.convert_element_type(rtol, dtype)
atol = lax.convert_element_type(atol, dtype)
out = lax.le(
lax.abs(lax.sub(a, b)),
lax.add(atol, lax.mul(rtol, lax.abs(b))))
# This corrects the comparisons for infinite and nan values
a_inf = ufuncs.isinf(a)
b_inf = ufuncs.isinf(b)
any_inf = ufuncs.logical_or(a_inf, b_inf)
both_inf = ufuncs.logical_and(a_inf, b_inf)
# Make all elements where either a or b are infinite to False
out = ufuncs.logical_and(out, ufuncs.logical_not(any_inf))
# Make all elements where both a or b are the same inf to True
same_value = lax.eq(a, b)
same_inf = ufuncs.logical_and(both_inf, same_value)
out = ufuncs.logical_or(out, same_inf)
# Make all elements where either a or b is NaN to False
a_nan = ufuncs.isnan(a)
b_nan = ufuncs.isnan(b)
any_nan = ufuncs.logical_or(a_nan, b_nan)
out = ufuncs.logical_and(out, ufuncs.logical_not(any_nan))
if equal_nan:
# Make all elements where both a and b is NaN to True
both_nan = ufuncs.logical_and(a_nan, b_nan)
out = ufuncs.logical_or(out, both_nan)
return out
else:
return lax.eq(a, b)
def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: ArrayLike | str | None = None,
right: ArrayLike | str | None = None,
period: ArrayLike | None = None) -> Array:
util.check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x_arr, xp_arr = util.promote_dtypes_inexact(x, xp)
fp_arr, = util.promote_dtypes_inexact(fp)
del x, xp, fp
if isinstance(left, str):
if left != 'extrapolate':
raise ValueError("the only valid string value of `left` is "
f"'extrapolate', but got: {left!r}")
extrapolate_left = True
else:
extrapolate_left = False