In [1]:
# imports
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from tree_ops import *
from util import *

# Simulated annealing
Inspird by the annealing process in physics, simulated annealing aims to improve on basic hill climbing (random mutation, accept if improvement) by introducing temperature-dependent randomness. It can be considered a stochastic variant of hill climbing.

1. Find candidate solution $\tilde{x}$, for example in $\epsilon$-ball around $x_t$.
2. If $E_{t+1} < E_{t}$ accept: $x_{i+t} = \tilde{x}$, <br>
   else accept with probability $e^{-\Delta E / T}$ or retain old position $x_{t+1} = x_t$.
3. Repeat and adapt temperature $T$.

With schedule $T(t) = \frac{T_0}{\log 1 + t}$ optimal solution is found with probability $1$ for $t \rightarrow \infty$.

In [2]:
# Local random change
def _epsilon_ball(key, position, temperature, eps=0.5):
    return position + jax.random.uniform(key, minval=-eps, maxval=eps)

# One step of simulated annealing
def sim_annealing_step(key, position, energy, temperature, energy_fn, local_mutation=_epsilon_ball):
    k1, k2 = jax.random.split(key)
    
    # choose candidate new position
    new_position = tree_multimap_rand(local_mutation, k1, position, temperature=temperature)
    
    # evaluate new energy
    new_energy = energy_fn(new_position)
    
    # likelihood to accept
    p = jnp.exp((energy - new_energy) / temperature)
    
    return jax.lax.cond(
        # condition
        jax.random.bernoulli(k2, p),
        # accept if true
        lambda _: (new_position, new_energy),
        # else reject
        lambda _: (position, energy),
        None)

In [3]:
def sin_energy(x):
    return x**2/1.5 + jnp.sin(x*7)

In [4]:
# fix energy function
step = partial(sim_annealing_step, energy_fn=sin_energy)

# make it be applicable to multiple points simultaneously
step = jax.vmap(step, in_axes=(0, 0, 0, None))
# compile for speed up
step = jax.jit(step)

In [5]:
# fix random number generator seed
rns = PRNGSequence(42)



In [6]:
temp0 = 2
number_steps = 200

# initial values
xs = jax.random.uniform(next(rns), (20,), minval=-5, maxval=5)
energy = sin_energy(xs)

# set up plotting
fig = plt.figure()
ax = plt.axes(xlim=(-5, 5), ylim=(-1, 17.5))
_xs = jnp.linspace(-5, 5, 200)
ax.plot(_xs, sin_energy(_xs))
ax.set_xticks([]); ax.set_yticks([])
fig.tight_layout()
pos, = ax.plot([], [], 'o')

def init():
    pos.set_data([], [])
    return pos,

def animate(i):
    # perform simulated annealing step
    global xs, energy
    temp = temp0 / jnp.log(1 + i)
    keys = jax.random.split(next(rns), len(xs))
    xs, energy = step(keys, xs, energy, temp)    
    
    pos.set_data(xs, energy)
    return pos,

anim = FuncAnimation(fig, animate, init_func=init, frames=number_steps, interval=100, blit=True)
anim.save('figs/sim-annealing.gif')
plt.close(fig)

![annealing](figs/sim-annealing.gif)