diff --git a/requirements.txt b/requirements.txt index 32c32741..95bab48a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ black +equinox flake8 isort jax diff --git a/scripts/ksg_spiral.py b/scripts/ksg_spiral.py new file mode 100644 index 00000000..70a7f01a --- /dev/null +++ b/scripts/ksg_spiral.py @@ -0,0 +1,84 @@ +"""Tests whether KSG estimator is invariant to the "spiral" diffeomorphism.""" +import argparse + +import numpy as np + +import bmi.api as bmi + + +def create_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Experiment with applying the spiral diffeomorphism." + ) + parser.add_argument("--dim-x", type=int, default=3, help="Dimension of the X variable.") + parser.add_argument("--dim-y", type=int, default=2, help="Dimension of the Y variable.") + parser.add_argument("--rho", type=float, default=0.8, help="Correlation, between -1 and 1.") + parser.add_argument( + "--n-points", type=int, default=5000, help="Number of points to be generated." + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + return parser + + +def generate_covariance(correlation: float, dim_x: int, dim_y: int) -> np.ndarray: + """The correlation between the first dimension of X and the first dimension of Y is fixed. + + The rest of the covariance entries are zero, + of course except for variance of each dimension (the diagonal), which is 1. + """ + covariance = np.eye(dim_x + dim_y) + covariance[0, dim_x] = correlation + covariance[dim_x, 0] = correlation + return covariance + + +def create_base_sampler(dim_x: int, dim_y: int, rho: float) -> bmi.samplers.SplitMultinormal: + assert -1 <= rho < 1 + + covariance = generate_covariance(dim_x=dim_x, dim_y=dim_y, correlation=rho) + + return bmi.samplers.SplitMultinormal( + dim_x=dim_x, + dim_y=dim_y, + covariance=covariance, + ) + + +def main() -> None: + parser = create_parser() + args = parser.parse_args() + + print(f"Settings:\n{args}") + + base_sampler = create_base_sampler(dim_x=args.dim_x, dim_y=args.dim_y, rho=args.rho) + + x_normal, y_normal = base_sampler.sample(args.n_points, rng=args.seed) + mi_true = base_sampler.mutual_information() + + mi_estimate_normal = bmi.estimators.KSGEnsembleFirstEstimator().estimate(x_normal, y_normal) + + print(f"True MI: {mi_true:.3f}") + print(f"KSG(X; Y) without distortion: {mi_estimate_normal:.3f}") + + print("-------------------") + print("speed\tKSG(spiral(X); Y)") + + generator = bmi.transforms.so_generator(args.dim_x, i=0, j=1) + + for speed in [0.0, 0.02, 0.1, 0.5, 1.0, 10.0]: + transform_x = bmi.transforms.Spiral(generator=generator, speed=speed) + transformed_sampler = bmi.samplers.TransformedSampler( + base_sampler=base_sampler, transform_x=transform_x + ) + + x_transformed, y_transformed = transformed_sampler.transform(x_normal, y_normal) + + mi_estimate_transformed = bmi.estimators.KSGEnsembleFirstEstimator().estimate( + x_transformed, y_transformed + ) + + print(f"{speed:.2f}\t {mi_estimate_transformed:.3f}") + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index 5c6d1ef7..7d21ac91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ package_dir= packages=find: python requires = >= 3.9 install_requires = + equinox jax jaxlib numpy diff --git a/src/bmi/api.py b/src/bmi/api.py index a5d8c91c..a5767922 100644 --- a/src/bmi/api.py +++ b/src/bmi/api.py @@ -2,9 +2,11 @@ import bmi.benchmark.api as benchmark import bmi.estimators.api as estimators import bmi.samplers.api as samplers +import bmi.transforms.api as transforms __all__ = [ "benchmark", "estimators", "samplers", + "transforms", ] diff --git a/src/bmi/samplers/api.py b/src/bmi/samplers/api.py index 21af2c7a..bbfa2fbd 100644 --- a/src/bmi/samplers/api.py +++ b/src/bmi/samplers/api.py @@ -1,7 +1,9 @@ from bmi.samplers.split_student_t import SplitStudentT from bmi.samplers.splitmultinormal import SplitMultinormal +from bmi.samplers.transformed import TransformedSampler __all__ = [ "SplitMultinormal", "SplitStudentT", + "TransformedSampler", ] diff --git a/src/bmi/samplers/transformed.py b/src/bmi/samplers/transformed.py new file mode 100644 index 00000000..a3e5a631 --- /dev/null +++ b/src/bmi/samplers/transformed.py @@ -0,0 +1,107 @@ +from typing import Callable, Optional, TypeVar, Union + +import jax +import jax.numpy as jnp +import numpy as np + +import bmi.samplers.base as base +from bmi.interface import ISampler, KeyArray + +SomeArray = Union[jnp.ndarray, np.ndarray] +Transform = Callable[[SomeArray], jnp.ndarray] + +_T = TypeVar("_T") + + +def identity(x: _T) -> _T: + """The identity mapping.""" + return x + + +class TransformedSampler(base.BaseSampler): + """Pushforward of a distribution P(X, Y) + via a product mapping + f x g. + + In other words, we have mutual information between f(X) and g(Y) + for some mappings f and g. + + Note: + By default we assume that f and g are diffeomorphisms, so that + I(f(X); g(Y)) = I(X; Y). + If you don't use diffeomorphisms (in particular, + non-default `add_dim_x` or `add_dim_y`), overwrite the + `mutual_information()` method + """ + + def __init__( + self, + base_sampler: ISampler, + transform_x: Optional[Callable] = None, + transform_y: Optional[Callable] = None, + add_dim_x: int = 0, + add_dim_y: int = 0, + ) -> None: + """ + Args: + base_sampler: allows sampling from P(X, Y) + transform_x: diffeomorphism f, so that we have variable f(X). + By default the identity mapping. + transform_y: diffeomorphism g, so that we have variable g(Y). + By default the identity mapping. + add_dim_x: the difference in dimensions of the output of `f` and its input. + Note that for any diffeomorphism it must be zero + add_dim_y: similarly as `add_dim_x`, but for `g` + + Note: + If you don't use diffeomorphisms (in particular, + non-default `add_dim_x` or `add_dim_y`), overwrite the + `mutual_information()` method + """ + super().__init__( + dim_x=base_sampler.dim_x + add_dim_x, dim_y=base_sampler.dim_y + add_dim_y + ) + + if transform_x is None: + transform_x = identity + if transform_y is None: + transform_y = identity + + self._vectorized_transform_x = jax.vmap(transform_x) + self._vectorized_transform_y = jax.vmap(transform_y) + self._base_sampler = base_sampler + + # Boolean flag checking whether the dimension of each variable + # is preserved + self._dimensions_preserved: bool = (add_dim_x == 0) and (add_dim_y == 0) + + def transform(self, x: SomeArray, y: SomeArray) -> tuple[jnp.ndarray, jnp.ndarray]: + """Transforms given samples by `f x g`. + + Args: + x: samples, (n_points, dim(X)) + y: samples, (n_points, dim(Y)) + + Returns: + f(x), shape (n_points, dim(X) + add_dim_x) + g(y), shape (n_points, dim(Y) + add_dim_y) + """ + return self._vectorized_transform_x(x), self._vectorized_transform_y(y) + + def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]: + """Samples from P(f(X), g(Y)). + + Returns: + paired samples + from f(X), shape (n_points, dim(X) + add_dim_x) and + from g(Y), shape (n_points, dim(Y) + add_dim_y) + """ + x, y = self._base_sampler.sample(n_points=n_points, rng=rng) + return self.transform(x, y) + + def mutual_information(self) -> float: + if not self._dimensions_preserved: + raise ValueError( + "The dimensions are not preserved. The mutual information may be different." + ) + return self._base_sampler.mutual_information() diff --git a/src/bmi/transforms/__init__.py b/src/bmi/transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bmi/transforms/api.py b/src/bmi/transforms/api.py new file mode 100644 index 00000000..209de2db --- /dev/null +++ b/src/bmi/transforms/api.py @@ -0,0 +1,7 @@ +from bmi.transforms.rotate import Spiral, skew_symmetrize, so_generator + +__all__ = [ + "Spiral", + "so_generator", + "skew_symmetrize", +] diff --git a/src/bmi/transforms/rotate.py b/src/bmi/transforms/rotate.py new file mode 100644 index 00000000..0f482e6c --- /dev/null +++ b/src/bmi/transforms/rotate.py @@ -0,0 +1,125 @@ +from typing import Optional, overload + +import equinox as eqx +import jax.numpy as jnp +import numpy as np +from jax.scipy.linalg import expm +from numpy.typing import ArrayLike + + +class Spiral(eqx.Module): + """Represents the "spiraling" function + x -> R(x) x, + where R(x) is a matrix given by a product `initial` @ `rotation(x)`. + `initial` can be an arbitrary invertible matrix + and `rotation(x)` is an SO(n) element given by + exp(generator * ||x||^2), + where `generator` is an element of the so(n) Lie algebra, i.e., a skew-symmetric matrix. + + Example: + >>> a = np.array([[0, -1], [1, 0]]) + >>> spiral = Spiral(a, speed=np.pi/2) + >>> x = np.array([1, 0]) + >>> spiral(x) + DeviceArray([0., 1.]) + """ + + initial: jnp.ndarray + generator: jnp.ndarray + + def __init__( + self, generator: ArrayLike, initial: Optional[ArrayLike] = None, speed: float = 1.0 + ) -> None: + """ + + Args: + generator: a skew-symmetric matrix, an element of so(n) Lie algebra. Shape (n, n) + initial: an (n, n) matrix used to left-multiply the spiral. + Default (None) corresponds to the identity. + speed: for convenience, the passed `generator` will be scaled up by `speed` constant, + which (for a given `generator`) controls how quickly the spiral will wind + """ + self.generator = jnp.asarray(generator * speed) + + if len(self.generator.shape) != 2 or self.generator.shape[0] != self.generator.shape[1]: + raise ValueError(f"Generator has wrong shape {self.generator.shape}.") + + if initial is None: + self.initial = jnp.eye(self.generator.shape[0]) + else: + initial = np.asarray(initial) + if self.generator.shape != initial.shape: + raise ValueError( + f"Initial point has shape {initial.shape} while " + f"the generator has {self.generator.shape}." + ) + self.initial = jnp.asarray(initial) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """ + Args: + x: point in the Euclidean space, shape (n,) + + Returns: + transformation applied to `x`, shape (n,) + """ + r = jnp.einsum("i, i", x, x) # We have r = ||x||^2 + return self.initial @ expm(self.generator * r) @ x + + +def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray: + """The (i,j)-th canonical generator of the so(n) Lie algebra. + + As so(n) Lie algebra is the vector space of all n x n + skew-symmetric matrices, we have a canonical basis + such that its (i,j)th vector is a matrix A such that + A[i, j] = 1, A[j, i] = -1, i < j + and all the other entries are 0. + + Note that there exist n(n-1)/2 such matrices. + + Args: + n: we use the Lie algebra so(n) + i: index in range {0, 1, ..., j-1} + j: index in range {i+1, i+2, ..., n-1} + + Returns: + NumPy array (n, n) + + Note: + This function is NumPy based and is *not* JITtable. + """ + if n < 2: + raise ValueError(f"{n = } needs to be at least 2.") + if not (0 <= i < j < n): + raise ValueError(f"Index is wrong: {n = } {i = } {j = }.") + + a = np.zeros((n, n)) + a[i, j] = 1 + a[j, i] = -1 + return a + + +@overload +def skew_symmetrize(a: np.ndarray) -> np.ndarray: + pass + + +@overload +def skew_symmetrize(a: jnp.ndarray) -> jnp.ndarray: + pass + + +def skew_symmetrize(a): + """The skew-symmetric part of a given matrix `a`. + + Args: + a: array, shape (n, n) + + Returns: + skew-symmetric part of `a`, shape (n, n) + + Note: + This function is compatible with both NumPy and JAX NumPy. + """ + return 0.5 * (a - a.T) diff --git a/tests/samplers/test_transformed.py b/tests/samplers/test_transformed.py new file mode 100644 index 00000000..ae3026e0 --- /dev/null +++ b/tests/samplers/test_transformed.py @@ -0,0 +1,118 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +import bmi.samplers.api as samplers +import bmi.samplers.transformed as tr + + +@pytest.mark.parametrize("x", [2, 120, 5.0, "something"]) +def test_identity_various_objects(x) -> None: + assert x == tr.identity(x) + + +@pytest.mark.parametrize("x", [np.asarray([1.0, 4.0]), np.eye(4)]) +def test_identity_array(x) -> None: + assert np.allclose(x, tr.identity(x)) + + +def get_gaussian_sampler( + dim_x: int = 2, dim_y: int = 3, corr: float = 0.5 +) -> samplers.SplitMultinormal: + """Auxiliary function creating a base sampler.""" + cov = np.eye(dim_x + dim_y) + cov[0, dim_x] = corr + cov[dim_x, 0] = corr + return samplers.SplitMultinormal(dim_x=dim_x, dim_y=dim_y, covariance=cov) + + +@pytest.mark.parametrize("dim_x", [3]) +@pytest.mark.parametrize("dim_y", [2]) +@pytest.mark.parametrize("corr", [0.3, 0.5]) +@pytest.mark.parametrize("n_points", [10, 20]) +def test_transformed_identity( + dim_x: int, dim_y: int, corr: float, n_points: int, random_seed: int = 0 +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + + transformed_sampler = tr.TransformedSampler(base_sampler=base_sampler) + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + assert transformed_sampler.mutual_information() == pytest.approx( + base_sampler.mutual_information() + ) + + x_base, y_base = base_sampler.sample(n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.sample(n_points, rng=random_seed) + + assert np.allclose(x_base, x_transformed) + assert np.allclose(y_base, y_transformed) + + x_transformed_new, y_transformed_new = transformed_sampler.transform(x_base, y_base) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + +def cubic(x: np.ndarray) -> np.ndarray: + return x**3 + + +def test_transformed_cubic( + dim_x: int = 5, dim_y: int = 3, corr: float = 0.4, n_points: int = 100, random_seed: int = 12 +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + transformed_sampler = tr.TransformedSampler(base_sampler, transform_x=cubic) + + x_base, y_base = base_sampler.sample(n_points=n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.sample(n_points, rng=random_seed) + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + + assert np.allclose(x_transformed, np.asarray([cubic(x) for x in x_base])) + assert np.allclose(y_base, y_transformed) + + x_transformed_new, y_transformed_new = transformed_sampler.transform(x_base, y_base) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + +def embed(n: int): + return lambda x: jnp.concatenate([x, jnp.zeros(n)]) + + +def test_change_dimension( + add_dim_y: int = 3, + dim_x: int = 2, + dim_y: int = 3, + corr: float = 0.4, + n_points: int = 50, + random_seed: int = 12, +) -> None: + base_sampler = get_gaussian_sampler(dim_x=dim_x, dim_y=dim_y, corr=corr) + transformed_sampler = tr.TransformedSampler( + base_sampler=base_sampler, + transform_x=cubic, + transform_y=embed(add_dim_y), + add_dim_y=add_dim_y, + ) + + with pytest.raises(ValueError): + transformed_sampler.mutual_information() + + assert transformed_sampler.dim_x == base_sampler.dim_x + assert transformed_sampler.dim_y == base_sampler.dim_y + add_dim_y + + x_base, y_base = base_sampler.sample(n_points, rng=random_seed) + x_transformed, y_transformed = transformed_sampler.transform(x_base, y_base) + + x_transformed_new, y_transformed_new = transformed_sampler.sample(n_points, rng=random_seed) + assert np.allclose(x_transformed, x_transformed_new) + assert np.allclose(y_transformed, y_transformed_new) + + assert x_transformed.shape == (n_points, transformed_sampler.dim_x) + assert y_transformed.shape == (n_points, transformed_sampler.dim_y) + + assert np.allclose(x_transformed, np.asarray([cubic(x) for x in x_base])) + assert np.allclose(y_transformed, np.asarray([embed(add_dim_y)(y) for y in y_base])) diff --git a/tests/test_api.py b/tests/test_api.py index a4946841..20c52c95 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,7 +5,14 @@ def test_api_imports() -> None: import bmi.api # noqa: F401 import not at the top of the file -@pytest.mark.parametrize("submodule", ["estimators", "samplers", "benchmark"]) +SUBMODULES = [ + "estimators", + "samplers", + "transforms", +] + + +@pytest.mark.parametrize("submodule", SUBMODULES) def test_api_exports_submodules(submodule: str) -> None: import bmi.api as bmi # noqa: F401 import not at the top of the file diff --git a/tests/transforms/test_rotate.py b/tests/transforms/test_rotate.py new file mode 100644 index 00000000..35cb24eb --- /dev/null +++ b/tests/transforms/test_rotate.py @@ -0,0 +1,97 @@ +import warnings + +import jax +import numpy as np +import pytest + +import bmi.transforms.rotate as rt + + +class TestSpiral: + warnings.warn("Tests for the Spiral are not ready.") + + @pytest.mark.skip("Tests for the Spiral are not ready.") + def test_spiral(self) -> None: + raise NotImplementedError + + +class TestSkewSymmetrize: + @pytest.mark.parametrize("n", (3, 4, 5)) + @pytest.mark.parametrize("k", (3, 8)) + @pytest.mark.parametrize("seed", range(3)) + def test_symmetric_zero(self, n: int, k: int, seed: int) -> None: + rng = np.random.default_rng(seed) + w = rng.normal(size=(n, k)) + a = w @ w.T + + s = rt.skew_symmetrize(a) + assert s.shape == a.shape + assert np.allclose(s, np.zeros_like(s)) + + @pytest.mark.parametrize("n", (3, 4, 5)) + @pytest.mark.parametrize("seed", range(3)) + def test_fixed_point(self, n: int, seed: int) -> None: + """Check whether s(s(A)) = s(A) for any matrix A""" + rng = np.random.default_rng(seed) + a = rng.normal(size=(n, n)) + + s_a = rt.skew_symmetrize(a) + s_s_a = rt.skew_symmetrize(s_a) + assert np.allclose(s_a, s_s_a) + + @pytest.mark.parametrize("n", (2, 3, 7)) + @pytest.mark.parametrize("seed", range(2)) + def test_skew_symmetric(self, n: int, seed: int) -> None: + rng = np.random.default_rng(seed) + a = rng.normal(size=(n, n)) + s_a = rt.skew_symmetrize(a) + + assert np.allclose(s_a.T, -s_a) + + def test_jittable(self) -> None: + jax.jit(rt.skew_symmetrize) + + @pytest.mark.parametrize("dim", (2, 5)) + @pytest.mark.parametrize("n_matrices", (3, 4)) + def test_vmappable(self, dim: int, n_matrices: int, seed: int = 0) -> None: + rng = np.random.default_rng(seed) + matrices = rng.normal(size=(n_matrices, dim, dim)) + s_matrices = jax.vmap(rt.skew_symmetrize)(matrices) + assert np.allclose(s_matrices, np.asarray([rt.skew_symmetrize(m) for m in matrices])) + + +class TestSOGenerator: + def test_2d_example(self) -> None: + expected = np.asarray([[0, 1], [-1, 0]]) + assert np.allclose(rt.so_generator(2, 0, 1), expected) + + def test_3d_example(self) -> None: + expected = np.asarray([[0, 1, 0], [-1, 0, 0], [0, 0, 0]]) + assert np.allclose(rt.so_generator(3, 0, 1), expected) + + @pytest.mark.parametrize("n", (2, 4)) + def test_indices_right_order(self, n: int) -> None: + for i in range(n): + for j in range(i + 1): + with pytest.raises(ValueError): + rt.so_generator(n, i, j) + + @pytest.mark.parametrize("n", (2, 3, 5)) + def test_basic_statistic(self, n: int) -> None: + """Checks if all entries are positive and sum up to 1.""" + for i in range(n): + for j in range(i + 1, n): + a = rt.so_generator(n, i, j) + assert np.allclose(a, -a.T) + + assert np.min(a) == pytest.approx(-1) + assert np.sum(a) == pytest.approx(0) + assert np.max(a) == pytest.approx(1) + + def test_raises(self) -> None: + with pytest.raises(ValueError): + rt.so_generator(3, 1, 1) + with pytest.raises(ValueError): + rt.so_generator(3, 4, 1) + with pytest.raises(ValueError): + rt.so_generator(2, 0, 0)