Skip to content

Commit

Permalink
fix for scalar arr
Browse files Browse the repository at this point in the history
  • Loading branch information
inailuig committed Jun 12, 2023
1 parent e96e00e commit d55d1a1
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions netket/utils/mpi/_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import jax.numpy as jnp
from jax import lax
from netket.utils.mpi import mpi_max_jax, mpi_sum_jax, mpi_allgather_jax
from netket.utils.mpi import mpi_max_jax, mpi_sum_jax, mpi_allgather_jax, n_nodes


def _promote_args_inexact(_, *args):
Expand Down Expand Up @@ -75,8 +75,11 @@ def mpi_logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, token
b_arr = a_arr # for type checking
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = jnp.max(a_arr, axis=dims, keepdims=keepdims)
if 0 in dims:
if jnp.iscomplexobj(amax):
print("dims", dims)
if (
0 in dims or dims == ()
) and n_nodes > 1: # skip if not on mpi as jnp.max(..., axis=0) fails on scalar
if np.issubdtype(amax.dtype, np.complexfloating):
# TODO mpi_max_jax does not work with complex numbers
# We would need lexicographic ordering just like jax.lax.max
# (consider first real part then imag part if equal)
Expand All @@ -94,7 +97,9 @@ def mpi_logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, token
if b is None and not np.issubdtype(a_arr.dtype, np.complexfloating):
tmp1 = lax.exp(lax.sub(a_arr, amax_with_dims)) # TODO MPI
tmp2 = jnp.sum(tmp1, axis=dims, keepdims=keepdims)
if 0 in dims:
if 0 in dims or (
dims == () and n_nodes > 1
): # if scalar but on mpi we still need to reduce
tmp2, token = mpi_sum_jax(tmp2, token=token)
out = lax.add(lax.log(tmp2), amax)

Expand All @@ -105,7 +110,9 @@ def mpi_logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, token
if b is not None:
expsub = lax.mul(expsub, b_arr)
sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)
if 0 in dims:
if 0 in dims or (
dims == () and n_nodes > 1
): # if scalar but on mpi we still need to reduce
sumexp, token = mpi_sum_jax(sumexp, token=token)

sign = lax.stop_gradient(jnp.sign(sumexp))
Expand Down

0 comments on commit d55d1a1

Please sign in to comment.