In [1]:
import argparse
from dataclasses import dataclass
import os
import shutil
import jax
import jax.numpy as jnp

from evojax.policy.base import PolicyNetwork, PolicyState
from evojax.task.base import TaskState
from evojax.task.slimevolley import SlimeVolley
from evojax.policy.mlp import MLPPolicy
from evojax.algo import CMA
from evojax import Trainer
from evojax import util

from flax import linen as nn

In [2]:
class NEATModel(nn.Module):
    input_dim: int
    output_dim: int
    max_nodes: int
    max_edges: int

    @nn.compact
    def __call__(self, x):
        nodes = self.param(
            "nodes", nn.initializers.zeros, (self.max_nodes, 3)
        )  # (id, activattion)
        edges = self.param(
            "edges", nn.initializers.zeros, (self.max_edges, 3)
        )  # (from, to, weight, enabled)
        values = jnp.zeros((self.max_nodes,))
        return jax.slice(values, -self.output_dim)


class NEATPolicy(PolicyNetwork):
    def __init__(self, input_dim: int, output_dim: int):
        self._input_dim = input_dim
        self._output_dim = output_dim
        # TODO: Implement NEAT policy.
        self.num_params = 1
        model = NEATModel(input_dim, output_dim, 100, 100)
        params = model.init(jax.random.PRNGKey(0), jnp.ones([1, input_dim]))
        print(params)
        self.num_params, format_params_fn = util.get_params_format_fn(params)
        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(model.apply)

    def get_actions(
        self, t_states, params, p_states: PolicyState
    ) -> tuple[jax.Array, PolicyState]:
        return self._forward_fn(params, t_states.obs), p_states

In [3]:
@jax.tree_util.register_dataclass
@dataclass
class Genome:
    """Stores node and connection info in JAX arrays."""

    node_ids: jax.Array  # shape [num_nodes]
    # 0: relu
    node_activation: jax.Array  # shape [num_nodes]

    # Connections.
    conn_id: jax.Array  # shape [num_connections]
    conn_in: jax.Array  # shape [num_connections]
    conn_out: jax.Array  # shape [num_connections]
    conn_weights: jax.Array  # shape [num_connections]
    conn_enabled: jax.Array  # shape [num_connections]


@dataclass(frozen=True)
class Config:
    """Stores static info for the NEAT algorithm."""

    input_dim: int
    output_dim: int
    max_nodes: int
    max_edges: int
    max_depth: int


config = Config(input_dim=2, output_dim=1, max_nodes=10, max_edges=20, max_depth=10)

In [4]:
node_ids = jnp.array([0, 1, 2, 3, 4])
node_activation = jnp.array([0, 0, 0, 0, 2])

# 0 1
# | |
# 2 3
#  \|
#   4
conn_id = jnp.array([0, 1, 2, 3])
conn_in = jnp.array([0, 1, 2, 3])
conn_out = jnp.array([2, 3, 4, 4])
conn_weights = jnp.array([1.0, 1.0, 1.0, 1.0])
conn_enabled = jnp.array([True, True, True, True])


def create_genome(
    config: Config,
    node_ids,
    node_activation,
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
) -> Genome:
    node_pad_config = [(0, config.max_nodes - node_ids.shape[0], 0)]
    node_ids = jax.lax.pad(node_ids, -1, node_pad_config)
    node_activation = jax.lax.pad(node_activation, -1, node_pad_config)

    conn_pad_config = [(0, config.max_edges - conn_id.shape[0], 0)]
    conn_id = jax.lax.pad(conn_id, -1, conn_pad_config)
    conn_in = jax.lax.pad(conn_in, -1, conn_pad_config)
    conn_out = jax.lax.pad(conn_out, -1, conn_pad_config)
    conn_weights = jax.lax.pad(conn_weights, jnp.nan, conn_pad_config)
    conn_enabled = jax.lax.pad(conn_enabled, False, conn_pad_config)
    return Genome(
        node_ids,
        node_activation,
        conn_id,
        conn_in,
        conn_out,
        conn_weights,
        conn_enabled,
    )


genome = create_genome(
    config,
    node_ids,
    node_activation,
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
)
genome = jax.tree.map(lambda x: jnp.array([x]), genome)
genome

