In [5]:
from datagen import DataGenerator

# Example usage
generator = DataGenerator()

# Generate training and test data for a specific type
train_data, test_data = generator.generate_random_dataset(choice=0)  # 0 for circle dataset

# Generate a mini-batch from training data
batch = generator.generate_batch(train_data)
batch_input, batch_output = batch[:, :2], batch[:, 2:]

In [21]:
# use 'act' to backprop on a single wMat 
# -- what's input and output? input 2D point, output 2D logit for each class (2 class)
nInput = 2
nOutput = 2 

from fineNeat import Ind 

# initialize individual with fixed shape 
ind = Ind.from_shapes([(nInput, 5), (5, nOutput)])
ind.express()

from jax import numpy as jnp 
wMat = jnp.copy(ind.wMat)
aVec = jnp.array(ind.aVec)

# convert 'act' into jax and conduct backprop 
from fineNeat.sneat_jax.ann import act
from loss import cross_entropy_loss
from jax import value_and_grad

# Define a function that computes loss given weights
def loss_fn(weights, aVec, nInput, nOutput, inputs, targets):
    logits = act(weights, aVec, nInput, nOutput, inputs)
    return cross_entropy_loss(logits, targets)

# Jax backprop step function
def step(wMat, aVec, nInput, nOutput, batch_input, batch_output, learning_rate=0.01):
    # Compute gradient
    loss_value, grads = value_and_grad(loss_fn)(wMat, aVec, nInput, nOutput, batch_input, batch_output)
    
    # Gradient descent step 
    wMat_updated = wMat - learning_rate * grads

    return wMat_updated, loss_value

def train(wMat, aVec, nInput, nOutput, batch_input, batch_output, learning_rate=0.01, num_steps=100):
    for _ in range(num_steps):
        wMat, loss_value = step(wMat, aVec, nInput, nOutput, batch_input, batch_output, learning_rate)
        print(f"Step {_ + 1}, Loss: {loss_value}")
    return wMat

train(wMat, aVec, nInput, nOutput, batch_input, batch_output, learning_rate=0.01, num_steps=500)

Step 1, Loss: 0.842182993888855
Step 2, Loss: 0.8362444043159485
Step 3, Loss: 0.8305174112319946
Step 4, Loss: 0.8249969482421875
Step 5, Loss: 0.8196778297424316
Step 6, Loss: 0.8145545721054077
Step 7, Loss: 0.8096219301223755
Step 8, Loss: 0.80487459897995
Step 9, Loss: 0.8003073930740356
Step 10, Loss: 0.7959151268005371
Step 11, Loss: 0.7916927337646484
Step 12, Loss: 0.7876350283622742
Step 13, Loss: 0.7837368249893188
Step 14, Loss: 0.7799933552742004
Step 15, Loss: 0.776399552822113
Step 16, Loss: 0.7729504108428955
Step 17, Loss: 0.7696412205696106
Step 18, Loss: 0.7664671540260315
Step 19, Loss: 0.7634236216545105
Step 20, Loss: 0.7605057954788208
Step 21, Loss: 0.7577095031738281
Step 22, Loss: 0.7550300359725952
Step 23, Loss: 0.7524630427360535
Step 24, Loss: 0.7500045895576477
Step 25, Loss: 0.7476502656936646
Step 26, Loss: 0.7453961968421936
Step 27, Loss: 0.7432383298873901
Step 28, Loss: 0.7411730289459229
Step 29, Loss: 0.7391964197158813
Step 30, Loss: 0.7373050451

Array([[ 0.   ,  0.   ,  0.   , ..., -0.015, -0.096,  0.149],
       [ 0.   ,  0.   ,  0.   , ..., -0.266,  0.004, -0.006],
       [ 0.   ,  0.   ,  0.   , ..., -0.111, -0.171,  0.062],
       ...,
       [ 0.   ,  0.   ,  0.   , ...,  0.   ,  0.152, -0.791],
       [ 0.   ,  0.   ,  0.   , ...,  0.   ,  0.   ,  0.178],
       [ 0.   ,  0.   ,  0.   , ...,  0.   ,  0.   ,  0.   ]], dtype=float32)