Skip to content

Commit

Permalink
Merge pull request #440 from hawkinsp/master
Browse files Browse the repository at this point in the history
Expose logsumexp as scipy.special.logsumexp.
  • Loading branch information
hawkinsp committed Feb 24, 2019
2 parents daf3e3f + 95483c7 commit dd5b2a6
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/mnist_classifier_fromscratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from jax.api import jit, grad
from jax.config import config
from jax.scipy.misc import logsumexp
from jax.scipy.special import logsumexp
import jax.numpy as np
from examples import datasets

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from jax import lax
from jax import random
from jax.scipy.misc import logsumexp
from jax.scipy.special import logsumexp
import jax.numpy as np


Expand Down
15 changes: 3 additions & 12 deletions jax/scipy/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,9 @@
import scipy.misc as osp_misc

from .. import lax
from ..scipy import special
from ..numpy.lax_numpy import _wraps, _reduction_dims, _constant_like


@_wraps(osp_misc.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None or return_sign:
raise NotImplementedError("Only implemented for b=None, return_sign=False")
dims = _reduction_dims(a, axis)
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
dimadd = lambda x: lax.reshape(x, shape)
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
amax_singletons = dimadd(amax)
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
_constant_like(a, 0), lax.add, dims)), amax)
return dimadd(out) if keepdims else out
if hasattr(osp_misc, 'logsumexp'):
logsumexp = special.logsumexp
17 changes: 16 additions & 1 deletion jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from __future__ import division
from __future__ import print_function

import numpy as onp
import scipy.special as osp_special

from .. import lax
from ..numpy.lax_numpy import _wraps, asarray
from ..numpy.lax_numpy import _wraps, asarray, _reduction_dims, _constant_like


# need to create new functions because _wraps sets the __name__ attribute
Expand All @@ -41,3 +42,17 @@ def expit(x):
x = asarray(x)
one = lax._const(x, 1)
return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))


@_wraps(osp_special.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None or return_sign:
raise NotImplementedError("Only implemented for b=None, return_sign=False")
dims = _reduction_dims(a, axis)
shape = lax.subvals(onp.shape(a), zip(dims, (1,) * len(dims)))
dimadd = lambda x: lax.reshape(x, shape)
amax = lax.reduce(a, _constant_like(a, -onp.inf), lax.max, dims)
amax_singletons = dimadd(amax)
out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
_constant_like(a, 0), lax.add, dims)), amax)
return dimadd(out) if keepdims else out
4 changes: 2 additions & 2 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def _GetArgsMaker(self, rng, shapes, dtypes):
def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
# TODO(mattjj): test autodiff
def scipy_fun(array_to_reduce):
return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

def lax_fun(array_to_reduce):
return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)
return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
Expand Down

0 comments on commit dd5b2a6

Please sign in to comment.