Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/tinygp/solvers/quasisep/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,16 @@ def transpose(self) -> "Block":
def T(self) -> "Block":
return self.transpose()

@property
def mT(self) -> "Block":
return Block(*(jnp.swapaxes(b, -1, -2) for b in self.blocks))

def to_dense(self) -> JAXArray:
assert all(np.ndim(b) == 2 for b in self.blocks)
return block_diag(*self.blocks)
ndim = self.ndim
assert ndim >= 2
if ndim == 2:
return block_diag(*self.blocks)
return jax.vmap(lambda *bs: Block(*bs).to_dense())(*self.blocks)

@jax.jit
def __mul__(self, other: Any) -> "Block":
Expand Down Expand Up @@ -98,16 +105,17 @@ def __matmul__(self, other: Any) -> Any:
assert len(self.blocks) == len(other.blocks)
assert all(
np.shape(b1) == np.shape(b2)
for b1, b2 in zip(self.blocks, other.blocks)
for b1, b2 in zip(self.blocks, other.blocks, strict=True)
)
return Block(
*(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks, strict=True))
)
return Block(*(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks)))
assert all(np.ndim(b) == 2 for b in self.blocks)
ndim = np.ndim(other)
assert ndim >= 1
idx = 0
ys = []
for b in self.blocks:
size = len(b)
size = np.shape(b)[-1]
x = (
other[idx : idx + size]
if ndim == 1
Expand All @@ -119,11 +127,10 @@ def __matmul__(self, other: Any) -> Any:

@jax.jit
def __rmatmul__(self, other: Any) -> Any:
assert all(np.ndim(b) == 2 for b in self.blocks)
idx = 0
ys = []
for b in self.blocks:
size = len(b)
size = np.shape(b)[-2]
x = other[..., idx : idx + size]
ys.append(x @ b)
idx += size
Expand Down
157 changes: 62 additions & 95 deletions src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@


def handle_matvec_shapes(
func: Callable[[Any, JAXArray], JAXArray],
) -> Callable[[Any, JAXArray], JAXArray]:
func: Callable[..., JAXArray],
) -> Callable[..., JAXArray]:
@wraps(func)
def wrapped(self: Any, x: JAXArray) -> JAXArray:
def wrapped(self: Any, x: JAXArray, **kwargs: Any) -> JAXArray:
output_shape = x.shape
result = func(self, jnp.reshape(x, (output_shape[0], -1)))
result = func(self, jnp.reshape(x, (output_shape[0], -1)), **kwargs)
return jnp.reshape(result, output_shape)

return wrapped
Expand All @@ -61,12 +61,14 @@ def transpose(self) -> Any:
raise NotImplementedError

@abstractmethod
def matmul(self, x: JAXArray) -> JAXArray:
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
"""The dot product of this matrix with a dense vector or matrix

Args:
x (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm
instead of the default sequential scan.
"""
raise NotImplementedError

Expand Down Expand Up @@ -149,7 +151,8 @@ def transpose(self) -> DiagQSM:
return self

@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
del parallel
return self.d[:, None] * x

def scale(self, other: JAXArray) -> DiagQSM:
Expand Down Expand Up @@ -188,16 +191,12 @@ def shape(self) -> tuple[int, int]:
def transpose(self) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=self.p, q=self.q, a=self.a)

@jax.jit
@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
def impl(f, data): # type: ignore
q, a, x = data
return a @ f + jnp.outer(q, x), f
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
from tinygp.solvers.quasisep.ops import lower_matmul, lower_matmul_parallel

init = jnp.zeros_like(jnp.outer(self.q[0], x[0]))
_, f = jax.lax.scan(impl, init, (self.q, self.a, x))
return jax.vmap(jnp.dot)(self.p, f)
impl = lower_matmul_parallel if parallel else lower_matmul
return impl(self.p, self.q, self.a, x)

def scale(self, other: JAXArray) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a)
Expand Down Expand Up @@ -265,16 +264,12 @@ def shape(self) -> tuple[int, int]:
def transpose(self) -> StrictLowerTriQSM:
return StrictLowerTriQSM(p=self.p, q=self.q, a=self.a)

