In [1]:
import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline

import torch
import torch.nn as nn
from torch import optim

import torchvision.transforms as T
import torch.nn.functional as F


%load_ext autoreload
%autoreload 2

In [3]:
## data generation
def gen_x(num_smaple):
    return np.sign(np.random.normal(0.,1.,[num_smaple,1]))

def gen_y(x, num_smaple, var):
    return x + np.random.normal(0.,np.sqrt(var),[num_smaple,1])

In [4]:
num_samples = 1000000
var = 2.0
x=gen_x(num_samples)
y=gen_y(x, num_samples, var)

p_y_x=np.exp(-(y - x)**2/(2 * var))
p_y_x_minus=np.exp(-(y + 1)**2/(2 * var))
p_y_x_plus=np.exp(-(y - 1)**2/(2 * var))

mi=np.average(np.log(p_y_x/(0.5 * p_y_x_minus + 0.5 * p_y_x_plus)))

In [5]:
print(mi)

0.20148642082091153


In [2]:
class MIEstimator1(nn.Module):
    def __init__(self, d):
        super(MIEstimator1, self).__init__()
        self.fc1 = nn.Linear(1, d)
        self.fc2 = nn.Linear(1, d)
        self.fc3 = nn.Linear(d, 1)

    def forward(self, x, y):
        h1 = F.relu(self.fc1(x)+self.fc2(y))
        h2 = self.fc3(h1)
        return h2  
    
class MIEstimator2(nn.Module):
    def __init__(self, size1, size2):
        super(MIEstimator2, self).__init__()

        # Vanilla MLP
        self.net = nn.Sequential(
            nn.Linear(size1 + size2, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1),
        )

    # Gradient for JSD mutual information estimation and EB-based estimation
    def forward(self, x1, x2):
        pos = self.net(torch.cat([x1, x2], 1))  # Positive Samples
        neg = self.net(torch.cat([torch.roll(x1, 1, 0), x2], 1))
        return -softplus(-pos).mean() - softplus(neg).mean(), pos.mean() - neg.exp().mean() + 1

In [None]:
model = MIEstimator1(10)
n_epoch = 500
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
plot_loss = []
for epoch in range(n_epoch):
    x_sample = gen_x(num_samples)
    y_sample = gen_y(x_sample, num_samples, var)
    y_shuffle = np.random.permutation(y_sample)
    
    x_sample = torch.from_numpy(x_sample).type(torch.FloatTensor)
    y_sample = torch.from_numpy(y_sample).type(torch.FloatTensor)
    y_shuffle = torch.from_numpy(y_shuffle).type(torch.FloatTensor)  
    
    pred_xy = model(x_sample, y_sample)
    pred_x_y = model(x_sample, y_shuffle)

    loss = - (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))
    plot_loss.append(loss.data.numpy())
    model.zero_grad()
    loss.backward()
    optimizer.step()