# Biomaker CA: End-to-end meta-evolution

In this colab we show how to perform end-to-end meta-evolution on a configuration.

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Imports

In [None]:
#@title install selforg package
# install the package locally
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=self_organising_systems&subdirectory=biomakerca
# activate the locally installed package (otherwise a runtime restart is required)
import pkg_resources
import importlib
# Reload the resources because we uninstalled and reinstalled some packages.
importlib.reload(pkg_resources)
pkg_resources.get_distribution("self_organising_systems").activate()

In [None]:
#@title imports & notebook utilities
from self_organising_systems.biomakerca import environments as evm
from self_organising_systems.biomakerca.agent_logic import BasicAgentLogic
from self_organising_systems.biomakerca.mutators import BasicMutator
from self_organising_systems.biomakerca.mutators import RandomlyAdaptiveMutator
from self_organising_systems.biomakerca.step_maker import step_env
from self_organising_systems.biomakerca.display_utils import zoom
from self_organising_systems.biomakerca.custom_ipython_display import display

import cv2
import numpy as np
import jax.random as jr
import jax.numpy as jp
from jax import vmap
from jax import jit
import jax
import time

import tqdm
import mediapy as media
from functools import partial
from IPython.display import clear_output
import matplotlib.pyplot as plt


def pad_text(img, text):
  font = cv2.FONT_HERSHEY_SIMPLEX
  orgin = (5, 15)
  fontScale = 0.5
  color = (0, 0, 0)
  thickness = 1

  # ensure to preserve even size (assumes the input size was even.
  new_h = img.shape[0]//15
  new_h = new_h if new_h % 2 == 0  else new_h + 1
  img = np.concatenate([np.ones([new_h, img.shape[1], img.shape[2]]), img], 0)
  img = cv2.putText(img, text, orgin, font, fontScale, color, thickness, cv2.LINE_AA)
  return img

## Select the configuration, the agent logic and the mutator

In [None]:
ec_id = "pestilence" #@param ['persistence', 'pestilence', 'collaboration', 'sideways']
env_width_type = "landscape" #@param ['wide', 'landscape', 'square', 'petri']
env_and_config = evm.get_env_and_config(ec_id, width_type=env_width_type)
st_env, config = env_and_config

agent_model = "extended" #@param ['minimal', 'extended']
agent_logic = BasicAgentLogic(config, minimal_net=agent_model=="minimal")

mutator_type = "basic" #@param ['basic', 'randomly_adaptive']
sd = 1e-2 if mutator_type == "basic" and agent_model == "basic" else 1e-3
mutator = (BasicMutator(sd=sd, change_perc=0.2) if mutator_type == "basic"
           else RandomlyAdaptiveMutator(init_sd=sd, change_perc=0.2))

## End-to-end meta-evolution

In [None]:
def count_agents_f(env, etd):
  return etd.is_agent_fn(env.type_grid).sum()

@partial(jit, static_argnames=["config", "agent_logic", "mutator", "n_steps", "n_max_programs"])
def evaluate_biome(key, st_env, config, agent_logic, mutator, n_steps,
                   init_program=None, n_max_programs=128):
  def body_f(i, carry):
    key, env, programs, tot_agents_n = carry
    ku, key = jr.split(key)

    env, programs = step_env(
        ku, env, config, agent_logic, programs, do_reproduction=True,
          mutate_programs=True, mutator=mutator)

    tot_agents_n += count_agents_f(env, config.etd)
    return key, env, programs, tot_agents_n

  if init_program is None:
    ku, key = jr.split(key)
    programs = vmap(agent_logic.initialize)(jr.split(ku, n_max_programs))
    ku, key = jr.split(key)
    programs = vmap(mutator.initialize)(jr.split(ku, programs.shape[0]), programs)
  else:
    programs = jp.repeat(init_program[None, :], n_max_programs, axis=0)

  key, env, programs, tot_agents_n = jax.lax.fori_loop(
      0, n_steps, body_f, (key, st_env, programs, 0))

  # check whether they got extinct:
  is_extinct = (count_agents_f(env, config.etd) == 0).astype(jp.int32)
  return tot_agents_n, is_extinct

In [None]:
from evojax.algo import PGPE

key = jr.PRNGKey(137)

N_MAX_PROGRAMS = 128

n_steps = 1000