@jax.jit
@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
def impl(f, data): # type: ignore
p, a, x = data
return a.T @ f + jnp.outer(p, x), f
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
from tinygp.solvers.quasisep.ops import upper_matmul, upper_matmul_parallel

init = jnp.zeros_like(jnp.outer(self.p[-1], x[-1]))
_, f = jax.lax.scan(impl, init, (self.p, self.a, x), reverse=True)
return jax.vmap(jnp.dot)(self.q, f)
impl = upper_matmul_parallel if parallel else upper_matmul
return impl(self.p, self.q, self.a, x)

def scale(self, other: JAXArray) -> StrictUpperTriQSM:
return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a)
Expand Down Expand Up @@ -306,8 +301,8 @@ def transpose(self) -> UpperTriQSM:
return UpperTriQSM(diag=self.diag, upper=self.lower.transpose())

@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x)
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x, parallel=parallel)

def scale(self, other: JAXArray) -> LowerTriQSM:
return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))
Expand All @@ -321,9 +316,8 @@ def inv(self) -> LowerTriQSM:
b = a - jax.vmap(jnp.outer)(v, p)
return LowerTriQSM(diag=DiagQSM(g), lower=StrictLowerTriQSM(p=u, q=v, a=b))

@jax.jit
@handle_matvec_shapes
def solve(self, y: JAXArray) -> JAXArray:
def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray:
"""Solve a linear system with this matrix

If this matrix is called ``L``, this solves ``L @ x = y`` for ``x``
Expand All @@ -332,16 +326,14 @@ def solve(self, y: JAXArray) -> JAXArray:
Args:
y (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import lower_solve, lower_solve_parallel

def impl(fn, data): # type: ignore
((cn,), (pn, wn, an)), yn = data
xn = (yn - pn @ fn) / cn
return an @ fn + jnp.outer(wn, xn), xn

init = jnp.zeros_like(jnp.outer(self.lower.q[0], y[0]))
_, x = jax.lax.scan(impl, init, (self, y))
return x
(d,) = self.diag
p, q, a = self.lower
impl = lower_solve_parallel if parallel else lower_solve
return impl(d, p, q, a, y)

def __neg__(self) -> LowerTriQSM:
return LowerTriQSM(diag=-self.diag, lower=-self.lower)
Expand All @@ -362,18 +354,17 @@ def transpose(self) -> LowerTriQSM:
return LowerTriQSM(diag=self.diag, lower=self.upper.transpose())

@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.upper.matmul(x)
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return self.diag.matmul(x) + self.upper.matmul(x, parallel=parallel)

def scale(self, other: JAXArray) -> UpperTriQSM:
return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other))

def inv(self) -> UpperTriQSM:
return self.transpose().inv().transpose()

@jax.jit
@handle_matvec_shapes
def solve(self, y: JAXArray) -> JAXArray:
def solve(self, y: JAXArray, *, parallel: bool = False) -> JAXArray:
"""Solve a linear system with this matrix

If this matrix is called ``U``, this solves ``U @ x = y`` for ``x``
Expand All @@ -382,16 +373,14 @@ def solve(self, y: JAXArray) -> JAXArray:
Args:
y (n, ...): A matrix or vector with leading dimension matching this
matrix.
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import upper_solve, upper_solve_parallel

def impl(fn, data): # type: ignore
((cn,), (pn, wn, an)), yn = data
xn = (yn - wn @ fn) / cn
return an.T @ fn + jnp.outer(pn, xn), xn

init = jnp.zeros_like(jnp.outer(self.upper.p[-1], y[-1]))
_, x = jax.lax.scan(impl, init, (self, y), reverse=True)
return x
(d,) = self.diag
p, q, a = self.upper
impl = upper_solve_parallel if parallel else upper_solve
return impl(d, p, q, a, y)

def __neg__(self) -> UpperTriQSM:
return UpperTriQSM(diag=-self.diag, upper=-self.upper)
Expand All @@ -418,8 +407,12 @@ def transpose(self) -> SquareQSM:
)

