Skip to content

Commit

Permalink
Support value computation of associated Legendre functions.
Browse files Browse the repository at this point in the history
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
  • Loading branch information
tlu7 and jakevdp committed Jun 14, 2021
1 parent 1e4d28a commit 095e650
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ jax.scipy.special
logit
logsumexp
lpmn
lpmn_values
multigammaln
ndtr
ndtri
Expand Down
56 changes: 55 additions & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def body_fun(i, p_val):
return p


def lpmn(m, n, z):
def lpmn(m: int, n: int, z: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""The associated Legendre functions (ALFs) of the first kind.
Args:
Expand Down Expand Up @@ -956,3 +956,57 @@ def lpmn(m, n, z):
p_derivatives = _gen_derivatives(p_vals, z, is_normalized)

return (p_vals, p_derivatives)


def lpmn_values(m: int, n: int, z: jnp.ndarray, is_normalized: bool) -> jnp.ndarray:
r"""The associated Legendre functions (ALFs) of the first kind.
Unlike `lpmn`, this function only computes the values of ALFs.
The ALFs of the first kind can be used in spherical harmonics. The
spherical harmonic of degree `l` and order `m` can be written as
:math:`Y_l^m(\theta, \phi) = N_l^m * P_l^m(\cos \theta) * \exp(i m \phi)`,
where :math:`N_l^m` is the normalization factor and θ and φ are the
colatitude and longitude, repectively. :math:`N_l^m` is chosen in the
way that the spherical harmonics form a set of orthonormal basis function
of :math:`L^2(S^2)`. Normalizing :math:`P_l^m` avoids overflow/underflow
and achieves better numerical stability.
Args:
m: The maximum order of the associated Legendre functions.
n: The maximum degree of the associated Legendre function, often called
`l` in describing ALFs. Both the degrees and orders are
`[0, 1, 2, ..., l_max]`, where `l_max` denotes the maximum degree.
z: A vector of type `float32` or `float64` containing the sampling
points at which the ALFs are computed.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, :math:`N_l^m` is applied such that the spherical
harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`.
Returns:
A 3D array of shape `(l_max + 1, l_max + 1, len(z))` containing
the values of the associated Legendre functions of the first kind. The
return type matches the type of `z`.
Raises:
TypeError if elements of array `z` are not in (float32, float64).
ValueError if array `z` is not 1D.
NotImplementedError if `m!=n`.
"""
dtype = lax.dtype(z)
if dtype not in (jnp.float32, jnp.float64):
raise TypeError(
'z.dtype={} is not supported, see docstring for supported types.'
.format(dtype))

if z.ndim != 1:
raise ValueError('z must be a 1D array.')

m = core.concrete_or_error(int, m, 'Argument m of lpmn.')
n = core.concrete_or_error(int, n, 'Argument n of lpmn.')

if m != n:
raise NotImplementedError('Computations for m!=n are not yet supported.')

l_max = n

return _gen_associated_legendre(l_max, z, is_normalized)
1 change: 1 addition & 0 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
logit,
logsumexp,
lpmn,
lpmn_values,
multigammaln,
log_ndtr,
ndtr,
Expand Down
46 changes: 46 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,52 @@ def testLpmn(self, l_max, num_z):
self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
rtol=1e-6, atol=8.4e-4)

with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
self._CompileAndCheck(lsp_special_fn, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_maxdegree={}_inputsize={}".format(l_max, num_z),
"l_max": l_max,
"num_z": num_z}
for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
def testNormalizedLpmnValues(self, l_max, num_z):
# Points on which the associated Legendre functions areevaluated.
z = np.linspace(-0.2, 0.9, num_z)
is_normalized = True
actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)

# The expected results are obtained from scipy.
expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
for i in range(num_z):
expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]

def apply_normalization(a):
"""Applies normalization to the associated Legendre functions."""
num_m, num_l, _ = a.shape
a_normalized = np.zeros_like(a)
for m in range(num_m):
for l in range(num_l):
c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
c2 = np.sqrt(c0 / c1)
a_normalized[m, l] = c2 * a[m, l]
return a_normalized

# The results from scipy are not normalized and the comparison requires
# normalizing the results.
expected_p_vals_normalized = apply_normalization(expected_p_vals)

with self.subTest('Test accuracy.'):
self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6)

with self.subTest('Test JIT compatibility'):
args_maker = lambda: [z]
lsp_special_fn = lambda z: lsp_special.lpmn_values(l_max, l_max, z, is_normalized)
self._CompileAndCheck(lsp_special_fn, args_maker)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 095e650

Please sign in to comment.