In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, vmap, lax
from jax.tree_util import tree_map
from typing import NamedTuple
import time
import networkx as nx
import matplotlib.pyplot as plt

from evojax.task.slimevolley import SlimeVolley

POP_SIZE = 4096
MAX_NODES = 50
INPUT_SIZE = 12
OUTPUT_SIZE = 3

IDX_INPUT = jnp.arange(INPUT_SIZE)
IDX_OUTPUT = jnp.arange(MAX_NODES - OUTPUT_SIZE, MAX_NODES)

# NEAT Constants
C1 = 1.0; C2 = 1.0; C3 = 0.4; COMPATIBILITY_THRESH = 1.0
ELITE_RATIO = 0.05
MUTATION_RATE_ADD_NODE = 0.1
MUTATION_RATE_WEIGHT = 0.9
MUTATION_RATE_ADD_LINK = 0.4
SIGMA = 1


class Genome(NamedTuple):
    weights: jnp.ndarray
    mask: jnp.ndarray
    bias: jnp.ndarray
    node_active: jnp.ndarray

@jit
def mutate_weight(key, genome):
    k1, k2 = random.split(key)
    noise_w = random.normal(k1, genome.weights.shape) * SIGMA
    new_weights = genome.weights + noise_w * genome.mask
    noise_b = random.normal(k2, genome.bias.shape) * SIGMA
    new_bias = genome.bias + noise_b * genome.node_active
    return Genome(weights=new_weights, mask=genome.mask, bias=new_bias, node_active=genome.node_active)

@jit
def mutate_add_link(key, genome):
    r_idx = jnp.arange(MAX_NODES)[:, None]
    c_idx = jnp.arange(MAX_NODES)[None, :]
    is_forward = r_idx < c_idx
    active_matrix = genome.node_active[:, None] * genome.node_active[None, :]
    candidate_mask = is_forward * active_matrix * (1.0 - genome.mask)
    candidate_mask = candidate_mask.at[:, IDX_INPUT].set(0.0)
    has_candidates = jnp.sum(candidate_mask) > 0.5
    noise = random.uniform(key, candidate_mask.shape)
    logits = jnp.where(candidate_mask > 0.5, noise, -1e9)
    flat_idx = jnp.argmax(logits.ravel())
    row = flat_idx // MAX_NODES
    col = flat_idx % MAX_NODES
    new_mask = genome.mask.at[row, col].set(1.0)
    new_weight_val = random.normal(key) * SIGMA
    new_weights = genome.weights.at[row, col].set(new_weight_val)
    return tree_map(lambda x, y: jnp.where(has_candidates, x, y),
                    Genome(new_weights, new_mask, genome.bias, genome.node_active), genome)

@jit
def mutate_add_node(key, genome):
    candidate_links = genome.mask
    inactive_nodes = 1.0 - genome.node_active
    k_link, k_node = random.split(key)
    noise = random.uniform(k_link, candidate_links.shape)
    logits = jnp.where(candidate_links > 0.5, noise, -1e9)
    flat_idx = jnp.argmax(logits.ravel())
    src = flat_idx // MAX_NODES
    dst = flat_idx % MAX_NODES
    old_weight = genome.weights[src, dst]
    node_indices = jnp.arange(MAX_NODES)
    between_mask = (node_indices > src) & (node_indices < dst) & (inactive_nodes > 0.5)
    has_space = jnp.sum(between_mask) > 0.5
    noise_n = random.uniform(k_node, between_mask.shape)
    logits_n = jnp.where(between_mask > 0.5, noise_n, -1e9)
    new_node_idx = jnp.argmax(logits_n)
    do_mutate = (candidate_links[src, dst] > 0.5) & has_space
    next_mask = genome.mask.at[src, dst].set(0.0)
    next_weights = genome.weights.at[src, dst].set(0.0)
    next_active = genome.node_active.at[new_node_idx].set(1.0)
    next_mask = next_mask.at[src, new_node_idx].set(1.0)
    next_weights = next_weights.at[src, new_node_idx].set(1.0)
    next_mask = next_mask.at[new_node_idx, dst].set(1.0)
    next_weights = next_weights.at[new_node_idx, dst].set(old_weight)
    next_bias = genome.bias.at[new_node_idx].set(0.0)
    mutated_genome = Genome(next_weights, next_mask, next_bias, next_active)
    return tree_map(lambda x, y: jnp.where(do_mutate, x, y), mutated_genome, genome)

