In [1]:
%run 'GPDM_direct_fixedpoints.ipynb'


Matplotlib is building the font cache using fc-list. This may take a moment.



In [2]:
import scipy.optimize
import pickle
import pickle, datetime, time

# 2D k wells example

In [3]:
def kwells_2D_true_tr_fnc(xprev_vec, well_locs, Sigma_eps, well_width):
    out = np.zeros_like(xprev_vec)
    for i in range(xprev_vec.shape[1]):
        xprev=xprev_vec[:,i]
        
        in_well = False
        for well_loc in well_locs:
            if np.sqrt(np.sum((xprev-well_loc)**2)) < np.sqrt(Sigma_eps)*well_width:
                out[:,i] = well_loc
                in_well = True
                
        if in_well == False:
            out[:,i] = xprev

    return out

In [4]:
def kwells_2D_draw_trial(well_locs, T, Sigma_eps, Sigma_nu, mu_0_0, Sigma_0_0, well_width):
    x0 = mu_0_0 + np.sqrt(Sigma_0_0)*np.random.randn(2,)
    x = np.zeros((2,T))
    y = np.zeros((2,T))
    for t in range(T):
        if t==0:
            xprev = x0
        else:
            xprev = x[:,t-1]

        x[:,t:(t+1)] = (kwells_2D_true_tr_fnc(np.atleast_2d(np.squeeze(xprev)).T, well_locs, Sigma_eps, well_width)
                    + np.sqrt(Sigma_eps)*np.random.randn(2,1))
        
        y[:,t:(t+1)] = x[:,t:(t+1)] + np.sqrt(Sigma_nu)*np.random.randn(2,1)
        
    return (x,y)    

In [5]:
np.random.seed(1234)
Ny = 10
well_locs = [np.array([1.0, 1.0]), np.array([-1.0, -1.0])]
T = 50
Sigma_eps = 1e-1
Sigma_nu = 1e-2
mu_0_0 = np.array([0.0, 0.0])
Sigma_0_0 = 1e-1*np.array([0.5e-0, 0.5e-0])
well_width = 3

kwells_params = OrderedDict()
kwells_params["Ny"] = Ny
kwells_params["well_locs"] = well_locs
kwells_params["T"] = T
kwells_params["Sigma_eps"] = Sigma_eps
kwells_params["Sigma_nu"] = Sigma_nu
kwells_params["mu_0_0"] = mu_0_0
kwells_params["Sigma_0_0"] = Sigma_0_0
kwells_params["well_width"] = well_width

# Collect trials into runnable object
all_trials_x = []
all_trials_y = []
for n in range(Ny):
    x,y = kwells_2D_draw_trial(well_locs, T, Sigma_eps, Sigma_nu, mu_0_0, Sigma_0_0, well_width)
    
    all_trials_x.append(x[:,:,None])
    all_trials_y.append(y[:,:,None])


x = np.concatenate(all_trials_x, axis=2)
y = np.concatenate(all_trials_y, axis=2)

plots_by_run = []
for v in range(Ny):
    plots_by_run.append(
        plt_type.Scatter(x=np.squeeze(y[0,:,v]), 
                      y=np.squeeze(y[1,:,v]), 
                      mode='lines')
    )
    
plt(plots_by_run)

In [6]:
y.shape

(2, 50, 10)

In [9]:
Sigma_eps

array([[ 0.1]])

In [8]:
# Try to solve the full problem (or some parts of it)
D = 2
Nz = 64
Ns = 2

(Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, lengthscales, kernel_variance, s, J) = \
    init_params(y, D, Nz, Ns)

(init_paramvec, dict_ind, dict_shape) = params_to_vec(Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, 
                                                      lengthscales, kernel_variance, s, J)

# Optimise only certain elements of paramvec
opt_params = np.arange(init_paramvec.shape[0]-Ns*D*D) # All but J      #np.concatenate([dict_ind['s']]) 
cur_pvec = init_paramvec[opt_params]
# # If want to do all:
# opt_params = np.arange(init_paramvec.shape[0])
# cur_pvec = init_paramvec

def replace_params(pvec, opt_params, paramvec):
    inds = (np.arange(paramvec.shape[0]))
    inds = np.setdiff1d(inds,opt_params)
    out = np.concatenate([pvec, paramvec[inds]])
    out = out[np.argsort(np.concatenate([opt_params, inds]))]
    return out

In [None]:
# # Plot smoothed trajectories
# obj, x_t1, x_t, sig_t1, sig_t = time_full_iter(init_paramvec, y, D, Nz, Ns, ret_smoothed=True)


plots_by_run = []
for v in [2]:
#     plots_by_run.append(
#         plt_type.Scatter(x=np.squeeze(np.arange(T)), 
#                       y=np.squeeze(y[:,:,v]), 
#                       mode='lines')
#     )
    plots_by_run.append(
        plt_type.Scatter(x=np.squeeze(x[0,:,v]), 
                      y=np.squeeze(x[1,:,v]), 
                      mode='lines')
    )
    plots_by_run.append(
        plt_type.Scatter(x=np.squeeze(x_t[0,:,v]), 
                      y=np.squeeze(x_t[1,:,v]), 
                      mode='lines')
    )
    
    plots_by_run.append(
        plt_type.Scatter(x=np.squeeze(x_t1[0,:,v]), 
                      y=np.squeeze(x_t1[1,:,v]), 
                      mode='lines')
    )
    
