In [2]:
from NEAT import EliteSelection, Crossover, Mutation, NEAT, FeedForwardNetwork

# Hyperparameters for CartPole-v1
# ---------------------------------------------------
# POP_SIZE = 256
# N_INPUTS = 4
# N_OUPUTS = 2
# THRESHOLD = 475
# INITIAL_CONNECTIONS = 0


# Hyperparameters for LunarLander-v3
# remember to change the fitness function accordingly
# ---------------------------------------------------
POP_SIZE = 100
N_INPUTS = 8
N_OUPUTS = 4
THRESHOLD = 200
INITIAL_CONNECTIONS = -1


import my_fitness_function

fitness_function = my_fitness_function.fitness_function

selection = EliteSelection(fitness_function, 0.2)
crossover = Crossover()
mutation = Mutation()

neat = NEAT(selection, crossover, mutation, distance_threshold=3.0, parallel=True)
winner = neat.start(POP_SIZE, (N_INPUTS, N_OUPUTS), 1000, THRESHOLD, INITIAL_CONNECTIONS)

Generation 1  -  Fit: -14.59
| spec | #mem | avg fit | best fit | best shape |
|------|------|---------|----------|------------|
| 1    | 100  | -485.5  | -14.6    | (12, 32)   |
'------'------'---------'----------'------------'

Generation 2  -  Fit: -32.27
| spec | #mem | avg fit | best fit | best shape |
|------|------|---------|----------|------------|
| 1    | 100  | -128.1  | -32.3    | (12, 32)   |
'------'------'---------'----------'------------'

Generation 3  -  Fit: -38.84
| spec | #mem | avg fit | best fit | best shape |
|------|------|---------|----------|------------|
| 1    | 100  | -112.3  | -38.8    | (12, 32)   |
'------'------'---------'----------'------------'

Generation 4  -  Fit: -78.52
| spec | #mem | avg fit | best fit | best shape |
|------|------|---------|----------|------------|
| 1    | 97   | -97.8   | -78.5    | (12, 32)   |
| 2    | 2    | -97.9   | -97.9    | (14, 35)   |
| 3    | 1    | -97.9   | -97.9    | (13, 35)   |
'------'------'---------'------

In [None]:
from gymnasium.utils.save_video import save_video
from IPython import display
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

def render_genome(genome, env_name, n_frames, frames=10):
    network = FeedForwardNetwork.create(genome)
    env = gym.make(env_name, render_mode="rgb_array_list")
    state = env.reset()[0]
    done = False

    for f in range(n_frames):
        if f % frames == 0:
            plt.imshow(env.render()[0])
            display.display(plt.gcf())    
            display.clear_output(wait=True)

        pred = network.activate(state)
        action = np.argmax(pred)

        next_state, _, done, _, _ = env.step(action)

        state = next_state

        if done:
            env.reset()
            break
    
    env.close()
        
def save_vid(env_name, genome, n_frames):
    network = FeedForwardNetwork.create(genome)
    env = gym.make(env_name, render_mode="rgb_array")
    state = env.reset()[0]
    step_starting_index = 0
    episode_index = 0
    frames = []

    for step_index in range(n_frames):
        pred = network.activate(state)
        action = np.argmax(pred)
        next_state, _, terminated, truncated, _ = env.step(action)
        state = next_state
        frames.append(env.render())

        if terminated or truncated:
            save_video(
                frames,
                "video",
                fps=env.metadata["render_fps"],
                step_starting_index=step_starting_index,
                episode_index=episode_index
            )
            step_starting_index = step_index + 1
            episode_index += 1
            env.reset()
    env.close()

In [None]:
render_genome(winner, "LunarLander-v3", 1000)

In [None]:
save_vid("LunarLander-v3", winner, 1000)