Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtures and pointwise mutual information: add features from TensorFlow Probability #110

Merged
merged 14 commits into from Jul 24, 2023

Conversation

pawel-czyz
Copy link
Member

@pawel-czyz pawel-czyz commented Jul 11, 2023

– I have decided that mixtures, like tequila, are inherently evil and should be avoided at all costs.
Larry Wasserman

This PR adds the distributions from TensorFlow Probability (on JAX), which can be used to estimate MI using Monte Carlo integration. It supports selected transformed distributions (possible thanks to TFP bijectors) and mixture models.

πŸ“ˆ Introduced changes

πŸš‚ Features added

  • A distribution class, based on TensorFlow Probability which is used to model a probability distribution $P_{XY}$ with known marginals $P_X$ and $P_Y$.
  • Implemented as main examples: multivariate normal and Student distributions.
  • These distributions can be arbitrarily (a) transformed using TensorFlow Probability bijectors, (b) mixed into mixture distributions.
  • Utilities for drawing samples from the pointwise MI distribution ("PMI profile") and estimating mutual information using Monte Carlo.

πŸ”¬ Experiments performed

  • Visualisations of PMI for several $(1+1)$-dimensional distributions. (See the cool plot below!)
  • PMI profiles for several distributions.

image

✨ Other changes

  • Workflows used to generate Beyond Normal paper have been moved to a subdirectory (workflows/Beyond_Normal). Note: although it seems that many files have been modified, majority of them has just been moved to another directory, so they don't need to be reviewed at all.
  • A new subdirectory (workflows/Pointwise_MI) has been created, which is supposed to host the workflows for the new project.
  • As a dependency we now require TensorFlow Probability on JAX. Requirements file (pyproject.toml) has been amended.

β›” Known limitations

  • The transformed distributions can be somewhat unstable: can result in NaNs for PMI values and float overflow. (Perhaps using float64 could help? But see below...) Perhaps Principled approach to handling NaNsΒ #111 is the reason.
  • Monte Carlo Standard Error seems to be too small. I suspect that variance of PMI can perhaps be infinite (or very large) and finite samples underestimate it. Although Monte Carlo integration relies on the strong law of large numbers (and we don't need finite variance to be able to estimate the expected value, MI), the convergence can be very slow (as the guarantees for Monte Carlo Standard Error assume that the variance is finite and estimates it using the sample variance). As a simple (but more costly) alternative: we can run Monte Carlo 20 times for $N=20,000$ and use this to quantify the error in the estimates. From my experience standard error calculated this way is larger than MCSE calculated from one sample, but still small for a large number of data points.
  • An alternative way to estimate MI (although the overall Monte Carlo Standard Error seems trickier to control) is to estimate differential entropies via Monte Carlo sampling. Perhaps it'd be more stable numerically: it's essentially the difference between $\mathbb E\big[\log p_{XY}(x, y) -\left(\log p_X(x) + \log p_Y(y)\right) \big]$ and $\mathbb E[\log p_{XY}(x, y)] - \big(\mathbb E[\log p_X(x)] + \mathbb E[\log p_Y(y)]\big)$.
  • Using 64-bit float does not seem to work. Further investigation is needed.
  • Currently TFP distributions are not compatible with benchmark (i.e., Sampler and Task classes). This is a third object of similar functionality, which may be somewhat confusing. We need a wrapper to make them into Sampler or Task, but can't really replace Sampler with them:
    • The API is quite a different as they serve slightly different purposes.
    • Samplers can be transformed via arbitrary continuous injective mappings implemented in JAX, while TFP supports only bijectors (to calculate PMI we need to calculate $\log p$. Samplers don't have this utility, only mutual information).
    • We have several custom JAX mappings (e.g., wiggly mapping, normal CDF, "bimodal") which haven't been implemented as TFP bijectors.
  • Unit tests and workflows are very basic.

πŸͺ¨ Tasks

This list of tasks is for me to keep track of the progress. It does not need to be read during the review.

Added features:

  • Core object representing the distributions $P_X$, $P_Y$ and $P_Y$.
  • Mixture distribution.
  • Transformed distribution via product diffeomorphism $f\times g$.
  • Sampling.
  • PMI profile calculation.
  • Implement multivariate normal distribution.
  • Implement multivariate Student distribution.

Unit tests:

  • Test whether a mixture can be formed.
  • Test whether a transformed distribution can be defined.
  • Check whether the MI estimate of transformed distribution is the same as of the original one.
  • Check whether the MI estimate of multivariate normal agrees with the analytical formula.
  • Check whether the MI estimate of multivariate Student agrees with the analytical formula.

Workflows:

  • Visualisations of PMI for several $(1+1)$-dimensional distributions.
  • PMI profiles for several distributions.

@pawel-czyz pawel-czyz added πŸ‘• effort: M πŸš‚ type: enhancement New feature or request πŸ§ͺ type: experiment Used for experiments and proof of concepts labels Jul 11, 2023
@pawel-czyz pawel-czyz changed the title Adding sampling using TensorFlow Probability Mixtures and pointwise mutual information: add features from TensorFlow Probability Jul 12, 2023
@pawel-czyz pawel-czyz marked this pull request as ready for review July 12, 2023 15:47
@grfrederic
Copy link
Collaborator

Nice coding! I think the new class colliding with Sampler and Task is not that big of a deal since for now we use the previous ones for the main project, and the new one for the new project. But we could think about integrating them somehow. For example, we could give Sampler and optional pdf(x) function, and automatically implement it for transformed functions (if they are tf bijectors, or just differentiable surjections written in jax)

@pawel-czyz pawel-czyz merged commit 3b6feb7 into main Jul 24, 2023
2 checks passed
@pawel-czyz pawel-czyz deleted the pawel/mixtures branch July 24, 2023 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
πŸ‘• effort: M πŸš‚ type: enhancement New feature or request πŸ§ͺ type: experiment Used for experiments and proof of concepts
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants