Skip to content

Commit

Permalink
UPDATE: matern32 and matern52
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Jul 18, 2018
1 parent 284bdcf commit ac90903
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions bayeso/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,32 @@ def cov_se(bx, bxp, lengthscales, signal):
assert bx.shape[0] == bxp.shape[0]
return signal**2 * np.exp(-0.5 * np.linalg.norm((bx - bxp) / lengthscales, ord=2)**2)

def cov_matern32(bx, bxp, lengthscales, signal):
assert isinstance(bx, np.ndarray)
assert isinstance(bxp, np.ndarray)
assert isinstance(lengthscales, np.ndarray) or isinstance(lengthscales, float)
if isinstance(lengthscales, np.ndarray):
assert bx.shape[0] == bxp.shape[0] == lengthscales.shape[0]
else:
assert bx.shape[0] == bxp.shape[0]
assert isinstance(signal, float)

dist = np.linalg.norm((bx - bxp) / lengthscales, ord=2)
return signal**2 * (1.0 + np.sqrt(3.0) * dist) * np.exp(-1.0 * np.sqrt(3.0) * dist)

def cov_matern52(bx, bxp, lengthscales, signal):
assert isinstance(bx, np.ndarray)
assert isinstance(bxp, np.ndarray)
assert isinstance(lengthscales, np.ndarray) or isinstance(lengthscales, float)
if isinstance(lengthscales, np.ndarray):
assert bx.shape[0] == bxp.shape[0] == lengthscales.shape[0]
else:
assert bx.shape[0] == bxp.shape[0]
assert isinstance(signal, float)

dist = np.linalg.norm((bx - bxp) / lengthscales, ord=2)
return signal**2 * (1.0 + np.sqrt(5.0) * dist + 5.0 / 3.0 * dist**2) * np.exp(-1.0 * np.sqrt(5.0) * dist)

def cov_main(str_cov, X, Xs, hyps, jitter=constants.JITTER_COV):
assert isinstance(str_cov, str)
assert isinstance(X, np.ndarray)
Expand All @@ -43,8 +69,18 @@ def cov_main(str_cov, X, Xs, hyps, jitter=constants.JITTER_COV):
for ind_X in range(0, num_X):
for ind_Xs in range(0, num_Xs):
cov_[ind_X, ind_Xs] += cov_se(X[ind_X], Xs[ind_Xs], hyps['lengthscales'], hyps['signal'])
elif str_cov == 'matern52' or str_cov == 'matern32':
raise NotImplementedError('cov_main: matern52 or matern32.')
elif str_cov == 'matern32':
if hyps.get('lengthscales') is None or hyps.get('signal') is None:
raise ValueError('cov_main: insufficient hyperparameters.')
for ind_X in range(0, num_X):
for ind_Xs in range(0, num_Xs):
cov_[ind_X, ind_Xs] += cov_matern32(X[ind_X], Xs[ind_Xs], hyps['lengthscales'], hyps['signal'])
elif str_cov == 'matern52':
if hyps.get('lengthscales') is None or hyps.get('signal') is None:
raise ValueError('cov_main: insufficient hyperparameters.')
for ind_X in range(0, num_X):
for ind_Xs in range(0, num_Xs):
cov_[ind_X, ind_Xs] += cov_matern52(X[ind_X], Xs[ind_Xs], hyps['lengthscales'], hyps['signal'])
else:
raise NotImplementedError('cov_main: allowed str_cov, but it is not implemented.')
return cov_

0 comments on commit ac90903

Please sign in to comment.