Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Apr 11, 2024
1 parent 6a4d8c3 commit 893f7e4
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 18 deletions.
4 changes: 3 additions & 1 deletion netket/jax/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def imag_fun(*args, **kwargs):
)(*args, **kwargs)

out = out_r + 1j * out_j
grad = jax.tree_util.tree_map(lambda re, im: re + 1j * im, grad_r, grad_j)
grad = jax.tree_util.tree_map(
lambda re, im: re + 1j * im, grad_r, grad_j
)

if has_aux:
return out, grad, aux
Expand Down
8 changes: 6 additions & 2 deletions netket/jax/_jacobian/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@ def jacobian(

if pdf is None:
if center:
jacobians = jax.tree_util.tree_map(lambda x: subtract_mean(x, axis=0), jacobians)
jacobians = jax.tree_util.tree_map(
lambda x: subtract_mean(x, axis=0), jacobians
)

if _sqrt_rescale:
sqrt_n_samp = math.sqrt(
Expand All @@ -366,7 +368,9 @@ def jacobian(
jacobians_avg = jax.tree_util.tree_map(
partial(sum_mpi, axis=0), _multiply_by_pdf(jacobians, pdf)
)
jacobians = jax.tree_util.tree_map(lambda x, y: x - y, jacobians, jacobians_avg)
jacobians = jax.tree_util.tree_map(
lambda x, y: x - y, jacobians, jacobians_avg
)

if _sqrt_rescale:
jacobians = _multiply_by_pdf(jacobians, jnp.sqrt(pdf))
Expand Down
4 changes: 3 additions & 1 deletion netket/jax/_scanmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from jax import linear_util as lu

_tree_add = partial(jax.tree_util.tree_map, jax.lax.add)
_tree_zeros_like = partial(jax.tree_util.tree_map, lambda x: jnp.zeros(x.shape, dtype=x.dtype))
_tree_zeros_like = partial(
jax.tree_util.tree_map, lambda x: jnp.zeros(x.shape, dtype=x.dtype)
)


# TODO put it somewhere
Expand Down
4 changes: 3 additions & 1 deletion netket/jax/_utils_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def PRNGKey(
else:
key = seed

key = jax.tree_util.tree_map(lambda k: mpi.mpi_bcast_jax(k, root=root, comm=comm)[0], key)
key = jax.tree_util.tree_map(
lambda k: mpi.mpi_bcast_jax(k, root=root, comm=comm)[0], key
)

return key

Expand Down
16 changes: 12 additions & 4 deletions netket/jax/_utils_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def tree_leaf_iscomplex(pars: PyTree) -> bool:
"""
Returns true if at least one leaf in the tree has complex dtype.
"""
return any(jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.iscomplexobj, pars)))
return any(
jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.iscomplexobj, pars))
)


def tree_leaf_isreal(pars: PyTree) -> bool:
Expand All @@ -100,7 +102,9 @@ def tree_conj(t: PyTree) -> PyTree:
Args:
t: pytree
"""
return jax.tree_util.tree_map(lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t)
return jax.tree_util.tree_map(
lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t
)


@jax.jit
Expand All @@ -116,7 +120,9 @@ def tree_dot(a: PyTree, b: PyTree) -> Scalar:
"""
return jax.tree_util.tree_reduce(
jax.numpy.add,
jax.tree_util.tree_map(jax.numpy.sum, jax.tree_util.tree_map(jax.numpy.multiply, a, b)),
jax.tree_util.tree_map(
jax.numpy.sum, jax.tree_util.tree_map(jax.numpy.multiply, a, b)
),
)


Expand Down Expand Up @@ -198,7 +204,9 @@ def _tree_to_real(x):
def _tree_to_real_inverse(x):
if isinstance(x, RealImagTuple):
# not using jax.lax.complex because it would convert scalars to arrays
return jax.tree_util.tree_map(lambda re, im: re + 1j * im if im is not None else re, *x)
return jax.tree_util.tree_map(
lambda re, im: re + 1j * im if im is not None else re, *x
)
else:
return x