@jit
def apply_mutations(key, genome):
    k1, k2, k3 = random.split(key, 3)
    g = jax.lax.cond(random.uniform(k1) < MUTATION_RATE_WEIGHT, lambda x: mutate_weight(k1, x), lambda x: x, genome)
    g = jax.lax.cond(random.uniform(k2) < MUTATION_RATE_ADD_LINK, lambda x: mutate_add_link(k2, x), lambda x: x, g)
    g = jax.lax.cond(random.uniform(k3) < MUTATION_RATE_ADD_NODE, lambda x: mutate_add_node(k3, x), lambda x: x, g)
    return g

@jit
def calculate_compatibility_distance(g1, g2):
    disjoint_count = jnp.sum(jnp.abs(g1.mask - g2.mask))
    matching_mask = g1.mask * g2.mask
    n_matching = jnp.sum(matching_mask)
    weight_diff_sum = jnp.sum(jnp.abs(g1.weights - g2.weights) * matching_mask)
    avg_weight_diff = jnp.where(n_matching > 0, weight_diff_sum / n_matching, 0.0)
    return C1 * disjoint_count + C3 * avg_weight_diff

@jit
def crossover_pair(key, p1, f1, p2, f2):
    p1_is_fitter = f1 >= f2
    primary = tree_map(lambda a, b: jnp.where(p1_is_fitter, a, b), p1, p2)
    secondary = tree_map(lambda a, b: jnp.where(p1_is_fitter, b, a), p1, p2)
    child_mask = primary.mask
    child_node_active = primary.node_active
    matching_mask = primary.mask * secondary.mask
    k_w, k_b = random.split(key, 2)
    swap_prob_w = random.uniform(k_w, shape=primary.weights.shape)
    take_secondary_w = (matching_mask > 0.5) & (swap_prob_w < 0.5)
    child_weights = jnp.where(take_secondary_w, secondary.weights, primary.weights)
    swap_prob_b = random.uniform(k_b, shape=primary.bias.shape)
    take_secondary_b = swap_prob_b < 0.5
    child_bias = jnp.where(take_secondary_b, secondary.bias, primary.bias)
    return Genome(weights=child_weights, mask=child_mask, bias=child_bias, node_active=child_node_active)

@jit
def adjust_fitness_by_speciation(population, raw_fitness):
    def dist_fn(g_a, g_b): return calculate_compatibility_distance(g_a, g_b)
    distance_matrix = vmap(vmap(dist_fn, in_axes=(None, 0)), in_axes=(0, None))(population, population)
    sh = jnp.where(distance_matrix < COMPATIBILITY_THRESH, 1.0, 0.0)
    niche_counts = jnp.sum(sh, axis=1)
    min_fit = jnp.min(raw_fitness)
    shift_val = jnp.abs(min_fit) + 1.0
    shifted_fitness = raw_fitness + shift_val
    return shifted_fitness / niche_counts

@jit
def evolve_step(key, population, fitness_scores):
    adj_fitness = adjust_fitness_by_speciation(population, fitness_scores)
    elite_count = int(POP_SIZE * ELITE_RATIO)
    sorted_indices = jnp.argsort(fitness_scores)
    elite_indices = sorted_indices[-elite_count:]
    elites = tree_map(lambda x: x[elite_indices], population)
    num_offspring = POP_SIZE - elite_count
    k_sel, k_cross, k_mut = random.split(key, 3)
    def get_parents(k, n):
        idx1 = random.randint(k, (n,), 0, POP_SIZE)
        idx2 = random.randint(k, (n,), 0, POP_SIZE)
        fit1 = adj_fitness[idx1]
        fit2 = adj_fitness[idx2]
        winner = jnp.where(fit1 > fit2, idx1, idx2)
        return tree_map(lambda x: x[winner], population), adj_fitness[winner]
    k_p1, k_p2 = random.split(k_sel)
    p1, f1 = get_parents(k_p1, num_offspring)
    p2, f2 = get_parents(k_p2, num_offspring)
    offspring = vmap(crossover_pair)(random.split(k_cross, num_offspring), p1, f1, p2, f2)
    offspring_mutated = vmap(apply_mutations)(random.split(k_mut, num_offspring), offspring)
    next_population = tree_map(lambda e, o: jnp.concatenate([e, o], axis=0), elites, offspring_mutated)
    return next_population

