## Bayesian Optimisation (GPyopt) applied to Hodgkin-Huxley model

In [None]:
%matplotlib inline
from __future__ import division
from GPyOpt.methods import BayesianOptimization
from lfmods.hh import HHSimulator

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

## Load observed data, pilot means and pilot stds

In [None]:
# load SNPE data
%run -i NIPS2017.ipynb
prefix ='hh_1000_iwloss_svi_cyth_seed1'

dists, infos, losses, nets, posteriors, sims = io.load_prefix(dirs['dir_nets_hh'], prefix)
sim = io.last(sims)
obs_stats = sim.obs
y_obs = sim.obs_trace.reshape(-1,1)
gt = sim.true_params
labels_params = sim.labels_params
param_invtransform = sim.param_invtransform
param_transform = sim.param_transform
gt_transf = param_transform(gt)
n_params = sim.n_params
prior_min = sim.prior_min
prior_max = sim.prior_max
unlog_prior_min = sim.param_invtransform(prior_min)
unlog_prior_max = sim.param_invtransform(prior_max)
prior_lims = np.concatenate((prior_min.reshape(-1,1),prior_max.reshape(-1,1)),axis=1)
unlog_prior_lims = sim.param_invtransform(prior_lims)
t = sim.t
I = sim.I_obs
A_soma = sim.A_soma
dt = sim.dt
duration = np.max(t)
posterior = io.nth(posteriors, -1)
bm = sim.bm
init = sim.init

n_summary_stats = sim.n_summary_stats

m, S = posterior.calc_mean_and_cov()
diff_params = np.divide(param_invtransform(m)-gt,gt)
err_params_up_norm = np.divide( (param_invtransform(m+np.sqrt(np.diag(S))) - param_invtransform(m)) , gt )
err_params_down_norm = np.divide( param_invtransform(m) - param_invtransform(m-np.sqrt(np.diag(S))) , gt )

sim_step = io.last(infos)['n_samples']
num_round = len(infos)

###############################################################################
# pilot means and standard deviations
pilot_means = sim.pilot_means
pilot_stds = sim.pilot_stds

## Set up the HH simulator

In [None]:
sim_1 = HHSimulator(seed=101,pilot_samples=0.,cached_sims=False, cached_pilot=False)

## Set up the objective to minimise and respective parameter bounds

In [None]:
bounds = []
for i in range(n_params):
    bounds.append((unlog_prior_min[i],unlog_prior_max[i]))

def f_hh(params):
    # simulation
    params = np.array(params)
    hh = sim_1.bm.HH(-70.,params.reshape(1,-1))
    states = hh.sim_time(dt, t, I).reshape(1, -1, 1)

    # summary statistics
    sum_stats = sim_1.calc_summary_stats(states)
    sum_stats -= pilot_means
    sum_stats /= pilot_stds
    
    return sim_1.calc_dist(sum_stats, obs_stats)

# GPyOpt object with model and acquisition function
Bopt_hh = BayesianOptimization(f=f_hh,                  # function to optimize       
                              bounds=bounds,            # box-constraints of problem
                              model_type = 'GP',
                              acquisition_type='EI',
                              exact_feval = True)       # Selects the Expected improvement

## Run GPyopt

In [None]:
max_iter = 300     # evaluation budget
# max_iter = int(num_round*sim_step)     # evaluation budget
# max_time = 60     # time budget 
# eps      = 10e-6  # Minimum allowed distance between the last two observations

# Bopt_hh.run_optimization(max_iter, max_time, eps, verbosity=True)
Bopt_hh.run_optimization(max_iter)

## Plot original and fitted traces

In [None]:
fig = plt.figure()

gs = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
ax = plt.subplot(gs[0])
plt.plot(t, y_obs, color = COL['GT'], lw=2, label='observation')

params = np.concatenate((np.array([param_invtransform(m)]),np.array([Bopt_hh.x_opt])))

V = np.zeros((len(t),2))
for i in range(2):
    hh = bm.HH(init, params[i,:].reshape(1,-1),seed=230+i)
    V[:,i] = hh.sim_time(dt, t, I)[:,0]

# plotting simulation
plt.plot(t, V[:, 0], color=COL['SNPE'], lw=2, label='SNPE')
plt.plot(t, V[:, 1], color='g', lw=2, label='GPyOpt')

plt.ylabel('voltage (mV)')
plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')

ax.set_xticks([])
ax.set_yticks([-80, -20, 40])

ax = plt.subplot(gs[1])
plt.plot(t,I*A_soma*1e3,color = COL['EFREE'], lw=2)
plt.xlabel('time (ms)')
plt.ylabel('input (nA)')

ax.set_xticks([0, duration/2, duration])
ax.set_yticks([0, 1.1*np.max(I*A_soma*1e3)])
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.2f'))

plt.show()

## Plot mean and variance of loss function (TODO: plot acquisition function)

In [None]:
def model_eval(model, x, N, ii=[0]):
    """Evaluates loss function for specified parameters and directions

    Parameters
    ----------
    lfun : object
    x: rows are inputs to evaluate at.
    N: int
    ii: a list of indices specifying which marginal to evaluate.

    Returns
    -------
    loss function
    """
    y = np.zeros((len(x),n_params))
    j = 0
    for i in range(N):
        if i in ii:
            y[:,i] = x[:,j]
            j = j+1
        else:
            y[:,i] = Bopt_hh.x_opt[i]
    
    return model.predict(y)

def plot_lfun(lfun, acqu, n_params, lims, gt=None,
        resolution=500, labels_params=None, ticks = False, diag_only=False, diag_only_cols=4,
        diag_only_rows=4, figsize=(5,5), fontscale=1, partial=False):
    """Plots loss function, for each variable and pair of variables.

    Parameters
    ----------
    lfun : object
    acqu : object
    n_params: int
    lims : array
    gt : array
    resolution: int
    labels_params : array of strings
    ticks: bool
        If True, includes ticks in plots
    diag_only : bool
    diag_only_cols : int
        Number of grid columns if only the diagonal is plotted
    diag_only_rows : int
        Number of grid rows if only the diagonal is plotted
    fontscale: int
    partial: bool
        If True, plots partial posterior with at the most 3 parameters.
        Only available if `diag_only` is False
    """
    
    lims = np.asarray(lims)

    if n_params == 1:

        fig, ax = plt.subplots(1, 1, facecolor='white', figsize=figsize)

        xx = np.linspace(lims[0,0], lims[0,1], resolution)
#         acqu = acqu(xx.reshape(-1,1))
        pp, pp_v = model_eval(l_fun,xx.reshape(-1,1),1,0)
        ax.plot(xx, pp, color=COL['SNPE'])
        plt.fill_between(xx,pp[:,0]-2*np.sqrt(pp_v[:,0]),pp[:,0]+2*np.sqrt(pp_v[:,0]), 
                     facecolor=COL['SNPE'], 
                     alpha=0.3)
        ax.set_xlim(lims[0])
        ax.set_ylim([0, ax.get_ylim()[1]])
        if gt is not None: ax.vlines(gt, 0, ax.get_ylim()[1], color='r')
            
        if ticks:
            ax.get_yaxis().set_tick_params(which='both', direction='out')
            ax.get_xaxis().set_tick_params(which='both', direction='out')
            ax.set_xticks(np.linspace(lims[0, 0], lims[0, 1],2))
            ax.set_yticks(np.linspace(min(pp),max(pp),2))
            ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
            ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
        else:
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])

    else:

        if not diag_only:           
            if partial:
                rows = min(3, n_params)
                cols = min(3, n_params)
            else:
                rows = n_params*1
                cols = n_params*1
        else:
            cols = diag_only_cols
            rows = diag_only_rows
            r = 0
            c = -1

        fig, ax = plt.subplots(rows, cols, facecolor='white', figsize=figsize)
        ax = ax.reshape(rows, cols)

        for i in range(rows):
            for j in range(cols):
                
                if i == j:
                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
