In [2]:
from fineNeat import Ind 
import jax.numpy as jnp
import numpy as np 

ind = Ind.from_shapes([[3,2], [2, 2]])
ind.express()

# numpy node & conn 
node = ind.node
conn = ind.conn
# jax node & conn
node_jax = jnp.array(node.tolist())
conn_jax = jnp.array(conn.tolist())
# util function
def check_equal(a: np.array, b: jnp.array):
    return np.isclose(a, b.tolist()).all()

ann functoinal checks

In [9]:
# Test ann jax implementation 
from fineNeat.sneat_jax.ann import getNodeOrder as getNodeOrder_jax
from fineNeat.sneat_jax.ann import getLayer as getLayer_jax
from fineNeat.sneat_jax.ann import getNodeKey as getNodeKey_jax
from fineNeat.sneat_jax.ann import getNodeMap as getNodeMap_jax
from fineNeat.sneat_jax.ann import getNodeInfo as getNodeInfo_jax
from fineNeat.sneat_jax.ann import act as act_jax
from fineNeat.sneat_jax.ann import obtainOutgoingConnections as obtainOutgoingConnections_jax
from fineNeat import getNodeOrder, getLayer, getNodeKey, getNodeMap, getNodeInfo, act, obtainOutgoingConnections

# getNodeOrder 
Q_jax, wMat_jax = getNodeOrder_jax(node_jax, conn_jax)
Q, wMat = getNodeOrder(node, conn)
print("check_equal(Q, Q_jax): ", check_equal(Q, Q_jax))
print("check_equal(wMat, wMat_jax): ", check_equal(wMat, wMat_jax))

# getLayer
layer_jax = getLayer_jax(wMat_jax)
layer = getLayer(wMat)
print("check_equal(layer, layer_jax): ", check_equal(layer, layer_jax))

# getNodeKey 
nodeKey_jax = getNodeKey_jax(node_jax, conn_jax)
nodeKey = getNodeKey(node, conn)
print("check_equal(nodeKey, nodeKey_jax): ", check_equal(nodeKey, nodeKey_jax))

# getNodeMap 
nodeMap_jax = getNodeMap_jax(node_jax, conn_jax) # gives bug
nodeMap = getNodeMap(node, conn)
print("check_equal(nodeMap, nodeMap_jax): ", nodeMap == nodeMap_jax)

# # getNodeInfo
nodeInfo_jax = getNodeInfo_jax(node_jax, conn_jax)
nodeInfo = getNodeInfo(node, conn)
is_equal = True 
for v, v_jax in zip(nodeInfo, nodeInfo_jax):
    if isinstance(v, dict): 
        is_equal = is_equal and (v == v_jax)
    else: 
        is_equal = is_equal and check_equal(v, v_jax)
print("check_equal(nodeInfo, nodeInfo_jax): ", is_equal)

# act 
wMat = ind.wMat
aVec = ind.aVec
nInput, nOutput = 3, 2
input_pattern = np.array([1, 1, 1])
input_pattern_jax = jnp.array(input_pattern.tolist())
output_jax = act_jax(wMat, aVec, nInput, nOutput, input_pattern_jax)
output = act(wMat, aVec, nInput, nOutput, input_pattern)
print("check_equal(act, act_jax): ", check_equal(output, output_jax))

# obtainOutgoingConnections
node_id = 1
out_conn_jax = obtainOutgoingConnections_jax(wMat, node_id)
out_conn = obtainOutgoingConnections(wMat, node_id)
print("check_equal(out_conn, out_conn_jax): ", check_equal(out_conn, out_conn_jax))




check_equal(Q, Q_jax):  True
check_equal(wMat, wMat_jax):  True
check_equal(layer, layer_jax):  True
check_equal(nodeKey, nodeKey_jax):  True
check_equal(nodeMap, nodeMap_jax):  True
check_equal(nodeInfo, nodeInfo_jax):  True
check_equal(act, act_jax):  True
check_equal(out_conn, out_conn_jax):  True


individual functional checks

In [50]:
from fineNeat.sneat_jax.ind import initIndiv as initIndiv_jax
from fineNeat.sneat_jax.ind import Ind as Ind_jax
from fineNeat.neat_src.ind import initIndiv, Ind

import numpy as np 
import jax.numpy as jnp
def check_equal(a: np.array, b: jnp.array):
    return np.isclose(a, b.tolist()).all()

# initIndiv
node_jax, conn_jax = initIndiv_jax([[3,2], [2,2], [2, 2]])
node, conn = initIndiv([[3,2], [2,2], [2, 2]])
print("check_equal(node_jax, node): ", check_equal(node_jax, node))
print("check_equal(conn_jax,conn): ", check_equal(conn_jax[[0,1,2,4],:],conn[[0,1,2,4],:]))

# Ind initialization
ind_jax = Ind_jax(conn_jax, node_jax)
conn = np.array(conn_jax.tolist())
ind = Ind(conn, node)
print("check_equal(ind_jax.nInput, ind.nInput): ", ind_jax.nInput == ind.nInput)
print("check_equal(ind_jax.nOutput, ind.nOutput): ", ind_jax.nOutput == ind.nOutput)
print("check_equal(ind_jax.nBias, ind.nBias): ", ind_jax.nBias == ind.nBias)
print("check_equal(ind_jax.nHidden, ind.nHidden): ", ind_jax.nHidden == ind.nHidden)

