In [1]:
# imports
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from tree_ops import *
from util import *

from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import numpy as np

# Evolutionary algorithms
Multiple types of evolution-inspired algorithms may be distinguished :
- **Genetic algorithms:** Iteratively evolve a population of "individuals", each made up of (binary) strings <br>
    optimization search on strings or binary arrays
- **Evolution strategies:** Evolve populations of real vectors <br>
    optimization search on real valued vectors 
- **Genetic programming:** Evolution of computer program or network architectures <br>
    optimization of computer programs represented as trees
- **Evolutionary programming:** Evolve the parameters of a program <br>
    optimization search on real vetors representing parameters of computer programs

Generally:
1. Create offspring by mutation, crossover
2. Evaluate fitness
3. Select fittest individuals

## Evolution strategies
ES consist of exploring the search space using populations of real vectors.
Populations are generated using a covariance matrix $C$, whose value is iteratively adapted, based on which individuals perform best.

### Example
Given a mean $m$ (state), generate offspring by $x_i = m + z_i$ with $z_i \sim \mathcal{G}(0, C)$.

- Generate $x_i = m + z_i$
- Select $\lambda$ best individuals $x_i$ as offspring
- New state by averaging $$m_{t+1} = \frac{1}{\lambda} \sum_i x_i = m_t + \frac{1}{\lambda} \sum_i z_i\,,$$ or by selecting best $x_i$
- Update $C \mapsto (1-\epsilon)C + \epsilon Z$ where $Z$ is the covariance among the offspring
  $$Z = \frac{1}{\lambda} \sum_i z_i z_i^T \,.$$ 
  
Various [other algorithms](https://en.wikipedia.org/wiki/CMA-ES) are possible to adapt the magnitude and correlation of the mutation noise.
  
### Nested evolution strategies
One can construct more complex strategies for evolving the population.
For example, one may have a collection of sub-populations, for each of which a child population is sampled and independently evolved for a number of generations.
From these, new founding populations may be selected.

In [2]:
def evo_strat_populate(key, mean, cov, fitness_fn, num_children, num_select, eps=0.1):
    """Generate child population from parent, i.e. (1, num_chilren)."""
    # generate z_i
    delta = jax.random.multivariate_normal(key, jnp.zeros(len(cov)), cov, shape=(num_children,))
    
    # children
    children = mean[None, :] + delta
    
    # evaluate fitness
    fitness = jax.vmap(fitness_fn)(children)
    
    # select top
    best = jnp.argsort(fitness)[-num_select:]
    children = children[best]
    
    # update cov matrix
    cov = (1-eps) * cov + eps * jnp.cov(delta[best], rowvar=False)
    
    return children, cov

In [3]:
# fix random number generator seed
rns = PRNGSequence(0)



In [4]:
def sin_fit(x):
    return -x[0]**2 - x[1]**2
def sin_fit(x):
    return -(jnp.sin(x[0]*2) - x[1])**2 - x[0]**2

In [5]:
step = partial(evo_strat_populate, fitness_fn=sin_fit, num_children=20, num_select=10)
step = jax.jit(step)

In [6]:
number_steps = 100

# initial values
cov = jnp.eye(2) * 0.2
mean = jnp.array([-6, 5])

# set up plotting
fig = plt.figure(figsize=(8,8))
ax = plt.axes(xlim=(-7, 7), ylim=(-7, 7))
ax.imshow(sin_fit(jnp.stack(jnp.mgrid[-7:7:.1, -7:7:.1])), alpha=0.6, extent=(-7, 7, -7, 7))
ax.set_xticks([]); ax.set_yticks([])
fig.tight_layout()
ell = ax.add_patch(Ellipse((0, 0), 1, 1, edgecolor='red', facecolor='none'))
pop_plt, = ax.plot([], [], '.')
mean_plt, = ax.plot([], [], 'o')

# update ellipse
def update_ellipse(ax, ell, mean, cov):
    pearson = cov[0, 1]/np.sqrt(cov[0, 0] * cov[1, 1])
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ell.set_width(ell_radius_x * 2)
    ell.set_height(ell_radius_y * 2)
    scale_x = np.sqrt(cov[0, 0])
    scale_y = np.sqrt(cov[1, 1])
    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean[0], mean[1])
    ell.set_transform(transf + ax.transData)
    return ell,

def init():
    mean_plt.set_data(*mean.T)
    pop_plt.set_data([], [])
    update_ellipse(ax, ell, mean, cov)
    return pop_plt, mean_plt, ell

def animate(i):
    # perform evolution step
    global mean, cov
    population, cov = step(next(rns), mean, cov)
    mean = jnp.mean(population, axis=0)
    
    pop_plt.set_data(*population.T)
    mean_plt.set_data(*mean)
    update_ellipse(ax, ell, mean, cov)
    return pop_plt, mean_plt, ell

anim = FuncAnimation(fig, animate, init_func=init, frames=number_steps, interval=100, blit=True)
anim.save('figs/evo-strategy.gif')
plt.close(fig)

![evo](figs/evo-strategy.gif)

## Genetic Algorithms
Similar to evolutionary strategies, genetic algorithms aim to solve an optimization problems by artificially evolving populations of candidate solutions.
This relies again on mutation and selection, but in addition may incorporate crossover.
Individuals (chromosomes) are commonly represented as bit arrays or strings.

General algorithm:
1. Evaluate fitness of pouplation
2. Rank individuals and select from best (may be sampling with replacement)
3. Mix individuals / "breed"; crossover (cut strings of chosen individuals and glue back together)
4. Mutate (random noise)

![crossover](https://upload.wikimedia.org/wikipedia/commons/d/dd/Computational.science.Genetic.algorithm.Crossover.One.Point.svg)

Source: [wikimedia](https://upload.wikimedia.org/wikipedia/commons/d/dd/Computational.science.Genetic.algorithm.Crossover.One.Point.svg)

### Examples
- Any discrete problem can be represented using a byte string
- Chromosome may not directly be selected for, but rather its "phenotype"
    - E.g. evolve architecture of neural network. To evaluate fitness, for each individual run machine learning algorithm. Network behaviour constitutes the phenotype.

In [7]:
@jax.jit
def roulette_wheel_positions(sizes, positions, epsilon=1e-6):
    squashed = jnp.empty_like(positions)
    s0 = 0
    for i, s in enumerate(sizes):
        squashed = jnp.where((positions >= s0) & (positions < s0 + s), i, squashed)
        s0 += s
    return jnp.floor(squashed).astype(jnp.int32)


def select_roulette(key, population, fitness):
    """Make a selection corresponding to a single spin of a roulette wheel."""
    sizes = fitness / jnp.sum(fitness)

    bin_width = 1 / len(population)
    pos = jnp.arange(0, 1, bin_width) + jax.random.uniform(key, maxval=bin_width)
    
    idc = roulette_wheel_positions(sizes, pos)
    if isinstance(population, list):
        return [population[i] for i in idc]
    return population[idc]


def _cross(args):
    key, chrom1, chrom2 = args
    pos = jax.random.uniform(key, minval=0, maxval=len(chrom1))
    idc = jnp.arange(len(chrom1)) < pos
    c1 = jnp.where(idc, chrom1, chrom2)
    c2 = jnp.where(idc, chrom2, chrom1)
    return c1, c2
    

def _reprod(args):
    key, chrom1, chrom2 = args
    return chrom1, chrom2


def cross(key, chrom1, chrom2, p_cross):
    k1, k2 = jax.random.split(key)
    return jax.lax.cond(
        jax.random.bernoulli(k1, p_cross),
        _cross,
        _reprod,
        (k2, chrom1, chrom2)
    )


def mutate(key, chrom, p_mut):
    mut = jax.random.bernoulli(key, p_mut, (len(chrom),))
    return chrom ^ mut  # bitwise xor


def ga_step(key, pop, p_mut, p_cross, fitness_fn):
    k1, k2, k3 = jax.random.split(key, 3)
    
    # evaluate
    fitness = jax.vmap(fitness_fn)(pop)
    
    # select
    pop = select_roulette(k1, pop, fitness)
    
    # crossover (note: may want to mix population for more randomness)
    pop1, pop2 = pop.reshape((2, len(pop)//2, -1))
    keys = jax.random.split(k2, len(pop1))
    pop1, pop2 = jax.vmap(cross, (0, 0, 0, None))(keys, pop1, pop2, p_cross)
    pop = jnp.concatenate((pop1, pop2))
    
    # mutate
    keys = jax.random.split(k3, len(pop))
    pop = jax.vmap(mutate, (0, 0, None))(keys, pop, p_mut)
    
    return pop

In [8]:
def fit_ones(chrom):
    return jnp.sum(chrom)  # count number of ones

In [9]:
step = partial(ga_step, fitness_fn=fit_ones, p_mut=0.005, p_cross=0.3)
step = jax.jit(step)

In [10]:
number_steps = 350

# initial values
pop = jax.random.bernoulli(next(rns), 0.1, (8, 20))

# set up plotting
fig = plt.figure(figsize=(10, 5))
ax = plt.axes()
pop_plt = ax.imshow(pop)
ax.set_xticks([]); ax.set_yticks([])
fig.tight_layout()

def init():
    pop_plt.set_data(pop)
    return pop_plt,

def animate(i):
    # perform evolution step
    global pop
    pop = step(next(rns), pop)
    pop_plt.set_data(pop)
    return pop_plt,

anim = FuncAnimation(fig, animate, init_func=init, frames=number_steps, interval=40, blit=True)
anim.save('figs/genetic-alg.gif')
plt.close()

![genetic-algorithms](figs/genetic-alg.gif)