Skip to content

Official pytorch implement of paper InfoNet: Neural Estimation of Mutual Information without Test-Time Optimization

Notifications You must be signed in to change notification settings

datou30/InfoNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

InfoNet: Neural Estimation of Mutual Information without Test-Time Optimization

infonet_logo

Welcome to InfoNet 😀! This is a PyTorch implementation of our paper InfoNet: Neural Estimation of Mutual Information without Test-Time Optimization. Our project page can be found here. You can utilize it to compute mutual information between two sequences quickly!

(This project is currently under active development. We are continuously working on perfecting this repo and the project page.✨)

Method overview

Mutual information (MI) is a valuable metric for assessing the similarity between two variables and has lots of applications in deep learning. However, current neural MI estimation methods such as MINE are a bit time-costly (needs over 1 minute to get the estimation between pairs of data). Our work is to utilize a neural method to make the process much faster.

In order to achieve this, our method designs a network structure and trains a pre-trained model InfoNet using various synthetic distributions, it can learn various information from a large amount of distributions. Unlike MINE, which requires training an MLP from scratch separately for every pair of data when estimating MI, our model only needs to do one forward pass to get the estimation. Experiments have shown our method has strong generalization ability and fast speed.

infonet_logo

Getting Started

Requirements

This is a requirement list, ensure you have all the necessary dependencies installed. You can install all required packages using:

pip install -r requirements.txt

Estimating Mutual Information

When doing inference, estimating mutual information estimation using InfoNet, you can follow examples below or refer to infer.py for additional instruction.

when x and y are scalable random variables:

import numpy as np
import torch
from scipy.stats import rankdata
from infer import load_model, estimate_mi, compute_smi_mean

config_path = "configs/config.yaml"
ckpt_path = "saved/uniform/model_5000_32_1000-720--0.16.pt"
model = load_model(config_path, ckpt_path)

## random generate gauss distribution examples
seq_len = 4781
rou = 0.5
x, y = np.random.multivariate_normal(mean=[0,0], cov=[[1,rou],[rou,1]], size=seq_len).T

## data preprocessing and estimating, for more detail please see the instruction in Data Preprocessing section
x = rankdata(x)/seq_len
y = rankdata(y)/seq_len
result = estimate_mi(model, x, y).squeeze().cpu().numpy()
real_MI = -np.log(1-rou**2)/2
print("estimate mutual information is: ", result, "real MI is ", real_MI)

If x and y are high-dimensional variables, we apply Sliced Mutual Information instead:

d = 10
mu = np.zeros(d)
sigma = np.eye(d)
sample_x = np.random.multivariate_normal(mu, sigma, 2000)
sample_y = np.random.multivariate_normal(mu, sigma, 2000)
result = compute_smi_mean(sample_x, sample_y, model, seq_len=2000, proj_num=1024, batchsize=32)
## proj_num means the number of random projections you want to use, the larger the more accuracy but higher time cost
## seq_len means the number of samples used for the estimation
## batchsize means the number of one-dimensional pairs estimate at one time, this only influences the estimation speed

Note that inputs should have shape [batchsize, sequence length, 2]. InfoNet is capable of estimating MI between multiple pairs at one time. Pre-trained checkpoint can be found in: Download Checkpoint

Training InfoNet from Scratch

To train the model from scratch or finetune on specific distributions, train.py provides an example. This script will guide you through the process of initializing and training your model using the default Gaussian mixture distribution dataset. It will take about 4 hours to get convergence on 2 RTX 4090 GPU.

Data Preprocessing

Data preprocessing is crucial in the estimation result of InfoNet. You should make sure to use the same data preprocessing method in the training and testing (e.g. using same copula transformation, softrank, or linear scaling).

from scipy.stats import rankdata
x = rankdata(x)/seq_len
y = rankdata(y)/seq_len

If you want to apply InfoNet in the training task, please go to softrank branch since rankdata may lead to undifferentiable. You can replace rankdata with soft rank: Fast Differentiable Sorting and Ranking, github repo. An example of soft rank is shown below:

pip install torchsort
x, y = np.random.multivariate_normal(mean=[0,0], cov=[[1,0.5],[0.5,1]], size=5000).T
x = torchsort.soft_rank(torch.from_numpy(x).unsqueeze(0), regularization_strength=1e-3)/5000
y = torchsort.soft_rank(torch.from_numpy(y).unsqueeze(0), regularization_strength=1e-3)/5000

If you set regularization_strength sufficiently small (such as 1e-3), softrank result will be just the same as rankdata. However, setting it too small may lead to gradient explosion. Thus, we fix the regularization_strength to 0.1 and training a new checkpoint, detailed instruction can be found in softrank branch.

For high-dimensional estimation using sliced mutual information, we have found first applying a linear mapping on each dimension separately (e.g. map all the dimensions between -1 and 1) before doing random projections will increase the performance.

## linear scale [batchsize, seq_len, dim] to [-1,1] on seq_len
min_val = torch.min(input_tensor, dim=1, keepdim=True).values
max_val = torch.max(input_tensor, dim=1, keepdim=True).values
scaled_tensor = 2 * (input_tensor - min_val) / (max_val - min_val) - 1

Evaluation Dataset

In gmm_eval_dataset, we have provided a series of parameters for Gaussian Mixture Models along with the ground truth mutual information between X and Y. They are categorized according to the number of Gaussian components, each with 5000 randomly generated distributions. In Notebooks/estimate_gmm.ipynb you can find examples to use this dataset.

Experiments

Experiments can be found in Notebooks, we provide four .ipynb files to reproduce our experimental results detailedly.

  • estimate_gmm.ipynb provides the evaluation results of InfoNet on Mixture of Gaussian distributions, along with the order accuracy.
  • estimate_pointodyssey_track.ipynb provides the reproduction results of InfoNet on PointOdyssey dataset.
  • evaluation_on_other_distributions.ipynb provides InfoNet results on several completely unseen distributions.
  • independence_test.ipynb provides high-dimensional independence testing results of InfoNet on three different correlations.

Acknowledgement

We would like to express our gratitude to esceptico/perceiver-io for providing the code base that significantly assisted in the development of our program.

Citing Our Work

If you find our work interesting and useful, please consider citing our paper:

@article{hu2024infonet,
  title={InfoNet: Neural Estimation of Mutual Information without Test-Time Optimization},
  author={Hu, Zhengyang and Kang, Song and Zeng, Qunsong and Huang, Kaibin and Yang, Yanchao},
  journal={arXiv preprint arXiv:2402.10158},
  year={2024}
}

About

Official pytorch implement of paper InfoNet: Neural Estimation of Mutual Information without Test-Time Optimization

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published