# MAP-Elites track generations

In [None]:
import sys
import numpy as np
import requests
import random
import matplotlib.pyplot as plt
from dask.distributed import Client, LocalCluster
from ribs.archives import GridArchive
from ribs.emitters import EmitterBase
from ribs.schedulers import Scheduler
from ribs.visualize import grid_archive_heatmap

In [None]:
BASE_URL = 'http://localhost:4242'
POINTS_COUNT = 50
MAX_SELECTED_CELLS = 10
SOLUTION_DIM = POINTS_COUNT * 2 + MAX_SELECTED_CELLS * 2 + 1 
TRACK_SIZE_RANGE = (2, 6)
LENGTH_RANGE = (400, 2000)
ITERATIONS = 500
ARCHIVE_DIM = 3
INIT_POPULATION = ARCHIVE_DIM * ARCHIVE_DIM 

DEBUG_CROSSOVER = True
DEBUG_MUTATION = True

In [None]:
cluster = LocalCluster(processes=True, n_workers=5, threads_per_worker=1)
client = Client(cluster)

### Helper functions

In [None]:
def generate_solution(iteration):
    print(f"Generating solution for iteration {iteration}")
    try:
        response = requests.post(
            f"{BASE_URL}/generate",
            json={
                "id": iteration + random.random(),
                "mode": "voronoi",
                "trackSize": random.randint(TRACK_SIZE_RANGE[0], TRACK_SIZE_RANGE[1])
            },
            timeout=60
        )
        response.raise_for_status()
        return response.json()
    except requests.RequestException as e:
        print(f"Error generating solution for iteration {iteration}: {e}")
        return None

def solution_to_array(sol):
    if sol is None:
        return None
    arr = np.zeros(SOLUTION_DIM)
    for i, p in enumerate(sol.get("dataSet", [])):
        arr[i * 2] = p.get("x", 0)
        arr[i * 2 + 1] = p.get("y", 0)
    for i, c in enumerate(sol.get("selectedCells", [])):
        if i < MAX_SELECTED_CELLS:
            idx = POINTS_COUNT * 2 + i * 2
            arr[idx] = c.get("x", 0)
            arr[idx + 1] = c.get("y", 0)
    arr[-1] = sol.get("id", 0)
    return arr

def array_to_solution(arr):
    ds = []
    for i in range(0, POINTS_COUNT * 2, 2):
        ds.append({"x": float(arr[i]), "y": float(arr[i+1])})
    sel = []
    for i in range(POINTS_COUNT * 2, SOLUTION_DIM - 1, 2):
        x_val = arr[i]
        y_val = arr[i+1]
        if x_val != 0 or y_val != 0:
            sel.append({"x": float(x_val), "y": float(y_val)})
    return {
        "id": float(arr[-1]),
        "mode": "voronoi",
        "dataSet": ds,
        "selectedCells": sel
    }

def get_fractional_part(x):
    return x - int(x)

def evaluate_solution(sol):
    solution_id = sol.get("id", 0)
    try:
        response = requests.post(
            f"{BASE_URL}/evaluate",
            json=sol,
            timeout=60
        )
        response.raise_for_status()
        data = response.json()
        fit = data.get("fitness", {})
        s = fit.get("speed_entropy", 0)
        c = fit.get("curvature_entropy", 0)
        g = fit.get("gaps_mean", 0)
        if not all(isinstance(x, (int, float)) for x in [s, c, g]):
            return solution_id, False, "Invalid fitness values", -9999, [0, 0]
        score = s + c - (0.01 * g)
        return solution_id, True, "", score, [s, g]
    except (requests.RequestException, ValueError) as e:
        return solution_id, False, str(e), -9999, [0, 0]

## Genetic operators

