Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Diversifier QD Meta Algorithm - JAX backend #52

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

dietmarwo
Copy link
Contributor

@dietmarwo dietmarwo commented Nov 11, 2022

This PR adds a new JAX-based QD meta algorithm called Diversifier. It is a generalization of CMA-ME.

It uses a MAP-Elites archive not for solution candidate generation, but only to modify the fitness values told (via tell) to the wrapped algorithm. This modification changes the fitness ranking of the population to favor exploration over exploitation. Tested with CR-FM-NES and CMA-ES, but other wrapped algorithms may work as well. Based on fcmaes diversifier.py (see MapElites.adoc).

The generalization over CMA-ME is necessary in the EvoJAX context, because CMA-ES struggles with a very high number of decision variables. Therefore CR-FM-NES-ME is superior here - as possibly are other not yet tested alternatives.

https://doi.org/10.1145/2739480.2754664 proposes the QD score (sum of fitness values of all elites in the map) as metric for comparison.

For Brax-Ant CR-FM-NES-ME (Diversifier applied to CR-FM-NES), compared with MAP-Elites, reaches a higher QD-score for high iteration numbers (see details below). So MAP-Elites should only be preferred for a low evaluation budget or if you want to maximize the number of occupied niches.

On a NVIDIA 3090 + AMD 5950, Linux Mint with optimized configurations we measured:

  • MAP-Elites has a the same optimizer overhead (evaluation/sec rate for the same popsize).
  • MAP-Elites has a higher number of occupied niches.

but

  • CR-FM-NES-ME has a much higher QD score and found a better global optimum for a high
    evaluation budget.

Detailed measurements for the Brax-Ant example (NVIDIA 3090 + AMD 5950, Linux Mint):

After 20 minutes MAP-Elites is in the lead, but slows down from there. CR_FM_NES-ME continues to
improve until 500 minutes / 8 million evaluations. CR_FM_NES-ME can even produce a good
global optimum - 4107 - thereby still occupying 6138 niches with a mean score of 1208.
After 500 minutes MAP-Elites continues to improve where CR_FM_NES-ME does not, but at that
time CR_FM_NES-ME has a >70% lead in score.

CR_FM_NES-ME with init-std = 0.159, popsize = 512, fitness_weight 0.0

20 min QD score: 1692282 occupied: 4936 max score: 558 mean score: 342 evaluations: 263680
50 min QD score: 2724260 occupied: 5628 max score: 918 mean score: 484 evaluations: 704512
100 min QD score: 4289807 occupied: 6087 max score: 1442 mean score: 704 evaluations: 1496576
200 min QD score: 5928753 occupied: 6138 max score: 2363 mean score: 965 evaluations: 3072000
300 min QD score: 6524518 occupied: 6138 max score: 2862 mean score: 1063 evaluations: 4710400
400 min QD score: 7353257 occupied: 6138 max score: 3889 mean score: 1198 evaluations: 6348800
500 min QD score: 7418018 occupied: 6138 max score: 4107 mean score: 1208 evaluations: 7884800
600 min QD score: 7444092 occupied: 6138 max score: 4211 mean score: 1212 evaluations: 9523200

MAP-Elites iso-sigma = 0.05, line-sigma = 0.2, popsize = 1024: (line-sigma = 0.3 is worse)

20 min QD score: 2509773 occupied: 5621 max score: 643 mean score: 446 evaluations: 346112
50 min QD score: 3022521 occupied: 6375 max score: 724 mean score: 474 evaluations: 915456
100 min QD score: 3383041 occupied: 6786 max score: 769 mean score: 498 evaluations: 1941504
200 min QD score: 3713977 occupied: 7107 max score: 825 mean score: 522 evaluations: 3936256
300 min QD score: 3915492 occupied: 7265 max score: 927 mean score: 538 evaluations: 5922816
400 min QD score: 4065677 occupied: 7400 max score: 927 mean score: 549 evaluations: 7941120
500 min QD score: 4179020 occupied: 7498 max score: 927 mean score: 557 evaluations: 9958400
600 min QD score: 4272665 occupied: 7566 max score: 927 mean score: 564 evaluations: 12083200
700 min QD score: 4351397 occupied: 7632 max score: 941 mean score: 570 evaluations: 14094336
800 min QD score: 4415351 occupied: 7675 max score: 1003 mean score: 575 evaluations: 16040960

These results indicate that it should be possible to apply MAP-Elites to the resulting CR_FM_NES-ME archive
to further improve occupancy and score. As algorithm wrapped by Diversifier,py CRFMNES can be replaced by FCRFMC (same algorithm but implemented in C++). We got the same results, but this may reduce the GPU load for smaller
GPUs/TPUs and is definively advantageous for CPU alone executions. On the Nvidia 3090 CRFMNES is slightly faster.

Note that 'fitness_weight' is a concept neither used in CMA-ME nor in fcmaes fcmaes diversifier. All these use implicitely fitness_weight=0. For fcmaes the reason is that there are other means to improve the elites of a given map, so the focus is on exploration here. We use as default fitness_weight=0, because for Brax Ant the final QD score is higher - but the final global optimum found is lower.

fcmaes even supports sequences of wrapped algorithms, something probably not relevant for EvoJAX.

Increasing popsize to 1024 closes the evaluations / sec gap to MAP-Elites, the rate is 34% higher than with popsize = 512. But popsize = 1024 seems to produce lower occupancy - which is quite suprising:

