Skip to content

Commit

Permalink
Adding check for unsorted input coordinates when using QuasisepSolver (
Browse files Browse the repository at this point in the history
…#123)

* adding exception and tests

* adding news

* fixing handling of coord_to_sortable for composite kernels
  • Loading branch information
dfm committed Oct 30, 2022
1 parent 2bdccea commit fada5dd
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 15 deletions.
9 changes: 7 additions & 2 deletions docs/tutorials/quasisep-custom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.6 ('tinygp')",
"language": "python",
"name": "python3"
},
Expand All @@ -684,7 +684,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
}
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions news/123.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added check for sorted input coordinates when using the ``QuasisepSolver``;
a ``ValueError`` is thrown if they are not.
7 changes: 6 additions & 1 deletion src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
solver: Optional[Any] = None,
mean_value: Optional[JAXArray] = None,
covariance_value: Optional[Any] = None,
**solver_kwargs: Any,
):
self.kernel = kernel
self.X = X
Expand Down Expand Up @@ -101,7 +102,11 @@ def __init__(
else:
solver = DirectSolver
self.solver = solver.init(
kernel, self.X, self.noise, covariance=covariance_value
kernel,
self.X,
self.noise,
covariance=covariance_value,
**solver_kwargs,
)

@property
Expand Down
27 changes: 19 additions & 8 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]

from abc import ABCMeta, abstractmethod
from typing import Optional, Union
from typing import Any, Optional, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -151,7 +151,7 @@ def __add__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
)
return Sum(self, other)

def __radd__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
def __radd__(self, other: Any) -> "Kernel":
# We'll hit this first branch when using the `sum` function
if other == 0:
return self
Expand All @@ -171,7 +171,7 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
)
return Scale(kernel=self, scale=other)

def __rmul__(self, other: Union["Kernel", JAXArray]) -> "Kernel":
def __rmul__(self, other: Any) -> "Kernel":
if isinstance(other, Quasisep):
return Product(other, self)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
Expand Down Expand Up @@ -204,6 +204,9 @@ class Wrapper(Quasisep, metaclass=ABCMeta):

kernel: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
return self.kernel.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
return self.kernel.design_matrix()

Expand All @@ -226,6 +229,10 @@ class Sum(Quasisep):
kernel1: Quasisep
kernel2: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
Expand Down Expand Up @@ -259,6 +266,10 @@ class Product(Quasisep):
kernel1: Quasisep
kernel2: Quasisep

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
F1 = self.kernel1.design_matrix()
F2 = self.kernel2.design_matrix()
Expand Down Expand Up @@ -699,14 +710,14 @@ def init(
params = jnp.linalg.solve(
params, 0.5 * sigma**2 * jnp.eye(p, 1, k=-p + 1)
)[:, 0]
stn = []
stn_ = []
for j in range(p):
stn.append([jnp.zeros(()) for _ in range(p)])
stn_.append([jnp.zeros(()) for _ in range(p)])
for n, k in enumerate(range(j - 2, -1, -2)):
stn[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
stn_[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1]
for n, k in enumerate(range(j, p, 2)):
stn[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
stn = jnp.array(list(map(jnp.stack, stn)))
stn_[-1][k] = (1 - 2 * (n % 2)) * params[n + j]
stn = jnp.array(list(map(jnp.stack, stn_)))

return cls(
sigma=sigma,
Expand Down
23 changes: 20 additions & 3 deletions src/tinygp/solvers/quasisep/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__all__ = ["QuasisepSolver"]

from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -41,6 +41,7 @@ def init(
noise: Noise,
*,
covariance: Optional[Any] = None,
assume_sorted: bool = False,
) -> "QuasisepSolver":
"""Build a :class:`QuasisepSolver` for a given kernel and coordinates
Expand All @@ -52,15 +53,24 @@ def init(
covariance: Optionally, a pre-computed
:class:`tinygp.solvers.quasisep.core.QSM` with the covariance
matrix.
assume_sorted: If ``True``, assume that the input coordinates are
sorted. If ``False``, check that they are sorted and throw an
error if they are not. This can introduce a runtime overhead,
and you can pass ``assume_sorted=True`` to get the best
performance.
"""
from tinygp.kernels.quasisep import Quasisep

if covariance is None:
assert isinstance(kernel, Quasisep)
if TYPE_CHECKING:
assert isinstance(kernel, Quasisep)
if not assume_sorted:
jax.debug.callback(_check_sorted, kernel.coord_to_sortable(X))
matrix = kernel.to_symm_qsm(X)
matrix += noise.to_qsm()
else:
assert isinstance(covariance, SymmQSM)
if TYPE_CHECKING:
assert isinstance(covariance, SymmQSM)
matrix = covariance
factor = matrix.cholesky()
return cls(X=X, matrix=matrix, factor=factor)
Expand Down Expand Up @@ -125,3 +135,10 @@ def condition(

A = self.solve_triangular(Ks)
return Kss - A.transpose() @ A


def _check_sorted(X: JAXArray) -> None:
if np.any(np.diff(X) < 0.0):
raise ValueError(
"Input coordinates must be sorted in order to use the QuasisepSolver"
)
21 changes: 20 additions & 1 deletion tests/test_solvers/test_quasisep/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_consistent_with_direct(kernel_pair, data):

@pytest.mark.skipif(celerite is None, reason="'celerite' must be installed")
def test_celerite(data):
x, y, t = data
x, y, _ = data
yerr = 0.1

a, b, c, d = 1.1, 0.8, 0.9, 0.1
Expand All @@ -125,3 +125,22 @@ def test_celerite(data):
calc = gp.log_probability(y)

np.testing.assert_allclose(calc, expected)


def test_unsorted(data):
random = np.random.default_rng(0)
inds = random.permutation(len(data[0]))
x_ = data[0][inds]
y_ = data[1][inds]

kernel = quasisep.Matern32(sigma=1.8, scale=1.5)
with pytest.raises(ValueError):
GaussianProcess(kernel, x_, diag=0.1)

@jax.jit
def impl(X, y):
return GaussianProcess(kernel, X, diag=0.1).log_probability(y)

with pytest.raises(jax.lib.xla_extension.XlaRuntimeError) as exc_info:
impl(x_, y_).block_until_ready()
assert exc_info.match(r"Input coordinates must be sorted")

0 comments on commit fada5dd

Please sign in to comment.