In [None]:
class CustomEmitter(EmitterBase):
    def __init__(self, archive, solution_dim, batch_size=ARCHIVE_DIM, bounds=None):
        super().__init__(archive, solution_dim=solution_dim, bounds=bounds)
        self.batch_size = batch_size
        self.iteration = 0

    def ask(self):
        self.iteration += 1
        print(f"Emitter.ask() called for iteration {self.iteration}")
        if self.iteration <= INIT_POPULATION:
            out = []
            for _ in range(self.batch_size):
                sol = generate_solution(self.iteration - 1)
                arr = solution_to_array(sol)
                if arr is not None:
                    out.append(arr)
                else:
                    out.append(np.full(SOLUTION_DIM, -9999))
            return np.array(out)
        else:
            if random.random() < 0.5:
                return self.mutate_solutions()
            else:
                return self.crossover_solutions()

    def mutate_solutions(self):
        print(f"Mutating solutions for iteration {self.iteration}")
        parents = self.archive.sample_elites(self.batch_size)
        out = []
        for i in range(self.batch_size):
            arr = parents["solution"][i]
            sol = array_to_solution(arr)
            try:
                response = requests.post(
                    f"{BASE_URL}/mutate",
                    json={
                        "individual": sol,
                        "intensityMutation": 10
                    },
                    timeout=60
                )
                response.raise_for_status()
                mutated = response.json().get("mutated", {})
                frac = get_fractional_part(sol["id"])
                mutated["id"] = self.iteration - 1 + frac
                mutated_arr = solution_to_array(mutated)
                if mutated_arr is not None:
                    out.append(mutated_arr)
                    print(f"Mutated ID={sol['id']} to ID={mutated['id']}")
                else:
                    out.append(np.full(SOLUTION_DIM, -9999))
            except requests.RequestException as e:
                print(f"Error mutating solution ID={sol['id']}: {e}")
                out.append(np.full(SOLUTION_DIM, -9999))
        return np.array(out)

    def crossover_solutions(self):
        print(f"Crossover solutions for iteration {self.iteration}")
        out = []
        for _ in range(self.batch_size // 2):
            try:
                while True:
                    parents = self.archive.sample_elites(2)
                    sol1 = array_to_solution(parents["solution"][0])
                    sol2 = array_to_solution(parents["solution"][1])
                    if sol1["id"] != sol2["id"]:
                        break
                response = requests.post(
                    f"{BASE_URL}/crossover",
                    json={
                        "mode": "voronoi",
                        "parent1": sol1,
                        "parent2": sol2
                    },
                    timeout=60
                )
                response.raise_for_status()
                offspring = response.json().get("offspring", {})
                f1 = get_fractional_part(sol1["id"])
                f2 = get_fractional_part(sol2["id"])
                frac = (f1 + f2) % 1
                child_id = self.iteration - 1 + frac
                child_sol = {
                    "id": child_id,
                    "mode": "voronoi",
                    "trackSize": len(offspring.get("sel", [])),
                    "dataSet": offspring.get("ds", []),
                    "selectedCells": offspring.get("sel", [])
                }
                child_arr = solution_to_array(child_sol)
                if child_arr is not None:
                    out.append(child_arr)
                    print(f"Crossover Parent1 ID={sol1['id']}, Parent2 ID={sol2['id']} => Child ID={child_id}")
                else:
                    out.append(np.full(SOLUTION_DIM, -9999))
            except requests.RequestException as e:
                print(f"Error during crossover: {e}")
                out.append(np.full(SOLUTION_DIM, -9999))
        return np.array(out)

## Illuminating search spaces by mapping elites


In [None]:
archive = GridArchive(
    solution_dim=SOLUTION_DIM,
    dims=[ARCHIVE_DIM, ARCHIVE_DIM],
    ranges=[(1,4), (0, 30)]
)
emitter = CustomEmitter(
    archive,
    solution_dim=SOLUTION_DIM,
    batch_size=INIT_POPULATION,
    bounds=[(0, 600)] * (SOLUTION_DIM - 1) + [(0, float('inf'))]
)
scheduler = Scheduler(archive, [emitter])

def run_map_elites(iters):
    global_best_score = -9999
    global_best_id = None
    for i in range(iters):
        print(f"=== Starting iteration {i+1} ===")
        try:
            sols = scheduler.ask()
            sol_dicts = [array_to_solution(s) for s in sols]
            results = client.map(evaluate_solution, sol_dicts)
            gathered = client.gather(results)
            
            obj_list = []
            meas_list = []
            failed_ids = []
            for res in gathered:
                sol_id, success, msg, score, measures = res
                if not success or not np.isfinite(score):
                    print(f"Warning: clamping {score} to -9999 for solution ID={sol_id}, reason: {msg}")
                    score = -9999
                    failed_ids.append(sol_id)
                else:
                    print(f"Solution ID={sol_id} evaluated with score={score:.2f}")
                    # Update global best if necessary
                    if score > global_best_score:
                        global_best_score = score
                        global_best_id = sol_id
                obj_list.append(score)
                meas_list.append(measures)
            
            scheduler.tell(obj_list, meas_list)
            
            batch_best = max(obj_list) if obj_list else -9999
            print(f"Iteration {i+1} ended. Best in batch = {batch_best:.2f}")
            if global_best_id is not None:
                print(f"Global Best Score so far: {global_best_score:.2f} (ID={global_best_id})")
            
            data = archive.data()
            if len(data) > 0:
                arch_obj = data["objective"]
                mean_val = np.mean(arch_obj)
                best_val = np.max(arch_obj)
                cov = archive.stats.coverage
                print(f"Archive size={len(archive)}, Coverage={cov:.3f}, Mean={mean_val:.2f}, Best={best_val:.2f}")
            else:
                print("Archive empty so far")
            
            if failed_ids:
                print(f"Failed evaluations for solution IDs: {failed_ids}")
            
            # Plot every 5 iterations
            if (i + 1) % 5 == 0:
                plt.figure(figsize=(8, 6))
                grid_archive_heatmap(archive)
                plt.title(f"Archive Heatmap - Iteration {i+1}")
                plt.xlabel("Speed Entropy")
                plt.ylabel("Mean Gaps")
                plt.savefig(f"archive_heatmap_iter_{i+1}.png")
                plt.close()
        
        except Exception as e:
            print(f"Error in iteration {i+1}: {e}")
            raise

In [None]:
run_map_elites(ITERATIONS)

## Visualize Results

In [None]:
print("All iterations complete.")
print(f"Final archive size={len(archive)}, Coverage={archive.stats.coverage:.3f}")
plt.figure(figsize=(6,5))
grid_archive_heatmap(archive)
plt.title("Final Archive Heatmap")
plt.xlabel("Speed Entropy")
plt.ylabel("Mean Gaps")
plt.savefig("final_archive_heatmap.png")
plt.show()