Skip to content

Commit

Permalink
fully functional mcmc training
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusMNoack committed Feb 17, 2023
1 parent bafe85a commit f139c66
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
11 changes: 4 additions & 7 deletions fvgp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ class GP():
and return a 2-D numpy array of shape V x V.
If ram_economy=False, the function should be of the form f(points1, points2, hyperparameters) and return a numpy array of shape
H x V x V, where H is the number of hyperparameters. V is the number of points. CAUTION: This array will be stored and is very large.
args : user defined, optional
These optional arguments will be available as attribute in kernel and mean function definitions.
Expand Down Expand Up @@ -125,7 +123,6 @@ def __init__(
self.y_data = np.array(values)
self.compute_device = compute_device
self.ram_economy = ram_economy
if args: self.args = args

self.use_inv = use_inv
self.K_inv = None
Expand Down Expand Up @@ -301,7 +298,6 @@ def train(self,
dask_client : distributed.client.Client, optional
A Dask Distributed Client instance for distributed training if HGDL is used. If None is provided, a new
`dask.distributed.Client` instance is constructed.
"""
############################################
if init_hyperparameters is None:
Expand Down Expand Up @@ -537,14 +533,15 @@ def _optimize_log_likelihood(self,starting_hps,
constraints = constraints)
obj = opt.optimize(dask_client = dask_client, x0 = np.array(starting_hps).reshape(1,-1))
res = opt.get_final()
print("fbfs ",res)
hyperparameters = res["x"][0]
opt.kill_client()
elif method == "mcmc":
logger.debug("MCMC started in fvGP")
logger.debug('bounds are {}', hp_bounds)
res = mcmc(self.log_likelihood,hp_bounds)
hyperparameters = np.array(res["x"])
def likelihood_wrapper(x):
return -self.log_likelihood(x)
res = mcmc(likelihood_wrapper,hp_bounds, x0 = starting_hps, max_iter = max_iter)
hyperparameters = np.array(res["distribution mean"])
elif callable(method):
hyperparameters = method(self)
else:
Expand Down
13 changes: 3 additions & 10 deletions fvgp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,29 @@ def mcmc(func,bounds, x0 = None, distr = None, max_iter = 1000, ):
f = []
if x0 is None: x.append(np.random.uniform(low = bounds[:,0],high = bounds[:,1],size = len(bounds)))
else: x.append(x0)
if distr is None: l = np.diag((np.abs(np.subtract(bounds[:,0],bounds[:,1])))/10.0)**2
if distr is None: l = np.diag((np.abs(np.subtract(bounds[:,0],bounds[:,1])))/100.0)**2
counter = 0
current_func = func(x0)
current_func = func(x[0])
f.append(current_func)
run = True
while run:
x_proposal = np.random.multivariate_normal(x[-1], l)
x_proposal = project_onto_bounds(x_proposal,bounds)
##check for constraints?
proposal_func = func(x_proposal) ####call function
acceptance_prob = proposal_func - current_func ##these are already logs
uu = np.random.rand()
u = np.log(uu)
print("iteration: ", counter,"current f: ",current_func, " prop f: ",proposal_func,flush = True)
#print("acceptance prob: ",acceptance_prob,"u: ", u,flush = True)
print("proposed x: ", x_proposal, "current x: ", x[-1], flush = True)
if u < acceptance_prob:
x.append(x_proposal)
f.append(proposal_func)
current_func = proposal_func
print("accepted", flush = True)
else:
x.append(x[-1])
f.append(current_func)
print(f[-1], flush = True)
logger.debug("mcmc f(x):{}",f[-1])
print("")
counter += 1
if counter >= max_iter: run = False
#if len(x)>201 and np.linalg.norm(np.mean(x[-100:],axis = 0)-np.mean(x[-200:-100],axis = 0)) < 0.01 * np.linalg.norm(np.mean(x[-100:],axis = 0)): run = False
if len(x)>201 and np.linalg.norm(np.mean(x[-100:],axis = 0)-np.mean(x[-200:-100],axis = 0)) < 0.01 * np.linalg.norm(np.mean(x[-100:],axis = 0)): run = False
arg_max = np.argmax(f)
x = np.asarray(x)
logger.debug(f"mcmc res: {f[arg_max]} at {x[arg_max]}")
Expand Down

0 comments on commit f139c66

Please sign in to comment.