In [None]:
import nbimporter
# Import main algorithm
from GPDM_direct_fixedpoints import *

# %run GPDM_direct_fixedpoints.ipynb

In [None]:
# Saving outputs and timing
import pickle, datetime, time

# Generate example data

# 1D, 2 well example (drift diffusion process)


In [None]:
def kwells_true_tr_fnc(xprev_vec, well_locs, Sigma_eps, well_width, slope=1.0, transform_func=(lambda x: x**3), **kwargs):
    def well_shape(a):
        """ Function from [0-1] to [0-1], shape of the well"""
        return a**2
    xprev_vec = np.atleast_2d(xprev_vec)
    out = np.zeros_like(xprev_vec)
    for i in range(np.prod(xprev_vec.shape)):
        xprev=np.asscalar(xprev_vec[:,i])
        
        in_well = False
        for well_loc in well_locs:
            if ((xprev>(well_loc - np.sqrt(Sigma_eps)*well_width)) and (xprev<(well_loc + np.sqrt(Sigma_eps)*well_width))):
                rel_dist = transform_func(np.abs(xprev-well_loc)/(np.sqrt(Sigma_eps)*well_width)) # Between 0 and 1, interpolate
                out[:,i] = xprev*slope*well_shape(rel_dist) + well_loc*(1-well_shape(rel_dist))
                in_well = True
                
        if in_well == False:
            out[:,i] = xprev*slope

    return out

In [None]:
def kwells_draw_trial(well_locs, T, Sigma_eps, Sigma_nu, mu_0_0, Sigma_0_0, well_width, **kwargs):
    x0 = mu_0_0 + np.sqrt(Sigma_0_0)*np.random.randn(1)
    x = np.zeros((1,T))
    y = np.zeros((1,T))
    for t in range(T):
        if t==0:
            xprev = x0
        else:
            xprev = x[:,t-1]

        x[:,t] = (kwells_true_tr_fnc(np.atleast_1d(xprev), well_locs, Sigma_eps, well_width, **kwargs)
                    + np.sqrt(Sigma_eps)*np.random.randn(1))
        
        y[:,t] = x[:,t] + np.sqrt(Sigma_nu)*np.random.randn(1)
        
    return (x,y)    

In [None]:
# rseed = 127
# Ny = 30
# well_locs = np.array([1.0, -1.0])
# T = 10
# Sigma_eps = 5e-2
# Sigma_nu = 1e-2
# mu_0_0 = 0.0
# Sigma_0_0 = 1.3e-0
# well_width = 3
# np.random.seed(rseed)

# kwells_params = OrderedDict()
# kwells_params["rseed"] = rseed
# kwells_params["Ny"] = Ny
# kwells_params["well_locs"] = well_locs
# kwells_params["T"] = T
# kwells_params["Sigma_eps"] = Sigma_eps
# kwells_params["Sigma_nu"] = Sigma_nu
# kwells_params["mu_0_0"] = mu_0_0
# kwells_params["Sigma_0_0"] = Sigma_0_0
# kwells_params["well_width"] = well_width

# # Collect trials into runnable object
# all_trials_x = []
# all_trials_y = []
# for n in range(Ny):
#     x,y = kwells_draw_trial(well_locs, T, Sigma_eps, Sigma_nu, mu_0_0, Sigma_0_0, well_width)
    
#     all_trials_x.append(x[:,:,None])
#     all_trials_y.append(y[:,:,None])


# x = np.concatenate(all_trials_x, axis=2)
# y = np.concatenate(all_trials_y, axis=2)

# plots_by_run = []
# for v in range(Ny):
#     plots_by_run.append(
#         plt_type.Scatter(x=np.squeeze(np.arange(T)), 
#                       y=np.squeeze(y[:,:,v]), 
#                       mode='lines')
#     )
    
# plt(plots_by_run)

In [None]:
# # Try to solve the full problem (or some parts of it)
# D = 1
# Nz = 16
# Ns = 2

# (Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, lengthscales, kernel_variance, s, J) = \
#     init_params(y, D, Nz, Ns)

# # Fix noise to true value
# Sigma_nu = kwells_params['Sigma_nu'] * np.ones_like(Sigma_nu)

# Sigma_u = 1e-2*np.ones((Nz,1))
# Sigma_s = 1e-3*np.ones((Ns,1))
# Sigma_J = 1e-3*np.ones((Ns*D,1))

    
# (init_paramvec, dict_ind, dict_shape) = params_to_vec(Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, 
#                                                       lengthscales, kernel_variance, s, J, Sigma_s=Sigma_s, Sigma_J=Sigma_J)