# to_params 
params_jax = ind_jax.to_params()
params = ind.to_params()
param_equal = all([check_equal(w, w_jax) and check_equal(b, b_jax) for (w, b), (w_jax, b_jax) in zip(params, params_jax)])
print("check_equal(params, params_jax): ", param_equal)

# nConns 
nConns_jax = ind_jax.nConns()
nConns = ind.nConns()
print("check_equal(nConns, nConns_jax): ", nConns == nConns_jax)

# express 
ind_jax.express()
ind.express()
print("check_equal(ind_jax.wMat, ind.wMat): ", check_equal(ind_jax.wMat, ind.wMat))
print("check_equal(ind_jax.aVec, ind.aVec): ", check_equal(ind_jax.aVec, ind.aVec))
print("check_equal(ind_jax.wVec, ind.wVec): ", check_equal(ind_jax.wVec, ind.wVec))
print("check_equal(ind_jax.nConn, ind.nConn): ", ind_jax.nConn == ind.nConn)
print("check_equal(ind_jax.max_layer, ind.max_layer): ", ind_jax.max_layer == ind.max_layer)
print("check_equal(ind_jax.node_map, ind.node_map): ", ind_jax.node_map == ind.node_map)

# safe mutate :: check if it can be run suffices 
seed = 1
ind_jax.safe_mutate(seed)
print(":: Safe Mutate Success")

# mutAddNode 
p = {'ann_actRange': [0, 1, 2, 3]}
child_conn, child_node, child_innov = ind_jax.mutAddNode(conn_jax, node_jax, None, 0, p, seed)
print(":: mutAddNode Success")

# mutAddConn 
child_conn, child_node, child_innov = ind_jax.mutAddConn(conn_jax, node_jax, None, 0, p, seed)
print(":: mutAddConn Success")

# mutate :: Issue -- it does not 'disable' connections? 
p = {'ann_actRange': [0, 1, 2, 3], 'prob_mutConn': 0.5, 'ann_mutSigma': 0.1, 'prob_addNode': 0.5, 'prob_addConn': 0.5, 'prob_enable': 0.5, 'ann_absWCap': 10}
ind_jax.mutate(p, seed=seed)
print(":: mutate Success")

check_equal(node_jax, node):  True
check_equal(conn_jax,conn):  True
check_equal(ind_jax.nInput, ind.nInput):  True
check_equal(ind_jax.nOutput, ind.nOutput):  True
check_equal(ind_jax.nBias, ind.nBias):  True
check_equal(ind_jax.nHidden, ind.nHidden):  True
check_equal(params, params_jax):  False
check_equal(nConns, nConns_jax):  True
check_equal(ind_jax.wMat, ind.wMat):  False
check_equal(ind_jax.aVec, ind.aVec):  True
check_equal(ind_jax.wVec, ind.wVec):  False
check_equal(ind_jax.nConn, ind.nConn):  True
check_equal(ind_jax.max_layer, ind.max_layer):  True
check_equal(ind_jax.node_map, ind.node_map):  True
:: Safe Mutate Success
:: mutAddNode Success
:: mutAddConn Success
:: mutate Success


In [36]:
# Add 'random disable' option for topology mutation 
# ind.express()
# ind.conn
# ind.node_map
ind.order

AttributeError: 'Ind' object has no attribute 'order'

In [51]:
ind.node[:,4]
# recheck on the ann functionals ... 
from fineNeat.neat_src.ann import getNodeOrder
order, wMat = getNodeOrder(ind.node, ind.conn) # order is [input, bias, output, hidden (topological sorted)]
node_idx = -1
node_layer = ind.node_map[order[node_idx]][0]



In [55]:

change_ratio = 0.1
sparsity_ratio = 0.8

def mutTurnConn(self, connG, nodeG, innov, gen, p = {"ann_turn_ratio": 0.1, 'sparsity_ratio': 0.8}): 
    """Turn off/on 'non-essential' connections with probability"""
    nodeMap = getNodeMap(nodeG, connG)
    if nodeMap is False:
        # print(":: Failed to get node order")
        return connG, nodeG, innov
    
    conn = self.conn
    node = self.node
    
    # pick non-essential connections and pick ratio of them to randomize 'on/off' status 
    start_hidden_node_idx = self.nInput + self.nBias + self.nOutput
    non_essential_conn_ids = (conn[1,:] >= start_hidden_node_idx) & (conn[2, :] >= start_hidden_node_idx)

    # Randomly select connections to modify based on change_ratio
    n_conns = np.sum(non_essential_conn_ids)
    n_change = int(n_conns * p['ann_turn_ratio'])
    change_mask = np.random.choice(n_conns, size=n_change, replace=False)

    # Create array of 1s and 0s based on sparsity ratio
    new_states = np.random.binomial(1, 1-p['sparsity_ratio'], size=n_change)

    # Update selected connections
    update_indices = np.where(non_essential_conn_ids)[0][change_mask]
    conn[4, update_indices] = new_states
    return conn, node, innov

In [54]:
ind.conn

array([[ 0.   ,  1.   ,  2.   , ..., 17.   , 18.   , 19.   ],
       [ 0.   ,  1.   ,  2.   , ...,  3.   ,  3.   ,  3.   ],
       [ 4.   ,  4.   ,  4.   , ...,  7.   ,  8.   ,  9.   ],
       [ 0.527, -0.485, -0.297, ...,  0.398,  1.328, -0.291],
       [ 1.   ,  1.   ,  1.   , ...,  1.   ,  1.   ,  1.   ]])