# Biomaker CA: interactive evolution

In this colab we show how to perform interactive evolution on a configuration.

Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Imports

In [None]:
#@title install selforg package
# install the package locally
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=self_organising_systems&subdirectory=biomakerca
# activate the locally installed package (otherwise a runtime restart is required)
import pkg_resources
import importlib
# Reload the resources because we uninstalled and reinstalled some packages.
importlib.reload(pkg_resources)
pkg_resources.get_distribution("self_organising_systems").activate()

In [None]:
#@title imports & notebook utilities
from self_organising_systems.biomakerca import environments as evm
from self_organising_systems.biomakerca.agent_logic import BasicAgentLogic
from self_organising_systems.biomakerca.mutators import BasicMutator
from self_organising_systems.biomakerca.mutators import RandomlyAdaptiveMutator
from self_organising_systems.biomakerca.step_maker import step_env
from self_organising_systems.biomakerca.display_utils import zoom
from self_organising_systems.biomakerca.display_utils import tile2d
from self_organising_systems.biomakerca.custom_ipython_display import display

import cv2
import numpy as np
import jax.random as jr
import jax.numpy as jp
from jax import vmap
from jax import jit
import jax
import time

import tqdm
import mediapy as media
from IPython.display import clear_output
from functools import partial
from functools import wraps
import ipywidgets as widgets
from jax.tree_util import tree_map
import matplotlib.pyplot as plt


def pad_text(img, text):
  font = cv2.FONT_HERSHEY_SIMPLEX
  orgin = (5, 15)
  fontScale = 0.5
  color = (0, 0, 0)
  thickness = 1

  # ensure to preserve even size (assumes the input size was even.
  new_h = img.shape[0]//15
  new_h = new_h if new_h % 2 == 0  else new_h + 1
  img = np.concatenate([np.ones([new_h, img.shape[1], img.shape[2]]), img], 0)
  img = cv2.putText(img, text, orgin, font, fontScale, color, thickness, cv2.LINE_AA)
  return img

## Select the configuration, the agent logic and the mutator

In [None]:
ec_id = "pestilence" #@param ['persistence', 'pestilence', 'collaboration', 'sideways']
env_width_type = "petri"
env_and_config = evm.get_env_and_config(ec_id, width_type=env_width_type)
st_env, config = env_and_config

agent_model = "extended" #@param ['minimal', 'extended']
agent_logic = BasicAgentLogic(config, minimal_net=agent_model=="minimal")

mutator_type = "basic" #@param ['basic', 'randomly_adaptive']
sd = 1e-2 if mutator_type == "basic" and agent_model == "basic" else 1e-3
mutator = (BasicMutator(sd=sd, change_perc=0.2) if mutator_type == "basic"
           else RandomlyAdaptiveMutator(init_sd=sd, change_perc=0.2))

## Interactive evolution

In [None]:
from evojax.algo import PGPE

key = jr.PRNGKey(137)


N_CANDIDATES = 8

# This is what is required to count as a succesful reproduction. The value is
# taken from reproduce op initialization values of our basic agent logic.
min_repr_energy_requirement = (config.dissipation_per_step * 4) + config.specialize_cost * 2

# This is the number of steps performed. This value is good for pestilence,
# but certainly not for other configurations. So consider tweaking it.
n_steps = config.max_lifetime - 100

# frame per seconds on the simulation.
fps = 20


def yield_for_change(widget, attribute):
    def f(iterator):
        @wraps(iterator)
        def inner():
            i = iterator()
            def next_i(change):
                try:
                    i.send(change)
                except StopIteration as e:
                    for w in widget:
                        w.unobserve(next_i, attribute)
            for w in widget:
                w.on_click(next_i)
                w.observe(next_i, attribute)
            # start the generator
            next(i)
        return inner
    return f


ku, key = jr.split(key)
programs = vmap(agent_logic.initialize)(jr.split(ku, N_CANDIDATES))
ku, key = jr.split(key)
programs = vmap(mutator.initialize)(jr.split(ku, N_CANDIDATES), programs)
chosen_program = programs[0]

v_st_env = tree_map(lambda x: jp.repeat(x[None,:], N_CANDIDATES, axis=0), st_env)


@jit
def v_grab_images(v_env):
  def grab_one(env):
    return evm.grab_image_from_env(env, config)

  v_imgs = vmap(grab_one)(v_env)
  return v_imgs