# # Transform certain elements of the parameter vector to optimise in log space
# # log_transformed=None
# log_transformed = np.concatenate([dict_ind['Sigma_0_0'], dict_ind['Sigma_u'],
#                                   dict_ind['Sigma_s'], dict_ind['Sigma_J'],
#                                   dict_ind['lengthscales'], dict_ind['Sigma_eps'], 
#                                   dict_ind['Sigma_nu'],
#                                   dict_ind['kernel_variance']
#                                  ])

# init_paramvec = log_transform(init_paramvec, log_transformed)

# # # Optimise only certain elements of paramvec
# opt_params = np.arange(init_paramvec.shape[0])
# opt_params = np.delete(opt_params, np.hstack([dict_ind['C'], dict_ind['Sigma_nu']])) # All except the ones listed here
# cur_pvec = init_paramvec[opt_params]

# # # Optimise only certain elements of paramvec
# # opt_params = np.concatenate([dict_ind['s']])
# # cur_pvec = init_paramvec[opt_params]
# # If want to do all:
# # opt_params = np.arange(init_paramvec.shape[0])
# # cur_pvec = init_paramvec

In [None]:
# Define a plotting function for callback that shows the current transition function estimate
def kwells_callback_plot_external(pvec_partial, 
                                  opt_params, init_paramvec, transforms, dict_ind, dict_shape,
                                  kwells_params
                                 ):
    
    paramvec = replace_params(pvec_partial, opt_params, init_paramvec)
    paramdict = vec_to_params(paramvec, dict_ind, dict_shape, transforms)
       
    # Unpack the usual parameters
    (Sigma_eps, mu_0_0, Sigma_0_0, C, Sigma_nu, z, u, Sigma_u, lengthscales, kernel_variance, s, J)  = \
        paramdict.values()[:12]
    
    if np.any(np.isnan(lengthscales)):
        set_trace()
    
    # Deal with the extra possible parameters
    Sigma_s = None; Sigma_J=None;
    if 'Sigma_s' in paramdict.keys():
        Sigma_s = paramdict['Sigma_s']
    if 'Sigma_J' in paramdict.keys():
        Sigma_J = paramdict['Sigma_J']
        
    # Plot transition function
    xstar = np.atleast_2d(np.arange(np.min(z)-0.5,np.max(z)+0.5,0.05))

    L, targets, params = fp_get_static_K(eta=kernel_variance, lengthscales=lengthscales, z=z, u=u, s=s, J=J, 
                                         sig_eps=Sigma_eps, sig_u=Sigma_u, sig_s=Sigma_s, sig_J = Sigma_J)
    mu_star, sig_star, K_pred = fp_predict(xstar, L, targets, params)

    # print(time_full_iter(pvec, y, dict_ind, dict_shape)[0])
        
    # Get true function values
    true_tr_vals = kwells_true_tr_fnc(xstar, **kwells_params)
    
    #set_trace()
    
    plt([plt_type.Scatter(x=np.squeeze(xstar), y=np.squeeze(mu_star), mode='markers', name='GP post. mean',
                         marker=dict(color='blue')),
         plt_type.Scatter(x=np.squeeze(xstar), y=np.squeeze(mu_star)+np.squeeze(np.sqrt(sig_star)), mode='markers', 
                         marker=dict(size=2, color='blue')),
         plt_type.Scatter(x=np.squeeze(xstar), y=np.squeeze(mu_star)-np.squeeze(np.sqrt(sig_star)), mode='markers', 
                         marker=dict(size=2, color='blue')),      
         plt_type.Scatter(x=np.squeeze(xstar), 
                          y=np.squeeze(true_tr_vals), mode='markers', name = 'True trans. f.',
                          marker=dict(color='orange')),
         plt_type.Scatter(x=np.squeeze(xstar), 
                          y=np.squeeze(true_tr_vals+np.sqrt(kwells_params["Sigma_eps"])), mode='markers', name = 'True trans. f.',
                          marker=dict(color='orange', size=2)),
         plt_type.Scatter(x=np.squeeze(xstar), 
                          y=np.squeeze(true_tr_vals-np.sqrt(kwells_params["Sigma_eps"])), mode='markers', name = 'True trans. f.',
                          marker=dict(color='orange', size=2)),
         plt_type.Scatter(x=np.squeeze(z), y=np.squeeze(-2.0*np.ones_like(z)), mode='markers', marker=dict(size=10),
                         name = 'Ind point loc'),
         plt_type.Scatter(x=np.squeeze(z), y=np.squeeze(u), mode='markers', marker=dict(size=10),
                         name = 'Ind point val'),
         plt_type.Scatter(x=np.atleast_1d(np.squeeze(s)), y=np.atleast_1d(np.squeeze(s)), mode='markers', marker=dict(size=10),
                         name = 'Fixed point')
        ])
    
    
