In [1]:
%load_ext autoreload
%autoreload 2

from neat_jax.config import NEATConfig
from neat_jax.neat import NEAT
import jax
import evojax
from evojax.task import flocking
from evojax.util import create_logger

log_dir = "./logs"
logger = create_logger(name="flocking", log_dir=log_dir)
logger.info("Starting!")
logger.info('Jax backend: {}'.format(jax.local_devices()))




flocking: 2024-10-12 23:01:53,741 [INFO] Starting!
jax._src.xla_bridge: 2024-10-12 23:01:54,138 [INFO] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
jax._src.xla_bridge: 2024-10-12 23:01:54,140 [INFO] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
flocking: 2024-10-12 23:01:54,141 [INFO] Jax backend: [CudaDevice(id=0)]


In [2]:
seed = 42
neighbor_num = 5

rollout_key = jax.random.PRNGKey(seed=seed)

reset_key, rollout_key = jax.random.split(rollout_key, 2)
reset_key = reset_key[None, :] 

train_task = flocking.FlockingTask(100, action_type=1)
test_task = flocking.FlockingTask(100, action_type=1)
from neat_jax.fitness import make_fitness_fn
train_fitness_fn = make_fitness_fn(
    task=train_task,
    num_steps=100,
)
test_fitness_fn = make_fitness_fn(
    task=test_task,
    num_steps=100,
)

In [3]:
import neat_jax.activations as act

output_size = train_task.act_shape[0]
input_size = train_task.obs_shape[0]
activation_fns = [act.iden, act.sigmoid, act.tanh, act.relu]
# output activation should be tanh
# tanh, tanh, tanh
output_activation_ids = [2, 2]
# identity, identity, identity
input_activation_ids = [0 for _ in range(input_size)]

def get_input_label(idx):
    neighbor_num = idx // 3
    if idx % 3 == 0:
        return f"Neighbor {neighbor_num}: X"
    elif idx % 3 == 1:
        return f"Neighbor {neighbor_num}: Y"
    else:
        return f"Neighbor {neighbor_num}: Theta"
    
def get_output_label(idx):
    if idx == 0:
        return "d_theta"
    else:
        return "d_speed"


from neat_jax.config import GenomeConfig, SelectionConfig, MutationConfig, NEATConfig
from neat_jax.species import make_improvement_stagnation_fn, make_remove_last_if_stagnant_and_full_stagnation_fn
mutation_config = MutationConfig(
    add_connection_prob = 0.05,
    add_node_prob = 0.03,
    disable_connection_prob = 0.05,
    disable_node_prob = 0.03,
    mutate_weight_std = 0.1,
    mutate_weight_prob = 0.8,
    mutate_activation_prob = 0.0,
    mutate_bias_prob=0.0,
    mutate_bias_std=0.0,
)

genome_config = GenomeConfig(
    input_size = input_size,
    output_size = output_size,
    initial_capacity = 100,
    init_weight_mean= 0.0,
    init_weight_std= 1.0,
    activation_fns = activation_fns,
    input_activation_ids=input_activation_ids,
    output_activation_ids=output_activation_ids,
    input_labels = [get_input_label(i) for i in range(input_size)],
    output_labels = [get_output_label(i) for i in range(output_size)],
    capacity_growth_strategy="linear",
    init_mode="full"
)

selection_config = SelectionConfig(
    population_size=100,
    cutoff_pct=0.1,
    speciation_threshold = 1.0,
    compatibility_coefficients = (1.0, 0.9),
    elitism=0.1,
    stagnation_fn=make_improvement_stagnation_fn(max_stagnated_steps=20),
    max_transfer_age=1,
    maximum_species=5,
    selection_tournament_size = 8,
    min_species_size=0,
    species_warmup_threshold=1
)

neat_config = NEATConfig(
    mutation_config=mutation_config,
    genome_config=genome_config,
    selection_config=selection_config,
)

In [4]:
import concurrent.futures
import os
import tqdm
import wandb
from evojax.task.flocking import render_single

logdir = "./logs"
def render_fn(state):
    gif_file = os.path.join(logdir, 'flocking.gif')
    with jax.default_device(jax.devices("cpu")[0]):
        screens = [render_single(state.state[i]) for i in range(state.obs.shape[0])]
        screens[0].save(gif_file, save_all=True, append_images=screens[1:], duration=40, loop=0)
    print("Render complete!")
    return wandb.Video(gif_file, fps=6)



neat = NEAT(
    config = neat_config,
    fitness_fn = train_fitness_fn,
    baseline_test_fn = test_fitness_fn,
    wandb_project = "turbo-neat"
)

population = neat.run(seed=0, num_generations=5, render_fn=render_fn)



[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mlowrollr[0m. Use [1m`wandb login --relogin`[0m to force relogin


Rendering and logging generation 0
{'generation': '0.00', 'mean_fitness': '-120.09', 'max_fitness': '-83.84', 'min_fitness': '-138.48', 'mean_hidden_nodes': '0.00', 'mean_condensed_hidden_nodes': '0.00', 'mean_connections': '30.00', 'mean_condensed_connections': '30.00', 'num_species': '1.00', 'fitness_against_baseline': '-94.23'}
{'generation': '1.00', 'mean_fitness': '-117.45', 'max_fitness': '-85.87', 'min_fitness': '-139.23', 'mean_hidden_nodes': '0.02', 'mean_condensed_hidden_nodes': '0.02', 'mean_connections': '29.99', 'mean_condensed_connections': '29.99', 'num_species': '5.00', 'species_dominance/s0': '0.96', 'species_fitness/s0': '-21.77', 'species_dominance/s1': '0.01', 'species_fitness/s1': '-22.16', 'species_dominance/s2': '0.01', 'species_fitness/s2': '-23.77', 'species_dominance/s3': '0.01', 'species_fitness/s3': '-23.01', 'species_dominance/s4': '0.01', 'species_fitness/s4': '-24.23'}
{'generation': '2.00', 'mean_fitness': '-113.20', 'max_fitness': '-83.27', 'min_fitness