plt(plots_by_run)

In [None]:
# Define a plotting function for callback that shows the current transition function estimate
import plotly.figure_factory

def kwells_2D_callback_plot(pvec_partial):
    
    pvec = replace_params(pvec_partial, opt_params, init_paramvec)
    (Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, lengthscales, kernel_variance, s, J) = \
        vec_to_params(pvec, dict_ind, dict_shape)

    # Plot transition function
    xtmp, ytmp = np.meshgrid(np.arange(np.min(z[0,:])-0.2,np.max(z[0,:])+0.2,0.2),
                             np.arange(np.min(z[1,:])-0.2,np.max(z[1,:])+0.2,0.2))
    xstar = np.concatenate([xtmp.flatten()[:,None], ytmp.flatten()[:,None]], axis=1).T

    fp_get_static_K, fp_predict = create_fp_gp_funcs()
    L, targets, params = fp_get_static_K(eta=kernel_variance, lengthscales=lengthscales, z=z, u=u, s=s, J=J, sig_eps=Sigma_eps)
    mu_star, sig_star, K_pred = fp_predict(xstar, L, targets, params)

    #print(time_full_iter(pvec, y, D, Nz, Ns)[0])
    
    quiver_fig = plotly.figure_factory.create_quiver(np.squeeze(xstar[0,:]), 
                                                     np.squeeze(xstar[1,:]), 
                                                     np.squeeze(mu_star[0,:]-xstar[0,:]), 
                                                     np.squeeze(mu_star[1,:]-xstar[1,:]),
                                                    scale=.25,
                                                    arrow_scale=.4,)
    
    # Add points to figure
    
    # True well locations
    quiver_fig['data'].append(
        plt_type.Scatter(x=np.atleast_1d(np.squeeze(np.array(kwells_params["well_locs"]).T[0,:])), 
                         y=np.atleast_1d(np.squeeze(np.array(kwells_params["well_locs"]).T[1,:])), 
                         mode='markers', name="Well loc", marker=dict(size=14)))
    
    # Inducing point locations
    quiver_fig['data'].append(
        plt_type.Scatter(x=np.atleast_1d(np.squeeze(z[0,:])), y=np.atleast_1d(np.squeeze(z[1,:])), 
                         mode='markers', name="Inducing loc", marker=dict(size=10)))
    
    # Estimated fixed points
    quiver_fig['data'].append(
        plt_type.Scatter(x=np.atleast_1d(np.squeeze(s[0,:])), y=np.atleast_1d(np.squeeze(s[1,:])), 
                         mode='markers', name="Fixed loc", marker=dict(size=14)))
    
    
        
    
    plt(quiver_fig)

In [None]:
tmp_func = lambda pvec_partial: (time_full_iter(replace_params(pvec_partial, opt_params, init_paramvec), 
                                                y, dict_ind, dict_shape)[0])
objective_with_grad = value_and_grad(tmp_func, argnum=0)


# Add bounds to ensure positive variances
bnds = list(((None, None),) * init_paramvec.shape[0])
for i in np.concatenate([dict_ind['Sigma_0_0'], dict_ind['Sigma_nu'], dict_ind['Sigma_eps'], dict_ind['Sigma_u']]):
    bnds[i] = (1e-3, None)
bnds_final = []
for i in opt_params:
    bnds_final.append(bnds[i])
bnds = tuple(bnds_final)
    
    
# By iterating minimize within a for cycle, we can save all intermediate results and set ending times
save_fname = "well_2d_k2_" + datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + ".pkl"
init_time = time.time()
max_time = 4.0*3600 # Maximum iteration time in seconds, break if reached
all_results = []
for it in range(100):
    result = scipy.optimize.minimize(objective_with_grad, cur_pvec, jac=True, method='L-BFGS-B', bounds=bnds, callback=None,
                          options={'maxiter':1, 'disp':True})
    all_results.append(result)
    # Save the results
    with open(save_fname, 'wb') as f:
        pickle.dump([kwells_params, x, y, all_results, init_paramvec, dict_ind, dict_shape, opt_params], f)
    
    cur_pvec = result.x
    cur_time = time.time()
    print([it, cur_time - init_time, result.fun])
    
    kwells_2D_callback_plot(cur_pvec)
    
    # Exit if maximum time is reached
    
    if ((cur_time - init_time) > max_time):
        print(["Maximum iteration time reached at iter", it])
        break

In [None]:
plots_by_run = []
for v in range(Ny):
    plots_by_run.append(
        plt_type.Scatter(x=np.squeeze(y[0,:,v]), 
                      y=np.squeeze(y[1,:,v]), 
                      mode='lines')
    )
    
plt(plots_by_run)

In [None]:
kwells_2D_callback_plot(all_results[3].x)

In [None]:
kwells_2D_callback_plot(all_results[25].x)