CR_FM_NES-ME with init-std = 0.159, popsize = 1024, fitness_weight 0.0

20 min QD score: 1864258 occupied: 4856 max score: 538 mean score: 383 evaluations: 350208
50 min QD score: 2702674 occupied: 5402 max score: 848 mean score: 500 evaluations: 905216
100 min QD score: 3853807 occupied: 5873 max score: 1288 mean score: 656 evaluations: 1879040
200 min QD score: 5005292 occupied: 5947 max score: 1781 mean score: 841 evaluations: 3891200
300 min QD score: 6425120 occupied: 5963 max score: 2936 mean score: 1077 evaluations: 5963776
400 min QD score: 7103424 occupied: 5976 max score: 3783 mean score: 1188 evaluations: 8192000
500 min QD score: 7282457 occupied: 5980 max score: 4111 mean score: 1217 evaluations: 10240000
600 min QD score: 7371868 occupied: 5982 max score: 4227 mean score: 1232 evaluations: 12288000
700 min QD score: 7405531 occupied: 5982 max score: 4276 mean score: 1237 evaluations: 14336000
800 min QD score: 7464371 occupied: 5983 max score: 4307 mean score: 1247 evaluations: 16384000
900 min QD score: 7509290 occupied: 5989 max score: 4325 mean score: 1253 evaluations: 18432000
1000min QD score: 7514184 occupied: 5989 max score: 4342 mean score: 1254 evaluations: 20480000

But why can't we have our cake and eat it too?

This is not part of the PR but discusses what could be done in the future:

Both Diversifier and MAP-Elites share the same archive management. They differ only in
population generation. In the future both could be unified into a single MD solver - still called MAP-Elites. This new implementation could randomly chose the way "ask" works.
We define a probability, a wrapped solver is used instead of the standard mechanism. If this
probability is 0, we have the old MAP-Elites. If it is 1.0, we have Diversifier.
The interesting question is: What happens for values in between? Lets try 0.5. This can
easily be implemented as:

    def ask(self) -> jnp.ndarray:
        self.key, key = jax.random.split(self.key)
        if jax.random.uniform(key) > 0.5: # a parameter to play with
            self.population = self.solver.ask() # population from wrapped solver
            self.solver_asked = True
        else: # population from MA-Elites generator
            self.key, mutate_key, parents = self._sample_parents(
                                key=self.key,
                                occupancy=self.occupancy_lattice,
                                params=self.params_lattice)      
            self.population = self._gen_pop(parents, mutate_key)
            self.solver_asked = False
        return self.population

    def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:      
        if self.solver_asked:   
            lattice_fitness = self.fitness_lattice[self.bin_idx]
            to_tell = self._get_to_tell(fitness, lattice_fitness, self.fitness_weight)
            self.solver.tell(to_tell)
        # update lattice 

MAP-Elites + CR_FM_NES-ME with iso-sigma = 0.05, line-sigma = 0.2, init-std = 0.159, popsize = 1024, fitness_weight 0.0

20 min QD score: 2201387 occupied: 5266 max score: 672 mean score: 418 evaluations: 344064
50 min QD score: 2738691 occupied: 6105 max score: 672 mean score: 448 evaluations: 892928
100 min QD score: 3348423 occupied: 6656 max score: 857 mean score: 503 evaluations: 1859584
200 min QD score: 4233393 occupied: 7135 max score: 1103 mean score: 593 evaluations: 3851264
300 min QD score: 5139277 occupied: 7334 max score: 1586 mean score: 700 evaluations: 5893120
400 min QD score: 5776929 occupied: 7457 max score: 1884 mean score: 774 evaluations: 7943168
500 min QD score: 6098261 occupied: 7537 max score: 2104 mean score: 809 evaluations: 9947136
600 min QD score: 7240351 occupied: 7603 max score: 2890 mean score: 952 evaluations: 11999232
700 min QD score: 7800357 occupied: 7660 max score: 3421 mean score: 1018 evaluations: 14023680
800 min QD score: 8004109 occupied: 7699 max score: 3735 mean score: 1039 evaluations: 16000000

This is a 81% QD score increase compared to MAP-Elites alone thereby also improving occupancy.

900 min QD score: 8115904 occupied: 7744 max score: 3917 mean score: 1048 evaluations: 18140160
1000min QD score: 8195842 occupied: 7772 max score: 4019 mean score: 1054 evaluations: 20133888
1100min QD score: 8259249 occupied: 7799 max score: 4090 mean score: 1059 evaluations: 22155264
1200min QD score: 8319027 occupied: 7826 max score: 4130 mean score: 1062 evaluations: 24177664
1300min QD score: 8362034 occupied: 7847 max score: 4156 mean score: 1065 evaluations: 26213376

QD-score 8362034 probably is a challenge for each algorithm independent from the evaluation budget.

@dietmarwo
Copy link
Contributor Author

dietmarwo commented Nov 12, 2022

Added a small fix for CRFMNES - applied jax.jit to generate_population. On my hardware it is only a small improvement.
Tried that also for tell, but got no improvement there, so I left it as it is. I also did some experiments with MAP-Elites
(applying simulated binary crossover or adding some random solutions to sample_parents) inspired by fcmaes mapelites, but Iso+LineDD proved to be optimal for Brax Ant. Only a small occupation improvement could be achieved, not worth to complicate things there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants