# Test the Adaptive Normal proposal with the Parallel Tempered Sampler

In this notebook we will test the adaptive normal proposal on a 2D Gaussian distribution with mean $\bar{x} = 2, \bar{y} = 5$ and variance $\sigma_x^2 = 1$, $\sigma_y^2 = 2$; $\sigma^2_{xy} = \sigma^2_{yx} = 0$, using a prior that is uniform over $x, y \in [-20, 20)$. For this we will use the parallel tempered sampler.

In [1]:
%matplotlib notebook
from __future__ import print_function
from matplotlib import pyplot
import numpy
import randomgen

import epsie
from epsie import make_betas_ladder
from epsie.samplers import ParallelTemperedSampler
from epsie.proposals import AdaptiveNormal
import multiprocessing

## Create the model to sample

***Note:*** Below we create a class with several functions to draw samples from the prior and to evaluate the log posterior. This isn't strictly necessary. The only thing the Sampler really requires is a function that it can pass keyword arguments to and get back a tuple of (log likelihood, log prior). However, setting things up as a class will make it convenient to, e.g., draw random samples from the prior for the starting positiions, as well as plot the model later on.

In [2]:
from scipy import stats
class Model(object):
    def __init__(self):
        # we'll use a 2D Gaussian for the likelihood distribution
        self.params = ['x', 'y']
        self.mean = [2., 5.]
        self.cov = [[1., 0.], [0., 2.]]
        self.likelihood_dist = stats.multivariate_normal(mean=self.mean,
                                                         cov=self.cov)

        # we'll just use a uniform prior
        self.prior_bounds = {'x': (-20., 20.),
                             'y': (-20., 20.)}
        xmin = self.prior_bounds['x'][0]
        dx = self.prior_bounds['x'][1] - xmin
        ymin = self.prior_bounds['y'][0]
        dy = self.prior_bounds['y'][1] - ymin
        self.prior_dist = {'x': stats.uniform(xmin, dx),
                           'y': stats.uniform(ymin, dy)}

    def prior_rvs(self, size=None, shape=None):
        return {p: self.prior_dist[p].rvs(size=size).reshape(shape)
                for p in self.params}
    
    def logprior(self, **kwargs):
        return sum([self.prior_dist[p].logpdf(kwargs[p]) for p in self.params])
    
    def loglikelihood(self, **kwargs):
        return self.likelihood_dist.logpdf([kwargs[p] for p in self.params])
    
    def __call__(self, **kwargs):
        logp = self.logprior(**kwargs)
        if logp == -numpy.inf:
            logl = None
        else:
            logl = self.loglikelihood(**kwargs)
        return logl, logp

In [3]:
model = Model()

## Setup the proposal

We'll setup the adaptive normal proposal to run for 100 iterations.

In [4]:
adaptation_duration = 100
prior_widths = {p: abs(bnds[1] - bnds[0]) for p, bnds in model.prior_bounds.items()}
proposal = AdaptiveNormal(model.params, prior_widths, adaptation_duration=adaptation_duration)

## Setup and run the sampler

Create a pool of 4 parallel processes, then initialize the sampler using the model we created above.

In [5]:
nchains = 12
ntemps = 3
nprocs = 4
pool = None #multiprocessing.Pool(nprocs)

betas = make_betas_ladder(ntemps, 1e5)
sampler = ParallelTemperedSampler(model.params, model, nchains, proposals={('x', 'y'): proposal},
                                  betas=betas, pool=pool)

Now set the starting positions of the chains by drawing random variates from the model's prior.

In [6]:
sampler.start_position = model.prior_rvs(size=nchains*ntemps, shape=(nchains, ntemps))

### Let's run it!

This will evolve each chain in the collection by 200 steps. This is parallelized over the pool of processes.

In [7]:
sampler.run(200)

Let's check how the covariance matrix has changed. We started with a covariance matrix of:

In [8]:
print(proposal.cov)

[[2.7576 0.    ]
 [0.     2.7576]]


Now we have:

In [9]:
# the current covariance of each
for ci, ptchain in enumerate(sampler.chains):
    print('==== Chain {} ===='.format(ci))
    for tk, c in enumerate(ptchain.chains):
        print("Temp {}".format(tk))
        print(c.proposal_dist.proposals[sampler.parameters].cov)

==== Chain 0 ====
Temp 0
[[9.34816076 0.        ]
 [0.         9.34816076]]
