In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import flax

from functools import partial
from util import *
from tree_ops import *
from matplotlib.animation import FuncAnimation

# Particle swarm optimization
PSO is inspired by the behaviour of a swarm of animals, such as a flock of birds.

Throughout PSO multiple state variables are updated
- particle positions $x_i$ and velocities $v_i$
- personal best per particle $\hat{x}_i$
- global best $\hat{g}$

The update procedure aims to create a dynamics that balances between exploring the search space (driven by randomness and personal best) and exploit around the global optimum. For each particle:
1. Sample random vectors $r_1$, $r_2$ uniformly in $[0, 1]^N$
2. Update velocities $$v_i \mapsto \omega v_i + \alpha_1 r_1 \cdot (\hat{x}_i - x_i) + \alpha_2 r_2 (\hat{g} - x_i)$$
3. Update positions $x_i \mapsto x_i + v_i$
4. Evaluate fitness & update personal and global best

In [2]:
@flax.struct.dataclass
class PSOState:
    """Represent the state of PSO."""
    pos: object
    vel: object
    
    # fitness of position
    fit: jnp.ndarray = None

    # personal best
    pbest: object = None
    pbest_fit: jnp.ndarray = None  # vector
        
    # global best
    gbest: object = None
    gbest_fit: jnp.ndarray = None  # scalar

    @property
    def pos_list(self):
        """Positions as a list of objects."""
        return unstack_tree(self.pos)
    
    @property
    def pbest_list(self):
        """Personal bests as list."""
        return unstack_tree(self.treedef.unflatten(self.pbest))
    
    @classmethod
    def from_stack(cls, positions, velocities=None):
        """Initialize using stacked objects.
        
        If velocities is None, set all to 0.
        """
        if velocities is None:
            velocities = jax.tree_map(lambda x: 0 * x, positions)
        return cls(positions, velocities)
    
    @classmethod
    def from_list(cls, positions, velocities=None):
        """Initialize using lists of objects.
        
        If velocities is None, set all to 0.
        """
        positions = stack_trees(positions)
        if velocities is not None:
            velocities = stack_trees(velocities)
        return cls.from_stack(positions, velocities)
    

def _pick_fitter(f1, f2, o1, o2):
    """Return the larger fitness and corresponding object."""
    return jax.lax.cond(f1 > f2, lambda _: (f1, o1), lambda _: (f2, o2), None)

def _vel_basic(key, pos, vel, pbest, gbest, omega, alpha1, alpha2):
    """Update velocitiy vectors."""
    r1, r2 = jax.random.uniform(key, (2, *pos.shape))
    
    return omega * vel \
         + alpha1 * r1 * (pbest - pos) \
         + alpha2 * r2 * (gbest[None, :] - pos)

def update_vel_basic(key, state, omega=0.7, alpha1=2, alpha2=2):
    """PSO update of velocities in state object."""
    update = partial(_vel_basic, omega=omega, alpha1=alpha1, alpha2=alpha2)
    vel = tree_multimap_rand(update, key, state.pos, state.vel, state.pbest, state.gbest)
    return state.replace(vel=vel)

def update_state(state: PSOState, fitness_fn):
    """Update fitnesses and bests."""
    fit = jax.vmap(fitness_fn)(state.pos)
    
    # select between personal-best and current for new personal best
    pbest_fit, pbest = jax.vmap(_pick_fitter)(state.pbest_fit, fit, state.pbest, state.pos)
    
    # find current best
    cbest_idx = jnp.argmax(fit)
    cbest_fit = fit[cbest_idx]
    cbest = tree_index(state.pos, cbest_idx)
    
    # update global best
    gbest_fit, gbest = _pick_fitter(state.gbest_fit, cbest_fit, state.gbest, cbest)
    
    return state.replace(fit=fit, 
                         pbest=pbest, pbest_fit=pbest_fit, 
                         gbest=gbest, gbest_fit=gbest_fit)
    
    
def init_state(state: PSOState, fitness_fn):
    """Initialize state by setting bests and fitnesses."""
    fit = jax.vmap(fitness_fn)(state.pos)
    
    # find current best
    best_idx = jnp.argmax(fit)
    best_fit = fit[best_idx]
    best = tree_index(state.pos, best_idx)
    
    return state.replace(fit=fit, 
                         pbest=state.pos, pbest_fit=fit, 
                         gbest=best, gbest_fit=best_fit)
    
def pso_step(key, state: PSOState, fitness_fn, update_vel=update_vel_basic):
    if state.fit is None:
        state = init_state(state, fitness_fn)
        
    state = update_vel(key, state)
    pos = jax.tree_multimap(jnp.add, state.pos, state.vel)
    state = state.replace(pos=pos)
    
    state = update_state(state, fitness_fn)
    return state

In [3]:
rns = PRNGSequence(0)



In [4]:
def sin_fit(x):
    return -(jnp.sin(x[0]*2) - x[1])**2 - x[0]**2

In [5]:
step = partial(pso_step, fitness_fn=sin_fit)
step = jax.jit(step)

In [6]:
initial = [jax.random.normal(k, (2,)) * 5 for k in jax.random.split(next(rns), 10)]
state = PSOState.from_list(initial)
state = init_state(state, sin_fit)

In [7]:
number_steps = 100

# initial values
initial = [jax.random.normal(k, (2,)) * 5 for k in jax.random.split(next(rns), 10)]
state = PSOState.from_list(initial)
state = init_state(state, sin_fit)

# set up plotting
fig = plt.figure(figsize=(8,8))
ax = plt.axes(xlim=(-7, 7), ylim=(-7, 7))
ax.imshow(sin_fit(jnp.stack(jnp.mgrid[-7:7:.1, -7:7:.1])), alpha=0.6, extent=(-7, 7, -7, 7))
ax.set_xticks([]); ax.set_yticks([])
fig.tight_layout()

pos_plt, = plt.plot(*state.pos.T, 'o')


def init():
    global state
    initial = [jax.random.normal(k, (2,)) * 5 for k in jax.random.split(next(rns), 10)]
    state = PSOState.from_list(initial)
    state = init_state(state, sin_fit)
    pos_plt.set_data(*state.pos.T)
    return pos_plt,
    

def animate(i):
    # perform evolution step
    global state
    state = step(next(rns), state)
    pos_plt.set_data(*state.pos.T)
    return pos_plt,

anim = FuncAnimation(fig, animate, init_func=init, frames=number_steps, interval=100, blit=True)
anim.save('figs/pso.gif')
plt.close(fig)

![pso](figs/pso.gif)