diff --git a/docs/tutorials/quasisep-custom.ipynb b/docs/tutorials/quasisep-custom.ipynb index a3813605..78a047a7 100644 --- a/docs/tutorials/quasisep-custom.ipynb +++ b/docs/tutorials/quasisep-custom.ipynb @@ -670,7 +670,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.10.6 ('tinygp')", "language": "python", "name": "python3" }, @@ -684,7 +684,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.9" + "version": "3.10.6" + }, + "vscode": { + "interpreter": { + "hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3" + } } }, "nbformat": 4, diff --git a/news/123.feature b/news/123.feature new file mode 100644 index 00000000..8cc568b0 --- /dev/null +++ b/news/123.feature @@ -0,0 +1,2 @@ +Added check for sorted input coordinates when using the ``QuasisepSolver``; +a ``ValueError`` is thrown if they are not. diff --git a/src/tinygp/gp.py b/src/tinygp/gp.py index 08970162..6578617b 100644 --- a/src/tinygp/gp.py +++ b/src/tinygp/gp.py @@ -66,6 +66,7 @@ def __init__( solver: Optional[Any] = None, mean_value: Optional[JAXArray] = None, covariance_value: Optional[Any] = None, + **solver_kwargs: Any, ): self.kernel = kernel self.X = X @@ -101,7 +102,11 @@ def __init__( else: solver = DirectSolver self.solver = solver.init( - kernel, self.X, self.noise, covariance=covariance_value + kernel, + self.X, + self.noise, + covariance=covariance_value, + **solver_kwargs, ) @property diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index 78a67749..71d91794 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -27,7 +27,7 @@ ] from abc import ABCMeta, abstractmethod -from typing import Optional, Union +from typing import Any, Optional, Union import jax import jax.numpy as jnp @@ -151,7 +151,7 @@ def __add__(self, other: Union["Kernel", JAXArray]) -> "Kernel": ) return Sum(self, other) - def __radd__(self, other: Union["Kernel", JAXArray]) -> "Kernel": + def __radd__(self, other: Any) -> "Kernel": # We'll hit this first branch when using the `sum` function if other == 0: return self @@ -171,7 +171,7 @@ def __mul__(self, other: Union["Kernel", JAXArray]) -> "Kernel": ) return Scale(kernel=self, scale=other) - def __rmul__(self, other: Union["Kernel", JAXArray]) -> "Kernel": + def __rmul__(self, other: Any) -> "Kernel": if isinstance(other, Quasisep): return Product(other, self) if isinstance(other, Kernel) or jnp.ndim(other) != 0: @@ -204,6 +204,9 @@ class Wrapper(Quasisep, metaclass=ABCMeta): kernel: Quasisep + def coord_to_sortable(self, X: JAXArray) -> JAXArray: + return self.kernel.coord_to_sortable(X) + def design_matrix(self) -> JAXArray: return self.kernel.design_matrix() @@ -226,6 +229,10 @@ class Sum(Quasisep): kernel1: Quasisep kernel2: Quasisep + def coord_to_sortable(self, X: JAXArray) -> JAXArray: + """We assume that both kernels use the same coordinates""" + return self.kernel1.coord_to_sortable(X) + def design_matrix(self) -> JAXArray: return jsp.linalg.block_diag( self.kernel1.design_matrix(), self.kernel2.design_matrix() @@ -259,6 +266,10 @@ class Product(Quasisep): kernel1: Quasisep kernel2: Quasisep + def coord_to_sortable(self, X: JAXArray) -> JAXArray: + """We assume that both kernels use the same coordinates""" + return self.kernel1.coord_to_sortable(X) + def design_matrix(self) -> JAXArray: F1 = self.kernel1.design_matrix() F2 = self.kernel2.design_matrix() @@ -699,14 +710,14 @@ def init( params = jnp.linalg.solve( params, 0.5 * sigma**2 * jnp.eye(p, 1, k=-p + 1) )[:, 0] - stn = [] + stn_ = [] for j in range(p): - stn.append([jnp.zeros(()) for _ in range(p)]) + stn_.append([jnp.zeros(()) for _ in range(p)]) for n, k in enumerate(range(j - 2, -1, -2)): - stn[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1] + stn_[-1][k] = (2 * (n % 2) - 1) * params[j - n - 1] for n, k in enumerate(range(j, p, 2)): - stn[-1][k] = (1 - 2 * (n % 2)) * params[n + j] - stn = jnp.array(list(map(jnp.stack, stn))) + stn_[-1][k] = (1 - 2 * (n % 2)) * params[n + j] + stn = jnp.array(list(map(jnp.stack, stn_))) return cls( sigma=sigma, diff --git a/src/tinygp/solvers/quasisep/solver.py b/src/tinygp/solvers/quasisep/solver.py index 01c23442..f6441d23 100644 --- a/src/tinygp/solvers/quasisep/solver.py +++ b/src/tinygp/solvers/quasisep/solver.py @@ -4,7 +4,7 @@ __all__ = ["QuasisepSolver"] -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import jax import jax.numpy as jnp @@ -41,6 +41,7 @@ def init( noise: Noise, *, covariance: Optional[Any] = None, + assume_sorted: bool = False, ) -> "QuasisepSolver": """Build a :class:`QuasisepSolver` for a given kernel and coordinates @@ -52,15 +53,24 @@ def init( covariance: Optionally, a pre-computed :class:`tinygp.solvers.quasisep.core.QSM` with the covariance matrix. + assume_sorted: If ``True``, assume that the input coordinates are + sorted. If ``False``, check that they are sorted and throw an + error if they are not. This can introduce a runtime overhead, + and you can pass ``assume_sorted=True`` to get the best + performance. """ from tinygp.kernels.quasisep import Quasisep if covariance is None: - assert isinstance(kernel, Quasisep) + if TYPE_CHECKING: + assert isinstance(kernel, Quasisep) + if not assume_sorted: + jax.debug.callback(_check_sorted, kernel.coord_to_sortable(X)) matrix = kernel.to_symm_qsm(X) matrix += noise.to_qsm() else: - assert isinstance(covariance, SymmQSM) + if TYPE_CHECKING: + assert isinstance(covariance, SymmQSM) matrix = covariance factor = matrix.cholesky() return cls(X=X, matrix=matrix, factor=factor) @@ -125,3 +135,10 @@ def condition( A = self.solve_triangular(Ks) return Kss - A.transpose() @ A + + +def _check_sorted(X: JAXArray) -> None: + if np.any(np.diff(X) < 0.0): + raise ValueError( + "Input coordinates must be sorted in order to use the QuasisepSolver" + ) diff --git a/tests/test_solvers/test_quasisep/test_solver.py b/tests/test_solvers/test_quasisep/test_solver.py index 4b3c3527..17443ebb 100644 --- a/tests/test_solvers/test_quasisep/test_solver.py +++ b/tests/test_solvers/test_quasisep/test_solver.py @@ -109,7 +109,7 @@ def test_consistent_with_direct(kernel_pair, data): @pytest.mark.skipif(celerite is None, reason="'celerite' must be installed") def test_celerite(data): - x, y, t = data + x, y, _ = data yerr = 0.1 a, b, c, d = 1.1, 0.8, 0.9, 0.1 @@ -125,3 +125,22 @@ def test_celerite(data): calc = gp.log_probability(y) np.testing.assert_allclose(calc, expected) + + +def test_unsorted(data): + random = np.random.default_rng(0) + inds = random.permutation(len(data[0])) + x_ = data[0][inds] + y_ = data[1][inds] + + kernel = quasisep.Matern32(sigma=1.8, scale=1.5) + with pytest.raises(ValueError): + GaussianProcess(kernel, x_, diag=0.1) + + @jax.jit + def impl(X, y): + return GaussianProcess(kernel, X, diag=0.1).log_probability(y) + + with pytest.raises(jax.lib.xla_extension.XlaRuntimeError) as exc_info: + impl(x_, y_).block_until_ready() + assert exc_info.match(r"Input coordinates must be sorted")