# NEAT Training on Slime Volley with EvoJax
This notebook demonstrates how to run PrettyNEAT on the Slime Volley game.
EvoJax provides a JAX based implementation of the environment for fast vectorised simulation.
We'll install the required packages, load the NEAT code from this repository and run a small training loop.

In [1]:
# Install EvoJax and the SlimeVolleyGym environment
!pip install -q evojax slimevolleygym opencv-python-headless

[0m

In [4]:
import sys, json, numpy as np, matplotlib.pyplot as plt
sys.path.append('prettyNEAT')
from neat_src import Neat
from vis.viewInd import viewInd
from domain.config import games
from domain.task_gym import GymTask
from slimevolleygym import SlimeVolley-v0
from evojax.task.slimevolley import SlimeVolley-v0

SyntaxError: invalid syntax (2998850439.py, line 7)

In [None]:
# Load the default NEAT settings and the Slime Volley overrides
with open('prettyNEAT/p/default_neat.json') as f:
    hyp = json.load(f)
with open('prettyNEAT/p/slime_volley.json') as f:
    hyp.update(json.load(f))
# Reduce population and generations for a quick demo
hyp['popSize'] = 16
hyp['maxGen'] = 5
hyp['alg_nReps'] = 1
hyp['task'] = 'slime'
hyp['save_mod'] = 2
hyp['bestReps'] = 5
display(hyp)

In [None]:
# Create the NEAT algorithm and a Gym based task for Slime Volley
task = GymTask(games[hyp['task']], nReps=hyp['alg_nReps'])
neat = Neat(hyp)

In [None]:
history = []
for gen in range(hyp['maxGen']):
    pop = neat.ask()
    fit = []
    for ind in pop:
        f = task.getFitness(ind.wMat.flatten(), ind.aVec)
        fit.append(f)
    neat.tell(np.array(fit))
    history.append([gen, np.mean(fit), np.max(fit)])
    print(f'gen {gen}: mean {history[-1][1]:.2f} top {history[-1][2]:.2f}')

In [None]:
history = np.array(history)
plt.figure(figsize=(6,4))
plt.plot(history[:,0], history[:,1], label='mean fitness')
plt.plot(history[:,0], history[:,2], label='top fitness')
plt.xlabel('generation')
plt.ylabel('fitness')
plt.legend()
plt.show()

In [None]:
# Visualise the best individual of the last generation
best = neat.pop[ np.argmax([ind.fitness for ind in neat.pop]) ]
viewInd(best, hyp['task'])
plt.show()

In [None]:
# Play a short match against the built in AI using the trained network
env = SlimeVolleyEnv(difficulty=1)
state = env.reset()
total = 0
for step in range(1000):
    a = task.getFitness(best.wMat.flatten(), best.aVec, view=True)
    # environment automatically handles opponent
    state, reward, done, _ = env.step(a)
    total += reward
    env.render()
    if done:
        break
env.close()
print('Total reward', total)