Skip to content

Commit

Permalink
Update to solve jax deprecations (netket#1766)
Browse files Browse the repository at this point in the history
Jax 0.4.25 deprecated `jax.tree_map` in favour of `jax.tree.map`.

However `jax.tree.map` is only available in recent jax versions, so to
avoid breaking older versions I replaced all usages of `jax.tree_map`
with `jax.tree_util.tree_map`, which should be updated to `jax.tree.map`
in a few months when we drop older jax versions.

Documentation is updated to reflect latest jax versions, that is,
`jax.tree.map`.
  • Loading branch information
PhilipVinc authored and jwnys committed Apr 23, 2024
1 parent 6b4433d commit 8407d92
Show file tree
Hide file tree
Showing 56 changed files with 223 additions and 182 deletions.
4 changes: 2 additions & 2 deletions Benchmarks/qgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def QGT(*args, **kwargs):

def construct_and_solve(vstate, rhs):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
return jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))
return jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))


# Benchmark starts here
Expand Down Expand Up @@ -111,7 +111,7 @@ def _benchmark(n_nodes, n_samples, n_layers, width):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
Tsolve = (
timeit_gc(
lambda: jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
lambda: jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
number=5,
)
/ 5
Expand Down
4 changes: 2 additions & 2 deletions Benchmarks/qgt_gcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def QGT(*args, **kwargs):

def construct_and_solve(vstate, rhs):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
return jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))
return jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))


# Benchmark starts here
Expand Down Expand Up @@ -82,7 +82,7 @@ def benchmark(side, n_samples, layers, features):

