In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.multivariate_normal import MultivariateNormal

from tqdm import tqdm, trange
from matplotlib import pyplot as plt

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
class MINE(nn.Module):
    def __init__(self):
        super(MINE, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(2, 50),
            nn.ELU(),
            nn.Linear(50,50),
            nn.ELU(),
            nn.Linear(50,1)
        )
    
    def forward(self, X):
        return self.layers(X)

In [None]:
def get_samples(dist, N_samples):
    minibatch = dist.sample(torch.Size([N_samples]))
    marginal = minibatch[torch.randperm(N_samples), 1]

    return torch.cat([minibatch, marginal[...,None]], dim=1)

In [None]:
corr = torch.linspace(-1, 1, 20)[1:-1].to(device)
mean_XZ = torch.tensor([0., 0.]).to(device)

In [None]:
gt_MI = []
naive_MI = []
ma_MI = []
beta = .01

epochs = 5000

In [None]:
for rho in corr:
    cov_XZ = torch.tensor([[1., rho], [rho, 1.]]).to(device)
    XZ = MultivariateNormal(mean_XZ, cov_XZ)
    ma = 0

    naive_model = MINE().to(device)
    ma_model = MINE().to(device)
    optimizer_1 = optim.Adam(naive_model.parameters(), lr=1e-4)
    optimizer_2 = optim.Adam(ma_model.parameters(), lr=1e-4)

    tqdm.write(f'\n======== Rho = {rho} ========')
    tqdm.write(f'GT MI : {-.5 * math.log(1 - rho.item()**2)}')
    for e in trange(1, epochs+1):
        batch = get_samples(XZ, 500).to(device)
        
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        # Needs to multiply minus due to pytorch's update method
        naive_lbd = -(torch.mean(naive_model(batch[:, :2])) - torch.log(
            torch.mean(torch.exp(
                naive_model(torch.stack([batch[:, 0], batch[:, 2]], dim=1))
                ))
            ))  
        
        et = torch.exp(
                ma_model(torch.stack([batch[:, 0], batch[:, 2]], dim=1))
                )
        if e == 1:
            ma = et
        else:
            ma = beta * ma + (1-beta) * et
        
        ma_lbd = -(torch.mean(ma_model(batch[:, :2])) - torch.log(
            torch.mean(et) * (et / ma).mean().detach()
        ))

        naive_lbd.backward()
        ma_lbd.backward()
        optimizer_1.step()
        optimizer_2.step()
        
        if e%250 == 0:
            tqdm.write(f'Naive MI : {-naive_lbd}')
            tqdm.write(f'Moving Avg MI : {-ma_lbd}')
            


    test_data = get_samples(XZ, 500).to(device)

    gt_MI.append(-.5 * math.log(1 - rho.item()**2))
    naive_MI.append((torch.mean(naive_model(test_data[:, :2])) - torch.log(
        torch.mean(
            torch.exp(
                naive_model(torch.stack([batch[:, 0], batch[:, 2]], dim=1))
                )
            ))
        ).item())
    ma_MI.append((torch.mean(ma_model(test_data[:, :2])) - torch.log(
        torch.mean(
            torch.exp(
                ma_model(torch.stack([batch[:, 0], batch[:, 2]], dim=1))
                )
            ))
        ).item())


    del naive_model, ma_model, batch, test_data
    torch.cuda.empty_cache()

In [None]:
corr = corr.detach().cpu().numpy()

plt.plot(corr, gt_MI,'b',label='Ground Truth')
plt.plot(corr, naive_MI,'r',label='Naive MINE')
plt.plot(corr, ma_MI,'g',label='EMA MINE')

plt.xlabel('Rho')
plt.ylabel('Mutual Information')
plt.legend(loc='right')

plt.show()