Customize Hamiltonian Rule #755
-
Hi guys, I am currently trying to implement custom sampling rules which allow to set the initial spin state of the sampler such that I can ensure particle number conversation after converting second quantized chemical Hamiltonians to spin basis and training an RBM for the ground state. In the case of the exchange rule things work as expected when I just write a new the random_state() function. However, in the case of the Hamiltonian rule I get a lengthy error message. You have any intuition where this is coming from? Below a minimal working example which however relies on the packages openfermion and openfermionpyscf which can be installed with this names from pypi. Best from flax import struct
import netket as nk
from netket.sampler.metropolis import MetropolisSampler
from netket.operator import AbstractOperator
from netket.jax import njit4jax
from netket.sampler.metropolis import MetropolisRule
from numba import jit
import numpy as np
import math
from openfermion.chem import MolecularData
from openfermionpyscf import run_pyscf
from openfermion.transforms import get_fermion_operator, jordan_wigner
import jax
def of_2_nk(q_ham, n_qubits):
ops = []
ws = []
for k, v in q_ham.terms.items():
pl_str = list('I') * n_qubits
for t in k:
pl_str[t[0]] = t[1]
ops.append("".join(pl_str))
ws.append(v)
return ws, ops
@jit(nopython=True)
def _choose(vp, sections, rand_vec, out, w):
low_range = 0
for i, s in enumerate(sections):
n_rand = low_range + int(np.floor(rand_vec[i] * (s - low_range)))
out[i] = vp[n_rand]
w[i] = math.log(s - low_range)
low_range = s
@jit(nopython=True)
def _choose(vp, sections, rand_vec, out, w):
low_range = 0
for i, s in enumerate(sections):
n_rand = low_range + int(np.floor(rand_vec[i] * (s - low_range)))
out[i] = vp[n_rand]
w[i] = math.log(s - low_range)
low_range = s
@struct.dataclass
class HamiltonianRuleCustom(MetropolisRule):
operator: AbstractOperator = struct.field(pytree_node=False)
"""The (hermitian) operator giving the transition amplitudes."""
n_chains: int
initial_state: list
def random_state(rule, sampler, machine, parameters, state, key, ):
return jax.numpy.asarray(rule.initial_state * rule.n_chains).reshape(rule.n_chains, len(rule.initial_state))
def init_state(rule, sampler, machine, params, key):
if sampler.hilbert != rule.operator.hilbert:
raise ValueError(
f"""
The hilbert space of the sampler ({sampler.hilbert}) and the hilbert space
of the operator ({rule.operator.hilbert}) for HamiltonianRule must be the same.
"""
)
return super().init_state(sampler, machine, params, key)
def __post_init__(self):
# Raise errors if hilbert is not an Hilbert
if not isinstance(self.operator, AbstractOperator):
raise TypeError(
"Argument to HamiltonianRule must be a valid operator.".format(
type(self.operator)
)
)
def transition(rule, sampler, machine, parameters, state, key, σ):
hilbert = sampler.hilbert
get_conn_flattened = rule.operator._get_conn_flattened_closure()
n_conn_from_sections = rule.operator._n_conn_from_sections
@njit4jax(
(
jax.abstract_arrays.ShapedArray(σ.shape, σ.dtype),
jax.abstract_arrays.ShapedArray((σ.shape[0],), σ.dtype),
)
)
def _transition(args):
# unpack arguments
v_proposed, log_prob_corr, v, rand_vec = args
print(v)
log_prob_corr.fill(0)
sections = np.empty(v.shape[0], dtype=np.int32)
vp, _ = get_conn_flattened(v, sections)
_choose(vp, sections, rand_vec, v_proposed, log_prob_corr)
# TODO: n_conn(v_proposed, sections) implemented below, but
# might be slower than fast implementations like ising
get_conn_flattened(v_proposed, sections)
n_conn_from_sections(sections)
log_prob_corr -= np.log(sections)
# ideally we would pass the key to python/numba in _choose, initialise a
# np.random.default_rng(key) and use it to generatee random uniform integers.
# However, numba dose not support np states, and reseeding it's MT1998 implementation
# would be slow so we generate floats in the [0,1] range in jax and pass those
# to python
rand_vec = jax.random.uniform(key, shape=(σ.shape[0],))
σp, log_prob_correction = _transition(σ, rand_vec)
return σp, log_prob_correction
def __repr__(self):
return f"HamiltonianRule({self.operator})"
geometry = [['H', [0, 0, 0]], ['H', [0, 0, .734]]]
basis = "sto-3g"
multiplicity = 1
charge = 0
h2_molecule = MolecularData(geometry, basis, multiplicity, charge)
h2_molecule = run_pyscf(h2_molecule,
run_mp2=True,
run_cisd=True,
run_ccsd=True,
run_fci=True)
qubit_hamiltonian = jordan_wigner(get_fermion_operator(h2_molecule.get_molecular_hamiltonian()))
weights,operators = of_2_nk(qubit_hamiltonian,h2_molecule.n_qubits)
ham = nk.operator.PauliStrings(operators=operators, weights=weights)
n_chains = 16
initial_state = [1, 1, 0, 0]
# here construct the custom rule
rule = HamiltonianRuleCustom(ham, n_chains=n_chains, initial_state=initial_state)
# fetch the Hilbert space from the PauliString Hamiltonian
hi = ham.hilbert
# reset_chains has to be set to True such that the initial state is set at the beginning of each sampling procedure
sa = MetropolisSampler(hi, rule, reset_chains=True, n_chains=n_chains, n_sweeps=40)
ma = nk.models.RBM(alpha=1, dtype=float)
# Optimizer
op = nk.optimizer.Sgd(learning_rate=0.1)
# stochastic reconfiguration
sr = nk.optimizer.SR(diag_shift=0.05, iterative=True)
gs = nk.VMC(ham, op, sa, ma, preconditioner=sr, n_samples=1000, n_discard_per_chain=25)
gs.run(n_iter=1000, out=None) EDIT: Error trace in this gist. An extract is pasted below: File "/Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 289, in _runPass
mutated |= check(pss.run_pass, internal_state)
File "/Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 262, in check
mangled = func(compiler_state)
File "/Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/typed_passes.py", line 94, in run_pass
typemap, return_type, calltypes = type_inference_stage(
File "/Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/typed_passes.py", line 72, in type_inference_stage
infer.propagate(raise_errors=raise_errors)
File "/Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/typeinfer.py", line 1071, in propagate
raise errors[0]
jax._src.traceback_util.FilteredStackTrace: numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function isub>) found for signature:
>>> isub(array(int64, 1d, C), array(float64, 1d, C))
There are 10 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'NumpyRulesInplaceArrayOperator.generic': File: numba/core/typing/npydecl.py: Line 213.
With argument(s): '(array(int64, 1d, C), array(float64, 1d, C))':
Rejected as the implementation raised a specific error:
AttributeError: 'NoneType' object has no attribute 'args'
raised from /Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/numba/core/typing/npydecl.py:224
- Of which 6 did not match due to:
Overload of function 'isub': File: <numerous>: Line N/A.
With argument(s): '(array(int64, 1d, C), array(float64, 1d, C))':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'isub': File: unknown: Line unknown.
With argument(s): '(array(int64, 1d, C), array(float64, 1d, C))':
No match for registered cases:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
During: typing of intrinsic-call at /Users/thorbenfrank/Documents/phd/projects/netket/scripts/minimal-example.py (112)
File "minimal-example.py", line 112:
def _transition(args):
<source elided>
log_prob_corr -= np.log(sections)
^
During: resolving callee type: type(CPUDispatcher(<function HamiltonianRuleCustom.transition.<locals>._transition at 0x14f834ee0>))
During: typing of call at /Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/netket/jax/numba4jax.py (228)
During: resolving callee type: type(CPUDispatcher(<function HamiltonianRuleCustom.transition.<locals>._transition at 0x14f834ee0>))
During: typing of call at /Users/thorbenfrank/Documents/venvs/netket/lib/python3.8/site-packages/netket/jax/numba4jax.py (228)
File "../../../../venvs/netket/lib/python3.8/site-packages/netket/jax/numba4jax.py", line 228:
def xla_custom_call_target(output_ptrs, input_ptrs):
<source elided>
numba_fn(args_out + args_in)
^
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Thanks for sharing it here! TLDR: the error is in your You can find an hint on the problem by working it backwards: the numba Jit compiler (which is called by the So we know that In the code you posted, So So the issue here is that you have an integer-dtype state. Enforcing the dtype of the resulting state val = jax.numpy.asarray(
rule.initial_state * rule.n_chains, dtype=sampler.dtype
).reshape(rule.n_chains, len(rule.initial_state))
return val will fix your issue. — from netket.utils import struct
from netket.sampler.rule import HamiltonianRule
@struct.dataclass
class HamiltonianRuleCustom(HamiltonianRule):
initial_state : list
def random_state(rule, sampler, machine, parameters, state, key):
assert len(rule.initial_state) == sampler.hilbert.size
val = jax.numpy.asarray(
rule.initial_state * sampler.n_chains, dtype=sampler.dtype
).reshape(sampler.n_chains, -1)
return val |
Beta Was this translation helpful? Give feedback.
Thanks for sharing it here!
Jax stack traces are terrible, and coupled with the callback into numba they definitely are gross....
TLDR: the error is in your
random_state
function. Below you can find my reasoning to identify the issue in case this helps in the future.You can find an hint on the problem by working it backwards: the numba Jit compiler (which is called by the
@njit4jax
decorator) is complaining that it cannot compilelog_prob_corr -= np.log(sections)
because he cannot find an implementation forisub
(inplace-subtraction, or-=
) where the first argument is a vector of integers and the latter is a float vector.So we know that
log_prob_corr
is a vector of integers, for some re…