In [16]:
import argparse
from dataclasses import dataclass
import os
import shutil
import jax
import jax.numpy as jnp

from evojax.policy.base import PolicyNetwork, PolicyState
from evojax.task.base import TaskState
from evojax.task.slimevolley import SlimeVolley
from evojax.policy.mlp import MLPPolicy
from evojax.algo import CMA
from evojax import Trainer
from evojax import util

from flax import linen as nn

In [17]:
class NEATModel(nn.Module):
    input_dim: int
    output_dim: int
    max_nodes: int
    max_edges: int

    @nn.compact
    def __call__(self, x):
        nodes = self.param(
            "nodes", nn.initializers.zeros, (self.max_nodes, 3)
        )  # (id, activattion)
        edges = self.param(
            "edges", nn.initializers.zeros, (self.max_edges, 3)
        )  # (from, to, weight, enabled)
        values = jnp.zeros((self.max_nodes,))
        return jax.slice(values, -self.output_dim)


class NEATPolicy(PolicyNetwork):
    def __init__(self, input_dim: int, output_dim: int):
        self._input_dim = input_dim
        self._output_dim = output_dim
        # TODO: Implement NEAT policy.
        self.num_params = 1
        model = NEATModel(input_dim, output_dim, 100, 100)
        params = model.init(jax.random.PRNGKey(0), jnp.ones([1, input_dim]))
        print(params)
        self.num_params, format_params_fn = util.get_params_format_fn(params)
        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(model.apply)

    def get_actions(
        self, t_states, params, p_states: PolicyState
    ) -> tuple[jax.Array, PolicyState]:
        return self._forward_fn(params, t_states.obs), p_states

In [18]:
@jax.tree_util.register_dataclass
@dataclass
class Genome:
    """Stores node and connection info in JAX arrays."""

    node_ids: jax.Array  # shape [num_nodes]
    # 0: relu
    node_activation: jax.Array  # shape [num_nodes]

    # Connections.
    conn_id: jax.Array  # shape [num_connections]
    conn_in: jax.Array  # shape [num_connections]
    conn_out: jax.Array  # shape [num_connections]
    conn_weights: jax.Array  # shape [num_connections]
    conn_enabled: jax.Array  # shape [num_connections]


@dataclass(frozen=True)
class Config:
    """Stores static info for the NEAT algorithm."""

    input_dim: int
    output_dim: int
    max_nodes: int
    max_edges: int
    max_depth: int


config = Config(input_dim=2, output_dim=1, max_nodes=10, max_edges=20, max_depth=10)

In [19]:
node_ids = jnp.array([0, 1, 2, 3, 4])
node_activation = jnp.array([0, 0, 0, 0, 2])

# 0 1
# | |
# 2 3
#  \|
#   4
conn_id = jnp.array([0, 1, 2, 3])
conn_in = jnp.array([0, 1, 2, 3])
conn_out = jnp.array([2, 3, 4, 4])
conn_weights = jnp.array([1.0, 1.0, 1.0, 1.0])
conn_enabled = jnp.array([True, True, True, True])


def create_genome(
    config: Config,
    node_ids,
    node_activation,
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
) -> Genome:
    node_pad_config = [(0, config.max_nodes - node_ids.shape[0], 0)]
    node_ids = jax.lax.pad(node_ids, -1, node_pad_config)
    node_activation = jax.lax.pad(node_activation, -1, node_pad_config)

    conn_pad_config = [(0, config.max_edges - conn_id.shape[0], 0)]
    conn_id = jax.lax.pad(conn_id, -1, conn_pad_config)
    conn_in = jax.lax.pad(conn_in, -1, conn_pad_config)
    conn_out = jax.lax.pad(conn_out, -1, conn_pad_config)
    conn_weights = jax.lax.pad(conn_weights, jnp.nan, conn_pad_config)
    conn_enabled = jax.lax.pad(conn_enabled, False, conn_pad_config)
    return Genome(
        node_ids,
        node_activation,
        conn_id,
        conn_in,
        conn_out,
        conn_weights,
        conn_enabled,
    )


genome = create_genome(
    config,
    node_ids,
    node_activation,
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
)
genome = jax.tree.map(lambda x: jnp.array([x]), genome)
genome

