Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to solve jax deprecations #1766

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Benchmarks/qgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def QGT(*args, **kwargs):

def construct_and_solve(vstate, rhs):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
return jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))
return jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))


# Benchmark starts here
Expand Down Expand Up @@ -111,7 +111,7 @@ def _benchmark(n_nodes, n_samples, n_layers, width):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
Tsolve = (
timeit_gc(
lambda: jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
lambda: jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
number=5,
)
/ 5
Expand Down
4 changes: 2 additions & 2 deletions Benchmarks/qgt_gcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def QGT(*args, **kwargs):

def construct_and_solve(vstate, rhs):
qgt_ = QGT(vstate=vstate, diag_shift=0.01)
return jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))
return jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs))


# Benchmark starts here
Expand Down Expand Up @@ -82,7 +82,7 @@ def benchmark(side, n_samples, layers, features):

Tsolve = (
timeit_gc(
lambda: jax.tree_map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
lambda: jax.tree.map(lambda x: x.block_until_ready(), qgt_.solve(cg, rhs2)),
number=5,
)
/ 5
Expand Down
2 changes: 1 addition & 1 deletion docs/advanced/custom_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@
" # this is the expectation value, as before\n",
" expval = np.array(0, dtype=op.dtype)\n",
" # this is the gradient, which of course is zero.\n",
" grad = jax.tree_map(jnp.zeros_like, vstate.parameters)\n",
" grad = jax.tree.map(jnp.zeros_like, vstate.parameters)\n",
" return expval, grad"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/hilbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ If you then want to sample this space, you'll encounter the following error:

```python
>>> import jax
>>> hi.random_state(jax.random.PRNGKey(3), 3)
>>> hi.random_state(jax.random.key(3), 3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../netket/hilbert/abstract_hilbert.py", line 84, in random_state
Expand Down
8 changes: 4 additions & 4 deletions docs/docs/sampler.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"log_pdf = nk.models.RBM(param_dtype=float)\n",
"\n",
"# and we initialize it's parameters\n",
"param_seed = jax.random.PRNGKey(0)\n",
"param_seed = jax.random.key(0)\n",
"\n",
"pars = log_pdf.init(param_seed, hilbert.random_state(param_seed, 3))"
]
Expand Down Expand Up @@ -268,7 +268,7 @@
"sampler = nk.sampler.ExactSampler(hilbert, dtype=jnp.int8)\n",
"\n",
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down Expand Up @@ -415,7 +415,7 @@
],
"source": [
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down Expand Up @@ -520,7 +520,7 @@
"outputs": [],
"source": [
"# We create the state of the sampler\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.PRNGKey(1))\n",
"sampler_state = sampler.init_state(log_pdf, pars, jax.random.key(1))\n",
"\n",
"# We call reset (this will pre-compute the log_pdf on the whole hilbert space)\n",
"sampler_state = sampler.reset(log_pdf, pars, sampler_state)\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/varstate.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ vstate.parameters['visible_bias']
0.06278384, 0.00275547, 0.05843748, 0.07516951,
0.21897993, -0.01632223], dtype=float64)

vstate.parameters = jax.tree_map(lambda x: x+0.1, vstate.parameters)
vstate.parameters = jax.tree.map(lambda x: x+0.1, vstate.parameters)

# Look at the new values
vstate.parameters['visible_bias']
Expand Down Expand Up @@ -219,7 +219,7 @@ the {py:attr}`~netket.vqs.VariationalState.model_state` attribute.

Parameters are stored as a set of nested dictionaries.
In Jax jargon, Parameters are a PyTree (see [PyTree documentation](https://jax.readthedocs.io/en/latest/pytrees.html)) and they
can be operated upon with functions like [jax.tree_map](https://jax.readthedocs.io/en/latest/jax.tree_util.html?highlight=tree_map#jax.tree_util.tree_map).
can be operated upon with functions like [jax.tree.map](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree.map.html).

Before modifying the parameters or what is stored inside of a dictionary of parameters, it is a good idea to make a copy using {func}`flax.core.copy`, which calls recursively `{}.copy()` in all nested dictionaries.
If you do not do this, you will only copy the outermost dictionary but the inner ones will be referenced, and so modifying them will modify the original parameters as well.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/gs-continuous-space.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
],
"source": [
"import jax\n",
"states = hi.random_state(jax.random.PRNGKey(0), 1)\n",
"states = hi.random_state(jax.random.key(0), 1)\n",
"\n",
"# You can always reshape those configurations to NxD matrix\n",
"states.reshape(N, 3)"
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/gs-ising.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"NetKet's Hilbert spaces define the computational basis of the calculation, and are used to label and generate elements from it. \n",
"The standard Spin-basis implicitly selects the `z` basis and elements of that basis will be elements $ v\\in\\{\\pm 1\\}^N $.\n",
"\n",
"It is possible to generate random basis elements through the function `random_state(rng, shape, dtype)`, where the first argument must be a jax RNG state (usually built with `jax.random.PRNGKey(seed)`, second is an integer or a tuple giving the shape of the samples and the last is the dtype of the generated states."
"It is possible to generate random basis elements through the function `random_state(rng, shape, dtype)`, where the first argument must be a jax RNG state (usually built with `jax.random.key(seed)`, second is an integer or a tuple giving the shape of the samples and the last is the dtype of the generated states."
]
},
{
Expand All @@ -154,7 +154,7 @@
],
"source": [
"import jax\n",
"hi.random_state(jax.random.PRNGKey(0), 3)"
"hi.random_state(jax.random.key(0), 3)"
]
},
{
Expand Down Expand Up @@ -592,7 +592,7 @@
" energy_history.append(E.mean.real)\n",
" # equivalent to vstate.parameters - 0.05*E_grad , but it performs this\n",
" # function on every leaf of the dictionaries containing the set of parameters\n",
" new_pars = jax.tree_map(lambda x,y: x-0.05*y, vstate.parameters, E_grad)\n",
" new_pars = jax.tree.map(lambda x,y: x-0.05*y, vstate.parameters, E_grad)\n",
" # actually update the parameters\n",
" vstate.parameters = new_pars"
]
Expand Down
57 changes: 30 additions & 27 deletions docs/tutorials/vmc-from-scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@
"model = MF()\n",
"\n",
"# pick a RNG key to initialise the random parameters\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"\n",
"# initialise the weights\n",
"parameters = model.init(key, np.random.rand(hi.size))"
Expand Down Expand Up @@ -709,7 +709,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Conceptually, parameters are vectors, even though they are stored in memory differently. You can apply mathematical operations to those vectors using `jax.tree_map(function, trees...)`, which calls the function on every element of the pytrees."
"Conceptually, parameters are vectors, even though they are stored in memory differently. You can apply mathematical operations to those vectors using `jax.tree.map(function, trees...)`, which calls the function on every element of the pytrees."
]
},
{
Expand Down Expand Up @@ -738,19 +738,19 @@
"def multiply_by_10(x):\n",
" return 10*x\n",
"\n",
"print(\"multiply_by_10: \", jax.tree_map(multiply_by_10, dict1))\n",
"print(\"multiply_by_10: \", jax.tree.map(multiply_by_10, dict1))\n",
"# this can also be done by defining the function as a lambda function, which is more compact\n",
"print(\"multiply_by_10, with lambda:\", jax.tree_map(lambda x: 10*x, dict1))\n",
"print(\"multiply_by_10, with lambda:\", jax.tree.map(lambda x: 10*x, dict1))\n",
"\n",
"def add(x,y):\n",
" return x+y\n",
"print(\"add dict1 and 2 :\", jax.tree_map(add, dict1, dict2))\n",
"print(\"add dict1 and 2 :\", jax.tree.map(add, dict1, dict2))\n",
"\n",
"\n",
"def sub(x,y):\n",
" return x-y\n",
"print(\"subtract dict1 and 2 :\", jax.tree_map(sub, dict1, dict2))\n",
"print(\"subtract dict1 and 2, lambda:\", jax.tree_map(lambda x,y:x-y, dict1, dict2))\n"
"print(\"subtract dict1 and 2 :\", jax.tree.map(sub, dict1, dict2))\n",
"print(\"subtract dict1 and 2, lambda:\", jax.tree.map(lambda x,y:x-y, dict1, dict2))\n"
]
},
{
Expand All @@ -777,7 +777,7 @@
],
"source": [
"# generate 4 random inputs\n",
"inputs = hi.random_state(jax.random.PRNGKey(1), (4,))\n",
"inputs = hi.random_state(jax.random.key(1), (4,))\n",
"\n",
"log_psi = model.apply(parameters, inputs)\n",
"# notice that logpsi has shape (4,) because we fed it 4 random configurations.\n",
Expand Down Expand Up @@ -1115,7 +1115,7 @@
"\n",
"# initialise \n",
"model = MF()\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
"# they accumulate data you throw at them\n",
Expand All @@ -1127,7 +1127,7 @@
" \n",
" # update parameters. Try using a learning rate of 0.01\n",
" # to update the parameters, which are stored as a dictionary (or pytree)\n",
" # you can use jax.tree_map as shown above.\n",
" # you can use jax.tree.map as shown above.\n",
" #...\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
Expand All @@ -1144,7 +1144,7 @@
"```python\n",
"# initialise \n",
"model = MF()\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
"# they accumulate data you throw at them\n",
Expand All @@ -1155,7 +1155,7 @@
" energy, gradient = compute_energy_and_gradient(model, parameters, hamiltonian_jax_sparse)\n",
" \n",
" # update parameters\n",
" parameters = jax.tree_map(lambda x,y:x-0.01*y, parameters, gradient)\n",
" parameters = jax.tree.map(lambda x,y:x-0.01*y, parameters, gradient)\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
" logger(step=i, item={'Energy':energy})\n",
Expand Down Expand Up @@ -1291,7 +1291,10 @@
" \"J\", nn.initializers.normal(), (n_sites,n_sites), float\n",
" )\n",
" # ensure same data types\n",
" J, input_x = nn.dtypes.kernel, x_in = promote_dtype(J, input_x, dtype=None)\n",
" dtype = jax.numpy.promote_types(J.dtype, input_x.dtype)\n",
" J = J.astype(dtype)\n",
" input_x = input_x.astype(dtype)\n",
" \n",
" # note that J_ij is not symmetric. So we symmetrize it by hand\n",
" J_symm = J.T + J\n",
" \n",
Expand Down Expand Up @@ -1334,11 +1337,11 @@
"# if the code above is correct, this should run\n",
"model_jastrow = Jastrow()\n",
"\n",
"one_sample = hi.random_state(jax.random.PRNGKey(0))\n",
"batch_samples = hi.random_state(jax.random.PRNGKey(0), (5,))\n",
"multibatch_samples = hi.random_state(jax.random.PRNGKey(0), (5,4,))\n",
"one_sample = hi.random_state(jax.random.key(0))\n",
"batch_samples = hi.random_state(jax.random.key(0), (5,))\n",
"multibatch_samples = hi.random_state(jax.random.key(0), (5,4,))\n",
"\n",
"parameters_jastrow = model_jastrow.init(jax.random.PRNGKey(0), one_sample)\n",
"parameters_jastrow = model_jastrow.init(jax.random.key(0), one_sample)\n",
"assert parameters_jastrow['params']['J'].shape == (hi.size, hi.size)\n",
"assert model_jastrow.apply(parameters_jastrow, one_sample).shape == ()\n",
"assert model_jastrow.apply(parameters_jastrow, batch_samples).shape == batch_samples.shape[:-1]\n",
Expand Down Expand Up @@ -1484,7 +1487,7 @@
"outputs": [],
"source": [
"# given sigma\n",
"sigma = hi.random_state(jax.random.PRNGKey(1))\n",
"sigma = hi.random_state(jax.random.key(1))\n",
"\n",
"eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)"
]
Expand Down Expand Up @@ -1542,7 +1545,7 @@
],
"source": [
"# given sigma\n",
"sigma = hi.random_state(jax.random.PRNGKey(1), (4,5))\n",
"sigma = hi.random_state(jax.random.key(1), (4,5))\n",
"\n",
"eta, H_sigmaeta = hamiltonian_jax.get_conn_padded(sigma)\n",
"\n",
Expand Down Expand Up @@ -1844,16 +1847,16 @@
"jacobian = jax.jacrev(logpsi_sigma_fun)(parameters_jastrow)\n",
"\n",
"#\n",
"print(\"The parameters of jastrow have shape:\\n\" , jax.tree_map(lambda x: x.shape, parameters_jastrow))\n",
"print(\"The parameters of jastrow have shape:\\n\" , jax.tree.map(lambda x: x.shape, parameters_jastrow))\n",
"\n",
"print(\"The jacobian of jastrow have shape:\\n\" , jax.tree_map(lambda x: x.shape, jacobian))"
"print(\"The jacobian of jastrow have shape:\\n\" , jax.tree.map(lambda x: x.shape, jacobian))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now implement a function that computes the jacobian-vector product in order to estimate the gradient of the energy. You can either do this vector-Jacobian-transpose product manually by using `jax.jacrev` and `jax.tree_map`, but you can also have a look at `jax.vjp` which does it automatically for you."
"Now implement a function that computes the jacobian-vector product in order to estimate the gradient of the energy. You can either do this vector-Jacobian-transpose product manually by using `jax.jacrev` and `jax.tree.map`, but you can also have a look at `jax.vjp` which does it automatically for you."
]
},
{
Expand All @@ -1875,7 +1878,7 @@
" # first define the function to be differentiated\n",
" logpsi_sigma_fun = lambda pars : model_jastrow.apply(parameters_jastrow, sigma_vector)\n",
" ...\n",
" # use jacrev with jax.tree_map, or even better, jax.vjp\n",
" # use jacrev with jax.tree.map, or even better, jax.vjp\n",
" E_grad = ...\n",
" \n",
" # compute the energy as well\n",
Expand Down Expand Up @@ -1912,7 +1915,7 @@
" # first define the function to be differentiated\n",
" logpsi_sigma_fun = lambda pars : model.apply(pars, sigma)\n",
"\n",
" # use jacrev with jax.tree_map, or even better, jax.vjp\n",
" # use jacrev with jax.tree.map, or even better, jax.vjp\n",
" _, vjpfun = jax.vjp(logpsi_sigma_fun, parameters)\n",
" E_grad = vjpfun((E_loc - E_average)/E_loc.size)\n",
"\n",
Expand Down Expand Up @@ -1974,7 +1977,7 @@
"chain_length = 1000//sampler.n_chains\n",
"\n",
"# initialise\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"sampler_state = sampler.init_state(model, parameters, seed=1)\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
Expand Down Expand Up @@ -2016,7 +2019,7 @@
"chain_length = 1000//sampler.n_chains\n",
"\n",
"# initialise\n",
"parameters = model.init(jax.random.PRNGKey(0), np.ones((hi.size, )))\n",
"parameters = model.init(jax.random.key(0), np.ones((hi.size, )))\n",
"sampler_state = sampler.init_state(model, parameters, seed=1)\n",
"\n",
"# logging: you can (if you want) use netket loggers to avoid writing a lot of boilerplate...\n",
Expand All @@ -2032,7 +2035,7 @@
" E, E_grad = estimate_energy_and_gradient(model, parameters, hamiltonian_jax, samples)\n",
" \n",
" # update parameters. Try using a learning rate of 0.01\n",
" parameters = jax.tree_map(lambda x,y: x-0.005*y, parameters, E_grad)\n",
" parameters = jax.tree.map(lambda x,y: x-0.005*y, parameters, E_grad)\n",
" \n",
" # log energy: the logger takes a step argument and a dictionary of variables to be logged\n",
" logger(step=i, item={'Energy':E})\n",
Expand Down
2 changes: 1 addition & 1 deletion netket/driver/abstract_variational_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def estimate(self, observables):

# Do not unpack operators, even if they are pytrees!
# this is necessary to support jax operators.
return jax.tree_map(
return jax.tree_util.tree_map(
self._estimate_stats,
observables,
is_leaf=lambda x: isinstance(x, AbstractObservable),
Expand Down
8 changes: 4 additions & 4 deletions netket/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class RealQGTComplexDomainError(NetketError):
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>>
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>> sol = G@vec_real

Or, if you used the QGT in a linear solver, try using:
Expand All @@ -601,7 +601,7 @@ class RealQGTComplexDomainError(NetketError):
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>>
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>>
>>> linear_solver = jax.scipy.sparse.linalg.cg
>>> solution, info = G.solve(linear_solver, vec_real)
Expand All @@ -624,14 +624,14 @@ def __init__(self):

.. code:: python

>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree_util.tree_map(lambda x: x.real, vec)
>>> G@vec_real

If you used the QGT in a linear solver, try using:

.. code:: python

>>> vec_real = jax.tree_map(lambda x: x.real, vec)
>>> vec_real = jax.tree_util.tree_map(lambda x: x.real, vec)
>>> G.solve(linear_solver, vec_real)

to fix this error.
Expand Down
2 changes: 1 addition & 1 deletion netket/experimental/driver/tdvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def odefun_tdvp( # noqa: F811

@partial(jax.jit, static_argnums=(3, 4))
def _map_parameters(forces, parameters, loss_grad_factor, propagation_type, state_T):
forces = jax.tree_map(
forces = jax.tree_util.tree_map(
lambda x, target: loss_grad_factor * x,
forces,
parameters,
Expand Down
Loading
Loading