@jit
def v_step_env(key, v_env, programs):
  # programs, for each call, need to be of the format [n,pars], in this case [1,pars] vmapped.
  return vmap(partial(
      step_env, config=config, agent_logic=agent_logic, do_reproduction=True,
      mutate_programs=False, intercept_reproduction=True,
      min_repr_energy_requirement=min_repr_energy_requirement))(jr.split(key, N_CANDIDATES), v_env, programs=programs[:,None])

def compose_mosaic(v_imgs, v_tot_n_successful_repr, selected_idx=None):
  v_pad_imgs = np.pad(
      v_imgs, ((0,0), (1,1), (1,1), (0,0)), mode="constant", constant_values=1.)
  # for each of them, write the number of reproductions.
  wrt_imgs = []
  for i,(pad_img, n_repr) in enumerate(zip(v_pad_imgs, v_tot_n_successful_repr)):
    if i == selected_idx:
      # highlight that value by coloring it Green
      clr = np.array([0., 1., 0.])
      pad_img[0,:] = clr
      pad_img[-1,:] = clr
      pad_img[:,0] = clr
      pad_img[:,-1] = clr

    wrt_imgs.append(pad_text(zoom(pad_img, 3), "N repr: %d" % n_repr))
  return tile2d(wrt_imgs, w=N_CANDIDATES//2)

# Store a lot of information to later make a nicer video.
# We want, when making the video, to highlight the selected offspring while it
# grows.
l_imgs_series = []
l_v_tot_n_successful_repr_series = []
selected_idx_series = []

button_list = []
for i in range(0,N_CANDIDATES):
  button = widgets.Button(description="agent %d" % (i+1))
  button_list.append(button)
buttons = widgets.VBox([
    widgets.HBox(button_list[:N_CANDIDATES//2]),
    widgets.HBox(button_list[N_CANDIDATES//2:])])

@yield_for_change(button_list, 'description')
def f():
  global v_env
  global v_tot_n_successful_repr
  global key
  global programs
  global chosen_program
  while True:
    # reset env
    v_env = v_st_env
    v_tot_n_successful_repr = jp.zeros(N_CANDIDATES, dtype=jp.int32)
    imgs = v_grab_images(v_env)
    frame = compose_mosaic(imgs, v_tot_n_successful_repr)
    l_imgs = [imgs]
    l_v_tot_n_successful_repr = [v_tot_n_successful_repr]
    video = [frame]

    for step in tqdm.trange(1, n_steps+1):
      # Important: step_env is setup to not require any mutation. This means
      # that it expects programs without mutation parameters. Hence, we
      # separate these params for step inputs.
      agent_params, _ = vmap(mutator.split_params)(programs)
      key, ku = jr.split(key)
      v_env, v_n_successful_repr = v_step_env(ku, v_env, agent_params)
      v_tot_n_successful_repr += v_n_successful_repr
      if step % 2 == 0:
        imgs = v_grab_images(v_env)
        frame = compose_mosaic(imgs, v_tot_n_successful_repr)
        video.append(frame)
        l_imgs.append(imgs)
        l_v_tot_n_successful_repr.append(v_tot_n_successful_repr)

    # show new video
    clear_output()
    media.show_video(video, fps=fps)

    display(buttons)

    x = yield

    # Get the selected program
    chosen_idx = int(x.description.split(" ")[-1])-1
    chosen_program = programs[chosen_idx]

    # add information to the lists for future video making.
    l_imgs_series.append(l_imgs)
    l_v_tot_n_successful_repr_series.append(l_v_tot_n_successful_repr)
    selected_idx_series.append(chosen_idx)

    # mutate!
    ku, key = jr.split(key)
    programs = jit(vmap(mutator.mutate))(
        jr.split(ku, N_CANDIDATES),
        jp.repeat(chosen_program[None,:], N_CANDIDATES, axis=0))

f()


## Create a video showing your selections

In [None]:
video = []
for l_imgs, l_v_tot_n_successful_repr, selected_idx in zip(
    l_imgs_series, l_v_tot_n_successful_repr_series, selected_idx_series):
  for i, (imgs, v_tot_n_successful_repr) in enumerate(zip(l_imgs, l_v_tot_n_successful_repr)):
    if i % 2 == 0:
      frame = compose_mosaic(imgs, v_tot_n_successful_repr, selected_idx)
      video.append(frame)

media.show_video(video, fps=fps)

## Evaluate the result

In [None]:
def count_agents_f(env, etd):
  return etd.is_agent_fn(env.type_grid).sum()

@partial(jit, static_argnames=["config", "agent_logic", "mutator", "n_steps", "n_max_programs"])
def evaluate_biome(key, st_env, config, agent_logic, mutator, n_steps,
                   init_program=None, n_max_programs=128):
  def body_f(i, carry):
    key, env, programs, tot_agents_n = carry
    ku, key = jr.split(key)

    env, programs = step_env(
        ku, env, config, agent_logic, programs, do_reproduction=True,
          mutate_programs=True, mutator=mutator)

    tot_agents_n += count_agents_f(env, config.etd)
    return key, env, programs, tot_agents_n

  if init_program is None:
    ku, key = jr.split(key)
    programs = vmap(agent_logic.initialize)(jr.split(ku, n_max_programs))
    ku, key = jr.split(key)
    programs = vmap(mutator.initialize)(jr.split(ku, programs.shape[0]), programs)
  else:
    programs = jp.repeat(init_program[None, :], n_max_programs, axis=0)

  key, env, programs, tot_agents_n = jax.lax.fori_loop(
      0, n_steps, body_f, (key, st_env, programs, 0))

  # check whether they got extinct:
  is_extinct = (count_agents_f(env, config.etd) == 0).astype(jp.int32)
  return tot_agents_n, is_extinct

In [None]:
n_reps = 16
key = jr.PRNGKey(123)

eval_env_width_type = "landscape"
eval_env_and_config = evm.get_env_and_config(ec_id, width_type=eval_env_width_type)
eval_st_env, _ = eval_env_and_config

eval_n_steps = 1000

t_st = time.time()
key, ku = jr.split(key)
b_tot_agents_n, b_is_extinct = jit(vmap(partial(
    evaluate_biome, st_env=eval_st_env, config=config, agent_logic=agent_logic,
    mutator=mutator, n_steps=eval_n_steps, init_program=chosen_program)))(jr.split(ku, n_reps))
print("Took", time.time()-t_st, "seconds")
print("Total number of agents", b_tot_agents_n, b_tot_agents_n.mean(), b_tot_agents_n.std())
print("Extinction events", b_is_extinct, b_is_extinct.mean(), b_is_extinct.std())

## Show an example run of the result

Consider modifying the code to vary the extent of the simulation and video configs.

In [None]:

key = jr.PRNGKey(43)

# How many unique programs (organisms) are allowed in the simulation.
N_MAX_PROGRAMS = 128

# The number of frames of the video. This is NOT the number of steps.
# The total number of steps depend on the number of steps per frame, which can
# vary over time.
# In the article, we generally use 500 or 750 frames.
n_frames = 500

# on what FRAME to double speed.
when_to_double_speed = [100, 200, 300, 400, 500]
# on what FRAME to reset speed.
when_to_reset_speed = []
fps = 20
# this affects the size of the image. If this number is not even, the resulting
# video *may* not be supported by all renderers.
zoom_sz = 4

# initialize with the chosen program
programs = jp.repeat(chosen_program[None,:], N_MAX_PROGRAMS, axis=0)

# we don't want to run it in petri.
eval_env_width_type = "landscape"
eval_env_and_config = evm.get_env_and_config(ec_id, width_type=eval_env_width_type)
eval_st_env, _ = eval_env_and_config

env = eval_st_env

def make_frame(env, step, speed):
  return pad_text(zoom(evm.grab_image_from_env(env, config),zoom_sz),
                  "Step {:<7} Speed: {}x".format(step, speed))

step = 0
# how many steps per frame we start with. This gets usually doubled many times
# during the simulation.
# In the article, we usually use 2 or 4 as the starting value, sometimes 1.
steps_per_frame = 2

frame = make_frame(env, step, steps_per_frame)

out_file = "video.mp4"
with media.VideoWriter(out_file, shape=frame.shape[:2], fps=fps) as video:
  video.add_image(frame)
  for i in tqdm.trange(n_frames):
    if i in when_to_double_speed:
      steps_per_frame *= 2
    if i in when_to_reset_speed:
      steps_per_frame = 1
    for j in range(steps_per_frame):
      step += 1
      key, ku = jr.split(key)
      env, programs = step_env(
          ku, env, config, agent_logic, programs, do_reproduction=True,
            mutate_programs=True, mutator=mutator)

    video.add_image(make_frame(env, step, steps_per_frame))

media.show_video(media.read_video(out_file))