Genome(node_ids=Array([[ 0,  1,  2,  3,  4, -1, -1, -1, -1, -1]], dtype=int32), node_activation=Array([[ 0,  0,  0,  0,  2, -1, -1, -1, -1, -1]], dtype=int32), conn_id=Array([[ 0,  1,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_in=Array([[ 0,  1,  2,  3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_out=Array([[ 2,  3,  4,  4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1]], dtype=int32), conn_weights=Array([[ 1.,  1.,  1.,  1., nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan]], dtype=float32), conn_enabled=Array([[ True,  True,  True,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False]], dtype=bool))

In [20]:
def get_adjacency_matrix_fn(config: Config, genome: Genome):
    matrix = jnp.zeros((config.max_nodes, config.max_nodes))
    matrix = matrix + jnp.pad(
        jnp.identity(config.input_dim),
        (
            (0, config.max_nodes - config.input_dim),
            (0, config.max_nodes - config.input_dim),
        ),
    )
    matrix = matrix.at[genome.conn_out, genome.conn_in].set(
        jnp.where(genome.conn_enabled, genome.conn_weights, 0.0)
    )
    return matrix


get_adjacency_matrix = jax.jit(
    jax.vmap(get_adjacency_matrix_fn, in_axes=(None, 0)), static_argnums=(0)
)

In [21]:
display(get_adjacency_matrix)

<PjitFunction of <function get_adjacency_matrix_fn at 0x305edd3a0>>

In [22]:
adj = get_adjacency_matrix(config, genome)
display(adj)


def matmul_fn(adj, x):
    return jnp.matmul(adj, x)


matmul = jax.vmap(matmul_fn)
x = jnp.stack([jnp.pad(jnp.array([1, 1, 0, 0, 0]), (0, 5))])
display(matmul(adj, x))
display(matmul(adj, matmul(adj, x)))
display(matmul(adj, matmul(adj, matmul(adj, x))))

DEBUG:2024-12-31 11:28:56,326:jax._src.dispatch:182: Finished tracing + transforming get_adjacency_matrix_fn for pjit in 0.002624989 sec
jax._src.dispatch: 2024-12-31 11:28:56,326 [DEBUG] Finished tracing + transforming get_adjacency_matrix_fn for pjit in 0.002624989 sec
DEBUG:2024-12-31 11:28:56,327:jax._src.interpreters.pxla:1906: Compiling get_adjacency_matrix_fn with global shapes and types [ShapedArray(int32[1,20]), ShapedArray(int32[1,20]), ShapedArray(float32[1,20]), ShapedArray(bool[1,20])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
jax._src.interpreters.pxla: 2024-12-31 11:28:56,327 [DEBUG] Compiling get_adjacency_matrix_fn with global shapes and types [ShapedArray(int32[1,20]), ShapedArray(int32[1,20]), ShapedArray(float32[1,20]), ShapedArray(bool[1,20])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
DEBUG:2024-12-31 11:28:56,332:jax._src.dispatch:182: Finished jaxpr to MLIR modul

Array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32)

Array([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

In [23]:
# TODO: Check NEAT-python for activation functions.
ACTIVATION_FUNCTIONS = [lambda x: x, jax.nn.relu, jax.nn.sigmoid, jax.nn.tanh]


def forward_fn(config: Config, genome: Genome, x: jax.Array):
    print(genome.node_ids.shape)
    adj = get_adjacency_matrix_fn(config, genome)
    for _ in range(config.max_depth):
        x = matmul_fn(adj, x)

        print(adj.shape)
        print(x.shape)
        print(genome.node_activation.shape)

        # Apply the activation function.
        def activate_fn(activation, x):
            return jax.lax.switch(activation, ACTIVATION_FUNCTIONS, x)

        activate = jax.vmap(activate_fn)
        x = activate(genome.node_activation, x)
    return x


forward = jax.jit(jax.vmap(forward_fn, in_axes=(None, 0, 0)), static_argnums=(0))
# forward = jax.vmap(forward_fn, in_axes=(None, 0, 0))

forward(config, genome, x)

DEBUG:2024-12-31 11:28:56,615:jax._src.dispatch:182: Finished tracing + transforming forward_fn for pjit in 0.034103870 sec
jax._src.dispatch: 2024-12-31 11:28:56,615 [DEBUG] Finished tracing + transforming forward_fn for pjit in 0.034103870 sec
DEBUG:2024-12-31 11:28:56,618:jax._src.interpreters.pxla:1906: Compiling forward_fn with global shapes and types [ShapedArray(int32[1,10]), ShapedArray(int32[1,20]), ShapedArray(int32[1,20]), ShapedArray(float32[1,20]), ShapedArray(bool[1,20]), ShapedArray(int32[1,10])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
jax._src.interpreters.pxla: 2024-12-31 11:28:56,618 [DEBUG] Compiling forward_fn with global shapes and types [ShapedArray(int32[1,10]), ShapedArray(int32[1,20]), ShapedArray(int32[1,20]), ShapedArray(float32[1,20]), ShapedArray(bool[1,20]), ShapedArray(int32[1,10])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedV

(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)
(10, 10)
(10,)
(10,)


Array([[1.      , 1.      , 1.      , 1.      , 0.880797, 0.      ,
        0.      , 0.      , 0.      , 0.      ]], dtype=float32)

In [24]:
genome2 = create_genome(
    config,
    node_ids,
    jnp.array([1, 1, 1, 1, 1]),
    conn_id,
    conn_in,
    conn_out,
    conn_weights,
    conn_enabled,
)
genome2 = jax.tree.map(lambda x: jnp.array([x]), genome2)

forward(config, genome2, x)

Array([[1., 1., 1., 1., 2., 0., 0., 0., 0., 0.]], dtype=float32)

In [25]:
import neat

In [26]:
from evojax.task.base import TaskState

In [27]:
def pad(x):
    return jnp.pad(x, (0, 10))


pad = jax.vmap(pad)
pad(jnp.arange(10)[None, :])

Array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],      dtype=int32)

In [28]:
idx = jnp.array([0, 1, 2, 3, 4])
enabled = jnp.array([True, False, True, False, True])


@jax.jit
def chose_enabled(enabled):
    idx = jnp.arange(enabled.shape[0])
    filtered = jnp.sort(jnp.where(enabled, idx, -1), descending=True)
    chosen = jax.random.randint(jax.random.PRNGKey(1), 10, 0, enabled.sum())
    return filtered[chosen]


chose_enabled(enabled)

DEBUG:2024-12-31 11:28:56,768:jax._src.dispatch:182: Finished tracing + transforming chose_enabled for pjit in 0.001190186 sec
jax._src.dispatch: 2024-12-31 11:28:56,768 [DEBUG] Finished tracing + transforming chose_enabled for pjit in 0.001190186 sec
DEBUG:2024-12-31 11:28:56,769:jax._src.interpreters.pxla:1906: Compiling chose_enabled with global shapes and types [ShapedArray(bool[5])]. Argument mapping: (UnspecifiedValue,).
jax._src.interpreters.pxla: 2024-12-31 11:28:56,769 [DEBUG] Compiling chose_enabled with global shapes and types [ShapedArray(bool[5])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-31 11:28:56,790:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(chose_enabled) in 0.020753145 sec
jax._src.dispatch: 2024-12-31 11:28:56,790 [DEBUG] Finished jaxpr to MLIR module conversion jit(chose_enabled) in 0.020753145 sec
DEBUG:2024-12-31 11:28:56,792:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDe

Array([4, 4, 2, 4, 4, 4, 0, 2, 0, 0], dtype=int32)

In [30]:
def genome_to_mermaid(genome: Genome, config: Config, show_disabled: bool) -> str:
    """Convert genome to mermaid graph definition."""
    mermaid = [
        "graph TD;",
    ]

    # Add nodes with styling
    def node_type(node_id):
        if node_id < config.input_dim:
            return "i"
        elif node_id < config.input_dim + config.output_dim:
            return "o"
        else:
            return "h"

    for node_id in genome.node_ids:
        if node_id >= 0:
            type = node_type(node_id)
            if node_id == 0:
                mermaid.append(f'    subgraph Input')
            elif node_id == config.input_dim:
                mermaid.append(f'    end')
                mermaid.append(f'    subgraph Output')
            elif node_id == config.input_dim + config.output_dim:
                mermaid.append(f'    end')
                mermaid.append(f'    subgraph Hidden')
            if type == "i":
                mermaid.append(f'    {type}{node_id}(["Input {node_id}"]):::input')
            elif type == "o":
                mermaid.append(f'    {type}{node_id}(["Output {node_id}"]):::output')
            else:
                mermaid.append(f'    {type}{node_id}(["Hidden {node_id}"]):::hidden')
    mermaid.append(f'    end')

    # Add connections
    for id, src, dst, w, enabled in zip(
        genome.conn_ids,
        genome.conn_in,
        genome.conn_out,
        genome.conn_weights,
        genome.conn_enabled,
    ):
        if id >= 0 and (show_disabled or enabled):
            src_prefix = node_type(src)
            dst_prefix = node_type(dst)
            weight = f"{w:.2f}"
            arrow = "-->" if enabled else "-.->"
            mermaid.append(f"    {src_prefix}{src} {arrow}|{weight}| {dst_prefix}{dst}")

    # Add styles
    mermaid.extend(
        [
            "    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;",
            "    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;",
            "    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;",
            "    linkStyle default stroke:#E5E7EB,stroke-width:2px;",
        ]
    )

    return "\n".join(mermaid)


def display_genome(genome: Genome, config: Config, show_disabled: bool = False):
    from IPython.display import display, Markdown
    mermaid = genome_to_mermaid(genome, config, show_disabled)
    print(mermaid)
    display(Markdown(f"```mermaid\n{mermaid}\n```"))

display_genome(genome, config)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [99]:
genome = create_empty_genome_fn(config)
x = mutate_add_node(genome, jax.random.PRNGKey(0), 30)[0]
print(x)
display_genome(x, config, show_disabled=True)

TypeError: mutate_add_node() missing 1 required positional argument: 'next_edge_id'

In [32]:
%run neat.py

In [33]:
jnp.argmin(jnp.array([1, 2, 3, 3, 1, 1]))


DEBUG:2024-12-31 11:29:43,733:jax._src.dispatch:182: Finished tracing + transforming convert_element_type for pjit in 0.000128031 sec
jax._src.dispatch: 2024-12-31 11:29:43,733 [DEBUG] Finished tracing + transforming convert_element_type for pjit in 0.000128031 sec
DEBUG:2024-12-31 11:29:43,734:jax._src.interpreters.pxla:1906: Compiling convert_element_type with global shapes and types [ShapedArray(int32[6])]. Argument mapping: (UnspecifiedValue,).
jax._src.interpreters.pxla: 2024-12-31 11:29:43,734 [DEBUG] Compiling convert_element_type with global shapes and types [ShapedArray(int32[6])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-31 11:29:43,737:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001785755 sec
jax._src.dispatch: 2024-12-31 11:29:43,737 [DEBUG] Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001785755 sec
DEBUG:2024-12-31 11:29:43,738:jax._src.compiler:167: get_compile_options: num_replicas=

Array(0, dtype=int32)

In [34]:
max_conns = genome.conn_ids.shape[0]
jnp.min(jnp.where(genome.conn_ids < 0, jnp.arange(max_conns), max_conns))
genome.conn_ids

DEBUG:2024-12-31 11:29:47,306:jax._src.interpreters.pxla:1906: Compiling less with global shapes and types [ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
jax._src.interpreters.pxla: 2024-12-31 11:29:47,306 [DEBUG] Compiling less with global shapes and types [ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
DEBUG:2024-12-31 11:29:47,310:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(less) in 0.002184153 sec
jax._src.dispatch: 2024-12-31 11:29:47,310 [DEBUG] Finished jaxpr to MLIR module conversion jit(less) in 0.002184153 sec
DEBUG:2024-12-31 11:29:47,311:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
jax._src.compiler: 2024-12-31 11:29:47,311 [DEBUG] get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:2024-12-31 11

Array([ 0,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1], dtype=int32)

In [35]:
key, key_src, key_dst, key_weight = jax.random.split(key, 4)

node_ids = genome.node_ids
node_activation = genome.node_activation
conn_ids = genome.conn_ids
conn_in = genome.conn_in
conn_out = genome.conn_out
conn_weights = genome.conn_weights
conn_enabled = genome.conn_enabled

src = pick_one(key_src, node_ids >= 0)
dst = pick_one(key_dst, node_ids >= 0)
(src, dest)
src == dst


NameError: name 'key' is not defined

In [36]:
%run neat.py

In [37]:
%run neat.py
genome = create_empty_genome_fn(config)
(genome, edge) = mutate_add_node(genome, jax.random.PRNGKey(0), 30)
(genome, edge) = mutate_add_connection(config, genome, jax.random.PRNGKey(5), edge)
(genome, edge) = mutate_add_node(genome, jax.random.PRNGKey(6), edge)
(genome, edge) = mutate_add_connection(config, genome, jax.random.PRNGKey(7), edge)
print(genome, edge)
display_genome(genome, config, show_disabled=True)

DEBUG:2024-12-31 11:29:52,589:jax._src.dispatch:182: Finished tracing + transforming pick_one for pjit in 0.001029253 sec
jax._src.dispatch: 2024-12-31 11:29:52,589 [DEBUG] Finished tracing + transforming pick_one for pjit in 0.001029253 sec
DEBUG:2024-12-31 11:29:52,590:jax._src.interpreters.pxla:1906: Compiling pick_one with global shapes and types [ShapedArray(uint32[2]), ShapedArray(bool[20])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
jax._src.interpreters.pxla: 2024-12-31 11:29:52,590 [DEBUG] Compiling pick_one with global shapes and types [ShapedArray(uint32[2]), ShapedArray(bool[20])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
DEBUG:2024-12-31 11:29:52,614:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(pick_one) in 0.023218870 sec
jax._src.dispatch: 2024-12-31 11:29:52,614 [DEBUG] Finished jaxpr to MLIR module conversion jit(pick_one) in 0.023218870 sec
DEBUG:2024-12-31 11:29:52,615:jax._src.compiler:167: get_compile_options: num

src: 1, dst: 2
valid: False
conn_in: [ 0  1  1  3  1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1], 
conn_out: [ 2  2  3  2  2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
[1 1 0 0 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
depth: 2
conn_in: [ 0  1  1  3  1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1], 
conn_out: [ 2  2  3  2  2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
[1 1 0 0 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
[1 1 2 2 0 0 0 0 0 0]
valid: False
src: 1, dst: 2
valid: False
conn_in: [ 0  1  1  3  0  4  1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1], 
conn_out: [ 2  2  3  2  4  2  2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
[1 

```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    end
    i0 -.->|1.00| o2
    i1 -.->|1.00| o2
    i1 -->|1.00| h3
    h3 -->|1.00| o2
    i0 -->|1.00| h4
    h4 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

In [38]:
jax.lax.switch(jnp.array(0), [lambda: jnp.array([1, 2, 3]), lambda: jnp.array([4, 5, 6])])

DEBUG:2024-12-31 11:29:53,201:jax._src.dispatch:182: Finished tracing + transforming convert_element_type for pjit in 0.000122309 sec
jax._src.dispatch: 2024-12-31 11:29:53,201 [DEBUG] Finished tracing + transforming convert_element_type for pjit in 0.000122309 sec
DEBUG:2024-12-31 11:29:53,203:jax._src.interpreters.pxla:1906: Compiling convert_element_type with global shapes and types [ShapedArray(int32[])]. Argument mapping: (UnspecifiedValue,).
jax._src.interpreters.pxla: 2024-12-31 11:29:53,203 [DEBUG] Compiling convert_element_type with global shapes and types [ShapedArray(int32[])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2024-12-31 11:29:53,208:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.003187895 sec
jax._src.dispatch: 2024-12-31 11:29:53,208 [DEBUG] Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.003187895 sec
DEBUG:2024-12-31 11:29:53,209:jax._src.compiler:167: get_compile_options: num_replicas=1 

Array([1, 2, 3], dtype=int32)

In [101]:
%run neat.py

config = Config(input_dim=2, output_dim=1, max_nodes=10, max_edges=20, max_depth=10)
neat_config = NEATConfig(config, pop_size=30, prob_add_connection=0.25, prob_add_node=0.5)

genome = create_empty_genome_fn(config)
edge = jnp.array(30)
key = jax.random.key(10)
for _ in range(20):
    key, key_mutate = jax.random.split(key, 2)
    genome, edge = mutate_genome(neat_config, genome, key_mutate, edge)
    display_genome(genome, config)

# TypeError: branch 0 and 2 outputs must have same type structure, got
# PyTreeDef((CustomNode(Genome[()], [*, *, *, *, *, *, *]), *)) and
# PyTreeDef((CustomNode(Genome[()], [*, *, *, *, *, *, *]), *)).


DEBUG:2024-12-31 12:35:06,297:jax._src.dispatch:182: Finished tracing + transforming pick_one for pjit in 0.027107000 sec
jax._src.dispatch: 2024-12-31 12:35:06,297 [DEBUG] Finished tracing + transforming pick_one for pjit in 0.027107000 sec
DEBUG:2024-12-31 12:35:06,326:jax._src.dispatch:182: Finished tracing + transforming pick_one for pjit in 0.001321077 sec
jax._src.dispatch: 2024-12-31 12:35:06,326 [DEBUG] Finished tracing + transforming pick_one for pjit in 0.001321077 sec
DEBUG:2024-12-31 12:35:06,369:jax._src.dispatch:182: Finished tracing + transforming mutate_add_connection for pjit in 0.046633244 sec
jax._src.dispatch: 2024-12-31 12:35:06,369 [DEBUG] Finished tracing + transforming mutate_add_connection for pjit in 0.046633244 sec
DEBUG:2024-12-31 12:35:06,373:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000210047 sec
jax._src.dispatch: 2024-12-31 12:35:06,373 [DEBUG] Finished tracing + transforming cond for pjit in 0.000210047 sec
DEBUG:2024-12-

p: 0.4880664348602295, index: 0


DEBUG:2024-12-31 12:35:06,471:jax._src.cache_key:152: get_cache_key hash of serialized computation: f2958dc6b6fa6c9a426f5c7df791522eae61ff4a0bace6827c12f61b7d0a7104
jax._src.cache_key: 2024-12-31 12:35:06,471 [DEBUG] get_cache_key hash of serialized computation: f2958dc6b6fa6c9a426f5c7df791522eae61ff4a0bace6827c12f61b7d0a7104
DEBUG:2024-12-31 12:35:06,472:jax._src.cache_key:158: get_cache_key hash after serializing computation: f2958dc6b6fa6c9a426f5c7df791522eae61ff4a0bace6827c12f61b7d0a7104
jax._src.cache_key: 2024-12-31 12:35:06,472 [DEBUG] get_cache_key hash after serializing computation: f2958dc6b6fa6c9a426f5c7df791522eae61ff4a0bace6827c12f61b7d0a7104
DEBUG:2024-12-31 12:35:06,472:jax._src.cache_key:152: get_cache_key hash of serialized jax_lib version: c8601d1831072872293c1f9c58282e40273dd0289eaea98e369c2037dc4231ae
jax._src.cache_key: 2024-12-31 12:35:06,472 [DEBUG] get_cache_key hash of serialized jax_lib version: c8601d1831072872293c1f9c58282e40273dd0289eaea98e369c2037dc4231ae


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| h3
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| h3
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:07,833:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000216007 sec
jax._src.dispatch: 2024-12-31 12:35:07,833 [DEBUG] Finished tracing + transforming cond for pjit in 0.000216007 sec
DEBUG:2024-12-31 12:35:07,838:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.8386558294296265, index: 2


DEBUG:2024-12-31 12:35:08,260:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:08,260 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:08,261:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.335956097 sec
jax._src.dispatch: 2024-12-31 12:35:08,261 [DEBUG] Finished XLA compilation of jit(cond) in 0.335956097 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| h3
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| h3
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:08,279:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000251293 sec
jax._src.dispatch: 2024-12-31 12:35:08,279 [DEBUG] Finished tracing + transforming cond for pjit in 0.000251293 sec
DEBUG:2024-12-31 12:35:08,283:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.7360057830810547, index: 1


DEBUG:2024-12-31 12:35:08,726:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:08,726 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:08,727:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.366386890 sec
jax._src.dispatch: 2024-12-31 12:35:08,727 [DEBUG] Finished XLA compilation of jit(cond) in 0.366386890 sec


src: 3, dst: 2
valid: False
depth: 3
valid: False
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:08,750:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000400782 sec
jax._src.dispatch: 2024-12-31 12:35:08,750 [DEBUG] Finished tracing + transforming cond for pjit in 0.000400782 sec
DEBUG:2024-12-31 12:35:08,753:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.378238320350647, index: 0


DEBUG:2024-12-31 12:35:09,180:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:09,180 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:09,180:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.338096142 sec
jax._src.dispatch: 2024-12-31 12:35:09,180 [DEBUG] Finished XLA compilation of jit(cond) in 0.338096142 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:09,200:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000292063 sec
jax._src.dispatch: 2024-12-31 12:35:09,200 [DEBUG] Finished tracing + transforming cond for pjit in 0.000292063 sec
DEBUG:2024-12-31 12:35:09,203:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.37176311016082764, index: 0


DEBUG:2024-12-31 12:35:09,613:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:09,613 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:09,614:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.340053797 sec
jax._src.dispatch: 2024-12-31 12:35:09,614 [DEBUG] Finished XLA compilation of jit(cond) in 0.340053797 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:09,630:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000358105 sec
jax._src.dispatch: 2024-12-31 12:35:09,630 [DEBUG] Finished tracing + transforming cond for pjit in 0.000358105 sec
DEBUG:2024-12-31 12:35:09,633:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.6920064687728882, index: 1


DEBUG:2024-12-31 12:35:10,036:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:10,036 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:10,037:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.331847191 sec
jax._src.dispatch: 2024-12-31 12:35:10,037 [DEBUG] Finished XLA compilation of jit(cond) in 0.331847191 sec


src: 0, dst: 4
valid: True
depth: 3
valid: True
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:10,056:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000226974 sec
jax._src.dispatch: 2024-12-31 12:35:10,056 [DEBUG] Finished tracing + transforming cond for pjit in 0.000226974 sec
DEBUG:2024-12-31 12:35:10,059:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.8567180633544922, index: 2


DEBUG:2024-12-31 12:35:10,502:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:10,502 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:10,503:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.336244822 sec
jax._src.dispatch: 2024-12-31 12:35:10,503 [DEBUG] Finished XLA compilation of jit(cond) in 0.336244822 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h3 -->|1.00| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:10,527:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000317097 sec
jax._src.dispatch: 2024-12-31 12:35:10,527 [DEBUG] Finished tracing + transforming cond for pjit in 0.000317097 sec
DEBUG:2024-12-31 12:35:10,530:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.604138970375061, index: 1


DEBUG:2024-12-31 12:35:10,938:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:10,938 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:10,939:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.333029270 sec
jax._src.dispatch: 2024-12-31 12:35:10,939 [DEBUG] Finished XLA compilation of jit(cond) in 0.333029270 sec


src: 4, dst: 4
valid: False
depth: -1
valid: False
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h3 -->|0.25| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|0.64| h4
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h3 -->|0.25| h4
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|0.64| h4
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:10,956:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000937223 sec
jax._src.dispatch: 2024-12-31 12:35:10,956 [DEBUG] Finished tracing + transforming cond for pjit in 0.000937223 sec
DEBUG:2024-12-31 12:35:10,959:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.14337170124053955, index: 0


DEBUG:2024-12-31 12:35:11,367:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:11,367 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:11,367:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.333714008 sec
jax._src.dispatch: 2024-12-31 12:35:11,367 [DEBUG] Finished XLA compilation of jit(cond) in 0.333714008 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|1.00| h6
    h6 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|1.00| h6
    h6 -->|1.00| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:11,387:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000272751 sec
jax._src.dispatch: 2024-12-31 12:35:11,387 [DEBUG] Finished tracing + transforming cond for pjit in 0.000272751 sec
DEBUG:2024-12-31 12:35:11,391:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.807952880859375, index: 2


DEBUG:2024-12-31 12:35:12,099:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:12,099 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:12,100:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.636163950 sec
jax._src.dispatch: 2024-12-31 12:35:12,100 [DEBUG] Finished XLA compilation of jit(cond) in 0.636163950 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|1.00| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| o2
    i0 -->|1.00| h6
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:12,119:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000213861 sec
jax._src.dispatch: 2024-12-31 12:35:12,119 [DEBUG] Finished tracing + transforming cond for pjit in 0.000213861 sec
DEBUG:2024-12-31 12:35:12,124:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.30236244201660156, index: 0


DEBUG:2024-12-31 12:35:12,534:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:12,534 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:12,535:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.331507921 sec
jax._src.dispatch: 2024-12-31 12:35:12,535 [DEBUG] Finished XLA compilation of jit(cond) in 0.331507921 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    i0 -->|1.00| h6
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    i0 -->|1.00| h6
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:12,552:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000216007 sec
jax._src.dispatch: 2024-12-31 12:35:12,552 [DEBUG] Finished tracing + transforming cond for pjit in 0.000216007 sec
DEBUG:2024-12-31 12:35:12,556:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.8239665031433105, index: 2


DEBUG:2024-12-31 12:35:12,978:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:12,978 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:12,979:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.331182003 sec
jax._src.dispatch: 2024-12-31 12:35:12,979 [DEBUG] Finished XLA compilation of jit(cond) in 0.331182003 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    i0 -->|1.00| h6
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    i0 -->|1.00| h6
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:12,997:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000221014 sec
jax._src.dispatch: 2024-12-31 12:35:12,997 [DEBUG] Finished tracing + transforming cond for pjit in 0.000221014 sec
DEBUG:2024-12-31 12:35:13,001:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.5250184535980225, index: 1


DEBUG:2024-12-31 12:35:13,397:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:13,397 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:13,398:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.321543932 sec
jax._src.dispatch: 2024-12-31 12:35:13,398 [DEBUG] Finished XLA compilation of jit(cond) in 0.321543932 sec


src: 0, dst: 3
valid: True
depth: 4
valid: True
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    i0 -->|0.37| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|1.00| h7
    h7 -->|-0.98| o2
    i0 -->|0.37| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:13,417:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000242949 sec
jax._src.dispatch: 2024-12-31 12:35:13,417 [DEBUG] Finished tracing + transforming cond for pjit in 0.000242949 sec
DEBUG:2024-12-31 12:35:13,419:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.4073692560195923, index: 0


DEBUG:2024-12-31 12:35:13,827:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:13,827 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:13,828:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.336697102 sec
jax._src.dispatch: 2024-12-31 12:35:13,828 [DEBUG] Finished XLA compilation of jit(cond) in 0.336697102 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    i0 -->|0.64| h3
    h5 -->|1.00| h8
    h8 -->|1.00| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    i0 -->|0.64| h3
    h5 -->|1.00| h8
    h8 -->|1.00| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:13,846:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000872135 sec
jax._src.dispatch: 2024-12-31 12:35:13,846 [DEBUG] Finished tracing + transforming cond for pjit in 0.000872135 sec
DEBUG:2024-12-31 12:35:13,849:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.9054309129714966, index: 2


DEBUG:2024-12-31 12:35:14,253:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:14,253 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:14,254:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.323952198 sec
jax._src.dispatch: 2024-12-31 12:35:14,254 [DEBUG] Finished XLA compilation of jit(cond) in 0.323952198 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    i0 -->|0.64| h3
    h5 -->|1.00| h8
    h8 -->|1.00| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    end
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    i0 -->|0.64| h3
    h5 -->|1.00| h8
    h8 -->|1.00| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:14,275:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000255108 sec
jax._src.dispatch: 2024-12-31 12:35:14,275 [DEBUG] Finished tracing + transforming cond for pjit in 0.000255108 sec
DEBUG:2024-12-31 12:35:14,278:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.04066503047943115, index: 0


DEBUG:2024-12-31 12:35:14,696:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:14,696 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:14,697:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.334053040 sec
jax._src.dispatch: 2024-12-31 12:35:14,697 [DEBUG] Finished XLA compilation of jit(cond) in 0.334053040 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|1.00| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h7 -->|-0.98| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|1.00| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:14,731:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000307798 sec
jax._src.dispatch: 2024-12-31 12:35:14,731 [DEBUG] Finished tracing + transforming cond for pjit in 0.000307798 sec
DEBUG:2024-12-31 12:35:14,844:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.3087378740310669, index: 0


DEBUG:2024-12-31 12:35:15,048:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(cond) in 0.173150063 sec
jax._src.dispatch: 2024-12-31 12:35:15,048 [DEBUG] Finished jaxpr to MLIR module conversion jit(cond) in 0.173150063 sec
DEBUG:2024-12-31 12:35:15,049:jax._src.compiler:167: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
jax._src.compiler: 2024-12-31 12:35:15,049 [DEBUG] get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:2024-12-31 12:35:15,050:jax._src.compiler:239: get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
jax._src.compiler: 2024-12-31 12:35:15,050 [DEBUG] get_compile_options XLA-AutoFDO profile: using XLA-AutoFDO profile version -1
DEBUG:2024-12-31 12:35:15,054:jax._src.cache_key:152: get_cache_key hash of serialized computation: cb665437dfb0e3b0f7f2024b948dc0f59ce153f1455c634c38dc30f03f6c231e
jax._src.cache_key: 2024-12-31 12:35:15,054 [

graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h7 -->|-0.98| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h7 -->|-0.98| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:15,410:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000422001 sec
jax._src.dispatch: 2024-12-31 12:35:15,410 [DEBUG] Finished tracing + transforming cond for pjit in 0.000422001 sec
DEBUG:2024-12-31 12:35:15,414:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.1653897762298584, index: 0


DEBUG:2024-12-31 12:35:15,831:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:15,831 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:15,832:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.341335058 sec
jax._src.dispatch: 2024-12-31 12:35:15,832 [DEBUG] Finished XLA compilation of jit(cond) in 0.341335058 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    h9 -->|0.64| h3
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:15,851:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000200987 sec
jax._src.dispatch: 2024-12-31 12:35:15,851 [DEBUG] Finished tracing + transforming cond for pjit in 0.000200987 sec
DEBUG:2024-12-31 12:35:15,856:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.06252622604370117, index: 0


DEBUG:2024-12-31 12:35:16,304:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:16,304 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:16,305:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.364853859 sec
jax._src.dispatch: 2024-12-31 12:35:16,305 [DEBUG] Finished XLA compilation of jit(cond) in 0.364853859 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h5 -->|0.56| h7
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    h5 -->|0.56| h7
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|0.03| h7
    i0 -->|-0.95| h9
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

DEBUG:2024-12-31 12:35:16,326:jax._src.dispatch:182: Finished tracing + transforming cond for pjit in 0.000211000 sec
jax._src.dispatch: 2024-12-31 12:35:16,326 [DEBUG] Finished tracing + transforming cond for pjit in 0.000211000 sec
DEBUG:2024-12-31 12:35:16,330:jax._src.interpreters.pxla:1906: Compiling cond with global shapes and types [ShapedArray(int32[]), ShapedArray(key<fry>[]), ShapedArray(bool[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[], weak_type=True), ShapedArray(bool[2]), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), ShapedArray(bool[20]), ShapedArray(key<fry>[]), ShapedArray(int32[], weak_type=True), ShapedArray(int32[10]), ShapedArray(int32[10]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(int32[20]), ShapedArray(float32[20]), Shap

p: 0.4093972444534302, index: 0


DEBUG:2024-12-31 12:35:16,743:jax._src.compiler:715: Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
jax._src.compiler: 2024-12-31 12:35:16,743 [DEBUG] Not writing persistent cache entry for 'jit_cond' because it uses host callbacks (e.g. from jax.debug.print or breakpoint)
DEBUG:2024-12-31 12:35:16,744:jax._src.dispatch:182: Finished XLA compilation of jit(cond) in 0.325572014 sec
jax._src.dispatch: 2024-12-31 12:35:16,744 [DEBUG] Finished XLA compilation of jit(cond) in 0.325572014 sec


graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    i0 -->|0.64| h4
    h5 -->|0.56| h7
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|-0.99| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;


```mermaid
graph TD;
    subgraph Input
    i0(["Input 0"]):::input
    i1(["Input 1"]):::input
    end
    subgraph Output
    o2(["Output 2"]):::output
    end
    subgraph Hidden
    h3(["Hidden 3"]):::hidden
    h4(["Hidden 4"]):::hidden
    h5(["Hidden 5"]):::hidden
    h6(["Hidden 6"]):::hidden
    h7(["Hidden 7"]):::hidden
    h8(["Hidden 8"]):::hidden
    h9(["Hidden 9"]):::hidden
    end
    i0 -->|1.00| o2
    i1 -->|-0.32| h3
    h4 -->|0.18| o2
    h5 -->|-0.20| o2
    i0 -->|0.64| h4
    h5 -->|0.56| h7
    h7 -->|0.01| o2
    h5 -->|-0.99| h8
    h8 -->|-0.99| h7
    classDef input fill:#61DAFB,stroke:#333,stroke-width:2px;
    classDef output fill:#4EC9B0,stroke:#333,stroke-width:2px;
    classDef hidden fill:#9CA3AF,stroke:#333,stroke-width:2px;
    linkStyle default stroke:#E5E7EB,stroke-width:2px;
```

In [81]:
a[1].children()[1].num_leaves

1