In [4]:

import argparse
from dataclasses import dataclass
import os
import shutil
import jax
import jax.numpy as jnp

from evojax.policy.base import PolicyNetwork, PolicyState
from evojax.task.base import TaskState
from evojax.task.slimevolley import SlimeVolley
from evojax.policy.mlp import MLPPolicy
from evojax.algo import CMA
from evojax import Trainer
from evojax import util

from flax import linen as nn

In [2]:

class NEATModel(nn.Module):
    input_dim: int
    output_dim: int
    max_nodes: int
    max_edges: int

    @nn.compact
    def __call__(self, x):
        nodes = self.param('nodes', nn.initializers.zeros, (self.max_nodes, 3)) # (id, activattion)
        edges = self.param('edges', nn.initializers.zeros, (self.max_edges, 3)) # (from, to, weight, enabled)
        values = jnp.zeros((self.max_nodes, ))
        return jax.slice(values, -self.output_dim)
class NEATPolicy(PolicyNetwork):
  def __init__(self, input_dim: int, output_dim: int):
    self._input_dim = input_dim
    self._output_dim = output_dim
    # TODO: Implement NEAT policy.
    self.num_params = 1
    model = NEATModel(input_dim, output_dim, 100, 100)
    params = model.init(jax.random.PRNGKey(0), jnp.ones([1, input_dim]))
    print(params)
    self.num_params, format_params_fn = util.get_params_format_fn(params)
    self._format_params_fn = jax.vmap(format_params_fn)
    self._forward_fn = jax.vmap(model.apply)

  def get_actions(self, t_states, params, p_states: PolicyState) -> tuple[jax.Array, PolicyState]:
    return self._forward_fn(params, t_states.obs), p_states


In [25]:
@jax.tree_util.register_dataclass
@dataclass
class Genome:
    """Stores node and connection info in JAX arrays."""

    node_ids: jax.Array  # shape [num_nodes]
    # 0: relu
    node_activation: jax.Array  # shape [num_nodes]

    # Connections.
    conn_id: jax.Array  # shape [num_connections]
    conn_in: jax.Array  # shape [num_connections]
    conn_out: jax.Array  # shape [num_connections]
    conn_weights: jax.Array  # shape [num_connections]
    conn_enabled: jax.Array  # shape [num_connections]

In [43]:
max_nodes = 5
input_dim = 2
node_ids = jnp.array([0, 1, 2, 3, 4])
node_activation = jnp.array([0, 0, 0, 0, 2])

# 0 1
# | |
# 2 3
#  \|
#   4
conn_id = jnp.array([0, 1, 2, 3, 4])
conn_in = jnp.array([0, 1, 2, 3, -1])
conn_out = jnp.array([2, 3, 4, 4, 0])
conn_weights = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0])
conn_enabled = jnp.array([True, True, True, True, False])
display(node_ids)
display(conn_out)

genome = Genome(node_ids, node_activation, conn_id, conn_in, conn_out, conn_weights, conn_enabled)

Array([0, 1, 2, 3, 4], dtype=int32)

Array([2, 3, 4, 4, 0], dtype=int32)

In [138]:
def get_adjacency_matrix_fn(max_nodes: int, input_dim: int, genome: Genome):
    matrix = jnp.zeros((max_nodes, max_nodes))
    matrix = matrix + jnp.pad(jnp.identity(input_dim), ((0, max_nodes - input_dim), (0, max_nodes - input_dim)))
    matrix = matrix.at[genome.conn_out, genome.conn_in].set(
        jnp.where(genome.conn_enabled, genome.conn_weights, 0.0)
    )
    return matrix

def create_genome_fn(node_ids, node_activation, conn_id, conn_in, conn_out, conn_weights, conn_enabled):
    return Genome(node_ids, node_activation, conn_id, conn_in, conn_out, conn_weights, conn_enabled)

get_adjacency_matrix = jax.jit(jax.vmap(get_adjacency_matrix_fn, in_axes=(None, None, 0)), static_argnums=(0, 1))
create_genome = jax.jit(jax.vmap(create_genome_fn, in_axes=0))

In [54]:
display(get_adjacency_matrix)

<PjitFunction of <function get_adjacency_matrix_fn at 0x3050f6840>>

In [140]:
genome = create_genome(node_ids, node_activation, conn_id, conn_in, conn_out, conn_weights, conn_enabled)
genome = jax.tree.map(lambda x: jnp.array([x]), genome)
display(genome)
adj = get_adjacency_matrix(max_nodes, input_dim, genome)
display(adj)

def matmul_fn(adj, x):
    return jnp.matmul(adj, x)

matmul = jax.vmap(matmul_fn)
x = jnp.array([[1, 1, 0, 0, 0]])
display(matmul(adj, x))
display(matmul(adj, matmul(adj, x)))
display(matmul(adj, matmul(adj, matmul(adj, x))))

Genome(node_ids=Array([[0, 1, 2, 3, 4]], dtype=int32), node_activation=Array([[0, 0, 0, 0, 0]], dtype=int32), conn_id=Array([[0, 1, 2, 3, 4]], dtype=int32), conn_in=Array([[ 0,  1,  2,  3, -1]], dtype=int32), conn_out=Array([[2, 3, 4, 4, 0]], dtype=int32), conn_weights=Array([[1., 1., 1., 1., 1.]], dtype=float32), conn_enabled=Array([[ True,  True,  True,  True, False]], dtype=bool))

Array([[[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0.]]], dtype=float32)

Array([[1., 1., 1., 1., 0.]], dtype=float32)

Array([[1., 1., 1., 1., 2.]], dtype=float32)

Array([[1., 1., 1., 1., 2.]], dtype=float32)

In [170]:
# TODO: Check NEAT-python for activation functions.
ACTIVATION_FUNCTIONS = [lambda x: x, jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]

def forward_fn(max_depth: int, genome: Genome, x: jax.Array):
    print(genome.node_ids.shape)
    adj = get_adjacency_matrix_fn(max_nodes, input_dim, genome)
    for _ in range(max_depth):
        x = matmul_fn(adj, x)

        print(adj.shape)
        print(x.shape)
        print(genome.node_activation.shape)

        # Apply the activation function.
        def activate_fn(activation, x):
            return jax.lax.switch(activation, ACTIVATION_FUNCTIONS, x)
        activate = jax.vmap(activate_fn)
        x = activate(genome.node_activation, x)
    return x

forward = jax.jit(jax.vmap(forward_fn, in_axes=(None, 0, 0)), static_argnums=(0))
# forward = jax.vmap(forward_fn, in_axes=(None, 0, 0))

forward(3, genome, x)

(5,)
(5, 5)
(5,)
(5,)
(5, 5)
(5,)
(5,)
(5, 5)
(5,)
(5,)


Array([[1., 1., 1., 1., 2.]], dtype=float32)

In [169]:
genome2 = Genome(node_ids, jnp.array([1, 1, 1, 1, 1]), conn_id, conn_in, conn_out, conn_weights, conn_enabled)
genome2 = jax.tree.map(lambda x: jnp.array([x]), genome2)

forward(3, genome2, x)

Array([[0.66263026, 0.66263026, 0.66263026, 0.66263026, 0.7941419 ]],      dtype=float32)

In [129]:
%env JAX_TRACEBACK_FILTERING=off
ACTIVATION_FUNCTIONS = [jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh, lambda x: x]
# ACTIVATION_FUNCTIONS = [jax.nn.relu]

jax.lax.switch(0, ACTIVATION_FUNCTIONS, jnp.array([0.]))

env: JAX_TRACEBACK_FILTERING=off


Array([0.], dtype=float32)