From 80cbdfecfa21019753fe0f6eb25d0f89089f4ae0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Faria?= Date: Wed, 9 Sep 2020 15:26:01 +0100 Subject: [PATCH] return std for GP prediction --- pykima/GP.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/pykima/GP.py b/pykima/GP.py index 171165af..7887d5f1 100644 --- a/pykima/GP.py +++ b/pykima/GP.py @@ -17,7 +17,7 @@ def __init__(self, eta1, eta2, eta3, eta4): self.eta2 = eta2 self.eta3 = eta3 self.eta4 = eta4 - + def setpars(self, eta1=None, eta2=None, eta3=None, eta4=None): self.eta1 = eta1 if eta1 else self.eta1 self.eta2 = eta2 if eta2 else self.eta2 @@ -90,9 +90,9 @@ def predict(self, y, t=None, return_std=False, return_cov=False): Conditional predictive distribution of the GP model given observations y, evaluated at coordinates t. """ - if t is None: + if t is None: t = self.t - + self.K = self._cov(self.t) self.L_ = cholesky(self.K, lower=True) self.alpha_ = cho_solve((self.L_, True), y) @@ -123,7 +123,8 @@ def predict(self, y, t=None, return_std=False, return_cov=False): else: return mean - def predict_with_hyperpars(self, results, sample, t=None, add_parts=True): + def predict_with_hyperpars(self, results, sample, t=None, add_parts=True, + return_std=False): """ Given the parameters in `sample`, return the GP predictive mean. If `t` is None, the prediction is done at the observed times and will contain @@ -149,10 +150,13 @@ def predict_with_hyperpars(self, results, sample, t=None, add_parts=True): y -= sample[results.indices['trend']] * (results.t - results.tmiddle) if t is not None: - mu = self.predict(y, t, return_cov=False) - return mu - - mu = self.predict(y, results.t, return_cov=False) + pred = self.predict(y, t, return_cov=False, return_std=return_std) + return pred + + mu = self.predict(y, results.t, return_cov=False, + return_std=return_std) + if return_std: + mu, std = mu if add_parts: # add the trend back @@ -163,7 +167,7 @@ def predict_with_hyperpars(self, results, sample, t=None, add_parts=True): for i in range(1, results.n_instruments): mu[results.obs == i] += sample[results.indices['inst_offsets']][i-1] - return mu + return mu, std def sample_conditional(self, y, t, size=1): @@ -262,8 +266,8 @@ def sample_conditional_with_hyperpars(self, results, sample, t, size=1): class QPkernel_celerite(terms.Term): # This implements a quasi-periodic kernel (QPK) devised by Andrew Collier - # Cameron, which mimics a standard QP kernel with a roughness parameter 0.5, - # and has zero derivative at the origin: k(tau=0)=amp and k'(tau=0)=0 + # Cameron, which mimics a standard QP kernel with a roughness parameter 0.5, + # and has zero derivative at the origin: k(tau=0)=amp and k'(tau=0)=0 # The kernel defined in the celerite paper (Eq 56 in Foreman-Mackey et al. 2017) # does not satisfy k'(tau=0)=0 #