# initialize params
ku, key = jr.split(key)
init_program = agent_logic.initialize(ku)
ku, key = jr.split(key)
init_program = mutator.initialize(ku, init_program)


pop_size = 32
ku, key = jr.split(key)
solver = PGPE(
    pop_size=pop_size,
    param_size=init_program.shape[0],
    optimizer='adam',
    center_learning_rate=0.001,
    stdev_learning_rate=0.001,
    stdev_max_change=0.002,
    seed=ku[0],
    init_params=init_program,
    init_stdev=0.001,
)

n_max_programs = 64 # less than usual, but it doesn't really matter.
death_penalty = 1e6

@jit
def v_fitness_f(key, v_params):
  ku, key = jr.split(key)
  tot_agents_n, num_deaths = vmap(partial(
      evaluate_biome, st_env=st_env, config=config, agent_logic=agent_logic,
      mutator=mutator, n_steps=n_steps, n_max_programs=n_max_programs))(
          jr.split(key, pop_size), init_program=v_params)
  fitness = tot_agents_n - num_deaths * death_penalty
  return fitness

mean_fit_log = []
max_fit_log = []

In [None]:

for _ in range(30):
  # sample
  programs = solver.ask()
  # eval
  key, k1 = jr.split(key)
  fitness = v_fitness_f(k1, programs)
  # update
  solver.tell(fitness)

  mean_fitness = fitness.mean()
  max_fitness = fitness.max()
  mean_fit_log.append(mean_fitness)
  max_fit_log.append(max_fitness)
  print(mean_fitness, max_fitness)
  if len(mean_fit_log) % 10 == 0:
    clear_output()
    plt.plot(mean_fit_log, label="mean_fitness")
    plt.plot(max_fit_log, label="max_fitness")
    plt.legend()
    plt.show()

## Evaluate the result

In [None]:
n_reps = 16
key = jr.PRNGKey(123)

t_st = time.time()
key, ku = jr.split(key)
b_tot_agents_n, b_is_extinct = jit(vmap(partial(
    evaluate_biome, st_env=st_env, config=config, agent_logic=agent_logic,
    mutator=mutator, n_steps=n_steps, init_program=solver.best_params)))(jr.split(ku, n_reps))
print("Took", time.time()-t_st, "seconds")
print("Total number of agents", b_tot_agents_n, b_tot_agents_n.mean(), b_tot_agents_n.std())
print("Extinction events", b_is_extinct, b_is_extinct.mean(), b_is_extinct.std())

## Show an example run of the result

Consider modifying the code to vary the extent of the simulation and video configs.

In [None]:

key = jr.PRNGKey(43)

# How many unique programs (organisms) are allowed in the simulation.
N_MAX_PROGRAMS = 128

# The number of frames of the video. This is NOT the number of steps.
# The total number of steps depend on the number of steps per frame, which can
# vary over time.
# In the article, we generally use 500 or 750 frames.
n_frames = 500

# on what FRAME to double speed.
when_to_double_speed = [100, 200, 300, 400, 500]
# on what FRAME to reset speed.
when_to_reset_speed = []
fps = 20
# this affects the size of the image. If this number is not even, the resulting
# video *may* not be supported by all renderers.
zoom_sz = 4

# get the program from the solver.
programs = jp.repeat(solver.best_params[None,:], N_MAX_PROGRAMS, axis=0)

env = st_env

def make_frame(env, step, speed):
  return pad_text(zoom(evm.grab_image_from_env(env, config),zoom_sz),
                  "Step {:<7} Speed: {}x".format(step, speed))

step = 0
# how many steps per frame we start with. This gets usually doubled many times
# during the simulation.
# In the article, we usually use 2 or 4 as the starting value, sometimes 1.
steps_per_frame = 2

frame = make_frame(env, step, steps_per_frame)

out_file = "video.mp4"
with media.VideoWriter(out_file, shape=frame.shape[:2], fps=fps) as video:
  video.add_image(frame)
  for i in tqdm.trange(n_frames):
    if i in when_to_double_speed:
      steps_per_frame *= 2
    if i in when_to_reset_speed:
      steps_per_frame = 1
    for j in range(steps_per_frame):
      step += 1
      key, ku = jr.split(key)
      env, programs = step_env(
          ku, env, config, agent_logic, programs, do_reproduction=True,
            mutate_programs=True, mutator=mutator)

    video.add_image(make_frame(env, step, steps_per_frame))

media.show_video(media.read_video(out_file))