def kwells_callback_plot(pvec_partial): 
    kwells_callback_plot_external(
        pvec_partial, 
        opt_params, init_paramvec, transforms, dict_ind, dict_shape,
        kwells_params)

## Set up the optimisation


In [None]:
# # Add bounds for parameters
# bnds = list(((None, None),) * init_paramvec.shape[0])
# for i in np.concatenate([dict_ind['Sigma_0_0'], dict_ind['Sigma_nu'], dict_ind['Sigma_eps'],
#                         dict_ind['Sigma_u'], dict_ind['Sigma_s'], dict_ind['Sigma_J']]):
#     lb = 1e-6; ub = 1e2
#     if i in log_transformed:
#         lb = np.log(lb)
#         ub = np.log(ub)
#     bnds[i] = (lb, ub)
# for i in np.concatenate([dict_ind['lengthscales'], dict_ind['kernel_variance']]):
#     if i in log_transformed:
#         lb = init_paramvec[i] + np.log(0.3)
#         ub = init_paramvec[i] + np.log(3.0)
#     else:
#         lb = init_paramvec[i]*0.3 
#         ub = init_paramvec[i]*3.
#     bnds[i] = (lb, ub)
# # cur_dim = 0
# # cur_z = 0
# # cur_tot = 0
# # z_mins = np.min(z, axis=1)
# # z_maxs = np.max(z, axis=1)
# # for i in np.concatenate([dict_ind['z'], dict_ind['s']]): # Note the idiotic python reshape order for setting bounds per dim
# #     z_min = z_mins[cur_dim]
# #     z_max = z_maxs[cur_dim]
# #     bnds[i] = (z_min-0.05*(z_max-z_min), z_max+0.05*(z_max-z_min))
# #     cur_z = cur_z+1
# #     if cur_tot < D*Nz:
# #         cur_z = np.mod(cur_z, Nz)
# #     else:
# #         cur_z = np.mod(cur_z, Ns)
# #     cur_tot = cur_tot+1
# #     if cur_z==0:
# #         cur_dim = cur_dim+1
# #     if cur_tot==D*Nz:
# #         cur_dim = 0
# # for i in np.concatenate([dict_ind['J']]):
# #     bnds[i] = (-1., 1.)
# bnds_final = []
# for i in opt_params:
#     bnds_final.append(bnds[i])
# bnds = tuple(bnds_final)

In [None]:
# # Add priors (to span at least the bounds)
# prior_funcs = list(((None),) * init_paramvec.shape[0])

# # # Add a strong prior to learn actual fixed points
# # logGamma_prior = create_prior("LogGamma", [2., 0.5, -6.])
# # for i in np.concatenate([dict_ind['Sigma_s'], dict_ind['Sigma_J']]):
# #     prior_funcs[i] = logGamma_prior
    
# # tmp_x = np.logspace(-6.0,2,100)    
# # plt(plt_type.Figure(data=[plt_type.Scatter(x=tmp_x, y=np.exp(-logGamma_prior(tmp_x)))], layout=plt_type.Layout(xaxis=dict(type= "log"))))
# # plt(plt_type.Figure(data=[plt_type.Scatter(x=tmp_x, y=logGamma_prior(tmp_x))], layout=plt_type.Layout(xaxis=dict(type= "log"))))
    

In [None]:
# tmp_func = lambda pvec_partial: (time_full_iter(replace_params(pvec_partial, opt_params, init_paramvec), 
#                                                 y, dict_ind, dict_shape, 
#                                                 log_transformed=log_transformed,
#                                                 prior_funcs=prior_funcs)[0])
# objective_with_grad = value_and_grad(tmp_func, argnum=0)

    
    