Temp 1
[[24.99254884  0.        ]
 [ 0.         24.99254884]]
Temp 2
[[26.25613502  0.        ]
 [ 0.         26.25613502]]
==== Chain 1 ====
Temp 0
[[10.33276945  0.        ]
 [ 0.         10.33276945]]
Temp 1
[[15.86801324  0.        ]
 [ 0.         15.86801324]]
Temp 2
[[24.15541635  0.        ]
 [ 0.         24.15541635]]
==== Chain 2 ====
Temp 0
[[7.74528665 0.        ]
 [0.         7.74528665]]
Temp 1
[[26.39783548  0.        ]
 [ 0.         26.39783548]]
Temp 2
[[24.62014241  0.        ]
 [ 0.         24.62014241]]
==== Chain 3 ====
Temp 0
[[7.94067528 0.        ]
 [0.         7.94067528]]
Temp 1
[[21.5858315  0.       ]
 [ 0.        21.5858315]]
Temp 2
[[21.43959603  0.        ]
 [ 0.         21.43959603]]
==== Chain 4 ====
Temp 0
[[4.9144089 0.       ]
 [0.        4.9144089]]
Temp 1
[[21.591275  0.      ]
 [ 0.       21.591275]]
Temp 2
[[19.81446126  0.        ]
 [ 0.         19.81446126]]
==== Chain 5

## Resume from a state

The sampler can be checkpointed by getting its current state with `sampler.state`. Let's check that this still works with the `AdaptiveNormal` proposal. To demonstrate this, we'll get the current state of the sampler, then run it for another set of iterations. We'll then create a new sampler, and set it's state to the state we obtained from first sampler. Running the same sampler for the same number of iterations should produce the same results.

In [10]:
# get the current state
state = sampler.state

In [11]:
# now advance the sampler for another 250 iterations
sampler.run(250)

In [12]:
# create a new sampler, but set it's state to what the original sampler's was after the first 250 iterations
sampler2 = ParallelTemperedSampler(model.params, model, nchains, proposals={('x', 'y'): proposal},
                                   betas=betas, pool=pool)
sampler2.set_state(state)

In [13]:
# now advance the new sampler for 250 iterations
# note that we don't have to run set_start first, since the starting positions have been set by set_start
sampler2.run(250)

In [14]:
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
      (sampler.acceptance['acceptance_ratio'][:,:,-1] == sampler2.acceptance['acceptance_ratio'][:,:,-1]).all())
print('accepted:',
      (sampler.acceptance['accepted'][:,:,-1] == sampler2.acceptance['accepted'][:,:,-1]).all())

x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True


## Clearing memory and continuing

The history of results in memory can be cleared using `.clear()`. Running the sampler after a clear yields the same results as if no clear had been done. This is useful for keeping memory usage down: you can dump results to a file after some number of iterations, clear, then continue.

To demonstrate this, we'll clear `sampler2`, then run both `sampler` and `sampler2` for another 250 iterations. We'll then compare the current results; they should be the same.

In [15]:
sampler2.clear()

In [16]:
sampler.run(250)
sampler2.run(250)

In [17]:
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
      (sampler.acceptance['acceptance_ratio'][:,:,-1] == sampler2.acceptance['acceptance_ratio'][:,:,-1]).all())
print('accepted:',
      (sampler.acceptance['accepted'][:,:,-1] == sampler2.acceptance['accepted'][:,:,-1]).all())

x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True


## Plot acceptance rates

We'll plot the acceptance rate for each chain, which we define here as the number of times a proposal was accepted divided by the total number of iterations. We expect this to be close to ~0.23 for the coldest chain, as this was the target rate of the `AdaptiveNormal` proposal that we used.

In [18]:
acceptance = sampler.acceptance
arate = acceptance['accepted'].sum(axis=2)/float(acceptance.shape[-1])
aratio = acceptance['acceptance_ratio']
# limit to 1
aratio[aratio > 1] = 1.
aratio = aratio.mean(axis=2)

In [19]:
# plot
fig, ax = pyplot.subplots()
for tk in range(ntemps):
    ax.scatter(range(nchains), arate[tk,:], label='temp {}'.format(tk))
    ax.axhline(arate[tk,:].mean(), color='C{}'.format(tk), linestyle='--')
ax.legend()
ax.set_ylabel('mean acceptance rate')
ax.set_xlabel('chain index')
fig.show()

<IPython.core.display.Javascript object>

In [20]:
print("Average acceptance rate over all chains:", arate.mean(axis=1))