#                     xx = xx.reshape(-1,1)
#                     acqu = acqu(xx.reshape(-1,1))
                    pp, pp_v = model_eval(lfun,xx.reshape(-1,1), rows, ii=[i])
    
                    if diag_only:
                        c+=1
                    else:
                        r = i
                        c = j

                    ax[r, c].plot(xx, pp[:,0], color=COL['SNPE'])
                    plt.fill_between(xx,pp[:,0]-2*np.sqrt(pp_v[:,0]),pp[:,0]+2*np.sqrt(pp_v[:,0]),
                                     facecolor=COL['SNPE'],
                                     alpha=0.3)
                    ax[r, c].set_xlim(lims[i])
                    ax[r, c].set_ylim([0, ax[r, c].get_ylim()[1]])                  
                    
                    if gt is not None: ax[r, c].vlines(gt[i], 0, ax[r, c].get_ylim()[1], color='r')
                    
                    if ticks:
                        ax[r, c].get_yaxis().set_tick_params(which='both', direction='out',
                                                             labelsize = fontscale*20)
                        ax[r, c].get_xaxis().set_tick_params(which='both', direction='out',
                                                             labelsize = fontscale*20)
#                         ax[r, c].locator_params(nbins=3)
                        ax[r, c].set_xticks(np.linspace(lims[i, 0], lims[j, 1],2))
                        ax[r, c].set_yticks(np.linspace(min(pp),max(pp),2))
                        ax[r, c].xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
                        ax[r, c].yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
                    else:
                        ax[r, c].get_xaxis().set_ticks([])
                        ax[r, c].get_yaxis().set_ticks([])
                        
                    if labels_params is not None:
                        ax[r, c].set_xlabel(labels_params[i], fontsize=fontscale*15)
                    else:
                        ax[r, c].set_xlabel([])
                    
                    x0,x1 = ax[r, c].get_xlim()
                    y0,y1 = ax[r, c].get_ylim()
                    ax[r, c].set_aspect((x1-x0)/(y1-y0))
                    
                    
                    if partial and i == rows-1:
                        ax[i, j].text(x1+(x1-x0)/6., (y0+y1)/2., '...', fontsize=fontscale*25)
                        plt.text(x1+(x1-x0)/8.4, y0-(y1-y0)/6., '...', fontsize=fontscale*25,rotation=-45)

                else:
                    if diag_only:
                        continue

                    if i>j:
                        ax[i, j].get_yaxis().set_visible(False)
                        ax[i, j].get_xaxis().set_visible(False)
                        ax[i,j].set_axis_off()
                        continue

                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
                    yy = np.linspace(lims[j ,0], lims[j, 1], resolution)
                    X, Y = np.meshgrid(xx, yy)
                    xy = np.concatenate([X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
                    pp, pp_v = model_eval(lfun,xy, rows, ii=[i, j])
                    pp = pp.reshape(list(X.shape))
                    ax[i, j].imshow(pp,origin='lower', cmap = cmaps.parula,
                                        extent=[lims[i, 0],lims[i, 1],lims[j ,0],lims[j ,1]],
                                        aspect='auto', interpolation='none')
                    ax[i, j].set_xlim(lims[i])
                    ax[i, j].set_ylim(lims[j])

                    if gt is not None: ax[i, j].plot(gt[i], gt[j], 'r.', ms=10, markeredgewidth=0.0)
                        
                    ax[i, j].get_xaxis().set_ticks([])
                    ax[i, j].get_yaxis().set_ticks([])
                    ax[i,j].set_axis_off()
                        
                    x0,x1 = ax[i, j].get_xlim()
                    y0,y1 = ax[i, j].get_ylim()
                    ax[i, j].set_aspect((x1-x0)/(y1-y0))
                    
                    if partial and j == cols-1:
                        ax[i, j].text(x1+(x1-x0)/6., (y0+y1)/2., '...', fontsize=fontscale*25)

                        
                if diag_only and c == cols-1:
                    c = -1
                    r += 1
 
    return

In [None]:
fig = plt.figure()

plot_lfun(Bopt_hh.model.model,
          Bopt_hh.acquisition.acquisition_function,
          n_params,
          lims=bounds,
          gt=gt_transf,
          resolution=100,
          labels_params=labels_params,
          partial=True)

## Plot differences between true and estimated parameters

In [None]:
# SNPE
m, S = posterior.calc_mean_and_cov()
diff_params = np.abs( np.divide( m-gt_transf,np.sqrt(np.diag(S)) ) )

# GPyopt
diff_params_gpyopt = np.abs( np.divide( param_transform(Bopt_hh.x_opt)-gt_transf,np.sqrt(np.diag(S)) ) )

sort_indices = np.flip(np.argsort(diff_params_gpyopt), axis=0)
labels_params_sort = [LABELS_HH[i] for i in sort_indices]

min_m = np.min([np.min(diff_params),
                np.min(diff_params_gpyopt)])
max_m = np.max([np.max(diff_params),
                np.max(diff_params_gpyopt)])

   
fig = plt.figure(figsize=(10,5))

# histogram of differences
width = 0.3
plt.bar(np.linspace(0,n_params-1,n_params),
            diff_params[sort_indices],width, color=COL['SNPE'], align='center',
            edgecolor = 'none', label='SNPE mean')
plt.bar(np.linspace(0,n_params-1,n_params)+1*width,
        diff_params_gpyopt[sort_indices],width,color=COL['IBEA'], align='center',
        edgecolor = 'none', label='GPyopt')
plt.ylabel(r'||$\theta$ - $\theta^*$|| / $\sigma_{\theta}$')

ax = plt.gca()
plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')
ax.set_xlim(-1.5*width,n_params+width/2)
ax.set_xticks(np.linspace(0,n_params-1,n_params)+width/2)
ax.set_ylim([0, max_m])
ax.set_yticks(np.linspace(0,max_m,3))
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
ax.set_xticklabels(labels_params_sort)

## Plot differences between observed features and features of modes

In [None]:
# SNPE
hh = bm.HH(init, param_invtransform(m).reshape(1,-1))
states = hh.sim_time(dt, t, I)[:,0]
mode_sum_stats = sim.calc_summary_stats(states.reshape(1, -1, 1))[0,:]

diff_sum_stats = np.abs(mode_sum_stats-obs_stats[0,:])


# GPyopt
hh = bm.HH(init, Bopt_hh.x_opt.reshape(1,-1))
states = hh.sim_time(dt, t, I)[:,0]
sum_stats_gpyopt = sim.calc_summary_stats(states.reshape(1, -1, 1))[0,:]

diff_sum_stats_gpyopt = np.abs(sum_stats_gpyopt-obs_stats[0,:])


# re-order
sort_indices = np.flip(np.argsort(diff_sum_stats_gpyopt), axis=0)

min_m = np.min([np.min(diff_sum_stats),
                np.min(diff_sum_stats_gpyopt)])
max_m = np.max([np.max(diff_sum_stats),
                np.max(diff_sum_stats_gpyopt)])

fig = plt.figure(figsize=(10,5))

# histogram of differences
width = 0.3
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),
        diff_sum_stats[sort_indices],width, color=COL['SNPE'], align='center',
        edgecolor = 'none', label='SNPE mean')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+1*width,
        diff_sum_stats_gpyopt[sort_indices],width,color=COL['IBEA'], align='center',
        edgecolor = 'none', label='GPyopt')
plt.ylabel(r'||f - $f^*$||')

ax = plt.gca()
plt.legend(bbox_to_anchor=(1.15, 1), loc='upper right')
ax.set_xlim(-1.5*width,n_summary_stats+width/2)
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats)+width/2)
ax.set_ylim([0, max_m])
ax.set_yticks(np.linspace(0,max_m,3))
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
ax.set_xticklabels(np.linspace(1,n_summary_stats,n_summary_stats).astype(int))