@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
return self.diag.matmul(x) + self.lower.matmul(x) + self.upper.matmul(x)
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return (
self.diag.matmul(x)
+ self.lower.matmul(x, parallel=parallel)
+ self.upper.matmul(x, parallel=parallel)
)

def scale(self, other: JAXArray) -> SquareQSM:
return SquareQSM(
Expand Down Expand Up @@ -504,71 +497,45 @@ def transpose(self) -> SymmQSM:
return self

@handle_matvec_shapes
def matmul(self, x: JAXArray) -> JAXArray:
def matmul(self, x: JAXArray, *, parallel: bool = False) -> JAXArray:
return (
self.diag.matmul(x)
+ self.lower.matmul(x)
+ self.lower.transpose().matmul(x)
+ self.lower.matmul(x, parallel=parallel)
+ self.lower.transpose().matmul(x, parallel=parallel)
)

def scale(self, other: JAXArray) -> SymmQSM:
return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other))

@jax.jit
def inv(self) -> SymmQSM:
"""The inverse of this matrix"""
(d,) = self.diag
p, q, a = self.lower
def inv(self, *, parallel: bool = False) -> SymmQSM:
"""The inverse of this matrix

def forward(carry, data): # type: ignore
f = carry
dk, pk, qk, ak = data
fpk = f @ pk
left = qk - ak @ fpk
igk = 1 / (dk - pk @ fpk)
sk = igk * left
ellk = ak - jnp.outer(sk, pk)
fk = ak @ f @ ak.T + igk * jnp.outer(left, left.T)
return fk, (igk, sk, ellk)

init = jnp.zeros_like(jnp.outer(q[0], q[0]))
ig, s, ell = jax.lax.scan(forward, init, (d, p, q, a))[1]
Args:
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import symm_inv, symm_inv_parallel

def backward(carry, data): # type: ignore
z = carry
igk, pk, ak, sk = data
zak = z @ ak
skzak = sk @ zak
lk = igk + sk @ z @ sk
tk = skzak - lk * pk
zk = ak.T @ zak - jnp.outer(skzak, pk) - jnp.outer(pk, tk)
return zk, (lk, tk)

init = jnp.zeros_like(jnp.outer(p[-1], p[-1]))
lam, t = jax.lax.scan(backward, init, (ig, p, a, s), reverse=True)[1]
(d,) = self.diag
p, q, a = self.lower
impl = symm_inv_parallel if parallel else symm_inv
lam, t, s, ell = impl(d, p, q, a)
return SymmQSM(diag=DiagQSM(d=lam), lower=StrictLowerTriQSM(p=t, q=s, a=ell))

@jax.jit
def cholesky(self) -> LowerTriQSM:
def cholesky(self, *, parallel: bool = False) -> LowerTriQSM:
"""The Cholesky decomposition of this matrix

If this matrix is called ``A``, this method returns the
:class:`LowerTriQSM` ``L`` such that ``L @ L.T = A``.

Args:
parallel: If ``True``, use a parallel associative-scan algorithm.
"""
from tinygp.solvers.quasisep.ops import cholesky, cholesky_parallel

(d,) = self.diag
p, q, a = self.lower

def impl(carry, data): # type: ignore
fp = carry
dk, pk, qk, ak = data
ck = jnp.sqrt(dk - pk @ fp @ pk)
tmp = fp @ ak.T
wk = (qk - pk @ tmp) / ck
fk = ak @ tmp + jnp.outer(wk, wk)
return fk, (ck, wk)

init = jnp.zeros_like(jnp.outer(q[0], q[0]))
_, (c, w) = jax.lax.scan(impl, init, (d, p, q, a))
impl = cholesky_parallel if parallel else cholesky
c, w = impl(d, p, q, a)
return LowerTriQSM(diag=DiagQSM(c), lower=StrictLowerTriQSM(p=p, q=w, a=a))

def __neg__(self) -> SymmQSM:
Expand Down
Loading