Skip to content

Commit

Permalink
predict for different terms
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 24, 2018
1 parent 31086bf commit 1ce913f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ transit/cuda
demo
.pytest_cache
dist
for_ed
24 changes: 16 additions & 8 deletions docs/_static/notebooks/gp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"logp = -1.651, ||grad|| = 0.057729: 100%|██████████| 48/48 [00:00<00:00, 746.05it/s]\n"
]
}
],
"source": [
"from exoplanet.sampling import TuningSchedule\n",
"\n",
Expand All @@ -165,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -178,7 +186,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -187,7 +195,7 @@
"((5000,), (5000,))"
]
},
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -198,16 +206,16 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x1c32829748>]"
"[<matplotlib.lines.Line2D at 0x1c4574d198>]"
]
},
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
Expand Down
11 changes: 11 additions & 0 deletions docs/_static/notebooks/notebook_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
get_ipython().magic('matplotlib inline')
get_ipython().magic('config InlineBackend.figure_format = "retina"')

from matplotlib import rcParams
rcParams["savefig.dpi"] = 100
rcParams["figure.dpi"] = 100
rcParams["font.size"] = 16
rcParams["text.usetex"] = False
rcParams["font.family"] = ["sans-serif"]
rcParams["font.sans-serif"] = ["cmss10"]
rcParams["axes.unicode_minus"] = False
20 changes: 14 additions & 6 deletions exoplanet/gp/celerite.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,25 @@ def log_likelihood(self, y):
def apply_inverse(self, rhs):
return self.general_solve_op(self.U, self.P, self.d, self.W, rhs)[0]

def predict(self, t=None, return_var=False, return_cov=False):
if t is None:
def predict(self, t=None, return_var=False, return_cov=False, kernel=None):
mu = None
if t is None and kernel is None:
mu = self.y - self.diag * self.z[:, 0]

if kernel is None:
kernel = self.kernel

if t is None:
t = self.x
Kxs = self.kernel.value(self.x[:, None] - self.x[None, :])
Kxs = kernel.value(self.x[:, None] - self.x[None, :])
KxsT = Kxs
Kss = Kxs
else:
KxsT = self.kernel.value(t[None, :] - self.x[:, None])
KxsT = kernel.value(t[None, :] - self.x[:, None])
Kxs = tt.transpose(KxsT)
Kss = self.kernel.value(t[:, None] - t[None, :])
Kss = kernel.value(t[:, None] - t[None, :])

if mu is None:
mu = tt.dot(Kxs, self.z)[:, 0]

if not (return_var or return_cov):
Expand All @@ -67,7 +75,7 @@ def predict(self, t=None, return_var=False, return_cov=False):
KinvKxsT = self.apply_inverse(KxsT)
if return_var:
var = -diag_dot(Kxs, KinvKxsT) # tt.sum(KxsT*KinvKxsT, axis=0)
var += self.kernel.value(0)
var += kernel.value(0)
return mu, var

cov = Kss - tt.dot(Kxs, KinvKxsT)
Expand Down

0 comments on commit 1ce913f

Please sign in to comment.