Skip to content

Commit

Permalink
docs for mean
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Jan 3, 2020
1 parent b46edce commit 09cb647
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/exoplanet/gp/celerite.py
Expand Up @@ -34,11 +34,15 @@ class GP:
GP kernel.
x: The input coordinates. This should be a 1D array and the elements
must be sorted. Otherwise the results are undefined.
diag: The extra diagonal to add to the covariance matrix. This should
have the same length as ``x`` and correspond to the excess
*variance* for each data point. **Note:** this is different from
the usage in the ``celerite`` package where the standard deviation
(instead of variance) is provided.
diag (Optional): The extra diagonal to add to the covariance matrix.
This should have the same length as ``x`` and correspond to the
excess *variance* for each data point. **Note:** this is different
from the usage in the ``celerite`` package where the standard
deviation (instead of variance) is provided.
mean (Optional): The mean function for the GP. This can be a constant
scalar value or a callable that will be called with a single, one
dimensional tensor argument specifying the input coordinates where
the mean should be evaluated.
J (Optional): The width of the system. This is the ``J`` parameter
from `Foreman-Mackey (2018) <https://arxiv.org/abs/1801.10156>`_
(not the original paper) so a real term contributes ``J += 1`` and
Expand Down Expand Up @@ -70,7 +74,7 @@ def __init__(self, kernel, x, diag=None, mean=Zero(), J=-1, model=None):
self.mean = mean
else:
self.mean = Constant(mean)
self.mean_val = self.mean(x)
self.mean_val = self.mean(self.x)

self.a, self.U, self.V, self.P = self.kernel.get_celerite_matrices(
self.x, self.diag
Expand Down Expand Up @@ -118,6 +122,14 @@ def dot_l(self, n):
return self.dot_l_op(self.U, self.P, self.d, self.W, n)

def apply_inverse(self, rhs):
rhs = tt.as_tensor_variable(rhs)
if rhs.ndim == 1:
return tt.reshape(
self.vector_solve_op(
self.U, self.P, self.d, self.W, tt.reshape(rhs, (rhs, 1))
),
(rhs.size,),
)
return self.general_solve_op(self.U, self.P, self.d, self.W, rhs)[0]

def predict(
Expand All @@ -144,11 +156,13 @@ def predict(

if t is None:
t = self.x
mean_val = self.mean_val
Kxs = kernel.value(self.x[:, None] - self.x[None, :])
KxsT = Kxs
Kss = Kxs
else:
t = tt.as_tensor_variable(t)
mean_val = self.mean(t)
KxsT = kernel.value(t[None, :] - self.x[:, None])
Kxs = tt.transpose(KxsT)
Kss = kernel.value(t[:, None] - t[None, :])
Expand All @@ -163,7 +177,7 @@ def predict(
)
else:
mu = tt.dot(Kxs, self.z)
mu = mu + self.mean(t)
mu = mu + mean_val

if not (return_var or return_cov):
return mu
Expand Down

0 comments on commit 09cb647

Please sign in to comment.