In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap, lax, value_and_grad
from jax.tree_util import tree_map
from typing import NamedTuple
import time
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

POP_SIZE = 1024
MAX_NODES = 32
INPUT_SIZE = 2
OUTPUT_SIZE = 1

BATCH_SIZE = 1024
LEARNING_RATE = 0.05
TRAIN_EPOCHS = 20

# NEAT Constants
C1 = 1.0; C2 = 1.0; C3 = 0.4; COMPATIBILITY_THRESH = 1.0
ELITE_RATIO = 0.05
MUTATION_RATE_ADD_NODE = 0.7
MUTATION_RATE_ADD_LINK = 0.4
MUTATION_RATE_WEIGHT = 0.5
SIGMA = 0.01

#Complexity Penalty
PENALTY_NODE = 0.01
PENALTY_LINK = 0.01

IDX_INPUT = jnp.arange(INPUT_SIZE)
IDX_OUTPUT = jnp.arange(MAX_NODES - OUTPUT_SIZE, MAX_NODES)

DATASET_NAME = "xor"

def get_dataset(task_name="spiral", n_samples=1024, key=random.PRNGKey(0)):
    if task_name == "xor":
        x = random.uniform(key, (n_samples, 2), minval=-1, maxval=1)
        y = jnp.logical_xor(x[:, 0] > 0, x[:, 1] > 0).astype(jnp.float32)
        noise = random.normal(key, x.shape) * 0.1
        x = x + noise
        return x, y[:, None]

    elif task_name == "circle":
        x = random.uniform(key, (n_samples * 2, 2), minval=-1, maxval=1)
        radius = jnp.linalg.norm(x, axis=1)

        idx_inner = jnp.where(radius <= 0.6)[0][:n_samples//2]
        idx_outer = jnp.where(radius > 0.6)[0][:n_samples//2]

        x_inner = x[idx_inner]
        y_inner = jnp.zeros((len(x_inner), 1))

        x_outer = x[idx_outer]
        y_outer = jnp.ones((len(x_outer), 1))

        x_final = jnp.concatenate([x_inner, x_outer], axis=0)
        y_final = jnp.concatenate([y_inner, y_outer], axis=0)

        return x_final, y_final

    elif task_name == "spiral":
        n = n_samples // 2
        k1, k2 = random.split(key)

        def make_spiral(delta_t, label, key):
            r = jnp.linspace(0.1, 1.0, n)
            t = jnp.linspace(0, 2.5 * 3.14159, n) + delta_t
            noise = random.normal(key, (n, 2)) * 0.02
            x = r * jnp.sin(t)
            y = r * jnp.cos(t)
            data = jnp.stack([x, y], axis=1) + noise
            labels = jnp.ones((n, 1)) * label
            return data, labels

        x1, y1 = make_spiral(0, 0.0, k1)
        x2, y2 = make_spiral(3.14159, 1.0, k2)

        x = jnp.concatenate([x1, x2], axis=0)
        y = jnp.concatenate([y1, y2], axis=0)
        return x, y

    else:
        raise ValueError("Unknown task")

DATA_X, DATA_Y = get_dataset(DATASET_NAME, BATCH_SIZE)

class Genome(NamedTuple):
    weights: jnp.ndarray
    mask: jnp.ndarray
    bias: jnp.ndarray
    node_active: jnp.ndarray

@jit
def mutate_weight(key, genome):
    k1, k2 = random.split(key)
    noise_w = random.normal(k1, genome.weights.shape) * SIGMA * 0.1
    new_weights = genome.weights + noise_w * genome.mask
    noise_b = random.normal(k2, genome.bias.shape) * SIGMA * 0.1
    new_bias = genome.bias + noise_b * genome.node_active
    return Genome(weights=new_weights, mask=genome.mask, bias=new_bias, node_active=genome.node_active)

@jit
def mutate_add_link(key, genome):
    r_idx = jnp.arange(MAX_NODES)[:, None]
    c_idx = jnp.arange(MAX_NODES)[None, :]
    is_forward = r_idx < c_idx
    active_matrix = genome.node_active[:, None] * genome.node_active[None, :]
    candidate_mask = is_forward * active_matrix * (1.0 - genome.mask)
    candidate_mask = candidate_mask.at[:, IDX_INPUT].set(0.0)
    has_candidates = jnp.sum(candidate_mask) > 0.5
    noise = random.uniform(key, candidate_mask.shape)
    logits = jnp.where(candidate_mask > 0.5, noise, -1e9)
    flat_idx = jnp.argmax(logits.ravel())
    row = flat_idx // MAX_NODES
    col = flat_idx % MAX_NODES
    new_mask = genome.mask.at[row, col].set(1.0)
    new_weights = genome.weights.at[row, col].set(random.normal(key) * 0.5)
    return tree_map(lambda x, y: jnp.where(has_candidates, x, y),
                    Genome(new_weights, new_mask, genome.bias, genome.node_active), genome)

@jit
def mutate_add_node(key, genome):
    candidate_links = genome.mask
    inactive_nodes = 1.0 - genome.node_active
    k_link, k_node = random.split(key)
    noise = random.uniform(k_link, candidate_links.shape)
    logits = jnp.where(candidate_links > 0.5, noise, -1e9)
    flat_idx = jnp.argmax(logits.ravel())
    src = flat_idx // MAX_NODES
    dst = flat_idx % MAX_NODES
    old_weight = genome.weights[src, dst]

    node_indices = jnp.arange(MAX_NODES)
    between_mask = (node_indices > src) & (node_indices < dst) & (inactive_nodes > 0.5)
    has_space = jnp.sum(between_mask) > 0.5

    noise_n = random.uniform(k_node, between_mask.shape)
    logits_n = jnp.where(between_mask > 0.5, noise_n, -1e9)
    new_node_idx = jnp.argmax(logits_n)

    do_mutate = (candidate_links[src, dst] > 0.5) & has_space

    next_mask = genome.mask.at[src, dst].set(0.0)
    next_weights = genome.weights.at[src, dst].set(0.0)

    next_active = genome.node_active.at[new_node_idx].set(1.0)

    next_mask = next_mask.at[src, new_node_idx].set(1.0)
    next_weights = next_weights.at[src, new_node_idx].set(1.0)

    next_mask = next_mask.at[new_node_idx, dst].set(1.0)
    next_weights = next_weights.at[new_node_idx, dst].set(old_weight)

    next_bias = genome.bias.at[new_node_idx].set(0.0)

    mutated_genome = Genome(next_weights, next_mask, next_bias, next_active)
    return tree_map(lambda x, y: jnp.where(do_mutate, x, y), mutated_genome, genome)

@jit
def apply_mutations(key, genome):
    k1, k2, k3 = random.split(key, 3)
    g = jax.lax.cond(random.uniform(k1) < MUTATION_RATE_WEIGHT, lambda x: mutate_weight(k1, x), lambda x: x, genome)
    g = jax.lax.cond(random.uniform(k2) < MUTATION_RATE_ADD_LINK, lambda x: mutate_add_link(k2, x), lambda x: x, g)
    g = jax.lax.cond(random.uniform(k3) < MUTATION_RATE_ADD_NODE, lambda x: mutate_add_node(k3, x), lambda x: x, g)
    return g

@jit
def calculate_compatibility_distance(g1, g2):
    disjoint_count = jnp.sum(jnp.abs(g1.mask - g2.mask))
    matching_mask = g1.mask * g2.mask
    n_matching = jnp.sum(matching_mask)
    weight_diff_sum = jnp.sum(jnp.abs(g1.weights - g2.weights) * matching_mask)
    avg_weight_diff = jnp.where(n_matching > 0, weight_diff_sum / n_matching, 0.0)
    return C1 * disjoint_count + C3 * avg_weight_diff

@jit
def crossover_pair(key, p1, f1, p2, f2):
    p1_is_fitter = f1 >= f2
    primary = tree_map(lambda a, b: jnp.where(p1_is_fitter, a, b), p1, p2)
    secondary = tree_map(lambda a, b: jnp.where(p1_is_fitter, b, a), p1, p2)

    child_mask = primary.mask
    child_node_active = primary.node_active

    matching_mask = primary.mask * secondary.mask
    k_w, k_b = random.split(key, 2)

    swap_prob_w = random.uniform(k_w, shape=primary.weights.shape)
    take_secondary_w = (matching_mask > 0.5) & (swap_prob_w < 0.5)
    child_weights = jnp.where(take_secondary_w, secondary.weights, primary.weights)

    swap_prob_b = random.uniform(k_b, shape=primary.bias.shape)
    take_secondary_b = swap_prob_b < 0.5
    child_bias = jnp.where(take_secondary_b, secondary.bias, primary.bias)

    return Genome(weights=child_weights, mask=child_mask, bias=child_bias, node_active=child_node_active)

@jit
def evolve_step(key, population, fitness_scores):
    def dist_fn(g_a, g_b): return calculate_compatibility_distance(g_a, g_b)
    distance_matrix = vmap(vmap(dist_fn, in_axes=(None, 0)), in_axes=(0, None))(population, population)
    sh = jnp.where(distance_matrix < COMPATIBILITY_THRESH, 1.0, 0.0)
    niche_counts = jnp.sum(sh, axis=1) + 1e-6

    min_fit = jnp.min(fitness_scores)
    shift_val = jnp.abs(min_fit) + 1.0
    shifted_fitness = fitness_scores + shift_val
    adj_fitness = shifted_fitness / niche_counts

    elite_count = int(POP_SIZE * ELITE_RATIO)
    sorted_indices = jnp.argsort(fitness_scores)
    elite_indices = sorted_indices[-elite_count:]
    elites = tree_map(lambda x: x[elite_indices], population)

    num_offspring = POP_SIZE - elite_count
    k_sel, k_cross, k_mut = random.split(key, 3)

    def get_parents(k, n):
        idx1 = random.randint(k, (n,), 0, POP_SIZE)
        idx2 = random.randint(k, (n,), 0, POP_SIZE)
        fit1 = adj_fitness[idx1]
        fit2 = adj_fitness[idx2]
        winner = jnp.where(fit1 > fit2, idx1, idx2)
        return tree_map(lambda x: x[winner], population), adj_fitness[winner]

    k_p1, k_p2 = random.split(k_sel)
    p1, f1 = get_parents(k_p1, num_offspring)
    p2, f2 = get_parents(k_p2, num_offspring)

    offspring = vmap(crossover_pair)(random.split(k_cross, num_offspring), p1, f1, p2, f2)
    offspring_mutated = vmap(apply_mutations)(random.split(k_mut, num_offspring), offspring)

    next_population = tree_map(lambda e, o: jnp.concatenate([e, o], axis=0), elites, offspring_mutated)
    return next_population

@jit
def forward_pass(weights, mask, bias, node_active, x):
    activations = jnp.zeros((x.shape[0], MAX_NODES))
    activations = activations.at[:, IDX_INPUT].set(x)

    def update_node(i, acts):
        w_col = weights[:, i] * mask[:, i]
        pre_act = jnp.dot(acts, w_col) + bias[i]
        val = jnp.tanh(pre_act)
        is_active = node_active[i]
        return acts.at[:, i].set(val * is_active)

    activations = lax.fori_loop(INPUT_SIZE, MAX_NODES, update_node, activations)

    out_node_idx = IDX_OUTPUT[0]
    return jax.nn.sigmoid(activations[:, out_node_idx])

def loss_fn(weights, mask, bias, node_active, x, y):
    preds = forward_pass(weights, mask, bias, node_active, x)
    preds = jnp.clip(preds, 1e-7, 1.0 - 1e-7)
    loss = -jnp.mean(y[:, 0] * jnp.log(preds) + (1 - y[:, 0]) * jnp.log(1 - preds))
    return loss

@jit
def train_genome_sgd(genome_tuple, x, y):
    weights, mask, bias, node_active = genome_tuple

    def train_step(carry, _):
        w, b = carry
        loss, (grads_w, grads_b) = value_and_grad(loss_fn, argnums=(0, 2))(w, mask, b, node_active, x, y)
        w_new = w - LEARNING_RATE * grads_w * mask
        b_new = b - LEARNING_RATE * grads_b * node_active
        return (w_new, b_new), loss

    (trained_w, trained_b), losses = lax.scan(train_step, (weights, bias), None, length=TRAIN_EPOCHS)

    final_loss = losses[-1]

    preds = forward_pass(trained_w, mask, trained_b, node_active, x)
    predictions = preds > 0.5
    accuracy = jnp.mean(predictions == y[:, 0])

    num_links = jnp.sum(mask)
    num_nodes = jnp.sum(node_active) - INPUT_SIZE
    complexity = PENALTY_LINK * num_links + PENALTY_NODE * num_nodes

    fitness = accuracy - complexity

    new_genome = Genome(trained_w, mask, trained_b, node_active)
    return new_genome, fitness, accuracy

train_population = vmap(train_genome_sgd, in_axes=(0, None, None))
@jit
def initialize_population(key):
    keys = random.split(key, POP_SIZE)
    def init_one(k):
        k1, k2 = random.split(k)
        w = random.normal(k1, (MAX_NODES, MAX_NODES)) * 0.1
        b = random.normal(k2, (MAX_NODES,)) * 0.1
        node_active = jnp.zeros((MAX_NODES,))
        node_active = node_active.at[IDX_INPUT].set(1.0)
        node_active = node_active.at[IDX_OUTPUT].set(1.0)
        mask = jnp.zeros((MAX_NODES, MAX_NODES))
        for i in range(INPUT_SIZE):
            for o in range(OUTPUT_SIZE):
                out_idx = MAX_NODES - OUTPUT_SIZE + o
                mask = mask.at[i, out_idx].set(1.0)
        return Genome(w, mask, b, node_active)
    return vmap(init_one)(keys)

def visualize_decision_boundary(genome, idx, x, y, gen):
    w = np.array(genome.weights[idx])
    m = np.array(genome.mask[idx])
    b = np.array(genome.bias[idx])
    na = np.array(genome.node_active[idx])

    xx, yy = np.meshgrid(np.linspace(-1.2, 1.2, 100), np.linspace(-1.2, 1.2, 100))
    grid_x = np.c_[xx.ravel(), yy.ravel()]

    preds_jax = forward_pass(genome.weights[idx], genome.mask[idx], genome.bias[idx], genome.node_active[idx], jnp.array(grid_x))
    preds = np.array(preds_jax).reshape(xx.shape)

    plt.figure(figsize=(14, 6))

    plt.subplot(1, 2, 1)
    plt.contourf(xx, yy, preds, cmap='RdBu', alpha=0.8)
    x_np = np.array(x)
    y_np = np.array(y)
    plt.scatter(x_np[:, 0], x_np[:, 1], c=y_np[:, 0], cmap='RdBu_r', edgecolors='k', s=20)
    plt.title(f"Decision Boundary (Gen {gen})")

    plt.subplot(1, 2, 2)

    G = nx.DiGraph()
    active_nodes = [i for i in range(MAX_NODES) if na[i] > 0.5]

    node_depths = {}

    for i in range(MAX_NODES):
        if na[i] < 0.5: continue

        if i < INPUT_SIZE:
            node_depths[i] = 0
            continue

        incoming_depths = [node_depths[src] for src in range(i)
                           if m[src, i] > 0.5 and na[src] > 0.5 and src in node_depths]

        if not incoming_depths:
            node_depths[i] = 1
        else:
            node_depths[i] = max(incoming_depths) + 1

    max_depth = max(node_depths.values()) if node_depths else 0
    output_layer = max_depth + 1

    for i in active_nodes:
        if i < INPUT_SIZE:
            color = 'lightgreen'
            subset = 0
            label = f"In{i}"
        elif i >= MAX_NODES - OUTPUT_SIZE:
            color = 'salmon'
            subset = output_layer
            label = "Out"
        else:
            color = 'skyblue'
            subset = node_depths.get(i, 1)
            label = str(i)

        G.add_node(i, color=color, subset=subset, label=label)

    edges = []
    edge_colors = []
    edge_widths = []

    for i in active_nodes:
        for j in active_nodes:
            if m[i, j] > 0.5:
                w_val = float(w[i, j])
                G.add_edge(i, j)

                edge_colors.append('red' if w_val > 0 else 'blue')
                width = min(3.0, abs(w_val)) * 0.8 + 0.2
                edge_widths.append(width)

    pos = nx.multipartite_layout(G, subset_key="subset")

    node_colors = [nx.get_node_attributes(G, 'color')[n] for n in G.nodes()]
    labels = nx.get_node_attributes(G, 'label')

    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500, alpha=0.9, edgecolors='gray')
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=10)
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths,
                           arrows=True, arrowstyle='-|>', arrowsize=15,
                           connectionstyle="arc3,rad=0.1")

    plt.title("Network Topology (Layered)")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"best_genome_gen_{gen}.png")
    plt.show()

def main():
    key = random.PRNGKey(21)
    key, subkey = random.split(key)

    population = initialize_population(subkey)
    print(f"--- Backprop NEAT Started (Pop: {POP_SIZE}) ---")
    print(f"Task: Spiral Classification (Samples: {BATCH_SIZE})")

    for gen in range(1, 201):
        start_time = time.time()

        trained_population, fitness, accuracy = train_population(population, DATA_X, DATA_Y)

        max_fit = jnp.max(fitness)
        max_acc = jnp.max(accuracy)
        mean_acc = jnp.mean(accuracy)
        best_idx = jnp.argmax(fitness)

        elapsed = time.time() - start_time
        print(f"Gen {gen:3d} | Max Fit: {max_fit:.4f} | Max Acc: {max_acc:.4f} | Mean Acc: {mean_acc:.4f} | Time: {elapsed:.2f}s")

        if gen % 10 == 0 or gen == 1:
            visualize_decision_boundary(trained_population, best_idx, DATA_X, DATA_Y, gen)

        key, subkey = random.split(key)
        population = evolve_step(subkey, trained_population, fitness)

In [None]:
main()