-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
special.py
1681 lines (1403 loc) · 55.5 KB
/
special.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import operator
from typing import cast, Any, Optional
import numpy as np
import scipy.special as osp_special
import jax.numpy as jnp
from jax import jit
from jax import vmap
from jax import lax
from jax._src import core
from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.ops import special as ops_special
from jax._src.third_party.scipy.betaln import betaln as _betaln_impl
from jax._src.typing import Array, ArrayLike
@_wraps(osp_special.gammaln, module='scipy.special')
def gammaln(x: ArrayLike) -> Array:
x, = promote_args_inexact("gammaln", x)
return lax.lgamma(x)
@_wraps(osp_special.gamma, module='scipy.special', lax_description="""\
The JAX version only accepts real-valued inputs.""")
def gamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("gamma", x)
return lax.exp(lax.lgamma(x))
betaln = _wraps(
osp_special.betaln,
module='scipy.special',
update_doc=False
)(_betaln_impl)
@_wraps(osp_special.betainc, module='scipy.special')
def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
a, b, x = promote_args_inexact("betainc", a, b, x)
return lax.betainc(a, b, x)
@_wraps(osp_special.digamma, module='scipy.special', lax_description="""\
The JAX version only accepts real-valued inputs.""")
def digamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("digamma", x)
return lax.digamma(x)
@_wraps(osp_special.gammainc, module='scipy.special', update_doc=False)
def gammainc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = promote_args_inexact("gammainc", a, x)
return lax.igamma(a, x)
@_wraps(osp_special.gammaincc, module='scipy.special', update_doc=False)
def gammaincc(a: ArrayLike, x: ArrayLike) -> Array:
a, x = promote_args_inexact("gammaincc", a, x)
return lax.igammac(a, x)
@_wraps(osp_special.erf, module='scipy.special', skip_params=["out"],
lax_description="Note that the JAX version does not support complex inputs.")
def erf(x: ArrayLike) -> Array:
x, = promote_args_inexact("erf", x)
return lax.erf(x)
@_wraps(osp_special.erfc, module='scipy.special', update_doc=False)
def erfc(x: ArrayLike) -> Array:
x, = promote_args_inexact("erfc", x)
return lax.erfc(x)
@_wraps(osp_special.erfinv, module='scipy.special')
def erfinv(x: ArrayLike) -> Array:
x, = promote_args_inexact("erfinv", x)
return lax.erf_inv(x)
@custom_derivatives.custom_jvp
@_wraps(osp_special.logit, module='scipy.special', update_doc=False)
def logit(x: ArrayLike) -> Array:
x, = promote_args_inexact("logit", x)
return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x)))
logit.defjvps(
lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x))))
@_wraps(osp_special.expit, module='scipy.special', update_doc=False)
def expit(x: ArrayLike) -> Array:
x, = promote_args_inexact("expit", x)
return lax.logistic(x)
logsumexp = _wraps(osp_special.logsumexp, module='scipy.special')(ops_special.logsumexp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlogy, module='scipy.special')
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlogy(0, 0) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlogy", x, y)
x_ok = x != 0.
return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x))
def _xlogy_jvp(primals, tangents):
(x, y) = primals
(x_dot, y_dot) = tangents
result = xlogy(x, y)
return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
xlogy.defjvp(_xlogy_jvp)
@custom_derivatives.custom_jvp
@_wraps(osp_special.xlog1py, module='scipy.special', update_doc=False)
def xlog1py(x: ArrayLike, y: ArrayLike) -> Array:
# Note: xlog1py(0, -1) should return 0 according to the function documentation.
x, y = promote_args_inexact("xlog1py", x, y)
x_ok = x != 0.
return jnp.where(x_ok, lax.mul(x, lax.log1p(y)), jnp.zeros_like(x))
def _xlog1py_jvp(primals, tangents):
(x, y) = primals
(x_dot, y_dot) = tangents
result = xlog1py(x, y)
return result, (x_dot * lax.log1p(y) + y_dot * x / (1 + y)).astype(result.dtype)
xlog1py.defjvp(_xlog1py_jvp)
@custom_derivatives.custom_jvp
def _xlogx(x):
"""Compute x log(x) with well-defined derivatives."""
return xlogy(x, x)
def _xlogx_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return _xlogx(x), x_dot * (lax.log(x) + 1)
_xlogx.defjvp(_xlogx_jvp)
@_wraps(osp_special.entr, module='scipy.special')
def entr(x: ArrayLike) -> Array:
x, = promote_args_inexact("entr", x)
return lax.select(lax.lt(x, _lax_const(x, 0)),
lax.full_like(x, -np.inf),
lax.neg(_xlogx(x)))
@_wraps(osp_special.multigammaln, update_doc=False)
def multigammaln(a: ArrayLike, d: ArrayLike) -> Array:
d = core.concrete_or_error(int, d, "d argument of multigammaln")
a, d_ = promote_args_inexact("multigammaln", a, d)
constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_),
lax.sub(d_, _lax_const(a, 1))),
lax.log(_lax_const(a, np.pi)))
b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
axis=-1)
return res + constant
@_wraps(osp_special.kl_div, module="scipy.special")
def kl_div(
p: ArrayLike,
q: ArrayLike,
) -> Array:
p, q = promote_args_inexact("kl_div", p, q)
zero = _lax_const(p, 0.0)
both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero))
one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero))
safe_p = jnp.where(both_gt_zero_mask, p, 1)
safe_q = jnp.where(both_gt_zero_mask, q, 1)
log_val = lax.sub(
lax.add(
lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q)),
safe_q,
),
safe_p,
)
result = jnp.where(
both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, np.inf)
)
return result
@_wraps(osp_special.rel_entr, module="scipy.special")
def rel_entr(
p: ArrayLike,
q: ArrayLike,
) -> Array:
p, q = promote_args_inexact("rel_entr", p, q)
zero = _lax_const(p, 0.0)
both_gt_zero_mask = lax.bitwise_and(lax.gt(p, zero), lax.gt(q, zero))
one_zero_mask = lax.bitwise_and(lax.eq(p, zero), lax.ge(q, zero))
safe_p = jnp.where(both_gt_zero_mask, p, 1)
safe_q = jnp.where(both_gt_zero_mask, q, 1)
log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q))
result = jnp.where(
both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, jnp.inf)
)
return result
# coefs of (2k)! / B_{2k} where B are bernoulli numbers
# those numbers are obtained using https://www.wolframalpha.com
_BERNOULLI_COEFS = [
12,
-720,
30240,
-1209600,
47900160,
-1307674368000 / 691,
74724249600,
-10670622842880000 / 3617,
5109094217170944000 / 43867,
-802857662698291200000 / 174611,
14101100039391805440000 / 77683,
-1693824136731743669452800000 / 236364091,
186134520519971831808000000 / 657931,
-37893265687455865519472640000000 / 3392780147,
759790291646040068357842010112000000 / 1723168255201,
-134196726836183700385281186201600000000 / 7709321041217,
]
@_wraps(osp_special.zeta, module='scipy.special')
def zeta(x: ArrayLike, q: Optional[ArrayLike] = None) -> Array:
assert q is not None, "Riemann zeta function is not implemented yet."
# Reference: Johansson, Fredrik.
# "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
# Numerical Algorithms 69.2 (2015): 253-270.
# https://arxiv.org/abs/1309.2877 - formula (5)
# here we keep the same notation as in reference
s, a = promote_args_inexact("zeta", x, q)
dtype = lax.dtype(a).type
s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
# precision ~ N, M
N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
assert M <= len(_BERNOULLI_COEFS)
k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
S = jnp.sum((a_ + k) ** -s_, -1)
I = lax.div((a + N) ** (dtype(1) - s), s - dtype(1))
T0 = (a + N) ** -s
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
s_over_a = (s_ + m) / (a_ + N)
T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
tuple(range(a.ndim)))
T1 = T1 / coefs
T = T0 * (dtype(0.5) + T1.sum(-1))
return S + I + T
@_wraps(osp_special.polygamma, module='scipy.special', update_doc=False)
def polygamma(n: ArrayLike, x: ArrayLike) -> Array:
assert jnp.issubdtype(lax.dtype(n), jnp.integer)
n_arr, x_arr = promote_args_inexact("polygamma", n, x)
return lax.polygamma(n_arr, x_arr)
# Normal distributions
# Functions "ndtr" and "ndtri" are derived from calculations made in:
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
# The "spence" function is also based on the Cephes library with
# the corresponding spence.c file located in the tarball:
# https://netlib.org/cephes/misc.tgz
# In the following email exchange, the author gives his consent to redistribute
# derived works under an Apache 2.0 license.
#
# From: Stephen Moshier <steve@moshier.net>
# Date: Sat, Jun 9, 2018 at 2:36 PM
# Subject: Re: Licensing cephes under Apache (BSD-like) license.
# To: rif <rif@google.com>
#
#
#
# Hello Rif,
#
# Yes, Google may distribute Cephes files under the Apache 2 license.
#
# If clarification is needed, I do not favor BSD over other free licenses.
# I would agree that Apache 2 seems to cover the concern you mentioned
# about sublicensees.
#
# Best wishes for good luck with your projects!
# Steve Moshier
#
#
#
# On Thu, 31 May 2018, rif wrote:
#
# > Hello Steve.
# > My name is Rif. I work on machine learning software at Google.
# >
# > Your cephes software continues to be incredibly useful and widely used. I
# > was wondering whether it would be permissible for us to use the Cephes code
# > under the Apache 2.0 license, which is extremely similar in permissions to
# > the BSD license (Wikipedia comparisons). This would be quite helpful to us
# > in terms of avoiding multiple licenses on software.
# >
# > I'm sorry to bother you with this (I can imagine you're sick of hearing
# > about this by now), but I want to be absolutely clear we're on the level and
# > not misusing your important software. In former conversation with Eugene
# > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD,
# > the formal way that has been handled is simply to add a statement to the
# > effect that you are incorporating the Cephes software by permission of the
# > author." I wanted to confirm that (a) we could use the Apache license, (b)
# > that we don't need to (and probably you don't want to) keep getting
# > contacted about individual uses, because your intent is generally to allow
# > this software to be reused under "BSD-like" license, and (c) you're OK
# > letting incorporators decide whether a license is sufficiently BSD-like?
# >
# > Best,
# >
# > rif
# >
# >
# >
# log_ndtr uses different functions over the ranges
# (-infty, lower](lower, upper](upper, infty)
# Lower bound values were chosen by examining where the support of ndtr
# appears to be zero, relative to scipy's (which is always 64bit). They were
# then made more conservative just to be safe. (Conservative means use the
# expansion more than we probably need to.)
_LOGNDTR_FLOAT64_LOWER = np.array(-20, np.float64)
_LOGNDTR_FLOAT32_LOWER = np.array(-10, np.float32)
# Upper bound values were chosen by examining for which values of 'x'
# Log[cdf(x)] is 0, after which point we need to use the approximation
# Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly
# conservative, meaning we use the approximation earlier than needed.
_LOGNDTR_FLOAT64_UPPER = np.array(8, np.float64)
_LOGNDTR_FLOAT32_UPPER = np.array(5, np.float32)
def ndtr(x: ArrayLike) -> Array:
r"""Normal distribution function.
Returns the area under the Gaussian probability density function, integrated
from minus infinity to x:
.. math::
\begin{align}
\mathrm{ndtr}(x) =&
\ \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \\
=&\ \frac{1}{2} (1 + \mathrm{erf}(\frac{x}{\sqrt{2}})) \\
=&\ \frac{1}{2} \mathrm{erfc}(\frac{x}{\sqrt{2}})
\end{align}
Args:
x: An array of type `float32`, `float64`.
Returns:
An array with `dtype=x.dtype`.
Raises:
TypeError: if `x` is not floating-type.
"""
x = jnp.asarray(x)
dtype = lax.dtype(x)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
"x.dtype={} is not supported, see docstring for supported types."
.format(dtype))
return _ndtr(x)
def _ndtr(x: ArrayLike) -> Array:
"""Implements ndtr core logic."""
dtype = lax.dtype(x).type
half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
w = x * half_sqrt_2
z = lax.abs(w)
y = lax.select(lax.lt(z, half_sqrt_2),
dtype(1.) + lax.erf(w),
lax.select(lax.gt(w, dtype(0.)),
dtype(2.) - lax.erfc(z),
lax.erfc(z)))
return dtype(0.5) * y
def ndtri(p: ArrayLike) -> Array:
r"""The inverse of the CDF of the Normal distribution function.
Returns `x` such that the area under the PDF from :math:`-\infty` to `x` is equal
to `p`.
A piece-wise rational approximation is done for the function.
This is a based on the implementation in netlib.
Args:
p: an array of type `float32`, `float64`.
Returns:
an array with `dtype=p.dtype`.
Raises:
TypeError: if `p` is not floating-type.
"""
dtype = lax.dtype(p)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
"x.dtype={} is not supported, see docstring for supported types."
.format(dtype))
return _ndtri(p)
def _ndtri(p: ArrayLike) -> Array:
"""Implements ndtri core logic."""
# Constants used in piece-wise rational approximations. Taken from the cephes
# library:
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
p0 = list(reversed([-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0]))
q0 = list(reversed([1.0,
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0]))
p1 = list(reversed([4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4]))
q1 = list(reversed([1.0,
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4]))
p2 = list(reversed([3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9]))
q2 = list(reversed([1.0,
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9]))
dtype = lax.dtype(p).type
shape = jnp.shape(p)
def _create_polynomial(var, coeffs):
"""Compute n_th order polynomial via Horner's method."""
coeffs = np.array(coeffs, dtype)
if not coeffs.size:
return jnp.zeros_like(var)
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
# later on. The result from the computation when p == 0 is not used so any
# number that doesn't result in NaNs is fine.
sanitized_mcp = jnp.where(
maybe_complement_p <= dtype(0.),
jnp.full(shape, dtype(0.5)),
maybe_complement_p)
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
w = sanitized_mcp - dtype(0.5)
ww = lax.square(w)
x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
/ _create_polynomial(ww, q0))
x_for_big_p *= -dtype(np.sqrt(2. * np.pi))
# Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
# where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
# arrays based on whether p < exp(-32).
z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
first_term = z - lax.log(z) / z
second_term_small_p = (
_create_polynomial(dtype(1.) / z, p2) /
_create_polynomial(dtype(1.) / z, q2) / z)
second_term_otherwise = (
_create_polynomial(dtype(1.) / z, p1) /
_create_polynomial(dtype(1.) / z, q1) / z)
x_for_small_p = first_term - second_term_small_p
x_otherwise = first_term - second_term_otherwise
x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)),
x_for_big_p,
jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))
x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
infinity = jnp.full(shape, dtype(np.inf))
x_nan_replaced = jnp.where(
p <= dtype(0.0), -infinity, jnp.where(p >= dtype(1.0), infinity, x))
return x_nan_replaced
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(1,))
def log_ndtr(x: ArrayLike, series_order: int = 3) -> Array:
r"""Log Normal distribution function.
For details of the Normal distribution function see `ndtr`.
This function calculates :math:`\log(\mathrm{ndtr}(x))` by either calling
:math:`\log(\mathrm{ndtr}(x))` or using an asymptotic series. Specifically:
- For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
:math:`\log(1-x) \approx -x, x \ll 1`.
- For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
and take a log.
- For `x <= lower_segment`, we use the series approximation of `erf` to compute
the log CDF directly.
The `lower_segment` is set based on the precision of the input:
.. math::
\begin{align}
\mathit{lower\_segment} =&
\ \begin{cases}
-20 & x.\mathrm{dtype}=\mathit{float64} \\
-10 & x.\mathrm{dtype}=\mathit{float32} \\
\end{cases} \\
\mathit{upper\_segment} =&
\ \begin{cases}
8& x.\mathrm{dtype}=\mathit{float64} \\
5& x.\mathrm{dtype}=\mathit{float32} \\
\end{cases}
\end{align}
When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
.. math::
\begin{align}
\mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\
\mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\
\mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\
R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3})
\end{align}
where :math:`(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a
`double-factorial
<https://en.wikipedia.org/wiki/Double_factorial>`_ operator.
Args:
x: an array of type `float32`, `float64`.
series_order: Positive Python integer. Maximum depth to
evaluate the asymptotic expansion. This is the `N` above.
Returns:
an array with `dtype=x.dtype`.
Raises:
TypeError: if `x.dtype` is not handled.
TypeError: if `series_order` is a not Python `integer.`
ValueError: if `series_order` is not in `[0, 30]`.
"""
if not isinstance(series_order, int):
raise TypeError("series_order must be a Python integer.")
if series_order < 0:
raise ValueError("series_order must be non-negative.")
if series_order > 30:
raise ValueError("series_order must be <= 30.")
x = jnp.asarray(x)
dtype = lax.dtype(x)
if dtype == jnp.float64:
lower_segment: np.ndarray = _LOGNDTR_FLOAT64_LOWER
upper_segment: np.ndarray = _LOGNDTR_FLOAT64_UPPER
elif dtype == jnp.float32:
lower_segment = _LOGNDTR_FLOAT32_LOWER
upper_segment = _LOGNDTR_FLOAT32_UPPER
else:
raise TypeError(f"x.dtype={np.dtype(dtype)} is not supported.")
# The basic idea here was ported from:
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
# We copy the main idea, with a few changes
# * For x >> 1, and X ~ Normal(0, 1),
# Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
# which extends the range of validity of this function.
# * We use one fixed series_order for all of 'x', rather than adaptive.
# * Our docstring properly reflects that this is an asymptotic series, not a
# Taylor series. We also provided a correct bound on the remainder.
# * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
# x=0. This happens even though the branch is unchosen because when x=0
# the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
# regardless of whether dy is finite. Note that the minimum is a NOP if
# the branch is chosen.
return jnp.where(
lax.gt(x, upper_segment),
-_ndtr(-x), # log(1-x) ~= -x, x << 1
jnp.where(lax.gt(x, lower_segment),
lax.log(_ndtr(lax.max(x, lower_segment))),
_log_ndtr_lower(lax.min(x, lower_segment),
series_order)))
def _log_ndtr_jvp(series_order, primals, tangents):
(x,), (t,) = primals, tangents
ans = log_ndtr(x, series_order=series_order)
t_out = lax.mul(t, lax.exp(lax.sub(_norm_logpdf(x), ans)))
return ans, t_out
log_ndtr.defjvp(_log_ndtr_jvp)
def _log_ndtr_lower(x, series_order):
"""Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
dtype = lax.dtype(x).type
x_2 = lax.square(x)
# Log of the term multiplying (1 + sum)
log_scale = -dtype(0.5) * x_2 - lax.log(-x) - dtype(0.5 * np.log(2. * np.pi))
return log_scale + lax.log(_log_ndtr_asymptotic_series(x, series_order))
def _log_ndtr_asymptotic_series(x, series_order):
"""Calculates the asymptotic series used in log_ndtr."""
dtype = lax.dtype(x).type
if series_order <= 0:
return np.array(1, dtype)
x_2 = lax.square(x)
even_sum = jnp.zeros_like(x)
odd_sum = jnp.zeros_like(x)
x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1.
for n in range(1, series_order + 1):
y = np.array(_double_factorial(2 * n - 1), dtype) / x_2n
if n % 2:
odd_sum += y
else:
even_sum += y
x_2n *= x_2
return dtype(1.) + even_sum - odd_sum
def _double_factorial(n: int) -> np.ndarray:
"""The double factorial function for small Python integer `n`."""
return np.prod(np.arange(n, 1, -2))
_norm_logpdf_constant = np.log(np.sqrt(2 * np.pi))
def _norm_logpdf(x):
neg_half = _lax_const(x, -0.5)
log_normalizer = _lax_const(x, _norm_logpdf_constant)
return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
@_wraps(osp_special.i0e, module='scipy.special')
def i0e(x: ArrayLike) -> Array:
x, = promote_args_inexact("i0e", x)
return lax.bessel_i0e(x)
@_wraps(osp_special.i0, module='scipy.special')
def i0(x: ArrayLike) -> Array:
x, = promote_args_inexact("i0", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i0e(x))
@_wraps(osp_special.i1e, module='scipy.special')
def i1e(x: ArrayLike) -> Array:
x, = promote_args_inexact("i1e", x)
return lax.bessel_i1e(x)
@_wraps(osp_special.i1, module='scipy.special')
def i1(x: ArrayLike) -> Array:
x, = promote_args_inexact("i1", x)
return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
def _bessel_jn_scan_body_fun(carry, k):
f0, f1, bs, z = carry
f = 2.0 * (k + 1.0) * f1 / z - f0
def true_fn_update_bs(u):
bs, f = u
return bs + 2.0 * f
def false_fn_update_bs(u):
bs, _ = u
return bs
bs = lax.cond(jnp.mod(k, 2) == 0, true_fn_update_bs,
false_fn_update_bs, operand=(bs, f))
f0 = f1
f1 = f
return (f0, f1, bs, z), f
def _bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
f0 = _lax_const(z, 0.0)
f1 = _lax_const(z, 1E-16)
f = _lax_const(z, 0.0)
bs = _lax_const(z, 0.0)
(_, _, bs, _), j_vals = lax.scan(
f=_bessel_jn_scan_body_fun, init=(f0, f1, bs, z),
xs=lax.iota(lax.dtype(z), n_iter+1), reverse=True)
f = j_vals[0] # Use the value at the last iteration.
j_vals = j_vals[:v+1]
j_vals = j_vals / (bs - f)
return j_vals
@partial(jit, static_argnames=["v", "n_iter"])
def bessel_jn(z: ArrayLike, *, v: int, n_iter: int=50) -> Array:
"""Bessel function of the first kind of integer order and real argument.
Reference:
Shanjie Zhang and Jian-Ming Jin. Computation of special functions.
Wiley-Interscience, 1996.
Args:
z: The sampling point(s) at which the Bessel function of the first kind are
computed.
v: The order (int) of the Bessel function.
n_iter: The number of iterations required for updating the function
values. As a rule of thumb, `n_iter` is the smallest nonnegative integer
that satisfies the condition
`int(0.5 * log10(6.28 + n_iter) - n_iter * log10(1.36 + abs(z) / n_iter)) > 20`.
Details in `BJNDD` (https://people.sc.fsu.edu/~jburkardt/f77_src/special_functions/special_functions.f)
Returns:
An array of shape `(v+1, *z.shape)` containing the values of the Bessel
function of orders 0, 1, ..., v. The return type matches the type of `z`.
Raises:
TypeError if `v` is not integer.
ValueError if elements of array `z` are not float.
"""
z = jnp.asarray(z)
z, = promote_dtypes_inexact(z)
z_dtype = lax.dtype(z)
if dtypes.issubdtype(z_dtype, complex):
raise ValueError("complex input not supported.")
v = core.concrete_or_error(operator.index, v, 'Argument v of bessel_jn.')
n_iter = core.concrete_or_error(int, n_iter, 'Argument n_iter of bessel_jn.')
bessel_jn_fun = partial(_bessel_jn, v=v, n_iter=n_iter)
for _ in range(z.ndim):
bessel_jn_fun = vmap(bessel_jn_fun)
return jnp.moveaxis(bessel_jn_fun(z), -1, 0)
def _gen_recurrence_mask(
l_max: int, is_normalized: bool, dtype: Any
) -> tuple[Array, Array]:
"""Generates mask for recurrence relation on the remaining entries.
The remaining entries are with respect to the diagonal and offdiagonal
entries.
Args:
l_max: see `gen_normalized_legendre`.
is_normalized: True if the recurrence mask is used by normalized associated
Legendre functions.
Returns:
Arrays representing the mask used by the recurrence relations.
"""
# Computes all coefficients.
m_mat, l_mat = jnp.meshgrid(
jnp.arange(l_max + 1, dtype=dtype),
jnp.arange(l_max + 1, dtype=dtype),
indexing='ij')
if is_normalized:
c0 = l_mat * l_mat
c1 = m_mat * m_mat
c2 = 2.0 * l_mat
c3 = (l_mat - 1.0) * (l_mat - 1.0)
d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
else:
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
d_zeros = jnp.zeros((l_max + 1, l_max + 1), dtype=dtype)
d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
# i = jnp.arange(l_max + 1)[:, None, None]
# j = jnp.arange(l_max + 1)[None, :, None]
# k = jnp.arange(l_max + 1)[None, None, :]
i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
mask = (i + j - k == 0).astype(dtype)
d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)
return (d0_mask_3d, d1_mask_3d)
@partial(jit, static_argnums=(2))
def _gen_derivatives(p: Array,
x: Array,
is_normalized: bool) -> Array:
"""Generates derivatives of associated Legendre functions of the first kind.
Args:
p: The 3D array containing the values of associated Legendre functions; the
dimensions are in the sequence of order (m), degree (l), and evalution
points.
x: A vector of type `float32` or `float64` containing the sampled points.
is_normalized: True if the associated Legendre functions are normalized.
Returns:
The 3D array representing the derivatives of associated Legendre functions
of the first kind.
"""
num_m, num_l, num_x = p.shape
# p_{l-1}^m.
p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]
# p_{l-1}^{m+2}.
p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]
# p_{l-1}^{m-2}.
p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]
# Derivative computation requires negative orders.
if is_normalized:
raise NotImplementedError(
'Negative orders for normalization is not implemented yet.')
else:
if num_l > 1:
l_vec = jnp.arange(1, num_l - 1, dtype=x.dtype)
p_p1 = p[1, 1:num_l - 1, :]
coeff = -1.0 / ((l_vec + 1) * l_vec)
update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)
if num_l > 2:
l_vec = jnp.arange(2, num_l - 1, dtype=x.dtype)
p_p2 = p[2, 2:num_l - 1, :]
coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec * (l_vec - 1))
update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)
m_mat, l_mat = jnp.meshgrid(
jnp.arange(num_m, dtype=x.dtype),
jnp.arange(num_l, dtype=x.dtype),
indexing='ij')
coeff_zeros = jnp.zeros((num_m, num_l), dtype=x.dtype)
upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
zero_vec = jnp.zeros((num_l,), dtype=x.dtype)
a0 = -0.5 / (m_mat - 1.0)
a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
a0_masked = a0_masked.at[1, :].set(zero_vec)
b0 = l_mat + m_mat
c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
c0_masked = c0_masked.at[1, :].set(zero_vec)
# p_l^{m-1}.
p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))
d0 = -0.5 / (m_mat + 1.0)
d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
e0 = d0 * b0 * (b0 + 1.0)
e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])
# p_l^{m+1}.
p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))
f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l
# Special treatment of the singularity at m = 1.
if num_m > 1:
l_vec = jnp.arange(num_l, dtype=p.dtype)
g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
if num_l > 2:
g0 = g0 - p[2, :, :]
p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
p_derivative = p_derivative.at[1, 0, :].set(0)
return p_derivative
@partial(jit, static_argnums=(0, 2))
def _gen_associated_legendre(l_max: int,
x: Array,
is_normalized: bool) -> Array:
r"""Computes associated Legendre functions (ALFs) of the first kind.
The ALFs of the first kind are used in spherical harmonics. The spherical
harmonic of degree `l` and order `m` can be written as
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
normalization factor and θ and φ are the colatitude and longitude,
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
a set of orthonormal basis function of L^2(S^2). For the computational
efficiency of spherical harmonics transform, the normalization factor is
used in the computation of the ALFs. In addition, normalizing `P_l^m`
avoids overflow/underflow and achieves better numerical stability. Three
recurrence relations are used in the computation.
Args:
l_max: The maximum degree of the associated Legendre function. Both the
degrees and orders are `[0, 1, 2, ..., l_max]`.
x: A vector of type `float32`, `float64` containing the sampled points in
spherical coordinates, at which the ALFs are computed; `x` is essentially
`cos(θ)`. For the numerical integration used by the spherical harmonics
transforms, `x` contains the quadrature points in the interval of
`[-1, 1]`. There are several approaches to provide the quadrature points:
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
transforms and convolutions on the 2-sphere." Advances in applied
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
points are nearly equal-spaced along θ and provide exact discrete
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
operation, `W` is a diagonal matrix containing the quadrature weights,
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
spaced, which only provide approximate discrete orthogonality. The
Driscoll & Healy qudarture points are equally spaced and provide the
exact discrete orthogonality. The number of sampling points is required to
be twice as the number of frequency points (modes) in the Driscoll & Healy
approach, which enables FFT and achieves a fast spherical harmonics
transform.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, `N_l^m` is applied such that the spherical harmonics
form a set of orthonormal basis functions of L^2(S^2).
Returns:
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
evalution points.
"""
p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]), dtype=x.dtype)
a_idx = jnp.arange(1, l_max + 1, dtype=x.dtype)
b_idx = jnp.arange(l_max, dtype=x.dtype)
if is_normalized:
initial_value: ArrayLike = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0).
f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
f_b = jnp.sqrt(2.0 * b_idx + 3.0)
else:
initial_value = 1.0 # The initial value p(0,0).
f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
f_b = 2.0 * b_idx + 1.0
p = p.at[(0, 0)].set(initial_value)