Genome(node_ids=Array([[ 0,  1,  2,  3,  4, -1, -1, -1, -1, -1]], dtype=int32), node_activation=Array([[ 0,  0,  0,  0,  2, -1, -1, -1, -1, -1]], dtype=int32), conn_id=Array([[ 0,  1,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_in=Array([[ 0,  1,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_out=Array([[ 2,  3,  4,  4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_weights=Array([[ 1.,  1.,  1.,  1., nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan]], dtype=float32), conn_enabled=Array([[ True,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False]], dtype=bool))

In [5]:
def get_adjacency_matrix_fn(config: Config, genome: Genome):
    matrix = jnp.zeros((config.max_nodes, config.max_nodes))
    matrix = matrix + jnp.pad(
        jnp.identity(config.input_dim),
        (
            (0, config.max_nodes - config.input_dim),
            (0, config.max_nodes - config.input_dim),
        ),
    )
    matrix = matrix.at[genome.conn_out, genome.conn_in].set(
        jnp.where(genome.conn_enabled, genome.conn_weights, 0.0)
    )
    return matrix


get_adjacency_matrix = jax.jit(
    jax.vmap(get_adjacency_matrix_fn, in_axes=(None, 0)), static_argnums=(0)
)

In [6]:
display(get_adjacency_matrix)

<PjitFunction of <function get_adjacency_matrix_fn at 0x147efe200>>

In [7]:
adj = get_adjacency_matrix(config, genome)
display(adj)


def matmul_fn(adj, x):
    return jnp.matmul(adj, x)


matmul = jax.vmap(matmul_fn)
x = jnp.stack([jnp.pad(jnp.array([1, 1, 0, 0, 0]), (0, 5))])
display(matmul(adj, x))
display(matmul(adj, matmul(adj, x)))
display(matmul(adj, matmul(adj, matmul(adj, x))))

Array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32)

Array([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

In [8]:
# TODO: Check NEAT-python for activation functions.
ACTIVATION_FUNCTIONS = [lambda x: x, jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]


def forward_fn(config: Config, genome: Genome, x: jax.Array):
    print(genome.node_ids.shape)
    adj = get_adjacency_matrix_fn(config, genome)
    for _ in range(config.max_depth):
        x = matmul_fn(adj, x)

        print(adj.shape)
        print(x.shape)
        print(genome.node_activation.shape)

        # Apply the activation function.
        def activate_fn(activation, x):
            return jax.lax.switch(activation, ACTIVATION_FUNCTIONS, x)

        activate = jax.vmap(activate_fn)
        x = activate(genome.node_activation, x)
    return x


forward = jax.jit(jax.vmap(forward_fn, in_axes=(None, 0, 0)), static_argnums=(0))
# forward = jax.vmap(forward_fn, in_axes=(None, 0, 0))

forward(config, genome, x)

(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)


Array([[1.      , 1.      , 1.      , 1.      , 0.880797, 0.      ,
        0.      , 0.      , 0.      , 0.      ]], dtype=float32)

In [9]:
genome2 = create_genome(
    config,
    node_ids,
    jnp.array([1, 1, 1, 1, 1]),
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
)
genome2 = jax.tree.map(lambda x: jnp.array([x]), genome2)

forward(config, genome2, x)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

In [10]:
import neat

In [11]:
from evojax.task.base import TaskState

In [12]:
def pad(x):
    return jnp.pad(x, (0, 10))


pad = jax.vmap(pad)
pad(jnp.arange(10)[None, :])

Array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],      dtype=int32)

In [13]:
idx = jnp.array([0, 1, 2, 3, 4])
enabled = jnp.array([True, False, True, False, True])


@jax.jit
def chose_enabled(enabled):
    idx = jnp.arange(enabled.shape[0])
    filtered = jnp.sort(jnp.where(enabled, idx, -1), descending=True)
    chosen = jax.random.randint(jax.random.PRNGKey(1), 10, 0, enabled.sum())
    return filtered[chosen]


chose_enabled(enabled)

Array([4, 4, 2, 4, 4, 4, 0, 2, 0, 0], dtype=int32)

In [22]:
def display_genome(genome: Genome, config: Config, show_disabled: bool = False):
    from IPython.display import display, Markdown
    mermaid = genome_to_mermaid(genome, config, show_disabled)
    print(mermaid)
    display(Markdown(f"```mermaid\n{mermaid}\n```"))

display_genome(pickup_genome(genome, 0), config)

AttributeError: 'Genome' object has no attribute 'conn_ids'

In [18]:
genome = create_empty_genome_fn(config)
x = mutate_add_node(genome, jax.random.PRNGKey(0), 30)[0]
print(x)
display_genome(x, config, show_disabled=True)

NameError: name 'create_empty_genome_fn' is not defined

In [19]:
%run neat.py

In [23]:
jnp.argmin(jnp.array([1, 2, 3, 3, 1, 1]))


Array(0, dtype=int32)

In [34]:
max_conns = genome.conn_ids.shape[0]
jnp.min(jnp.where(genome.conn_ids < 0, jnp.arange(max_conns), max_conns))
genome.conn_ids

DEBUG:2024-12-31 11:29:47,306:jax._src.interpreters.pxla:1906: Compiling less with global shapes and types [ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
jax._src.interpreters.pxla: 2024-12-31 11:29:47,306 [DEBUG] Compiling less with global shapes and types [ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
DEBUG:2024-12-31 11:29:47,310:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(less) in 0.002184153 sec
jax._src.dispatch: 2024-12-31 11:29:47,310 [DEBUG] Finished jaxpr to MLIR module conversion jit(less) in 0.002184153 sec
DEBUG:2024-12-31 11:29:47,311:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
jax._src.compiler: 2024-12-31 11:29:47,311 [DEBUG] get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:2024-12-31 11

Array([ 0,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1], dtype=int32)

In [35]:
key, key_src, key_dst, key_weight = jax.random.split(key, 4)

node_ids = genome.node_ids
node_activation = genome.node_activation
conn_ids = genome.conn_ids
conn_in = genome.conn_in
conn_out = genome.conn_out
conn_weights = genome.conn_weights
conn_enabled = genome.conn_enabled

src = pick_one(key_src, node_ids >= 0)
dst = pick_one(key_dst, node_ids >= 0)
(src, dest)
src == dst


NameError: name 'key' is not defined

In [26]:
%run neat.py

config = Config(input_dim=2, output_dim=1, max_nodes=10, max_edges=20, max_depth=10)
neat_config = NEATConfig(config, pop_size=30, prob_add_connection=0.25, prob_add_node=0.5)

genome = create_empty_genome_fn(config)
edge = jnp.array(30)
key = jax.random.key(10)
for _ in range(20):
    key, key_mutate = jax.random.split(key, 2)
    genome, edge = mutate_genome(neat_config, genome, key_mutate, edge)
    display_genome(genome, config)

# TypeError: branch 0 and 2 outputs must have same type structure, got
# PyTreeDef((CustomNode(Genome[()], [*, *, *, *, *, *, *]), *)) and
# PyTreeDef((CustomNode(Genome[()], [*, *, *, *, *, *, *]), *)).


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    end
    i0 -->|0.00| o2
    i1 -->|1.00| h3
    h3 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h4 -->|1.00| h5
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h5 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h5 -->|0.00| o2
    h3 -->|1.00| h6
    h6 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    end
    h3 -->|0.00| o2
    i0 -->|1.00| h4
    h4 -->|0.00| o2
    h5 -->|0.00| o2
    h3 -->|1.00| h6
    h6 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.40| o2
    h5 -->|0.00| o2
    h3 -->|-0.01| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.40| o2
    h5 -->|0.00| o2
    h3 -->|-0.01| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h5 -->|1.00| h7
    h7 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h5 -->|1.00| h7
    h7 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h5 -->|1.00| h7
    h7 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h5 -->|1.00| h7
    h7 -->|0.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h3 -->|-0.01| h6
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    end
    i1 -->|-0.23| h3
    h4 -->|0.16| o2
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    end
    i1 -->|0.88| o2
    i1 -->|-0.23| h3
    h4 -->|-0.34| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    end
    i1 -->|0.88| o2
    i1 -->|-0.23| h3
    h4 -->|-0.34| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i1 -->|0.88| o2
    i1 -->|-0.23| h3
    h4 -->|-0.34| o2
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    h5 -->|1.00| h9
    h9 -->|0.11| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i1 -->|0.88| o2
    i1 -->|-0.23| h3
    h4 -->|-0.34| o2
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    h5 -->|1.00| h9
    h9 -->|0.11| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    h5 -->|1.00| h9
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    h8 -->|0.80| h6
    h5 -->|1.00| h9
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h4 -->|-0.54| h5
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h4 -->|-0.54| h5
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|1.00| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i1 -->|0.88| o2
    i1 -->|0.76| h3
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h4 -->|-0.54| h5
    h6 -->|0.00| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|-0.43| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i1 -->|0.88| o2
    i1 -->|0.76| h3
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h4 -->|-0.54| h5
    h6 -->|0.00| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|-0.43| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i1 -->|0.88| o2
    i1 -->|0.76| h3
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h6 -->|0.00| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|-0.43| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden_node
    h4(["Hidden 4"]):::hidden_node
    h5(["Hidden 5"]):::hidden_node
    h6(["Hidden 6"]):::hidden_node
    h7(["Hidden 7"]):::hidden_node
    h8(["Hidden 8"]):::hidden_node
    h9(["Hidden 9"]):::hidden_node
    end
    i0 -->|0.38| o2
    i1 -->|0.88| o2
    i1 -->|0.76| h3
    i0 -->|-0.46| h4
    h4 -->|-0.34| o2
    h6 -->|0.00| o2
    h5 -->|0.11| h7
    h7 -->|0.00| o2
    i0 -->|0.37| h3
    h3 -->|-0.43| h8
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden_node fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

In [102]:
jnp.max(jnp.array([[1, 2], [3, 4]]))

DEBUG:2024-12-31 12:41:23,287:jax._src.dispatch:182: Finished tracing + transforming convert_element_type for pjit in 0.001977205 sec
jax._src.dispatch: 2024-12-31 12:41:23,287 [DEBUG] Finished tracing + transforming convert_element_type for pjit in 0.001977205 sec
DEBUG:2024-12-31 12:41:23,291:jax._src.interpreters.pxla:1906: Compiling convert_element_type with global shapes and types [ShapedArray(int32[2,2])]. Argument mapping: (UnspecifiedValue,).
jax._src.interpreters.pxla: 2024-12-31 12:41:23,291 [DEBUG] Compiling convert_element_type with global shapes and types [ShapedArray(int32[2,2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-31 12:41:23,304:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.011681795 sec
jax._src.dispatch: 2024-12-31 12:41:23,304 [DEBUG] Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.011681795 sec
DEBUG:2024-12-31 12:41:23,384:jax._src.compiler:167: get_compile_options: num_repli

Array(4, dtype=int32)

In [150]:
%run neat.py

In [27]:
%run neat.py

test3()

[dtype('int32'), dtype('int32'), dtype('int32'), dtype('int32'), dtype('int32'), dtype('float32'), dtype('bool')]
[ 50 100 300 500 700 900]


SlimeVolley: 2025-01-02 14:13:18,149 [INFO] use_for_loop=False
SlimeVolley: 2025-01-02 14:13:18,177 [INFO] Start to train for 30 iterations.
SlimeVolley: 2025-01-02 14:13:27,659 [INFO] Iter=1, size=300, max=-22.0000, avg=-35.1967, min=-42.0000, std=2.3829
SlimeVolley: 2025-01-02 14:13:30,770 [INFO] Iter=2, size=300, max=-22.0000, avg=-33.8467, min=-41.0000, std=3.3161
SlimeVolley: 2025-01-02 14:13:33,865 [INFO] Iter=3, size=300, max=-23.0000, avg=-33.7733, min=-41.0000, std=3.3875
SlimeVolley: 2025-01-02 14:13:37,043 [INFO] Iter=4, size=300, max=-21.0000, avg=-31.8433, min=-38.0000, std=3.1218
SlimeVolley: 2025-01-02 14:13:40,232 [INFO] Iter=5, size=300, max=-20.0000, avg=-32.5533, min=-41.0000, std=4.2865
SlimeVolley: 2025-01-02 14:13:43,578 [INFO] Iter=6, size=300, max=-16.0000, avg=-33.2133, min=-42.0000, std=4.0193
SlimeVolley: 2025-01-02 14:13:46,880 [INFO] Iter=7, size=300, max=-25.0000, avg=-33.9033, min=-43.0000, std=3.3328
SlimeVolley: 2025-01-02 14:13:50,037 [INFO] Iter=8, si

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    i2(["Input 2"]):::input
    i3(["Input 3"]):::input
    i4(["Input 4"]):::input
    i5(["Input 5"]):::input
    i6(["Input 6"]):::input
    i7(["Input 7"]):::input
    i8(["Input 8"]):::input
    i9(["Input 9"]):::input
    i10(["Input 10"]):::input
    i11(["Input 11"]):::input
    end
    subgraph Output
    o12(["Output 12"]):::output
    o13(["Output 13"]):::output
    o14(["Output 14"]):::output
    end
    subgraph Hidden
    h15(["Hidden 15"]):::hidden_node
    h16(["Hidden 16"]):::hidden_node
    end
    i0 -.->|0.00| o12
    i1 -.->|0.68| o12
    i2 -->|-0.96| o12
    i3 -->|0.71| o12
    i4 -->|-0.91| o12
    i5 -->|0.77| o12
    i6 -->|0.24| o12
    i7 -.->|0.36| o12
    i8 -->|0.11| o12
    i9 -->|-0.79| o12
    i10 -->|0.23| o12
    i11 -.->|-0.88| o12
    i0 -->|0.43| o13
    i1 -.->|0.60| o13
    i2 -->|-0.46| o13
    i3 -.->|0.62| o13
    i4 -.->|0.12| o13
    i5 -.->|0.00| o13
    

SlimeVolley: 2025-01-02 14:15:23,895 [INFO] GIF saved to ./log/slimevolley_20250102-141317/slimevolley.gif.


In [175]:
jnp.nanmean(jnp.array([1, 2, 3, jnp.nan]))

DEBUG:2025-01-02 13:58:36,780:jax._src.dispatch:182: Finished tracing + transforming _reduce_sum for pjit in 0.021831989 sec
jax._src.dispatch: 2025-01-02 13:58:36,780 [DEBUG] Finished tracing + transforming _reduce_sum for pjit in 0.021831989 sec
DEBUG:2025-01-02 13:58:36,889:jax._src.dispatch:182: Finished tracing + transforming _broadcast_arrays for pjit in 0.049833298 sec
jax._src.dispatch: 2025-01-02 13:58:36,889 [DEBUG] Finished tracing + transforming _broadcast_arrays for pjit in 0.049833298 sec
DEBUG:2025-01-02 13:58:36,904:jax._src.dispatch:182: Finished tracing + transforming _where for pjit in 0.092578888 sec
jax._src.dispatch: 2025-01-02 13:58:36,904 [DEBUG] Finished tracing + transforming _where for pjit in 0.092578888 sec
DEBUG:2025-01-02 13:58:36,910:jax._src.dispatch:182: Finished tracing + transforming nansum for pjit in 0.103650093 sec
jax._src.dispatch: 2025-01-02 13:58:36,910 [DEBUG] Finished tracing + transforming nansum for pjit in 0.103650093 sec
DEBUG:2025-01-02

Array(2., dtype=float32)

In [30]:
%run neat.py
test3()

SlimeVolley: 2025-01-02 14:17:56,690 [INFO] use_for_loop=False
SlimeVolley: 2025-01-02 14:17:56,745 [INFO] Start to train for 30 iterations.


[dtype('int32'), dtype('int32'), dtype('int32'), dtype('int32'), dtype('int32'), dtype('float32'), dtype('bool')]
[ 50 100 300 500 700 900]
size: [300   0   0   0   0], pop_size: [0 0 0 0 0]


SlimeVolley: 2025-01-02 14:18:06,260 [INFO] Iter=1, size=300, max=-22.0000, avg=-35.1967, min=-42.0000, std=2.3829


size: [219  18  23  29  11], pop_size: [63 59 56 59 66]


SlimeVolley: 2025-01-02 14:18:09,518 [INFO] Iter=2, size=300, max=-22.0000, avg=-33.8467, min=-41.0000, std=3.3161


size: [287   3   8   1   1], pop_size: [52 81 52 66 51]


SlimeVolley: 2025-01-02 14:18:12,686 [INFO] Iter=3, size=300, max=-23.0000, avg=-33.7733, min=-41.0000, std=3.3875


size: [104  55  49  66  26], pop_size: [69 55 70 48 61]


SlimeVolley: 2025-01-02 14:18:15,892 [INFO] Iter=4, size=300, max=-21.0000, avg=-31.8433, min=-38.0000, std=3.1218


size: [175  32  48   7  38], pop_size: [62 70 57 53 61]


SlimeVolley: 2025-01-02 14:18:19,078 [INFO] Iter=5, size=300, max=-20.0000, avg=-32.5533, min=-41.0000, std=4.2865


size: [151  29  67  15  38], pop_size: [59 55 63 64 62]


SlimeVolley: 2025-01-02 14:18:22,283 [INFO] Iter=6, size=300, max=-16.0000, avg=-33.2133, min=-42.0000, std=4.0193


size: [131  45  49  63  12], pop_size: [61 67 60 60 54]


SlimeVolley: 2025-01-02 14:18:25,769 [INFO] Iter=7, size=300, max=-25.0000, avg=-33.9033, min=-43.0000, std=3.3328


size: [110  35  55  16  84], pop_size: [61 68 58 53 62]


SlimeVolley: 2025-01-02 14:18:29,062 [INFO] Iter=8, size=300, max=-20.0000, avg=-31.4667, min=-40.0000, std=3.5640


size: [86 80 59 27 48], pop_size: [66 62 59 53 62]


SlimeVolley: 2025-01-02 14:18:32,382 [INFO] Iter=9, size=300, max=-21.0000, avg=-32.5233, min=-39.0000, std=3.0913


size: [100  51  56  75  18], pop_size: [60 61 60 58 64]


SlimeVolley: 2025-01-02 14:18:35,637 [INFO] Iter=10, size=300, max=-17.0000, avg=-32.1800, min=-40.0000, std=3.5101


size: [77 25 57 91 50], pop_size: [64 62 58 62 56]


SlimeVolley: 2025-01-02 14:18:36,702 [INFO] [TEST] Iter=10, #tests=1, max=-5.0000, avg=-5.0000, min=-5.0000, std=0.0000
SlimeVolley: 2025-01-02 14:18:39,816 [INFO] Iter=11, size=300, max=-21.0000, avg=-32.1633, min=-41.0000, std=4.1821


size: [101  79   8  27  85], pop_size: [55 64 48 65 71]


SlimeVolley: 2025-01-02 14:18:43,021 [INFO] Iter=12, size=300, max=-19.0000, avg=-31.5867, min=-41.0000, std=3.8543


size: [71 57 59 47 66], pop_size: [61 58 57 71 55]


SlimeVolley: 2025-01-02 14:18:46,411 [INFO] Iter=13, size=300, max=-21.0000, avg=-32.4100, min=-39.0000, std=3.5490


size: [53 64 89 54 40], pop_size: [49 58 62 66 67]


SlimeVolley: 2025-01-02 14:18:49,617 [INFO] Iter=14, size=300, max=-22.0000, avg=-33.5300, min=-43.0000, std=3.1127


size: [87 46 56 76 35], pop_size: [62 55 60 59 66]


SlimeVolley: 2025-01-02 14:18:52,760 [INFO] Iter=15, size=300, max=-24.0000, avg=-33.5533, min=-41.0000, std=3.4400


size: [100  61  38  17  84], pop_size: [65 60 63 53 61]


SlimeVolley: 2025-01-02 14:18:56,123 [INFO] Iter=16, size=300, max=-22.0000, avg=-31.9733, min=-40.0000, std=3.3195


size: [ 36  59 116  56  33], pop_size: [61 60 59 56 67]


SlimeVolley: 2025-01-02 14:18:59,375 [INFO] Iter=17, size=300, max=-21.0000, avg=-33.0200, min=-44.0000, std=4.3688


size: [42 67 72 69 50], pop_size: [57 54 65 65 61]


SlimeVolley: 2025-01-02 14:19:02,595 [INFO] Iter=18, size=300, max=-24.0000, avg=-35.6333, min=-43.0000, std=3.3841


size: [54 53 59 72 62], pop_size: [66 56 55 64 61]


SlimeVolley: 2025-01-02 14:19:05,834 [INFO] Iter=19, size=300, max=-20.0000, avg=-31.4867, min=-40.0000, std=3.2634


size: [82 56 63 69 30], pop_size: [61 60 58 65 59]


SlimeVolley: 2025-01-02 14:19:08,980 [INFO] Iter=20, size=300, max=-22.0000, avg=-32.0800, min=-41.0000, std=3.7754
SlimeVolley: 2025-01-02 14:19:09,114 [INFO] [TEST] Iter=20, #tests=1, max=-5.0000, avg=-5.0000, min=-5.0000, std=0.0000


size: [55 66 56 63 60], pop_size: [64 64 62 52 60]


SlimeVolley: 2025-01-02 14:19:12,305 [INFO] Iter=21, size=300, max=-18.0000, avg=-31.5067, min=-39.0000, std=3.4501


size: [39 65 49 97 50], pop_size: [59 64 60 60 59]


SlimeVolley: 2025-01-02 14:19:15,554 [INFO] Iter=22, size=300, max=-20.0000, avg=-32.2067, min=-40.0000, std=3.3182


size: [74 59 64 66 37], pop_size: [57 62 56 63 64]


SlimeVolley: 2025-01-02 14:19:18,740 [INFO] Iter=23, size=300, max=-26.0000, avg=-34.0667, min=-44.0000, std=3.6980


size: [83 88 57  9 63], pop_size: [66 62 64 47 64]


SlimeVolley: 2025-01-02 14:19:21,929 [INFO] Iter=24, size=300, max=-20.0000, avg=-30.8700, min=-42.0000, std=3.3892


size: [68 75 68 43 46], pop_size: [60 60 61 58 63]


SlimeVolley: 2025-01-02 14:19:25,171 [INFO] Iter=25, size=300, max=-20.0000, avg=-31.7200, min=-45.0000, std=3.6499


size: [61 58 60 52 69], pop_size: [58 60 61 58 65]


SlimeVolley: 2025-01-02 14:19:28,459 [INFO] Iter=26, size=300, max=-22.0000, avg=-32.2767, min=-42.0000, std=3.7612


size: [62 57 62 56 63], pop_size: [50 65 68 56 63]


SlimeVolley: 2025-01-02 14:19:31,767 [INFO] Iter=27, size=300, max=-20.0000, avg=-30.7800, min=-39.0000, std=3.7280


size: [57 61 63 69 50], pop_size: [58 54 68 64 58]


SlimeVolley: 2025-01-02 14:19:34,923 [INFO] Iter=28, size=300, max=-23.0000, avg=-30.5933, min=-40.0000, std=3.6880


size: [56 68 66 52 58], pop_size: [56 66 59 64 59]


SlimeVolley: 2025-01-02 14:19:38,131 [INFO] Iter=29, size=300, max=-21.0000, avg=-30.0733, min=-39.0000, std=3.3894
SlimeVolley: 2025-01-02 14:19:38,266 [INFO] [TEST] Iter=30, #tests=1, max=-5.0000, avg=-5.0000, min=-5.0000, std=0.0000
SlimeVolley: 2025-01-02 14:19:38,269 [INFO] Training done, best_score=-5.0000


size: [57 66 63 59 55], pop_size: [64 67 59 55 57]
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    i2(["Input 2"]):::input
    i3(["Input 3"]):::input
    i4(["Input 4"]):::input
    i5(["Input 5"]):::input
    i6(["Input 6"]):::input
    i7(["Input 7"]):::input
    i8(["Input 8"]):::input
    i9(["Input 9"]):::input
    i10(["Input 10"]):::input
    i11(["Input 11"]):::input
    end
    subgraph Output
    o12(["Output 12"]):::output
    o13(["Output 13"]):::output
    o14(["Output 14"]):::output
    end
    subgraph Hidden
    h15(["Hidden 15"]):::hidden_node
    h16(["Hidden 16"]):::hidden_node
    end
    i0 -.->|0.00| o12
    i1 -.->|0.68| o12
    i2 -->|-0.96| o12
    i3 -->|0.71| o12
    i4 -->|-0.91| o12
    i5 -->|0.77| o12
    i6 -->|0.24| o12
    i7 -.->|0.36| o12
    i8 -->|0.11| o12
    i9 -->|-0.79| o12
    i10 -->|0.23| o12
    i11 -.->|-0.88| o12
    i0 -->|0.43| o13
    i1 -.->|0.60| o13
    i2 -->|-0.46| o13
    i3 -.->|0.62| o

SlimeVolley: 2025-01-02 14:20:04,142 [INFO] GIF saved to ./log/slimevolley_20250102-141756/slimevolley.gif.
