This serves a testing ground for a simple SAM type optimizer implementation in JAX with existing apis.

In [None]:
import jax
import jax.numpy as np
import matplotlib.pyplot as plt
import optax
import flax
import chex
from optax.contrib import sam

One way to describe what SAM does is that it does some number of steps (usually 1) of adversarial updates, followed by an outer gradient update.

What this means is that we have to do a bunch of steps:


    #adversarial step
    params = params + sam_rho * normalize(gradient)

    #outer update step
    params = cache - learning_rate * gradient
    cache = params


To actually use SAM then, you create your adversarial optimizer, here SGD with normalized gradients, and then wrap it with SAM itself.

In [None]:
lr = 0.001
rho = 0.1
adv_opt = optax.chain(sam.normalize(), optax.sgd(rho))
opt = sam.sam(lr, adv_opt, sync_period=2)   # This is the drop-in SAM optimizer.

In [None]:
sgd_opt = optax.sgd(lr) # baseline comparison optimization

We'll set up a simple test problem below, we're going to try to optimize a sum of two exponentials that has two minima, one at (0,0) and another at (2,0) and compare the performance of both SAM and ordinary SGD.

In [None]:
# An example 2D loss function. It has two minima at (0,0) and (2,0).
# Both points attain almost zero loss value, but the first one is much sharper.

def loss(params):
  x, y = params
  return -np.exp(-(x - 2)**2 - y**2) - 1.0*np.exp(-((x)**2 + (y)**2*100))

In [None]:
params = np.array([-0.4, -0.4])

@chex.dataclass
class Store:
  params: chex.Array
  state: optax.OptState
  step: int = 0

store = Store(params=params, state=opt.init(params))
sgd_store = Store(params=params, state=sgd_opt.init(params))

In [None]:
def make_step(opt):
  @jax.jit
  def step(store):
    value, grads = jax.value_and_grad(loss)(store.params)
    updates, state = opt.update(grads, store.state, store.params)
    params = optax.apply_updates(store.params, updates)
    return store.replace(
        params=params,
        state=state,
        step=store.step+1), value
  return step

In [None]:
step = make_step(opt)
sgd_step = make_step(sgd_opt)

In [None]:
vals = []
params = []
sgd_vals = []
sgd_params = []

In [None]:
T = 1000
for i in range(T):
  for j in range(100):
    store, val = step(store);
    sgd_store, sgd_val = sgd_step(sgd_store);
  vals.append(val)
  sgd_vals.append(sgd_val)
  params.append(store.params)
  sgd_params.append(sgd_store.params)

In [None]:
ts = np.arange(T)
fig, axs = plt.subplots(2)
axs[0].plot(ts, vals, label='SAM', lw=3)
axs[0].plot(ts, sgd_vals, label='SGD')
axs[0].legend();
axs[1].plot(ts / 2, vals, label='1/2 SAM', lw=3)
axs[1].plot(ts, sgd_vals, label='SGD')
axs[1].legend();

In [None]:
plt.plot(*np.array(params).T, label='SAM')
plt.plot(*np.array(sgd_params).T, label='SGD')
plt.legend(loc=4);

As you can see, the SAM optimizer finds the correct optimum, while SGD gets stuck in the local optimum.