@jit
def initialize_population(key):
    keys = random.split(key, POP_SIZE)
    def init_one(k):
        k1, k2 = random.split(k)
        w = random.normal(k1, (MAX_NODES, MAX_NODES)) * 0.5
        b = random.normal(k2, (MAX_NODES,)) * 0.5
        node_active = jnp.zeros((MAX_NODES,))
        node_active = node_active.at[IDX_INPUT].set(1.0)
        node_active = node_active.at[IDX_OUTPUT].set(1.0)
        mask = jnp.zeros((MAX_NODES, MAX_NODES))
        for i in range(INPUT_SIZE):
            for o in range(OUTPUT_SIZE):
                out_idx = MAX_NODES - OUTPUT_SIZE + o
                mask = mask.at[i, out_idx].set(1.0)
        return Genome(w, mask, b, node_active)
    return vmap(init_one)(keys)

@jit
def jax_forward_dag(weights, mask, bias, node_active, x):
    activations = jnp.zeros(MAX_NODES)
    activations = activations.at[IDX_INPUT].set(x)

    def update_node(i, acts):
        is_active = node_active[i]
        pre_act = jnp.dot(acts, weights[:, i] * mask[:, i]) + bias[i]
        val = jnp.tanh(pre_act)
        return acts.at[i].set(val * is_active)

    activations = lax.fori_loop(INPUT_SIZE, MAX_NODES, update_node, activations)
    return activations[IDX_OUTPUT]


batch_forward = vmap(jax_forward_dag, in_axes=(0, 0, 0, 0, 0))

GLOBAL_ENV = SlimeVolley(test=False, max_steps=3000)

@jit
def rollout_batch(key, population):
    keys = random.split(key, POP_SIZE)
    init_state = GLOBAL_ENV.reset(keys)

    def step_fn(state, _):
      obs = state.obs

      out = batch_forward(population.weights, population.mask, population.bias, population.node_active, obs)
      action = jnp.where(out > 0.0, 1.0, 0.0).astype(jnp.float32)

      next_state, reward, done = GLOBAL_ENV.step(state, action)

      # R_ball
      dist_x = jnp.abs(obs[:, 0] - obs[:, 4])
      bonus_ball = jnp.maximum(0.0, (2.0 - dist_x) * 0.005)

      # R_forward
      agent_x = obs[:, 0]
      bonus_forward = (agent_x + 1.0) * 0.002


      # R_return
      ball_x = obs[:, 4]
      bonus_opponent_court = jnp.where(ball_x < 0.0, 0.005, 0.0)

      # Epsilon
      survival_bonus = 0.0001

      # Total fitness
      fitness = reward + bonus_ball + bonus_forward + survival_bonus + bonus_opponent_court

      pure = reward
      return next_state, (fitness, pure)

    final_state, (rewards_with_bonus, pure_rewards) = lax.scan(step_fn, init_state, None, length=3000)

    total_with_bonus = jnp.sum(rewards_with_bonus, axis=0)
    total_pure = jnp.sum(pure_rewards, axis=0)
    return total_with_bonus, total_pure

