-
Notifications
You must be signed in to change notification settings - Fork 2
/
interface.py
108 lines (81 loc) · 3.27 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Most important interfaces of the package.
Note:
The `interface` module CANNOT import anything from the developed package.
This restriction is to ensure that any subpackage can import from
the `interface` module and that we do not run into the circular imports issue.
"""
import pathlib
from abc import abstractmethod
from typing import Any, Optional, Protocol, Union
import numpy as np
import pydantic
from numpy.typing import ArrayLike
class BaseModel(pydantic.BaseModel): # pytype: disable=invalid-annotation
"""As pytype has a false-positive problem with BaseModel and our CI fails,
we need to create this dummy class.
We can remove it once the problem has been solved:
https://github.com/google/pytype/issues/1105
"""
pass
# This should be updated to the Array (or possibly union with Any)
# when it becomes a part of public JAX API
KeyArray = Any
Pathlike = Union[str, pathlib.Path]
Seed = int
class EstimateResult(BaseModel):
mi_estimate: float
time_in_seconds: Optional[float] = None
additional_information: dict = pydantic.Field(default_factory=dict)
class IMutualInformationPointEstimator(Protocol):
"""Interface for the mutual information estimator returning point estimates.
All estimators should be implementations of this interface."""
@abstractmethod
def estimate(self, x: ArrayLike, y: ArrayLike) -> float:
"""A point estimate of MI(X; Y) from an i.i.d. sample from the $P(X, Y)$ distribution.
Args:
x: shape `(n_samples, dim_x)`
y: shape `(n_samples, dim_y)`
Returns:
mutual information estimate
"""
raise NotImplementedError
def estimate_with_info(self, x: ArrayLike, y: ArrayLike) -> EstimateResult:
"""Allows for reporting additional information about the run."""
return EstimateResult(mi_estimate=self.estimate(x, y))
@abstractmethod
def parameters(self) -> BaseModel:
"""Returns the parameters of the estimator."""
raise NotImplementedError
class ISampler(Protocol):
"""Interface for a distribution $P(X, Y)$."""
@abstractmethod
def sample(self, n_points: int, rng: Union[int, KeyArray]) -> tuple[np.ndarray, np.ndarray]:
"""Returns a sample from the joint distribution P(X, Y).
Args:
n_points: sample size
rng: pseudorandom number generator
Returns:
X samples, shape (n_points, dim_x)
Y samples, shape (n_points, dim_y). Note that these samples are paired with X samples.
"""
raise NotImplementedError
@property
@abstractmethod
def dim_x(self) -> int:
"""Dimension of the space in which the `X` variable is valued."""
raise NotImplementedError
@property
@abstractmethod
def dim_y(self) -> int:
"""Dimension of the space in which the `Y` variable is valued."""
raise NotImplementedError
@property
def dim_total(self) -> int:
"""Dimension of the space in which the `(X, Y)` variable is valued.
Should be equal to the sum of `dim_x` and `dim_y`.
"""
return self.dim_x + self.dim_y
@abstractmethod
def mutual_information(self) -> float:
"""Mutual information MI(X; Y)."""
raise NotImplementedError