Skip to content

Commit

Permalink
Adding equinox as dependency and switching all baseclasses to modules (
Browse files Browse the repository at this point in the history
…#200)

* Adding equinox as dependency

* adding news

* Fixing tutorial syntax

* Add comment about old init function

* Refactoring tests to handle float32 precision

* reverting CARMA init

* reverting CARMA init

* Xfail polynomial kernel tests

* adding tests for kernel pytrees
  • Loading branch information
dfm committed Jan 5, 2024
1 parent d17ae68 commit 1e798b7
Show file tree
Hide file tree
Showing 35 changed files with 588 additions and 563 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ jobs:
matrix:
python-version: ["3.9", "3.10", "3.11"]
nox-session: ["test"]
x64: ["1"]
include:
- python-version: "3.10"
nox-session: "test"
x64: "0"
- python-version: "3.10"
nox-session: "comparison"
x64: "1"
- python-version: "3.10"
nox-session: "doctest"
x64: "1"

steps:
- name: Checkout
Expand All @@ -36,6 +44,8 @@ jobs:
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
--session ${{ matrix.nox-session }} --python ${{ matrix.python-version }}
env:
JAX_ENABLE_X64: ${{ matrix.x64 }}

build:
runs-on: ubuntu-latest
Expand Down
13 changes: 8 additions & 5 deletions docs/tutorials/derivative.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@
"\n",
"\n",
"class DerivativeKernel(tinygp.kernels.Kernel):\n",
" def __init__(self, kernel):\n",
" self.kernel = kernel\n",
" kernel: tinygp.kernels.Kernel\n",
"\n",
" def evaluate(self, X1, X2):\n",
" t1, d1 = X1\n",
Expand Down Expand Up @@ -301,6 +300,10 @@
" shape as ``coeff_prim``.\n",
" \"\"\"\n",
"\n",
" kernel: tinygp.kernels.Kernel\n",
" coeff_prim: jax.Array\n",
" coeff_deriv: jax.Array\n",
"\n",
" def __init__(self, kernel, coeff_prim, coeff_deriv):\n",
" self.kernel = kernel\n",
" self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(\n",
Expand Down Expand Up @@ -497,7 +500,7 @@
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
},
"kernelspec": {
"display_name": "Python 3.9.9 ('tinygp')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -511,9 +514,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
9 changes: 4 additions & 5 deletions docs/tutorials/kernels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@
"\n",
"\n",
"class SpectralMixture(tinygp.kernels.Kernel):\n",
" def __init__(self, weight, scale, freq):\n",
" self.weight = jnp.atleast_1d(weight)\n",
" self.scale = jnp.atleast_1d(scale)\n",
" self.freq = jnp.atleast_1d(freq)\n",
" weight: jax.Array\n",
" scale: jax.Array\n",
" freq: jax.Array\n",
"\n",
" def evaluate(self, X1, X2):\n",
" tau = jnp.atleast_1d(jnp.abs(X1 - X2))[..., None]\n",
Expand Down Expand Up @@ -210,7 +209,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tinygp",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
2 changes: 1 addition & 1 deletion news/188.bugfix
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant
Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant.
1 change: 1 addition & 0 deletions news/200.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Switched all base classes to `equinox.Module <https://docs.kidger.site/equinox/api/module/module/>`_ objects to simplify dataclass handling.
6 changes: 6 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
@nox.session(python=PYTHON_VERSIONS)
def test(session: nox.Session) -> None:
session.install(".[test]")
session.run("pytest", *session.posargs)


@nox.session(python=PYTHON_VERSIONS)
def comparison(session: nox.Session) -> None:
session.install(".[test,comparison]")
session.run("pytest", *session.posargs, env={"JAX_ENABLE_X64": "1"})


Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dynamic = ["version"]
dependencies = ["jax", "jaxlib"]
dependencies = ["jax", "jaxlib", "equinox"]

[project.optional-dependencies]
test = ["pytest", "george", "celerite"]
test = ["pytest"]
comparison = ["george", "celerite"]
docs = [
"sphinx-book-theme",
"myst-nb",
Expand Down
35 changes: 25 additions & 10 deletions src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@
NamedTuple,
)

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np

from tinygp import kernels, means
from tinygp.helpers import JAXArray
from tinygp.kernels.quasisep import Quasisep
from tinygp.noise import Diagonal, Noise
from tinygp.solvers import DirectSolver, QuasisepSolver
from tinygp.solvers.quasisep.core import SymmQSM
from tinygp.solvers.solver import Solver

if TYPE_CHECKING:
from tinygp.numpyro_support import TinyDistribution


class GaussianProcess:
class GaussianProcess(eqx.Module):
"""An interface for designing a Gaussian Process regression model
Args:
Expand All @@ -50,14 +53,23 @@ class GaussianProcess:
algebra.
"""

num_data: int = eqx.field(static=True)
dtype: np.dtype = eqx.field(static=True)
kernel: kernels.Kernel
X: JAXArray
mean_function: means.MeanBase
mean: JAXArray
noise: Noise
solver: Solver

def __init__(
self,
kernel: kernels.Kernel,
X: JAXArray,
*,
diag: JAXArray | None = None,
noise: Noise | None = None,
mean: Callable[[JAXArray], JAXArray] | JAXArray | None = None,
mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
solver: Any | None = None,
mean_value: JAXArray | None = None,
covariance_value: Any | None = None,
Expand All @@ -66,7 +78,7 @@ def __init__(
self.kernel = kernel
self.X = X

if callable(mean):
if isinstance(mean, means.MeanBase):
self.mean_function = mean
elif mean is None:
self.mean_function = means.Mean(jnp.zeros(()))
Expand All @@ -76,7 +88,7 @@ def __init__(
mean_value = jax.vmap(self.mean_function)(self.X)
self.num_data = mean_value.shape[0]
self.dtype = mean_value.dtype
self.loc = self.mean = mean_value
self.mean = mean_value
if self.mean.ndim != 1:
raise ValueError(
"Invalid mean shape: " f"expected ndim = 1, got ndim={self.mean.ndim}"
Expand All @@ -92,14 +104,18 @@ def __init__(
solver = QuasisepSolver
else:
solver = DirectSolver
self.solver = solver.init(
self.solver = solver(
kernel,
self.X,
self.noise,
covariance=covariance_value,
**solver_kwargs,
)

@property
def loc(self) -> JAXArray:
return self.mean

@property
def variance(self) -> JAXArray:
return self.solver.variance()
Expand Down Expand Up @@ -209,7 +225,6 @@ def condition(

@partial(
jax.jit,
static_argnums=(0,),
static_argnames=("include_mean", "return_var", "return_cov"),
)
def predict(
Expand Down Expand Up @@ -281,7 +296,7 @@ def numpyro_dist(self, **kwargs: Any) -> TinyDistribution:

return TinyDistribution(self, **kwargs)

@partial(jax.jit, static_argnums=(0, 2))
@partial(jax.jit, static_argnums=(2,))
def _sample(
self,
key: jax.random.KeyArray,
Expand All @@ -296,16 +311,16 @@ def _sample(
self.solver.dot_triangular(normal_samples), 0, -1
)

@partial(jax.jit, static_argnums=0)
@jax.jit
def _compute_log_prob(self, alpha: JAXArray) -> JAXArray:
loglike = -0.5 * jnp.sum(jnp.square(alpha)) - self.solver.normalization()
return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)

@partial(jax.jit, static_argnums=0)
@jax.jit
def _get_alpha(self, y: JAXArray) -> JAXArray:
return self.solver.solve_triangular(y - self.loc)

@partial(jax.jit, static_argnums=(0, 3))
@partial(jax.jit, static_argnums=(3,))
def _condition(
self,
y: JAXArray,
Expand Down
69 changes: 7 additions & 62 deletions src/tinygp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,18 @@

__all__ = ["JAXArray", "dataclass", "field"]

import dataclasses
from typing import Any, Callable, TypeVar, Union
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np

JAXArray = Union[np.ndarray, jnp.ndarray]
JAXArray = jax.Array

# This section is based closely on the implementation in flax:
#
# https://github.com/google/flax/blob/b60f7f45b90f8fc42a88b1639c9cc88a40b298d3/flax/struct.py
#
# This decorator is interpreted by static analysis tools as a hint
# that a decorator or metaclass causes dataclass-like behavior.
# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
# for more information about the __dataclass_transform__ magic.
_T = TypeVar("_T")

# The following is just for backwards compatibility since tinygp used to provide a
# custom dataclass implementation
field = eqx.field

def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: tuple[type | Callable[..., Any], ...] = (()),
) -> Callable[[_T], _T]:
# If used within a stub file, the following implementation can be
# replaced with "...".
return lambda a: a


@__dataclass_transform__()
def dataclass(clz: type[Any]) -> type[Any]:
data_clz: Any = dataclasses.dataclass(frozen=True)(clz)
meta_fields = []
data_fields = []
for name, field_info in data_clz.__dataclass_fields__.items():
is_pytree_node = field_info.metadata.get("pytree_node", True)
if is_pytree_node:
data_fields.append(name)
else:
meta_fields.append(name)

def replace(self: Any, **updates: _T) -> _T:
return dataclasses.replace(self, **updates)

data_clz.replace = replace

def iterate_clz(x: Any) -> tuple[tuple[Any, ...], tuple[Any, ...]]:
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple(getattr(x, name) for name in data_fields)
return data, meta

def clz_from_iterable(meta: tuple[Any, ...], data: tuple[Any, ...]) -> Any:
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)

jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable)

# Hack to make this class act as a tuple when unpacked
data_clz.iter_elems = lambda self: iterate_clz(self)[0].__iter__()

return data_clz


def field(pytree_node: bool = True, **kwargs: Any) -> Any:
return dataclasses.field(metadata={"pytree_node": pytree_node}, **kwargs)
return clz

0 comments on commit 1e798b7

Please sign in to comment.