def draw_genome_network(genome, idx=0, filename=None, gen = None):
    """
    Draw the network structure of individual idx using NetworkX.
    - First add all nodes with attributes
    - Then add edges (to prevent implicit node addition)
    - Fill in missing layer attributes before layout
    """
    w = np.array(genome.weights[idx])
    m = np.array(genome.mask[idx])
    na = np.array(genome.node_active[idx])

    active_nodes = [i for i in range(MAX_NODES) if float(na[i]) > 0.5]

    G = nx.DiGraph()

    for i in active_nodes:
        if i < INPUT_SIZE:
            color = 'lightgreen'
            layer = 0
            label = f"In{i}"
        elif i >= MAX_NODES - OUTPUT_SIZE:
            color = 'salmon'
            layer = None
            label = "Out"
        else:
            color = 'skyblue'
            layer = None
            label = str(i)
        G.add_node(i, color=color, layer=layer, label=label)

    node_depths = {}
    for i in active_nodes:
        if i < INPUT_SIZE:
            node_depths[i] = 0

    changed = True
    while changed:
        changed = False
        for j in active_nodes:
            if j < INPUT_SIZE or j in node_depths:
                continue
            incoming = [node_depths[src] for src in active_nodes
                        if src < j and m[src, j] > 0.5 and src in node_depths]
            if incoming:
                node_depths[j] = int(max(incoming)) + 1
                changed = True

    for j in active_nodes:
        if j not in node_depths and j < MAX_NODES - OUTPUT_SIZE:
            node_depths[j] = 1

    max_depth = max(node_depths.values()) if node_depths else 1
    for j in active_nodes:
        if j >= MAX_NODES - OUTPUT_SIZE:
            node_depths[j] = max_depth + 1

    for n in active_nodes:
        G.nodes[n]['layer'] = node_depths[n]

    edge_colors = []
    edge_widths = []
    for i in active_nodes:
        for j in active_nodes:
            if m[i, j] > 0.5:
                w_ij = float(w[i, j])
                G.add_edge(i, j)
                edge_colors.append('red' if w_ij > 0 else 'blue')
                edge_widths.append(min(3.0, abs(w_ij)) * 0.8 + 0.2)

    for n in list(G.nodes()):
        if 'layer' not in G.nodes[n] or G.nodes[n]['layer'] is None:
            fallback = 0 if n < INPUT_SIZE else (max_depth + 1 if n >= MAX_NODES - OUTPUT_SIZE else 1)
            G.nodes[n]['layer'] = fallback

    pos = nx.multipartite_layout(G, subset_key='layer')

    node_colors = [nx.get_node_attributes(G, 'color')[n] for n in G.nodes()]
    labels = nx.get_node_attributes(G, 'label')

    plt.figure(figsize=(10, 6))
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=600, edgecolors='gray', alpha=0.9)
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=9)
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=edge_widths,
                           arrows=True, arrowstyle='-|>', arrowsize=14,
                           connectionstyle="arc3,rad=0.08")
    plt.title(f"Best Genome Network Structure (Gen {gen}/200)")
    plt.axis('off')
    if filename:
        plt.savefig(filename, dpi=160)
    plt.show()

def main():
    key = random.PRNGKey(42)
    key, subkey = random.split(key)

    population = initialize_population(subkey)
    print(f"--- Slime Volley NEAT (Full GPU / EvoJAX) Started (Pop: {POP_SIZE}) ---")

    max_pure_hist = []
    mean_pure_hist = []

    for gen in range(1, 110):
        start_time = time.time()

        # 1. Evaluation
        key, subkey = random.split(key)
        fitness_scores_with_bonus, fitness_scores_pure = rollout_batch(subkey, population)

        fitness_np_with_bonus = jax.device_get(fitness_scores_with_bonus)
        fitness_np_pure = jax.device_get(fitness_scores_pure)

        max_fit_bonus = np.max(fitness_np_with_bonus)
        mean_fit_bonus = np.mean(fitness_np_with_bonus)

        max_fit_pure = np.max(fitness_np_pure)
        max_pure_hist.append(max_fit_pure)
        mean_fit_pure = np.mean(fitness_np_pure)
        mean_pure_hist.append(mean_fit_pure)

        # 2. Elite selection

        key, subkey = random.split(key)
        population = evolve_step(subkey, population, fitness_scores_with_bonus)

        elapsed = time.time() - start_time
        print(
            f"Gen {gen:3d} | BonusFit: {max_fit_bonus:6.2f} (Mean: {mean_fit_bonus:6.2f}) "
            f"| PureScore: {max_fit_pure:6.2f} (Mean: {mean_fit_pure:6.2f}) | Time: {elapsed:.3f}s"
        )

        if gen % 10 == 0:
            best_idx = np.argmax(fitness_np_with_bonus)
            draw_genome_network(population, best_idx, f"best_genome_gen_{gen}.png", gen)

    return population, fitness_np_pure, max_pure_hist, mean_pure_hist


In [None]:
population, fitness_np_pure, max_pure_hist, mean_pure_hist = main()