Tsolve = (
timeit_gc(
lambda: jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
lambda: jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
number=5,
)
/ 5
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/custom_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@
" # this is the expectation value, as before\n",
" expval = np.array(0, dtype=op.dtype)\n",
" # this is the gradient, which of course is zero.\n",
" grad = jax.tree_map(jnp.zeros_like, vstate.parameters)\n",
" grad = jax.tree.map(jnp.zeros_like, vstate.parameters)\n",
" return expval, grad"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/hilbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ If you then want to sample this space, you'll encounter the following error:

```python
>>> import jax
>>> hi.random_state(jax.random.PRNGKey(3), 3)
>>> hi.random_state(jax.random.key(3), 3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../netket/hilbert/abstract_hilbert.py", line 84, in random_state
Expand Down
8 changes: 4 additions & 4 deletions docs/docs/sampler.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"log_pdf = nk.models.RBM(param_dtype=float)\n",
"\n",
"# and we initialize it's parameters\n",
"param_seed = jax.random.PRNGKey(0)\n",
"param_seed = jax.random.key(0)\n",
"\n",
"pars = log_pdf.init(param_seed, hilbert.random_state(param_seed, 3))"
]
Expand Down Expand Up @@ -268,7 +268,7 @@
"sampler = nk.sampler.ExactSampler(hilbert, dtype=jnp.int8)\n",
"\n",
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down Expand Up @@ -415,7 +415,7 @@
],
"source": [
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down Expand Up @@ -520,7 +520,7 @@
"outputs": [],
"source": [
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/varstate.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ vstate.parameters['visible_bias']
0.06278384, 0.00275547, 0.05843748, 0.07516951,
0.21897993, -0.01632223], dtype=float64)

vstate.parameters = jax.tree_map(lambda x: x+0.1, vstate.parameters)
vstate.parameters = jax.tree.map(lambda x: x+0.1, vstate.parameters)

# Look at the new values
vstate.parameters['visible_bias']
Expand Down Expand Up @@ -219,7 +219,7 @@ the {py:attr}`~netket.vqs.VariationalState.model_state` attribute.

Parameters are stored as a set of nested dictionaries.
In Jax jargon, Parameters are a PyTree (see [PyTree documentation](https://jax.readthedocs.io/en/latest/pytrees.html)) and they
can be operated upon with functions like [jax.tree_map](https://jax.readthedocs.io/en/latest/jax.tree_util.html?highlight=tree_map#jax.tree_util.tree_map).
can be operated upon with functions like [jax.tree.map](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree.map.html).

Before modifying the parameters or what is stored inside of a dictionary of parameters, it is a good idea to make a copy using {func}`flax.core.copy`, which calls recursively `{}.copy()` in all nested dictionaries.
If you do not do this, you will only copy the outermost dictionary but the inner ones will be referenced, and so modifying them will modify the original parameters as well.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/gs-continuous-space.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
],
"source": [
"import jax\n",
"states = hi.random_state(jax.random.PRNGKey(0), 1)\n",
"states = hi.random_state(jax.random.key(0), 1)\n",
"\n",
"# You can always reshape those configurations to NxD matrix\n",
"states.reshape(N, 3)"
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/gs-ising.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"NetKet's Hilbert spaces define the computational basis of the calculation, and are used to label and generate elements from it. \n",
"The standard Spin-basis implicitly selects the `z` basis and elements of that basis will be elements $ v\\in\\{\\pm 1\\}^N $.\n",
"\n",
"It is possible to generate random basis elements through the function `random_state(rng, shape, dtype)`, where the first argument must be a jax RNG state (usually built with `jax.random.PRNGKey(seed)`, second is an integer or a tuple giving the shape of the samples and the last is the dtype of the generated states."
"It is possible to generate random basis elements through the function `random_state(rng, shape, dtype)`, where the first argument must be a jax RNG state (usually built with `jax.random.key(seed)`, second is an integer or a tuple giving the shape of the samples and the last is the dtype of the generated states."
]
},
{
Expand All @@ -154,7 +154,7 @@
],
"source": [
"import jax\n",
"hi.random_state(jax.random.PRNGKey(0), 3)"
"hi.random_state(jax.random.key(0), 3)"
]
},
{
Expand Down Expand Up @@ -592,7 +592,7 @@
" energy_history.append(E.mean.real)\n",
" # equivalent to vstate.parameters - 0.05*E_grad , but it performs this\n",
" # function on every leaf of the dictionaries containing the set of parameters\n",
" new_pars = jax.tree_map(lambda x,y: x-0.05*y, vstate.parameters, E_grad)\n",
" new_pars = jax.tree.map(lambda x,y: x-0.05*y, vstate.parameters, E_grad)\n",
" # actually update the parameters\n",
" vstate.parameters = new_pars"
]
Expand Down
57 changes: 30 additions & 27 deletions docs/tutorials/vmc-from-scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@
"model = MF()\n",
"\n",
"# pick a RNG key to initialise the random parameters\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"\n",
"# initialise the weights\n",
"parameters = model.init(key, np.random.rand(hi.size))"
Expand Down Expand Up @@ -709,7 +709,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Conceptually, parameters are vectors, even though they are stored in memory differently. You can apply mathematical operations to those vectors using `jax.tree_map(function, trees...)`, which calls the function on every element of the pytrees."
"Conceptually, parameters are vectors, even though they are stored in memory differently. You can apply mathematical operations to those vectors using `jax.tree.map(function, trees...)`, which calls the function on every element of the pytrees."
]
},
{
Expand Down Expand Up @@ -738,19 +738,19 @@
"def multiply_by_10(x):\n",
" return 10*x\n",
"\n",
"print(\"multiply_by_10: \", jax.tree_map(multiply_by_10, dict1))\n",
"print(\"multiply_by_10: \", jax.tree.map(multiply_by_10, dict1))\n",
"# this can also be done by defining the function as a lambda function, which is more compact\n",
"print(\"multiply_by_10, with lambda:\", jax.tree_map(lambda x: 10*x, dict1))\n",
"print(\"multiply_by_10, with lambda:\", jax.tree.map(lambda x: 10*x, dict1))\n",
"\n",
"def add(x,y):\n",
" return x+y\n",
"print(\"add dict1 and 2 :\", jax.tree_map(add, dict1, dict2))\n",
"print(\"add dict1 and 2 :\", jax.tree.map(add, dict1, dict2))\n",
"\n",
"\n",
"def sub(x,y):\n",
" return x-y\n",
"print(\"subtract dict1 and 2 :\", jax.tree_map(sub, dict1, dict2))\n",
"print(\"subtract dict1 and 2, lambda:\", jax.tree_map(lambda x,y:x-y, dict1, dict2))\n"
"print(\"subtract dict1 and 2 :\", jax.tree.map(sub, dict1, dict2))\n",
"print(\"subtract dict1 and 2, lambda:\", jax.tree.map(lambda x,y:x-y, dict1, dict2))\n"
]
},
{
Expand All @@ -777,7 +777,7 @@
],
"source": [
"# generate 4 random inputs\n",
"inputs = hi.random_state(jax.random.PRNGKey(1), (4,))\n",
"inputs = hi.random_state(jax.random.key(1), (4,))\n",
"\n",
"log_psi = model.apply(parameters, inputs)\n",
"# notice that logpsi has shape (4,) because we fed it 4 random configurations.\n",
Expand Down Expand Up @@ -1115,7 +1115,7 @@
"\n",
"# initialise \n",
"model = MF()\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
"# they accumulate data you throw at them\n",
Expand All @@ -1127,7 +1127,7 @@
" \n",
" # update parameters. Try using a learning rate of 0.01\n",
" # to update the parameters, which are stored as a dictionary (or pytree)\n",
" # you can use jax.tree_map as shown above.\n",
" # you can use jax.tree.map as shown above.\n",
" #...\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
Expand All @@ -1144,7 +1144,7 @@
"```python\n",
"# initialise \n",
"model = MF()\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
"# they accumulate data you throw at them\n",
Expand All @@ -1155,7 +1155,7 @@
" energy, gradient = compute_energy_and_gradient(model, parameters, hamiltonian_jax_sparse)\n",
" \n",
" # update parameters\n",
" parameters = jax.tree_map(lambda x,y:x-0.01*y, parameters, gradient)\n",
" parameters = jax.tree.map(lambda x,y:x-0.01*y, parameters, gradient)\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
" logger(step=i, item={'Energy':energy})\n",
Expand Down Expand Up @@ -1291,7 +1291,10 @@
" \"J\", nn.initializers.normal(), (n_sites,n_sites), float\n",
" )\n",
" # ensure same data types\n",
" J, input_x = nn.dtypes.kernel, x_in = promote_dtype(J, input_x, dtype=None)\n",
" dtype = jax.numpy.promote_types(J.dtype, input_x.dtype)\n",
" J = J.astype(dtype)\n",
" input_x = input_x.astype(dtype)\n",
" \n",
" # note that J_ij is not symmetric. So we symmetrize it by hand\n",
" J_symm = J.T + J\n",
" \n",
Expand Down Expand Up @@ -1334,11 +1337,11 @@
"# if the code above is correct, this should run\n",
"model_jastrow = Jastrow()\n",
"\n",
"one_sample = hi.random_state(jax.random.PRNGKey(0))\n",
"batch_samples = hi.random_state(jax.random.PRNGKey(0), (5,))\n",
"multibatch_samples = hi.random_state(jax.random.PRNGKey(0), (5,4,))\n",
"one_sample = hi.random_state(jax.random.key(0))\n",
"batch_samples = hi.random_state(jax.random.key(0), (5,))\n",
"multibatch_samples = hi.random_state(jax.random.key(0), (5,4,))\n",
"\n",
"parameters_jastrow = model_jastrow.init(jax.random.PRNGKey(0), one_sample)\n",
"parameters_jastrow = model_jastrow.init(jax.random.key(0), one_sample)\n",
"assert parameters_jastrow['params']['J'].shape == (hi.size, hi.size)\n",
"assert model_jastrow.apply(parameters_jastrow, one_sample).shape == ()\n",
"assert model_jastrow.apply(parameters_jastrow, batch_samples).shape == batch_samples.shape[:-1]\n",
Expand Down Expand Up @@ -1484,7 +1487,7 @@
"outputs": [],
"source": [
"# given sigma\n",
"sigma = hi.random_state(jax.random.PRNGKey(1))\n",
"sigma = hi.random_state(jax.random.key(1))\n",
"\n",
"eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)"
]
Expand Down Expand Up @@ -1542,7 +1545,7 @@
],
"source": [
"# given sigma\n",
"sigma = hi.random_state(jax.random.PRNGKey(1), (4,5))\n",
"sigma = hi.random_state(jax.random.key(1), (4,5))\n",
"\n",
"eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)\n",
"\n",
Expand Down Expand Up @@ -1844,16 +1847,16 @@
"jacobian = jax.jacrev(logpsi_sigma_fun)(parameters_jastrow)\n",
"\n",
"#\n",
"print(\"The parameters of jastrow have shape:\\n\" , jax.tree_map(lambda x: x.shape, parameters_jastrow))\n",
"print(\"The parameters of jastrow have shape:\\n\" , jax.tree.map(lambda x: x.shape, parameters_jastrow))\n",
"\n",
"print(\"The jacobian of jastrow have shape:\\n\" , jax.tree_map(lambda x: x.shape, jacobian))"
"print(\"The jacobian of jastrow have shape:\\n\" , jax.tree.map(lambda x: x.shape, jacobian))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now implement a function that computes the jacobian-vector product in order to estimate the gradient of the energy. You can either do this vector-Jacobian-transpose product manually by using `jax.jacrev` and `jax.tree_map`, but you can also have a look at `jax.vjp` which does it automatically for you."
"Now implement a function that computes the jacobian-vector product in order to estimate the gradient of the energy. You can either do this vector-Jacobian-transpose product manually by using `jax.jacrev` and `jax.tree.map`, but you can also have a look at `jax.vjp` which does it automatically for you."
]
},
{
Expand All @@ -1875,7 +1878,7 @@
" # first define the function to be differentiated\n",
" logpsi_sigma_fun = lambda pars : model_jastrow.apply(parameters_jastrow, sigma_vector)\n",
" ...\n",
" # use jacrev with jax.tree_map, or even better, jax.vjp\n",
" # use jacrev with jax.tree.map, or even better, jax.vjp\n",
" E_grad = ...\n",
" \n",
" # compute the energy as well\n",
Expand Down Expand Up @@ -1912,7 +1915,7 @@
" # first define the function to be differentiated\n",
" logpsi_sigma_fun = lambda pars : model.apply(pars, sigma)\n",
"\n",
" # use jacrev with jax.tree_map, or even better, jax.vjp\n",
" # use jacrev with jax.tree.map, or even better, jax.vjp\n",
" _, vjpfun = jax.vjp(logpsi_sigma_fun, parameters)\n",
" E_grad = vjpfun((E_loc - E_average)/E_loc.size)\n",
"\n",
Expand Down Expand Up @@ -1974,7 +1977,7 @@
"chain_length = 1000//sampler.n_chains\n",
"\n",
"# initialise\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"sampler_state = sampler.init_state(model, parameters, seed=1)\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
Expand Down Expand Up @@ -2016,7 +2019,7 @@
"chain_length = 1000//sampler.n_chains\n",
"\n",
"# initialise\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"sampler_state = sampler.init_state(model, parameters, seed=1)\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
Expand All @@ -2032,7 +2035,7 @@
" E, E_grad = estimate_energy_and_gradient(model, parameters, hamiltonian_jax, samples)\n",
" \n",
" # update parameters. Try using a learning rate of 0.01\n",
" parameters = jax.tree_map(lambda x,y: x-0.005*y, parameters, E_grad)\n",
" parameters = jax.tree.map(lambda x,y: x-0.005*y, parameters, E_grad)\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
" logger(step=i, item={'Energy':E})\n",
Expand Down
2 changes: 1 addition & 1 deletion netket/driver/abstract_variational_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def estimate(self, observables):

# Do not unpack operators, even if they are pytrees!
# this is necessary to support jax operators.
return jax.tree_map(
return jax.tree_util.tree_map(
self._estimate_stats,
observables,
is_leaf=lambda x: isinstance(x, AbstractObservable),
Expand Down
8 changes: 4 additions & 4 deletions netket/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class RealQGTComplexDomainError(NetketError):
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>>
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>> sol = G@vec_real
Or, if you used the QGT in a linear solver, try using:
Expand All @@ -601,7 +601,7 @@ class RealQGTComplexDomainError(NetketError):
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>>
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>>
>>> linear_solver = jax.scipy.sparse.linalg.cg
>>> solution, info = G.solve(linear_solver, vec_real)
Expand All @@ -624,14 +624,14 @@ def __init__(self):
.. code:: python
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree_util.tree_map(lambda x: x.real, vec)
>>> G@vec_real
If you used the QGT in a linear solver, try using:
.. code:: python
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree_util.tree_map(lambda x: x.real, vec)
>>> G.solve(linear_solver, vec_real)
to fix this error.
Expand Down
2 changes: 1 addition & 1 deletion netket/experimental/driver/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def odefun_tdvp( # noqa: F811

@partial(jax.jit, static_argnums=(3, 4))
def _map_parameters(forces, parameters, loss_grad_factor, propagation_type, state_T):
forces = jax.tree_map(
forces = jax.tree_util.tree_map(
lambda x, target: loss_grad_factor * x,
forces,
parameters,
Expand Down
Loading

0 comments on commit 8407d92

Please sign in to comment.