# # By iterating minimize within a for cycle, we can save all intermediate results and set ending times
# save_fname = "well_1d_k2_" + datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + ".pkl"
# init_time = time.time()
# max_time = 0.45*3600 # Maximum iteration time in seconds, break if reached
# all_results = []

In [None]:
# from autograd.tracer import trace, Node

# def make_varname_generator():
#     for i in range(65, 91):
#         for j in range(3000):
#             yield (chr(i) + str(j))
#     raise Exception("Ran out of alphabet!")

# class PrintNode(Node):
#     def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
#         self.varname_generator = parents[0].varname_generator
#         self.varname = self.varname_generator.next()
#         args_or_vars = list(args)
#         for argnum, parent in zip(parent_argnums, parents):
#             args_or_vars[argnum] = parent.varname
#         print '{} = {}({}) = {}'.format(
#             self.varname, fun.__name__, ','.join(map(str, args_or_vars)), value)

#     def initialize_root(self, x):
#         self.varname_generator = make_varname_generator()
#         self.varname = self.varname_generator.next()
#         print '{} = {}'.format(self.varname, x)

# def print_trace(f, x):
#     start_node = PrintNode.new_root(x)
#     trace(start_node, f, x)
#     print
    
# val_orig = tmp_func(cur_pvec)[0,0]

In [None]:
# for it in range(100):
#     result = scipy.optimize.minimize(objective_with_grad, cur_pvec, jac=True, method='L-BFGS-B', bounds=bnds, callback=None,
#                           options={'maxiter':1, 'disp':True})
#     all_results.append(result)
#     # Save the results
#     with open(save_fname, 'wb') as f:
#         pickle.dump([y, x, kwells_params,
#                      all_results, 
#                      init_paramvec, dict_ind, dict_shape, opt_params, 
#                      bnds, log_transformed], f)
#     cur_pvec = result.x
#     cur_time = time.time()
#     print([it, cur_time - init_time, result.fun])
#     kwells_callback_plot(cur_pvec)
    
#     # Exit if maximum time is reached
    
#     if ((cur_time - init_time) > max_time):
#         print(["Maximum iteration time reached at iter", it])
#         break
        
#     if len(all_results)>=2:
#         if (all_results[-1].fun - all_results[-2].fun) >= -1e-1:
#             print(["Update did not improve objective function, stopping"])
#             break

# Analyse performance

In [None]:
# # Generate test data

# kwells_params_test = kwells_params
# kwells_params_test['Sigma_0_0']= 0.5e-0

# np.random.seed(kwells_params['rseed']+1)

# Ny_test = 350
# test_trials_x = []
# test_trials_y = []
# for n in range(Ny_test):
#     x1,y1 = kwells_draw_trial(**kwells_params_test)
    
#     test_trials_x.append(x1[:,:,None])
#     test_trials_y.append(y1[:,:,None])

# x_test = np.concatenate(test_trials_x, axis=2)
# y_test = np.concatenate(test_trials_y, axis=2)

In [None]:
# # Get final GP parameter vector
# final_paramvec = log_transform_inv(replace_params(all_results[-1].x, opt_params, init_paramvec), log_transformed)

# cutoff = 4

# # Get GP predictions on test data
# results_GP = pred_GP(y_test, final_paramvec, dict_ind, dict_shape, cutoff = None)
# results_GP_cutoff = pred_GP(y_test, final_paramvec, dict_ind, dict_shape, cutoff = cutoff)

In [None]:
# # Get AR prediction as baseline
# results_AR = pred_lin_AR1(y_test, y)
# results_AR_cutoff = pred_lin_AR1(y_test, y, cutoff=cutoff)

In [None]:
# # Get RMSE values
# RMSE_AR = rmse(results_AR, y_test, axis=(0,2))
# RMSE_AR_cutoff = rmse(results_AR_cutoff, y_test, axis=(0,2))
# RMSE_GP = rmse(results_GP[1], y_test, axis=(0,2))
# RMSE_GP_cutoff = rmse(results_GP_cutoff[1], y_test, axis=(0,2))

# plt([
#         plt_type.Scatter(y=np.squeeze(RMSE_AR), name='RMSE_AR'),
#         plt_type.Scatter(y=np.squeeze(RMSE_GP), name='RMSE_GP'),
#         plt_type.Scatter(y=np.squeeze(RMSE_AR_cutoff), name='RMSE_AR_cutoff'),
#         plt_type.Scatter(y=np.squeeze(RMSE_GP_cutoff), name='RMSE_GP_cutoff')
#     ])