Expand Down
12 changes: 9 additions & 3 deletions netket/jax/_vmap_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ def _eval_fun_in_chunks(vmapped_fun, chunk_size, argnums, *args, **kwargs):
else:
# split inputs
def _get_chunks(x):
x_chunks = jax.tree_util.tree_map(lambda x_: x_[: n_elements - n_rest, ...], x)
x_chunks = jax.tree_util.tree_map(
lambda x_: x_[: n_elements - n_rest, ...], x
)
x_chunks = _chunk(x_chunks, chunk_size)
return x_chunks

def _get_rest(x):
x_rest = jax.tree_util.tree_map(lambda x_: x_[n_elements - n_rest :, ...], x)
x_rest = jax.tree_util.tree_map(
lambda x_: x_[n_elements - n_rest :, ...], x
)
return x_rest

args_chunks = [
Expand All @@ -41,7 +45,9 @@ def _get_rest(x):
y = y_chunks
else:
y_rest = vmapped_fun(*args_rest, **kwargs)
y = jax.tree_util.tree_map(lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest)
y = jax.tree_util.tree_map(
lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest
)
return y


Expand Down
4 changes: 3 additions & 1 deletion netket/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def _sele_op(o):
if o is True:
return lambda x: Partial(partial(lambda x: x, x))
else:
return partial(jax.tree_util.tree_map, partial(o, axis_name="i"))
return partial(
jax.tree_util.tree_map, partial(o, axis_name="i")
)

reductions = [_sele_op(o) for o in reduction_op]
res = out_treedef.flatten_up_to(res)
Expand Down
4 changes: 3 additions & 1 deletion netket/optimizer/qgt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def check(x, target, target_im=None):
jax.tree_util.tree_map(check, x, target)
except ValueError:
# catches jax tree map errors
pars_struct = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), x)
pars_struct = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), x
)
vec_struct = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), target
)
Expand Down
6 changes: 5 additions & 1 deletion netket/vqs/mc/mc_state/expect_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,8 @@ def forces_expect_hermitian(

new_model_state = new_model_state[0] if is_mutable else None

return , jax.tree_util.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
return (
,
jax.tree_util.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad),
new_model_state,
)
4 changes: 3 additions & 1 deletion test/optimizer/test_qgt_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def assert_allclose(x, y, rtol, atol):
rtol = 1e-6
np.testing.assert_allclose(x, y, rtol, atol)

jax.tree_util.tree_map(lambda x, y: assert_allclose(x, y, rtol=rtol, atol=atol), t1, t2)
jax.tree_util.tree_map(
lambda x, y: assert_allclose(x, y, rtol=rtol, atol=atol), t1, t2
)


def tree_samedtypes(t1, t2):
Expand Down
4 changes: 3 additions & 1 deletion test/variational/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def test_variables_from_file(vstate, tmp_path):
vstate2.variables = nkx.vqs.variables_from_file(name, vstate2.variables)

# check
jax.tree_util.tree_map(np.testing.assert_allclose, vstate.parameters, vstate2.parameters)
jax.tree_util.tree_map(
np.testing.assert_allclose, vstate.parameters, vstate2.parameters
)


def test_variables_from_tar(vstate, tmp_path):
Expand Down
4 changes: 3 additions & 1 deletion test/variational/test_variational_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def test_serialization(vstate):

vstate_new = serialization.from_bytes(vstate_new, bdata)

jax.tree_util.tree_map(np.testing.assert_allclose, vstate.parameters, vstate_new.parameters)
jax.tree_util.tree_map(
np.testing.assert_allclose, vstate.parameters, vstate_new.parameters
)
np.testing.assert_allclose(vstate.samples, vstate_new.samples)
np.testing.assert_allclose(vstate.diagonal.samples, vstate_new.diagonal.samples)
assert vstate.n_samples == vstate_new.n_samples
Expand Down

0 comments on commit 893f7e4

Please sign in to comment.