Average acceptance rate over all chains: [0.25357143 0.80488095 0.81619048]


Indeed, the average acceptance rate over all of the coldest chains is close to 0.23.

Let's also plot the average acceptance ratio. This should be approximately the same as the acceptance rate.

In [21]:
# plot
fig, ax = pyplot.subplots()
for tk in range(ntemps):
    ax.scatter(range(nchains), aratio[tk,:], label='temp {}'.format(tk))
    ax.axhline(aratio[tk,:].mean(), color='C{}'.format(tk), linestyle='--')
ax.set_ylabel('mean acceptance ratio')
ax.set_xlabel('chain index')
fig.show()

<IPython.core.display.Javascript object>

In [22]:
print("Average acceptance ratio over all chains:", aratio.mean(axis=1))

Average acceptance ratio over all chains: [0.25240023 0.80323891 0.81627927]


## Plot temperature swaps

Information about the temperature swaps is stored by the `temperature_swaps` array. This consists of two fields, `acceptance_ratio` and `swap_index`. The former gives the acceptance ratio that was used to determine whether or not to swap the positions of the temperatures. The latter gives the indices of the chains that were swapped. These have shape `ntemps-1 x nchains x niterations`. The first dimension is `ntemps-1` because the acceptance ratio invovles two temperatures. For the `swap_index`, only the indices of the coldest `ntemps-1` swaps are given; the last one can be inferred from the rest.

For example, say `ntemps = 3` and the `swap_index` for one of the chains and iteration was:

`[2, 0]`

This means that the position of the hottest chain (= 2) was swapped all the way down to the coldest chain (= 0), bumping up the colder positions; i.e., $\mathbf{x}_2 \rightarrow \mathbf{x}_0$; $\mathbf{x}_0 \rightarrow \mathbf{x}_1$; $\mathbf{x}_1 \rightarrow \mathbf{x}_2$. Likewise, a swap index `[0, 2]` means $\mathbf{x}_2 \leftrightarrow \mathbf{x}_1$ while $\mathbf{x}_0$ remained in place, and `[1,0]` means that $\mathbf{x}_1 \leftrightarrow \mathbf{x}_0$, while $\mathbf{x}_2$ remained in place.

Note that you will never see a swap index like `[2, 1]` or `[1,2]`; i.e., `0` will always be in either the first or second slot. This is because the swaps always progress from the hottest chain down to the coldest. So while a hotter position may move down more than one level in single swap, a colder position may at most only ever move up one level.

In [23]:
temperature_swaps = sampler.temperature_swaps

In [24]:
# we'll plot the swap index of the first chain;
# first, add the missing swap index back in for plotting
swap_index = temperature_swaps['swap_index'][:, 1, :]
lastindex = numpy.repeat(2, sampler.niterations).astype(int)
lastindex[swap_index[1, :] != 1] = 1
swap_index = numpy.stack([swap_index[0,:], swap_index[1,:], lastindex])

fig, ax = pyplot.subplots()
# we'll just plot the first 40 iterations
pltiters = 40
x = numpy.array([[ii, ii+0.5] for ii in range(pltiters)]).flatten()
for tk in range(sampler.ntemps):
    y = numpy.array([[swap_index[tk, ii], tk] for ii in range(pltiters)]).flatten()
    ax.plot(x, y, color='C{}'.format(tk))
#ax.legend()
ax.set_ylim(-0.25, 2.25)
ax.set_yticks([0, 1, 2])
ax.set_xlabel('iteration')
ax.set_ylabel('temperature level')
fig.show()

<IPython.core.display.Javascript object>

We see that the first and second level are swapping on nearly every iteration, but coldest barely swaps at all. This suggests that our temperature ladder spacing was too large. Let's check the acceptance ratios.

In [25]:
fig, ax = pyplot.subplots()
for tk in range(sampler.ntemps-1):
    # cap acceptance ratio at 1
    ar = temperature_swaps['acceptance_ratio'][tk, 1, :].copy()
    ar[ar > 1] = 1.
    ax.plot(ar, label='$A_{%i %i}$'%(tk, tk+1))
ax.legend()
ax.set_xlabel('iteration')
ax.set_ylabel('acceptance ratio')
fig.show()

<IPython.core.display.Javascript object>

## Create an animation of the results

To visualize the results, we'll create an animation showing how the chains evolved. We'll do this by plotting one point for each chain, with each frame in the animation representing a single iteration.

