Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a Python/JAX port of Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 . Derived from https://github.com/nomuramasahir0/crfmnes.
This variant is slightly faster than FCRFMC (the C++ port) on fast GPUs/TPUs, but slower on CPUs and for smaller dimensions.
It uses 32 bit accuracy (FCRFMC uses 64 bit) which mostly doesn't harm the convergence (with Waterworld MA being the exception for very high iteration numbers).
Wall time and convergence is mostly comparable with PGPE (as FCRFMC) for the benchmarks. Slower in the beginning, but improving at higher iterations.
Since there are no for-loops I found no beneficial applications of 'jax.jit', just converted most 'np.arrays' into 'jnp.arrays' deployed on the GPUs/TPUs.
def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray:
uses not
evals: jnp.ndarray
because this slowed things down on my NVIDIA 3090.Since this is Python code, no missing shared libraries on Ubuntu 18 this time.
Added test results for CRFMNES (this Python implementation) at EvoJax.adoc.