Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The example with spiral #23

Merged
merged 12 commits into from
Nov 19, 2022
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
black
equinox
flake8
isort
jax
Expand Down
84 changes: 84 additions & 0 deletions scripts/ksg_spiral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests whether KSG estimator is invariant to the "spiral" diffeomorphism."""
import argparse

import numpy as np

import bmi.api as bmi


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Experiment with applying the spiral diffeomorphism."
)
parser.add_argument("--dim-x", type=int, default=3, help="Dimension of the X variable.")
parser.add_argument("--dim-y", type=int, default=2, help="Dimension of the Y variable.")
parser.add_argument("--rho", type=float, default=0.8, help="Correlation, between -1 and 1.")
parser.add_argument(
"--n-points", type=int, default=5000, help="Number of points to be generated."
)
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
return parser


def generate_covariance(correlation: float, dim_x: int, dim_y: int) -> np.ndarray:
"""The correlation between the first dimension of X and the first dimension of Y is fixed.

The rest of the covariance entries are zero,
of course except for variance of each dimension (the diagonal), which is 1.
"""
covariance = np.eye(dim_x + dim_y)
covariance[0, dim_x] = correlation
covariance[dim_x, 0] = correlation
return covariance


def create_base_sampler(dim_x: int, dim_y: int, rho: float) -> bmi.samplers.SplitMultinormal:
assert -1 <= rho < 1

covariance = generate_covariance(dim_x=dim_x, dim_y=dim_y, correlation=rho)

return bmi.samplers.SplitMultinormal(
dim_x=dim_x,
dim_y=dim_y,
covariance=covariance,
)


def main() -> None:
parser = create_parser()
args = parser.parse_args()

print(f"Settings:\n{args}")

base_sampler = create_base_sampler(dim_x=args.dim_x, dim_y=args.dim_y, rho=args.rho)

x_normal, y_normal = base_sampler.sample(args.n_points, rng=args.seed)
mi_true = base_sampler.mutual_information()

mi_estimate_normal = bmi.estimators.KSGEnsembleFirstEstimator().estimate(x_normal, y_normal)

print(f"True MI: {mi_true:.3f}")
print(f"KSG(X; Y) without distortion: {mi_estimate_normal:.3f}")

print("-------------------")
print("speed\tKSG(spiral(X); Y)")

generator = bmi.transforms.so_generator(args.dim_x, i=0, j=1)

for speed in [0.0, 0.02, 0.1, 0.5, 1.0, 10.0]:
transform_x = bmi.transforms.Spiral(generator=generator, speed=speed)
transformed_sampler = bmi.samplers.TransformedSampler(
base_sampler=base_sampler, transform_x=transform_x
)

x_transformed, y_transformed = transformed_sampler.transform(x_normal, y_normal)

mi_estimate_transformed = bmi.estimators.KSGEnsembleFirstEstimator().estimate(
x_transformed, y_transformed
)

print(f"{speed:.2f}\t {mi_estimate_transformed:.3f}")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package_dir=
packages=find:
python requires = >= 3.9
install_requires =
equinox
jax
jaxlib
numpy
Expand Down
2 changes: 2 additions & 0 deletions src/bmi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import bmi.benchmark.api as benchmark
import bmi.estimators.api as estimators
import bmi.samplers.api as samplers
import bmi.transforms.api as transforms

__all__ = [
"benchmark",
"estimators",
"samplers",
"transforms",
]
2 changes: 2 additions & 0 deletions src/bmi/samplers/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from bmi.samplers.split_student_t import SplitStudentT
from bmi.samplers.splitmultinormal import SplitMultinormal
from bmi.samplers.transformed import TransformedSampler

__all__ = [
"SplitMultinormal",
"SplitStudentT",
"TransformedSampler",
]
107 changes: 107 additions & 0 deletions src/bmi/samplers/transformed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Callable, Optional, TypeVar, Union

import jax
import jax.numpy as jnp
import numpy as np

import bmi.samplers.base as base
from bmi.interface import ISampler, KeyArray

SomeArray = Union[jnp.ndarray, np.ndarray]
Transform = Callable[[SomeArray], jnp.ndarray]

_T = TypeVar("_T")


def identity(x: _T) -> _T:
"""The identity mapping."""
return x


class TransformedSampler(base.BaseSampler):
"""Pushforward of a distribution P(X, Y)
via a product mapping
f x g.

In other words, we have mutual information between f(X) and g(Y)
for some mappings f and g.

Note:
By default we assume that f and g are diffeomorphisms, so that
I(f(X); g(Y)) = I(X; Y).
If you don't use diffeomorphisms (in particular,
non-default `add_dim_x` or `add_dim_y`), overwrite the
`mutual_information()` method
"""

def __init__(
self,
base_sampler: ISampler,
transform_x: Optional[Callable] = None,
transform_y: Optional[Callable] = None,
add_dim_x: int = 0,
add_dim_y: int = 0,
) -> None:
"""
Args:
base_sampler: allows sampling from P(X, Y)
transform_x: diffeomorphism f, so that we have variable f(X).
By default the identity mapping.
transform_y: diffeomorphism g, so that we have variable g(Y).
By default the identity mapping.
add_dim_x: the difference in dimensions of the output of `f` and its input.
Note that for any diffeomorphism it must be zero
add_dim_y: similarly as `add_dim_x`, but for `g`

Note:
If you don't use diffeomorphisms (in particular,
non-default `add_dim_x` or `add_dim_y`), overwrite the
`mutual_information()` method
"""
super().__init__(
dim_x=base_sampler.dim_x + add_dim_x, dim_y=base_sampler.dim_y + add_dim_y
)

if transform_x is None:
transform_x = identity
if transform_y is None:
transform_y = identity

self._vectorized_transform_x = jax.vmap(transform_x)
self._vectorized_transform_y = jax.vmap(transform_y)
self._base_sampler = base_sampler

# Boolean flag checking whether the dimension of each variable
# is preserved
self._dimensions_preserved: bool = (add_dim_x == 0) and (add_dim_y == 0)

def transform(self, x: SomeArray, y: SomeArray) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Transforms given samples by `f x g`.

Args:
x: samples, (n_points, dim(X))
y: samples, (n_points, dim(Y))

Returns:
f(x), shape (n_points, dim(X) + add_dim_x)
g(y), shape (n_points, dim(Y) + add_dim_y)
"""
return self._vectorized_transform_x(x), self._vectorized_transform_y(y)

def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Samples from P(f(X), g(Y)).

Returns:
paired samples
from f(X), shape (n_points, dim(X) + add_dim_x) and
from g(Y), shape (n_points, dim(Y) + add_dim_y)
"""
x, y = self._base_sampler.sample(n_points=n_points, rng=rng)
return self.transform(x, y)

def mutual_information(self) -> float:
if not self._dimensions_preserved:
raise ValueError(
"The dimensions are not preserved. The mutual information may be different."
)
return self._base_sampler.mutual_information()
Empty file added src/bmi/transforms/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions src/bmi/transforms/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from bmi.transforms.rotate import Spiral, skew_symmetrize, so_generator

__all__ = [
"Spiral",
"so_generator",
"skew_symmetrize",
]
125 changes: 125 additions & 0 deletions src/bmi/transforms/rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Optional, overload

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import expm
from numpy.typing import ArrayLike


class Spiral(eqx.Module):
"""Represents the "spiraling" function
x -> R(x) x,
where R(x) is a matrix given by a product `initial` @ `rotation(x)`.
`initial` can be an arbitrary invertible matrix
and `rotation(x)` is an SO(n) element given by
exp(generator * ||x||^2),
where `generator` is an element of the so(n) Lie algebra, i.e., a skew-symmetric matrix.

Example:
>>> a = np.array([[0, -1], [1, 0]])
>>> spiral = Spiral(a, speed=np.pi/2)
>>> x = np.array([1, 0])
>>> spiral(x)
DeviceArray([0., 1.])
"""

initial: jnp.ndarray
generator: jnp.ndarray

def __init__(
self, generator: ArrayLike, initial: Optional[ArrayLike] = None, speed: float = 1.0
) -> None:
"""

Args:
generator: a skew-symmetric matrix, an element of so(n) Lie algebra. Shape (n, n)
initial: an (n, n) matrix used to left-multiply the spiral.
Default (None) corresponds to the identity.
speed: for convenience, the passed `generator` will be scaled up by `speed` constant,
which (for a given `generator`) controls how quickly the spiral will wind
"""
self.generator = jnp.asarray(generator * speed)

if len(self.generator.shape) != 2 or self.generator.shape[0] != self.generator.shape[1]:
raise ValueError(f"Generator has wrong shape {self.generator.shape}.")

if initial is None:
self.initial = jnp.eye(self.generator.shape[0])
else:
initial = np.asarray(initial)
if self.generator.shape != initial.shape:
raise ValueError(
f"Initial point has shape {initial.shape} while "
f"the generator has {self.generator.shape}."
)
self.initial = jnp.asarray(initial)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Args:
x: point in the Euclidean space, shape (n,)

Returns:
transformation applied to `x`, shape (n,)
"""
r = jnp.einsum("i, i", x, x) # We have r = ||x||^2
return self.initial @ expm(self.generator * r) @ x


def so_generator(n: int, i: int = 0, j: int = 1) -> np.ndarray:
"""The (i,j)-th canonical generator of the so(n) Lie algebra.

As so(n) Lie algebra is the vector space of all n x n
skew-symmetric matrices, we have a canonical basis
such that its (i,j)th vector is a matrix A such that
A[i, j] = 1, A[j, i] = -1, i < j
and all the other entries are 0.

Note that there exist n(n-1)/2 such matrices.

Args:
n: we use the Lie algebra so(n)
i: index in range {0, 1, ..., j-1}
j: index in range {i+1, i+2, ..., n-1}

Returns:
NumPy array (n, n)

Note:
This function is NumPy based and is *not* JITtable.
"""
if n < 2:
raise ValueError(f"{n = } needs to be at least 2.")
if not (0 <= i < j < n):
raise ValueError(f"Index is wrong: {n = } {i = } {j = }.")

a = np.zeros((n, n))
a[i, j] = 1
a[j, i] = -1
return a


@overload
def skew_symmetrize(a: np.ndarray) -> np.ndarray:
pass


@overload
def skew_symmetrize(a: jnp.ndarray) -> jnp.ndarray:
pass


def skew_symmetrize(a):
"""The skew-symmetric part of a given matrix `a`.

Args:
a: array, shape (n, n)

Returns:
skew-symmetric part of `a`, shape (n, n)

Note:
This function is compatible with both NumPy and JAX NumPy.
"""
return 0.5 * (a - a.T)