Skip to content

Commit

Permalink
Improve list of estimators (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Jan 12, 2024
1 parent a01b28e commit 2a7b1fd
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Then, we will run two estimators on this task.
```python
import bmi

task = bmi.benchmark.BENCHMARK_TASKS['multinormal-dense-2-5-0.5']
task = bmi.benchmark.BENCHMARK_TASKS['1v1-normal-0.75']
print(f"Task {task.name} with dimensions {task.dim_x} and {task.dim_y}")
print(f"Ground truth mutual information: {task.mutual_information():.2f}")

Expand Down
5 changes: 3 additions & 2 deletions docs/api/interfaces.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
## Interfaces
This section explains the most important interfaces used in the package.
# Interfaces
This section lists the most important interfaces used in the package.

::: bmi.interface.IMutualInformationPointEstimator
52 changes: 47 additions & 5 deletions docs/estimators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,60 @@

The package supports a range of existing mutual information estimators. For the full list, see [below](#list-of-estimators).

## General usage instructions
## Example

The design of the estimators was motivated by [SciKit-Learn](https://scikit-learn.org/) API[@SciKit-Learn-API-2013].
All estimators are classes. Once a class is initialized, one can use the `estimate` method, which maps arrays containing data points (of shape `(n_points, n_dim)`)
to mutual information estimates:

```python
import bmi

# Generate a sample with 1000 data points
task = bmi.benchmark.BENCHMARK_TASKS['1v1-normal-0.75']
X, Y = task.sample(1000, seed=42)
print(f"X shape: {X.shape}") # Shape (1000, 1)
print(f"Y shape: {Y.shape}") # Shape (1000, 1)

# Once an estimator is instantiated, it can be used to estimate mutual information
# by using the `estimate` method.
cca = bmi.estimators.CCAMutualInformationEstimator()
print(f"Estimate by CCA: {cca.estimate(X, Y):.2f}")

ksg = bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(5,))
print(f"Estimate by KSG: {ksg.estimate(X, Y):.2f}")
```

Additionally, the estimators can be queried for their hyperparameters:
```python
print(cca.parameters()) # CCA does not have tunable hyperparameters
# _EmptyParams()

print(ksg.parameters()) # KSG has tunable hyperparameters
# KSGEnsembleParameters(neighborhoods=[5], standardize=True, metric_x='euclidean', metric_y='euclidean')
```

The returned objects are structured using [Pydantic](https://docs.pydantic.dev/).

## List of estimators

### Neural estimators
- MINE[@belghazi:18:mine] estimator is implemented as [`MINEEstimator`](api/estimators.md#bmi.estimators.MINEEstimator).
- InfoNCE[@oord:18:infonce], also known as Contrastive Predictive Coding is implemented as [`InfoNCEEstimator`](api/estimators.md#bmi.estimators.InfoNCEEstimator).

We support several standard neural estimators in [JAX](https://github.com/google/jax) basing on the [PyTorch implementations](https://github.com/ermongroup/smile-mi-estimator)[@Song-Ermon-2019]:

- Donsker-Varadhan estimator[@belghazi:18:mine] is implemented in [`DonskerVaradhanEstimator`](api/estimators.md#bmi.estimators.DonskerVaradhanEstimator).
- MINE[@belghazi:18:mine] estimator, which is a Donsker-Varadhan estimator with correction debiasing gradient during the fitting phase, is implemented in [`MINEEstimator`](api/estimators.md#bmi.estimators.MINEEstimator).
- InfoNCE[@oord:18:infonce], also known as Contrastive Predictive Coding, is implemented in [`InfoNCEEstimator`](api/estimators.md#bmi.estimators.InfoNCEEstimator).
- NWJ estimator[@NWJ2007] is implemented as [`NWJEstimator`](api/estimators.md#bmi.estimators.NWJEstimator).

### Model-based estimators
- Canonical correlation analysis[@Brillinger-2004,@kay-elliptic]
- Canonical correlation analysis[@Brillinger-2004,@kay-elliptic] is suitable when $P(X, Y)$ is multivariate normal and does not require hyperparameter tuning. It's implemented in [`CCAMutualInformationEstimator`](api/estimators.md#bmi.estimators.CCAMutualInformationEstimator).

### Histogram-based estimators
- We implement a histogram-based estimator[@Cellucci-HistogramsMI] in [`HistogramEstimator`](api/estimators.md#bmi.estimators.HistogramEstimator). However, note that we do not support adaptive binning schemes.

## Kernel density estimators
- We implement a simple kernel density estimator in [`KDEMutualInformationEstimator`](api/estimators.md#bmi.estimators.KDEMutualInformationEstimator).

### Neighborhood-based estimators
- An ensemble of Kraskov-Stögbauer-Grassberger estimators[@kraskov:04:ksg] is implemented as [`KSGEnsembleFirstEstimator`](api/estimators.md#bmi.estimators.KSGEnsembleFirstEstimator).
Expand All @@ -34,7 +76,7 @@ The API is [here](api/estimators.md).
### How can I add a new estimator?
Thank you for considering contributing to this project! Please, consult [contributing guidelines](contributing.md) and reach out to us on [GitHub](https://github.com/cbg-ethz/bmi/issues), so we can discuss the best way of adding the estimator to the package.
Generally, the following steps are required:
1. Implement the interface `IMutualInformationPointEstimator` in a new file inside `src/bmi/estimators` directory. The unit tests should be added in `tests/estimators` directory.
1. Implement the interface [`IMutualInformationPointEstimator`](api/interfaces.md#bmi.interface.IMutualInformationPointEstimator) in a new file inside `src/bmi/estimators` directory. The unit tests should be added in `tests/estimators` directory.
2. Export the new estimator to the public API by adding an entry in `src/bmi/estimators/__init__.py`.
3. Export the docstring of new estimator to `docs/api/estimators.md`.
4. Add the estimator to the [list of estimators](#list-of-estimators).
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
site_name: Benchmarking Mutual Information
theme:
icon:
logo: material/alpha-i-box
name: material
features:
- navigation.tabs
Expand Down
29 changes: 28 additions & 1 deletion references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ @InProceedings{marx:21:myl
year={2021}
}

@article{Song-Ermon-2019,
author = {Jiaming Song and
Stefano Ermon},
title = {Understanding the Limitations of Variational Mutual Information Estimators},
journal = {CoRR},
volume = {abs/1910.06222},
year = {2019},
url = {http://arxiv.org/abs/1910.06222},
eprinttype = {arXiv},
eprint = {1910.06222},
timestamp = {Wed, 16 Oct 2019 16:25:53 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-1910-06222.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

% ----- Very nice applications of mutual information -----
@article{Nalecz-Jawecki-2023,
Expand Down Expand Up @@ -553,6 +568,17 @@ @article{geomstats
url = {http://jmlr.org/papers/v21/19-027.html}
}

% SciKit-Learn API design
@article{SciKit-Learn-API-2013,
author = "Buitinck, Lars and others",
title = "{API design for machine learning software: experiences from the scikit-learn project}",
eprint = "1309.0238",
journal = "arXiv",
primaryClass = "cs.LG",
month = "9",
year = "2013"
}

% ----- Other works -----
% Normalizing flows: an overview
Expand Down Expand Up @@ -655,4 +681,5 @@ @book{Lee-2003-SmoothManifolds
edition={2nd},
url={https://doi.org/10.1007/978-1-4419-9982-5},
publisher={Springer}
}
}

0 comments on commit 2a7b1fd

Please sign in to comment.