From e721ce0c03b82d062a802e551b752e7b338d66a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 8 Nov 2022 21:54:53 +0100 Subject: [PATCH 01/12] The prototype of the spiraling function. --- src/bmi/transforms/__init__.py | 0 src/bmi/transforms/rotate.py | 67 ++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 src/bmi/transforms/__init__.py create mode 100644 src/bmi/transforms/rotate.py diff --git a/src/bmi/transforms/__init__.py b/src/bmi/transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bmi/transforms/rotate.py b/src/bmi/transforms/rotate.py new file mode 100644 index 00000000..fd4e207d --- /dev/null +++ b/src/bmi/transforms/rotate.py @@ -0,0 +1,67 @@ +from typing import Optional + +import equinox as eqx +import jax.numpy as jnp +import numpy as np +from jax.scipy.linalg import expm +from numpy.typing import ArrayLike + + +class Spiral(eqx.Module): + """Represents the "spiraling" function + x -> R(x) x, + where R(x) is a matrix given by a product `initial` @ `rotation(x)`. + `initial` can be an arbitrary invertible matrix + and `rotation(x)` is an SO(n) element given by + exp(generator * ||x||^2), + where `generator` is an element of the so(n) Lie algebra, i.e., a skew-symmetric matrix. + + Example: + >>> a = np.array([[0, -1], [1, 0]]) + >>> spiral = Spiral(a, speed=np.pi/2) + >>> x = np.array([1, 0]) + >>> spiral(x) + DeviceArray([0., 1.]) + """ + + initial: jnp.ndarray + generator: jnp.ndarray + + def __init__( + self, generator: ArrayLike, initial: Optional[ArrayLike] = None, speed: float = 1.0 + ) -> None: + """ + + Args: + generator: a skew-symmetric matrix, an element of so(n) Lie algebra. Shape (n, n) + initial: an (n, n) matrix used to left-multiply the spiral. + Default (None) corresponds to the identity. + speed: for convenience, the passed `generator` will be scaled up by `speed` constant, + which (for a given `generator`) controls how quickly the spiral will wind + """ + self.generator = jnp.asarray(generator * speed) + + if len(self.generator.shape) != 2 or self.generator.shape[0] != self.generator.shape[1]: + raise ValueError(f"Generator has wrong shape {self.generator.shape}.") + + if initial is None: + self.initial = jnp.eye(self.generator.shape[0]) + else: + initial = np.asarray(initial) + if self.generator.shape != initial.shape: + raise ValueError( + f"Initial point has shape {initial.shape} while " + f"the generator has {self.generator.shape}." + ) + self.initial = jnp.asarray(initial) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """ + Args: + x: point in the Euclidean space, shape (n,) + + Returns: + transformation applied to `x`, shape (n,) + """ + r = jnp.einsum("i, i", x, x) # We have r = ||x||^2 + return self.initial @ expm(self.generator * r) @ x From f57bc3b33a583ba4b5dbb1bfc7ace1c18d5ef81f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Thu, 17 Nov 2022 21:10:14 +0100 Subject: [PATCH 02/12] Apply KSG estimator to a spiraled example. --- scripts/ksg_spiral.py | 60 +++++++++++++++++++++++++++++++++ src/bmi/samplers/api.py | 2 ++ src/bmi/samplers/transformed.py | 48 ++++++++++++++++++++++++++ src/bmi/transforms/rotate.py | 40 ++++++++++++++++++++++ 4 files changed, 150 insertions(+) create mode 100644 scripts/ksg_spiral.py create mode 100644 src/bmi/samplers/transformed.py diff --git a/scripts/ksg_spiral.py b/scripts/ksg_spiral.py new file mode 100644 index 00000000..11100b6d --- /dev/null +++ b/scripts/ksg_spiral.py @@ -0,0 +1,60 @@ +import numpy as np + +import bmi.api as bmi +import bmi.transforms.rotate as rot + + +def generate_covariance(correlation: float, dim_x: int, dim_y: int) -> np.ndarray: + """The correlation between the first dimension of X and the first dimension of Y is fixed. + + The rest of the covariance entries are zero. + """ + covariance = np.eye(dim_x + dim_y) + covariance[0, dim_x] = correlation + covariance[dim_x, 0] = correlation + return covariance + + +def main() -> None: + dim_x = 5 + dim_y = 2 + rho = 0.5 + n_points = 5000 + seed = 42 + + assert 0 <= rho < 1 + + covariance = generate_covariance(dim_x=dim_x, dim_y=dim_y, correlation=rho) + + base_sampler = bmi.samplers.SplitMultinormal( + dim_x=dim_x, + dim_y=dim_y, + covariance=covariance, + ) + + x_normal, y_normal = base_sampler.sample(n_points, rng=seed) + mi_true = base_sampler.mutual_information() + + mi_estimate_normal = bmi.estimators.KSGEnsembleFirstEstimator().estimate(x_normal, y_normal) + + print(f"{mi_true = :.3f}\t{mi_estimate_normal = :.3f}") + + generator = rot.so_generator(dim_x, i=0, j=1) + + for speed in [0.0, 0.02, 0.1, 0.5, 1.0, 10.0]: + transform_x = rot.Spiral(generator=generator, speed=speed) + transformed_sampler = bmi.samplers.TransformedSampler( + base_sampler=base_sampler, transform_x=transform_x + ) + + x_transformed, y_transformed = transformed_sampler.transform(x_normal, y_normal) + + mi_estimate_transformed = bmi.estimators.KSGEnsembleFirstEstimator().estimate( + x_transformed, y_transformed + ) + + print(f"{speed = :.2f}\t {mi_estimate_transformed = :.3f}") + + +if __name__ == "__main__": + main() diff --git a/src/bmi/samplers/api.py b/src/bmi/samplers/api.py index 21af2c7a..bbfa2fbd 100644 --- a/src/bmi/samplers/api.py +++ b/src/bmi/samplers/api.py @@ -1,7 +1,9 @@ from bmi.samplers.split_student_t import SplitStudentT from bmi.samplers.splitmultinormal import SplitMultinormal +from bmi.samplers.transformed import TransformedSampler __all__ = [ "SplitMultinormal", "SplitStudentT", + "TransformedSampler", ] diff --git a/src/bmi/samplers/transformed.py b/src/bmi/samplers/transformed.py new file mode 100644 index 00000000..6b38c4fb --- /dev/null +++ b/src/bmi/samplers/transformed.py @@ -0,0 +1,48 @@ +from typing import Callable, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np + +import bmi.samplers.base as base +from bmi.interface import ISampler, KeyArray + +SomeArray = Union[jnp.ndarray, np.ndarray] +Transform = Callable[[SomeArray], jnp.ndarray] + + +def identity(x: SomeArray) -> SomeArray: + return x + + +class TransformedSampler(base.BaseSampler): + def __init__( + self, + base_sampler: ISampler, + transform_x: Optional[Callable] = None, + transform_y: Optional[Callable] = None, + add_dim_x: int = 0, + add_dim_y: int = 0, + ) -> None: + super().__init__( + dim_x=base_sampler.dim_x + add_dim_x, dim_y=base_sampler.dim_y + add_dim_y + ) + + if transform_x is None: + transform_x = identity + if transform_y is None: + transform_y = identity + + self._vectorized_transform_x = jax.vmap(transform_x) + self._vectorized_transform_y = jax.vmap(transform_y) + self._base_sampler = base_sampler + + def transform(self, x: SomeArray, y: SomeArray) -> tuple[jnp.ndarray, jnp.ndarray]: + return self._vectorized_transform_x(x), self._vectorized_transform_y(y) + + def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: + x, y = self._base_sampler.sample(n_points=n_points, rng=rng) + return self.transform(x, y) + + def mutual_information(self) -> float: + return self._base_sampler.mutual_information() diff --git a/src/bmi/transforms/rotate.py b/src/bmi/transforms/rotate.py index fd4e207d..70aea9ef 100644 --- a/src/bmi/transforms/rotate.py +++ b/src/bmi/transforms/rotate.py @@ -65,3 +65,43 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """ r = jnp.einsum("i, i", x, x) # We have r = ||x||^2 return self.initial @ expm(self.generator * r) @ x + + +def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray: + """The (i,j)-th canonical generator of the so(n) Lie algebra. + + As so(n) Lie algebra is the vector space of all n x n + skew-symmetric matrices, we have a canonical basis + such that its (i,j)th vector is a matrix A such that + A[i, j] = 1, A[j, i] = -1, i < j + and all the other entries are 0. + + Note that there exist n(n-1)/2 such matrices. + + Args: + n: we use the Lie algebra so(n) + i: index in range {0, 1, ..., j-1} + j: index in range {i+1, i+2, ..., n-1} + + Returns: + array (n, n) + """ + assert n >= 2 + assert 0 <= i < j < n + + a = np.zeros((n, n)) + a[i, j] = 1 + a[j, i] = -1 + return a + + +def skew_symmetrize(a: np.ndarray) -> np.ndarray: + """The skew-symmetric part of a given matrix `a`. + + Args: + a: array, shape (n, n) + + Returns: + skew-symmetric part of `a`, shape (n, n) + """ + return 0.5 * (a - a.T) From 3a2934d974c97fc444e9392f6caceb8ca5a0c843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Thu, 17 Nov 2022 21:24:21 +0100 Subject: [PATCH 03/12] Add CLI to the script. --- scripts/ksg_spiral.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/scripts/ksg_spiral.py b/scripts/ksg_spiral.py index 11100b6d..9519c38b 100644 --- a/scripts/ksg_spiral.py +++ b/scripts/ksg_spiral.py @@ -1,13 +1,30 @@ +import argparse + import numpy as np import bmi.api as bmi import bmi.transforms.rotate as rot +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Experiment with applying the spiral diffeomorphism." + ) + parser.add_argument("--dim-x", type=int, default=3, help="Dimension of the X variable.") + parser.add_argument("--dim-y", type=int, default=2, help="Dimension of the Y variable.") + parser.add_argument("--rho", type=float, default=0.8, help="Correlation, between -1 and 1.") + parser.add_argument( + "--n-points", type=int, default=5000, help="Number of points to be generated." + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + return parser + + def generate_covariance(correlation: float, dim_x: int, dim_y: int) -> np.ndarray: """The correlation between the first dimension of X and the first dimension of Y is fixed. - The rest of the covariance entries are zero. + The rest of the covariance entries are zero, + of course except for variance of each dimension (the diagonal), which is 1. """ covariance = np.eye(dim_x + dim_y) covariance[0, dim_x] = correlation @@ -15,31 +32,32 @@ def generate_covariance(correlation: float, dim_x: int, dim_y: int) -> np.ndarra return covariance -def main() -> None: - dim_x = 5 - dim_y = 2 - rho = 0.5 - n_points = 5000 - seed = 42 - - assert 0 <= rho < 1 +def create_base_sampler(dim_x: int, dim_y: int, rho: float) -> bmi.samplers.SplitMultinormal: + assert -1 <= rho < 1 covariance = generate_covariance(dim_x=dim_x, dim_y=dim_y, correlation=rho) - base_sampler = bmi.samplers.SplitMultinormal( + return bmi.samplers.SplitMultinormal( dim_x=dim_x, dim_y=dim_y, covariance=covariance, ) - x_normal, y_normal = base_sampler.sample(n_points, rng=seed) + +def main() -> None: + parser = create_parser() + args = parser.parse_args() + + base_sampler = create_base_sampler(dim_x=args.dim_x, dim_y=args.dim_y, rho=args.rho) + + x_normal, y_normal = base_sampler.sample(args.n_points, rng=args.seed) mi_true = base_sampler.mutual_information() mi_estimate_normal = bmi.estimators.KSGEnsembleFirstEstimator().estimate(x_normal, y_normal) print(f"{mi_true = :.3f}\t{mi_estimate_normal = :.3f}") - generator = rot.so_generator(dim_x, i=0, j=1) + generator = rot.so_generator(args.dim_x, i=0, j=1) for speed in [0.0, 0.02, 0.1, 0.5, 1.0, 10.0]: transform_x = rot.Spiral(generator=generator, speed=speed) From afa4535e3fd9493d9371a3130fdaa1af89a6ed68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Thu, 17 Nov 2022 22:20:58 +0100 Subject: [PATCH 04/12] Make the script output better formatted --- scripts/ksg_spiral.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/ksg_spiral.py b/scripts/ksg_spiral.py index 9519c38b..e220b046 100644 --- a/scripts/ksg_spiral.py +++ b/scripts/ksg_spiral.py @@ -48,6 +48,8 @@ def main() -> None: parser = create_parser() args = parser.parse_args() + print(f"Settings:\n{args}") + base_sampler = create_base_sampler(dim_x=args.dim_x, dim_y=args.dim_y, rho=args.rho) x_normal, y_normal = base_sampler.sample(args.n_points, rng=args.seed) @@ -55,7 +57,11 @@ def main() -> None: mi_estimate_normal = bmi.estimators.KSGEnsembleFirstEstimator().estimate(x_normal, y_normal) - print(f"{mi_true = :.3f}\t{mi_estimate_normal = :.3f}") + print(f"True MI: {mi_true:.3f}") + print(f"KSG(X; Y) without distortion: {mi_estimate_normal:.3f}") + + print("-------------------") + print("speed\tKSG(spiral(X); Y)") generator = rot.so_generator(args.dim_x, i=0, j=1) @@ -71,7 +77,7 @@ def main() -> None: x_transformed, y_transformed ) - print(f"{speed = :.2f}\t {mi_estimate_transformed = :.3f}") + print(f"{speed:.2f}\t {mi_estimate_transformed:.3f}") if __name__ == "__main__": From 9766a162317e9a5594e0797739780d2fdf3684e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 11:36:21 +0100 Subject: [PATCH 05/12] Add the spiral diffeomorphism to the public API --- scripts/ksg_spiral.py | 6 +++--- src/bmi/api.py | 2 ++ src/bmi/transforms/api.py | 7 +++++++ tests/test_api.py | 9 ++++++++- 4 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 src/bmi/transforms/api.py diff --git a/scripts/ksg_spiral.py b/scripts/ksg_spiral.py index e220b046..70a7f01a 100644 --- a/scripts/ksg_spiral.py +++ b/scripts/ksg_spiral.py @@ -1,9 +1,9 @@ +"""Tests whether KSG estimator is invariant to the "spiral" diffeomorphism.""" import argparse import numpy as np import bmi.api as bmi -import bmi.transforms.rotate as rot def create_parser() -> argparse.ArgumentParser: @@ -63,10 +63,10 @@ def main() -> None: print("-------------------") print("speed\tKSG(spiral(X); Y)") - generator = rot.so_generator(args.dim_x, i=0, j=1) + generator = bmi.transforms.so_generator(args.dim_x, i=0, j=1) for speed in [0.0, 0.02, 0.1, 0.5, 1.0, 10.0]: - transform_x = rot.Spiral(generator=generator, speed=speed) + transform_x = bmi.transforms.Spiral(generator=generator, speed=speed) transformed_sampler = bmi.samplers.TransformedSampler( base_sampler=base_sampler, transform_x=transform_x ) diff --git a/src/bmi/api.py b/src/bmi/api.py index a5d8c91c..a5767922 100644 --- a/src/bmi/api.py +++ b/src/bmi/api.py @@ -2,9 +2,11 @@ import bmi.benchmark.api as benchmark import bmi.estimators.api as estimators import bmi.samplers.api as samplers +import bmi.transforms.api as transforms __all__ = [ "benchmark", "estimators", "samplers", + "transforms", ] diff --git a/src/bmi/transforms/api.py b/src/bmi/transforms/api.py new file mode 100644 index 00000000..209de2db --- /dev/null +++ b/src/bmi/transforms/api.py @@ -0,0 +1,7 @@ +from bmi.transforms.rotate import Spiral, skew_symmetrize, so_generator + +__all__ = [ + "Spiral", + "so_generator", + "skew_symmetrize", +] diff --git a/tests/test_api.py b/tests/test_api.py index a4946841..20c52c95 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,7 +5,14 @@ def test_api_imports() -> None: import bmi.api # noqa: F401 import not at the top of the file -@pytest.mark.parametrize("submodule", ["estimators", "samplers", "benchmark"]) +SUBMODULES = [ + "estimators", + "samplers", + "transforms", +] + + +@pytest.mark.parametrize("submodule", SUBMODULES) def test_api_exports_submodules(submodule: str) -> None: import bmi.api as bmi # noqa: F401 import not at the top of the file From 25996762500352ed9e8790bdd025e723686f3f54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 11:37:42 +0100 Subject: [PATCH 06/12] Add Equinox to the requirements. --- requirements.txt | 1 + setup.cfg | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 32c32741..95bab48a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ black +equinox flake8 isort jax diff --git a/setup.cfg b/setup.cfg index 5c6d1ef7..7d21ac91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ package_dir= packages=find: python requires = >= 3.9 install_requires = + equinox jax jaxlib numpy From 31b4d7508790e8aa8b285ecbad6cc575315b2e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 12:15:29 +0100 Subject: [PATCH 07/12] Improve type annotations. --- src/bmi/transforms/rotate.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/bmi/transforms/rotate.py b/src/bmi/transforms/rotate.py index 70aea9ef..e4a10717 100644 --- a/src/bmi/transforms/rotate.py +++ b/src/bmi/transforms/rotate.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, overload import equinox as eqx import jax.numpy as jnp @@ -84,7 +84,10 @@ def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray: j: index in range {i+1, i+2, ..., n-1} Returns: - array (n, n) + NumPy array (n, n) + + Note: + This function is NumPy based and is *not* JITtable. """ assert n >= 2 assert 0 <= i < j < n @@ -95,7 +98,17 @@ def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray: return a +@overload def skew_symmetrize(a: np.ndarray) -> np.ndarray: + pass + + +@overload +def skew_symmetrize(a: jnp.ndarray) -> jnp.ndarray: + pass + + +def skew_symmetrize(a): """The skew-symmetric part of a given matrix `a`. Args: @@ -103,5 +116,8 @@ def skew_symmetrize(a: np.ndarray) -> np.ndarray: Returns: skew-symmetric part of `a`, shape (n, n) + + Note: + This function is compatible with both NumPy and JAX NumPy. """ return 0.5 * (a - a.T) From fc4d9e958cb512abf5ecf28db808c5b937eb5514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 16:27:22 +0100 Subject: [PATCH 08/12] Add type annotations. --- src/bmi/samplers/transformed.py | 63 +++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/src/bmi/samplers/transformed.py b/src/bmi/samplers/transformed.py index 6b38c4fb..7534fa82 100644 --- a/src/bmi/samplers/transformed.py +++ b/src/bmi/samplers/transformed.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import Callable, Optional, TypeVar, Union import jax import jax.numpy as jnp @@ -10,12 +10,30 @@ SomeArray = Union[jnp.ndarray, np.ndarray] Transform = Callable[[SomeArray], jnp.ndarray] +_T = TypeVar("_T") -def identity(x: SomeArray) -> SomeArray: + +def identity(x: _T) -> _T: + """The identity mapping.""" return x class TransformedSampler(base.BaseSampler): + """Pushforward of a distribution P(X, Y) + via a product mapping + f x g. + + In other words, we have mutual information between f(X) and g(Y) + for some mappings f and g. + + Note: + By default we assume that f and g are diffeomorphisms, so that + I(f(X); g(Y)) = I(X; Y). + If you don't use diffeomorphisms (in particular, + non-default `add_dim_x` or `add_dim_y`), overwrite the + `mutual_information()` method + """ + def __init__( self, base_sampler: ISampler, @@ -24,6 +42,22 @@ def __init__( add_dim_x: int = 0, add_dim_y: int = 0, ) -> None: + """ + Args: + base_sampler: allows sampling from P(X, Y) + transform_x: diffeomorphism f, so that we have variable f(X). + By default the identity mapping. + transform_y: diffeomorphism g, so that we have variable g(Y). + By default the identity mapping. + add_dim_x: the difference in dimensions of the output of `f` and its input. + Note that for any diffeomorphism it must be zero + add_dim_y: similarly as `add_dim_x`, but for `g` + + Note: + If you don't use diffeomorphisms (in particular, + non-default `add_dim_x` or `add_dim_y`), overwrite the + `mutual_information()` method + """ super().__init__( dim_x=base_sampler.dim_x + add_dim_x, dim_y=base_sampler.dim_y + add_dim_y ) @@ -37,12 +71,37 @@ def __init__( self._vectorized_transform_y = jax.vmap(transform_y) self._base_sampler = base_sampler + # Boolean flag checking whether the dimension of each variable + # is preserved + self._dimensions_preserved: bool = (add_dim_x == 0) and (add_dim_y == 0) + def transform(self, x: SomeArray, y: SomeArray) -> tuple[jnp.ndarray, jnp.ndarray]: + """Transforms given samples by `f x g`. + + Args: + x: samples, (n_points, dim(X)) + y: samples, (n_points, dim(Y)) + + Returns: + f(x), shape (n_points, dim(X) + add_dim_x) + g(y), shape (n_points, dim(Y) + add_dim_y) + """ return self._vectorized_transform_x(x), self._vectorized_transform_y(y) def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: + """Samples from P(f(X), g(Y)). + + Returns: + paired samples + from f(X), shape (n_points, dim(X) + add_dim_x) and + from g(Y), shape (n_points, dim(Y) + add_dim_y) + """ x, y = self._base_sampler.sample(n_points=n_points, rng=rng) return self.transform(x, y) def mutual_information(self) -> float: + if not self._dimensions_preserved: + raise ValueError( + "The dimensions are not preserved." "The mutual information may be different." + ) return self._base_sampler.mutual_information() From c02a03079ad49b7938cd79b3e62c9b4787e7d68c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 17:41:55 +0100 Subject: [PATCH 09/12] Add tests for the transformed sampler. --- src/bmi/samplers/transformed.py | 2 +- tests/samplers/test_transformed.py | 118 +++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 tests/samplers/test_transformed.py diff --git a/src/bmi/samplers/transformed.py b/src/bmi/samplers/transformed.py index 7534fa82..a3e5a631 100644 --- a/src/bmi/samplers/transformed.py +++ b/src/bmi/samplers/transformed.py @@ -102,6 +102,6 @@ def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, def mutual_information(self) -> float: if not self._dimensions_preserved: raise ValueError( - "The dimensions are not preserved." "The mutual information may be different." + "The dimensions are not preserved. The mutual information may be different." ) return self._base_sampler.mutual_information() diff --git a/tests/samplers/test_transformed.py b/tests/samplers/test_transformed.py new file mode 100644 index 00000000..ae3026e0 --- /dev/null +++ b/tests/samplers/test_transformed.py @@ -0,0 +1,118 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +import bmi.samplers.api as samplers +import bmi.samplers.transformed as tr + + +@pytest.mark.parametrize("x", [2, 120, 5.0, "something"]) +def test_identity_various_objects(x) -> None: + assert x == tr.identity(x) + + +@pytest.mark.parametrize("x", [np.asarray([1.0, 4.0]), np.eye(4)]) +def test_identity_array(x) -> None: + assert np.allclose(x, tr.identity(x)) + + +def get_gaussian_sampler( + dim_x: int = 2, dim_y: int = 3, corr: float = 0.5 +) -> samplers.SplitMultinormal: + """Auxiliary function creating a base sampler.""" + cov = np.eye(dim_x + dim_y) + cov[0, dim_x] = corr + cov[dim_x, 0] = corr + return samplers.SplitMultinormal(dim_x=dim_x, dim_y=dim_y, covariance=cov) + + +@pytest.mark.parametrize("dim_x", [3]) +@pytest.mark.parametrize("dim_y", [2]) +@pytest.mark.parametrize("corr", [0.3, 0.5]) +@pytest.mark.parametrize("n_points", [10, 20]) +def test_transformed_identity( + dim_x: int, dim_y: int, corr: float, n_points: int, random_seed: int = 0 +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + + transformed_sampler = tr.TransformedSampler(base_sampler=base_sampler) + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + assert transformed_sampler.mutual_information() == pytest.approx( + base_sampler.mutual_information() + ) + + x_base, y_base = base_sampler.sample(n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.sample(n_points, rng=random_seed) + + assert np.allclose(x_base, x_transformed) + assert np.allclose(y_base, y_transformed) + + x_transformed_new, y_transformed_new = transformed_sampler.transform(x_base, y_base) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + +def cubic(x: np.ndarray) -> np.ndarray: + return x**3 + + +def test_transformed_cubic( + dim_x: int = 5, dim_y: int = 3, corr: float = 0.4, n_points: int = 100, random_seed: int = 12 +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + transformed_sampler = tr.TransformedSampler(base_sampler, transform_x=cubic) + + x_base, y_base = base_sampler.sample(n_points=n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.sample(n_points, rng=random_seed) + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + + assert np.allclose(x_transformed, np.asarray([cubic(x) for x in x_base])) + assert np.allclose(y_base, y_transformed) + + x_transformed_new, y_transformed_new = transformed_sampler.transform(x_base, y_base) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + +def embed(n: int): + return lambda x: jnp.concatenate([x, jnp.zeros(n)]) + + +def test_change_dimension( + add_dim_y: int = 3, + dim_x: int = 2, + dim_y: int = 3, + corr: float = 0.4, + n_points: int = 50, + random_seed: int = 12, +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + transformed_sampler = tr.TransformedSampler( + base_sampler=base_sampler, + transform_x=cubic, + transform_y=embed(add_dim_y), + add_dim_y=add_dim_y, + ) + + with pytest.raises(ValueError): + transformed_sampler.mutual_information() + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + add_dim_y + + x_base, y_base = base_sampler.sample(n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.transform(x_base, y_base) + + x_transformed_new, y_transformed_new = transformed_sampler.sample(n_points, rng=random_seed) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + assert x_transformed.shape == (n_points, transformed_sampler.dim_x) + assert y_transformed.shape == (n_points, transformed_sampler.dim_y) + + assert np.allclose(x_transformed, np.asarray([cubic(x) for x in x_base])) + assert np.allclose(y_transformed, np.asarray([embed(add_dim_y)(y) for y in y_base])) From 1dc8e24d47bf6f59265529663d6add5c00434f7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 17:54:21 +0100 Subject: [PATCH 10/12] Tests for skew-symmetrization. --- tests/transforms/test_rotate.py | 46 +++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/transforms/test_rotate.py diff --git a/tests/transforms/test_rotate.py b/tests/transforms/test_rotate.py new file mode 100644 index 00000000..cd7bb02f --- /dev/null +++ b/tests/transforms/test_rotate.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest + +import bmi.transforms.rotate as rt + + +class TestSpiral: + pass + + +class TestSkewSymmetrize: + @pytest.mark.parametrize("n", (3, 4, 5)) + @pytest.mark.parametrize("k", (3, 8)) + @pytest.mark.parametrize("seed", range(3)) + def test_symmetric_zero(self, n: int, k: int, seed: int) -> None: + rng = np.random.default_rng(seed) + w = rng.normal(size=(n, k)) + a = w @ w.T + + s = rt.skew_symmetrize(a) + assert s.shape == a.shape + assert np.allclose(s, np.zeros_like(s)) + + @pytest.mark.parametrize("n", (3, 4, 5)) + @pytest.mark.parametrize("seed", range(3)) + def test_fixed_point(self, n: int, seed: int) -> None: + """Check whether s(s(A)) = s(A) for any matrix A""" + rng = np.random.default_rng(seed) + a = rng.normal(size=(n, n)) + + s_a = rt.skew_symmetrize(a) + s_s_a = rt.skew_symmetrize(s_a) + assert np.allclose(s_a, s_s_a) + + @pytest.mark.parametrize("n", (2, 3, 7)) + @pytest.mark.parametrize("seed", range(2)) + def test_skew_symmetric(self, n: int, seed: int) -> None: + rng = np.random.default_rng(seed) + a = rng.normal(size=(n, n)) + s_a = rt.skew_symmetrize(a) + + assert np.allclose(s_a.T, -s_a) + + +class TestSOGenerator: + pass From 704d66dd6f6301bf49d7e6cfd280ef002a844d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 18:13:34 +0100 Subject: [PATCH 11/12] Add tests for the so(n) generators. --- src/bmi/transforms/rotate.py | 6 +++-- tests/transforms/test_rotate.py | 47 ++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/bmi/transforms/rotate.py b/src/bmi/transforms/rotate.py index e4a10717..0f482e6c 100644 --- a/src/bmi/transforms/rotate.py +++ b/src/bmi/transforms/rotate.py @@ -89,8 +89,10 @@ def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray: Note: This function is NumPy based and is *not* JITtable. """ - assert n >= 2 - assert 0 <= i < j < n + if n < 2: + raise ValueError(f"{n = } needs to be at least 2.") + if not (0 <= i < j < n): + raise ValueError(f"Index is wrong: {n = } {i = } {j = }.") a = np.zeros((n, n)) a[i, j] = 1 diff --git a/tests/transforms/test_rotate.py b/tests/transforms/test_rotate.py index cd7bb02f..cc9ad7d6 100644 --- a/tests/transforms/test_rotate.py +++ b/tests/transforms/test_rotate.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pytest @@ -41,6 +42,50 @@ def test_skew_symmetric(self, n: int, seed: int) -> None: assert np.allclose(s_a.T, -s_a) + def test_jittable(self) -> None: + jax.jit(rt.skew_symmetrize) + + @pytest.mark.parametrize("dim", (2, 5)) + @pytest.mark.parametrize("n_matrices", (3, 4)) + def test_vmappable(self, dim: int, n_matrices: int, seed: int = 0) -> None: + rng = np.random.default_rng(seed) + matrices = rng.normal(size=(n_matrices, dim, dim)) + s_matrices = jax.vmap(rt.skew_symmetrize)(matrices) + assert np.allclose(s_matrices, np.asarray([rt.skew_symmetrize(m) for m in matrices])) + class TestSOGenerator: - pass + def test_2d_example(self) -> None: + expected = np.asarray([[0, 1], [-1, 0]]) + assert np.allclose(rt.so_generator(2, 0, 1), expected) + + def test_3d_example(self) -> None: + expected = np.asarray([[0, 1, 0], [-1, 0, 0], [0, 0, 0]]) + assert np.allclose(rt.so_generator(3, 0, 1), expected) + + @pytest.mark.parametrize("n", (2, 4)) + def test_indices_right_order(self, n: int) -> None: + for i in range(n): + for j in range(i + 1): + with pytest.raises(ValueError): + rt.so_generator(n, i, j) + + @pytest.mark.parametrize("n", (2, 3, 5)) + def test_basic_statistic(self, n: int) -> None: + """Checks if all entries are positive and sum up to 1.""" + for i in range(n): + for j in range(i + 1, n): + a = rt.so_generator(n, i, j) + assert np.allclose(a, -a.T) + + assert np.min(a) == pytest.approx(-1) + assert np.sum(a) == pytest.approx(0) + assert np.max(a) == pytest.approx(1) + + def test_raises(self) -> None: + with pytest.raises(ValueError): + rt.so_generator(3, 1, 1) + with pytest.raises(ValueError): + rt.so_generator(3, 4, 1) + with pytest.raises(ValueError): + rt.so_generator(2, 0, 0) From 170e590e12e037db22251f23c2144cfbf99130fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Sat, 19 Nov 2022 18:19:16 +0100 Subject: [PATCH 12/12] Mark tests for spiral as not implemented. --- tests/transforms/test_rotate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/transforms/test_rotate.py b/tests/transforms/test_rotate.py index cc9ad7d6..35cb24eb 100644 --- a/tests/transforms/test_rotate.py +++ b/tests/transforms/test_rotate.py @@ -1,3 +1,5 @@ +import warnings + import jax import numpy as np import pytest @@ -6,7 +8,11 @@ class TestSpiral: - pass + warnings.warn("Tests for the Spiral are not ready.") + + @pytest.mark.skip("Tests for the Spiral are not ready.") + def test_spiral(self) -> None: + raise NotImplementedError class TestSkewSymmetrize: