In [1]:
import torch
import math

import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from scipy.stats import norm, gaussian_kde
from tqdm import tqdm

In [2]:
def calc_mi(vae, x_validation, device='cpu', S=1):
    '''Approximate the mutual information between x and z
    I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
    Modified from the implementation by Author of the paper "LAGGING INFERENCE NETWORKS 
    AND POSTERIOR COLLAPSE IN VARIATIONAL AUTOENCODERS"
    see https://github.com/jxhe/vae-lagging-encoder/blob/master/modules/encoders/encoder.py
    
    This function will calculate the mutual information during the training with LIN trick,
    as a criterion wehther we should stop the agressive training. 
    
    Parameters:
    -----------
    vae: A vae instance, with .infer() method
    
    x_validation: Validation X data set
    
    Returns: Float
    '''
    N_batch = x_validation.shape[0]
    
    #infer zs with encoder 
    # 2D Tensor, shape [N_batch, z_dim]
    mean, std = self.infer(x_validation)
    assert std.shape == (N_batch, vae.z_dim)
    assert mean.shape == (N_batch, vae.z_dim)
    
    ## Term 1: calculate Negative Entropy, E_{q(z|x)}log(q(z|x))
    # E_{q(z|x)}log(q(z|x)) = -0.5* z_dim *log(2*\pi) - 0.5*(1+log(std**2)).sum(-1)
    # 1D Tensor, shape [N_batch]
    neg_entropy = (-0.5 * vae.z_dim * math.log(2. * math.pi)- 0.5 * (1 + torch.log(std**2)).sum(-1))
    
    
    ## Term 2: calculate E_{q(z|x)}log(q(z))
    #sample zs with the parameters
    if device == 'cuda': z_samples = torch.normal(0,1,size=(S, N_batch, vae.z_dim)).cuda() * std + mean
    if device == 'cpu': z_samples = torch.normal(0,1,size=(S, N_batch, vae.z_dim)) * std + mean
    assert z_samples.shape == (S, N_batch, vae.z_dim)
    
    #evaluate sampled z's under variational distribution
    # 2D Tensor, shape [S, N_batch]
    norm1 = torch.distributions.Normal(mean, std)
    log_qz= torch.sum(norm1.log_prob(z_samples), axis=-1)
    
    return (torch.mean(neg_entropy) - torch.mean(log_qz)).item()