***Note: To keep file size down, the animation has not been created for the version of this notebook uploaded to the repository.***

In [None]:
from matplotlib import animation

In [None]:
# Prepare an array to create a density map showing the shape of the model posterior
npts = 100
xmean, ymean = model.likelihood_dist.mean
xsig = model.likelihood_dist.cov[0,0]**0.5
ysig = model.likelihood_dist.cov[1,1]**0.5
X, Y = numpy.mgrid[xmean-3*xsig:xmean+3*xsig:complex(0, npts),
                   ymean-3*ysig:ymean+3*ysig:complex(0, npts)]
Z = numpy.zeros(X.shape)
for ii in range(Z.shape[0]):
    for jj in range(Z.shape[1]):
        logl, logp = model(x=X[ii,jj], y=Y[ii,jj])
        Z[ii, jj] = numpy.exp(logl+logp)

In [None]:
# we'll just animate the first 200 iterations; change this to
# nframes = xdata.shape[1] if you want to see all iterations
nframes = 200

In [None]:
fig, ax = pyplot.subplots()

positions = sampler.positions[0,...]
xdata = positions['x']
ydata = positions['y']

# Plot contours showing the shape of the true posterior density
#ax.contour(X, Y, Z, 2, colors='k', linewidths=1, linestyles='dashed', zorder=-2)
ax.imshow(numpy.rot90(Z), extent=[X.min(), X.max(), Y.min(), Y.max()],
          aspect='auto', cmap='binary', zorder=-3)

# Put an x at the maximum posterior point
ax.scatter(model.mean[0], model.mean[1], marker='x', color='w', s=10, zorder=-2)
ax.set_xlabel('x')
ax.set_ylabel('y')
# create the scatter points
ptsize = 60

# we'll include the last bufferlen number of steps a chain visited, having the size and transparency
# exponentially damped with each new frame
bufferlen = 16
alphas = numpy.exp(-4*(numpy.arange(bufferlen))/float(bufferlen))
sizes = ptsize * alphas
#colors = numpy.array(['C{}'.format(ii) for ii in range(nchains)])
colors = numpy.arange(nchains)
plts = [ax.scatter(xdata[:, bufferlen-ii-1], ydata[:, bufferlen-ii-1], c=colors, s=sizes[ii],
                   edgecolors='w', linewidths=0.5,
                   alpha=alphas[ii], zorder=bufferlen-ii, marker='s' if ii==0 else 'o', cmap='jet')
        for ii in range(bufferlen)]
# put a + showing the average of the chain positions at the current iteration
meanplt = ax.scatter(xdata[:,0].mean(), ydata[:,0].mean(), marker='P', c='w', edgecolors='k', linewidths=0.5,
                     zorder=bufferlen+1)

# add some text giving the iteration
itertxt = 'Iteration {}'
txt = ax.annotate(itertxt.format(1), (0.03, 0.94), xycoords='axes fraction')

def animate(ii):
    txt.set_text(itertxt.format(ii+1))
    for jj,plt in enumerate(plts):
        plt.set_offsets(numpy.array([xdata[:, max(ii-jj, 0)], ydata[:, max(ii-jj, 0)]]).T)
    meanplt.set_offsets([xdata[:,ii].mean(), ydata[:,ii].mean()])
    # zoom in as it narrows on the result
    istart = max(ii-bufferlen, 0)
    # smooth it out a bit
    xmin = numpy.array([xdata[:, max(istart-kk, 0):].min() for kk in range(50)]).mean()
    xmax = numpy.array([xdata[:, max(istart-kk, 0):].max() for kk in range(50)]).mean()
    ymin = numpy.array([ydata[:, max(istart-kk, 0):].min() for kk in range(50)]).mean()
    ymax = numpy.array([ydata[:, max(istart-kk, 0):].max() for kk in range(50)]).mean()
    ax.set_xlim((1.1 if xmin < 1 else 0.9)*xmin, (0.9 if xmax < 1 else 1.1)*xmax)
    ax.set_ylim((1.1 if ymin < 1 else 0.9)*ymin, (0.9 if ymax < 1 else 1.1)*ymax)


ani = animation.FuncAnimation(fig, animate, frames=nframes, interval=160, blit=True)

Save the animation:

In [None]:
ani.save('pt_chain_animation.mp4')

The result:

In [None]:
%%HTML
<video width="640" height="480" controls>
  <source src="pt_chain_animation.mp4" type="video/mp4">
</video>