In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import autograd.numpy as np
from autograd import grad
from scipy.optimize import minimize
from scipy.stats import multivariate_normal
import sys, os
sys.path.append('..')
sys.path.append('../..')
from  viabel._distributions import  multivariate_t_logpdf
import autograd.scipy.stats.multivariate_normal as mvn
import autograd.scipy.stats.multivariate_t as mvn


from paragami import (PatternDict,
                      NumericVectorPattern,
                      PSDSymmetricMatrixPattern,
                      FlattenFunctionInput)

from viabel.vb import (mean_field_gaussian_variational_family,
                       mean_field_t_variational_family,
                       t_variational_family,
                       black_box_klvi,
                       black_box_chivi,
                       make_stan_log_density,
                       _get_mu_sigma_pattern,
                       adagrad_optimize
                      )

sns.set_style('white')
sns.set_context('notebook', font_scale=2, rc={'lines.linewidth': 2})
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
def plot_contours(means, covs, colors=None, xlim=[-10,10], ylim=[-3, 3], corr=None):
    xlist = np.linspace(xlim[0], xlim[1], 100)
    ylist = np.linspace(ylim[0], ylim[1], 100)
    X,Y = np.meshgrid(xlist, ylist)
    XY = np.concatenate([X[:,:,np.newaxis], Y[:,:,np.newaxis]], axis=2)
    colors = colors or sns.color_palette()
    for m, c, col in zip(means, covs, colors):
        Z = multivariate_normal.pdf(XY, mean=m, cov=c)
        plt.contour(X, Y, Z, colors=[col], linestyles='solid')
    if corr is not None:
        plt.title('correlation = {:.2f}'.format(corr))
        plt.savefig('../writing/variational-objectives/figures/kl-vb-corr-{:.2f}.pdf'.format(corr), 
                    bbox_inches='tight')
    plt.show()
    


In [None]:
def plot_approx_and_exact_contours(logdensity, var_family, var_param,colors=None, 
                                    xlim=[-2.5,2.5], ylim=[-3, 3],
                                    savepath=None, aux_var=None):
    xlist = np.linspace(*xlim, 100)
    ylist = np.linspace(*ylim, 100)
    X, Y = np.meshgrid(xlist, ylist)
    XY = np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T
    if aux_var is not None:
        a1= XY.shape[0]
        XY = np.concatenate([XY, np.repeat(aux_var[None,:], a1, axis=0)], axis=1)
    zs = np.exp(logdensity(XY))
    Z = zs.reshape(X.shape)
    zsapprox = np.exp(var_family.logdensity(XY, var_param))
    Zapprox = zsapprox.reshape(X.shape)
    colors = colors or sns.color_palette()
    plt.contour(X, Y, Z, colors=[colors[0]], linestyles='solid')
    plt.contour(X, Y, Zapprox, colors=[colors[1]], linestyles='solid')
    if savepath is not None:
        plt.savefig(savepath, bbox_inches='tight')

    plt.show()

In [None]:
rhos = [.5, .88, .94, .99]

def _get_mu_sigma_pattern(dim):
    ms_pattern = PatternDict(free_default=True)
    ms_pattern['mu'] = NumericVectorPattern(length=dim)
    ms_pattern['Sigma'] = PSDSymmetricMatrixPattern(size=dim)
    return ms_pattern
# rhos.reverse()
ds = np.concatenate([np.arange(2,10,2), np.arange(10,105,20,dtype=int)]) # np.arange(2,11,2,dtype=int)
df = pd.DataFrame(columns=['corr', 'd', 'KL'])
inc_df = pd.DataFrame(columns=['corr', 'd', 'KL'])
n_iters = 10000
for rho in rhos:
    for d in ds:
        c2 = rho*np.ones((d,d))
        c2[np.diag_indices_from(c2)] = 1
        m2 = np.zeros(d)
        
        def objective(logc1):
            c1 = np.diag(np.exp(logc1))
            return gaussianKL(m2, c1, m2, c2)

        init_log_std = np.ones(d)*0.2   
        init_var_param1 = np.concatenate([m2, init_log_std])
        mf_t_var_family = mean_field_t_variational_family(d, df=6)
        
        ms_pattern = _get_mu_sigma_pattern(d)
        
        lnpdf2 = FlattenFunctionInput(
        lambda x: multivariate_t_logpdf(x, m2, c2, 100000),
        patterns=ms_pattern, free=True, argnums=1)
        lnpdf = lambda z: mvn.logpdf(z, m2, c2)
        #lnpdf_t = lambda z:
        
        klvi_objective_and_grad = black_box_klvi(mf_t_var_family, lnpdf, 2000)
        klvi_var_param,  klvi_param_history, value_history, grad_norm_history, oplog = \
        adagrad_optimize(1400, klvi_objective_and_grad, init_var_param1, learning_rate=.02, 
                                  learning_rate_end=0.001)
        if d == 2:
            plot_contours(means=[m2]*2, covs=[c2, np.diag(np.exp(res.x))], 
                          colors=[(0.,0.,0.)]+sns.color_palette(),
                          xlim=[-2.5,2.5], corr=rho)
        df = df.append(dict(corr=rho, dimension=d, KL=value_history[-1]), ignore_index=True)
        #inc_kl = gaussianKL(m2, c2, m2, np.diag(np.exp(klvi_var_param[d:])) )
        #inc_df = inc_df.append(dict(corr=rho, dimension=d, KL=inc_kl), ignore_index=True)

#### First we plot KL-divergence at the KLVI solution for increasing correlation and increasing dimensions

In [None]:
# sns.lineplot(data=df, x='corr', y='KL', hue='d', legend='full')
# plt.legend(bbox_to_anchor=(1.04,1), loc="upper left")
# sns.despine()
# plt.show()
sns.lineplot(data=df, x='dimension', y='KL', hue='corr', legend='full')
#plt.legend(rhos, bbox_to_anchor=(1.04,1), loc="upper left")
plt.ylabel('KL divergence')
plt.legend(rhos, loc='upper center', bbox_to_anchor=(0.5, 1.4),
           ncol=3, frameon=False)
sns.despine()
plt.savefig('../writing/variational-objectives/figures/kl-gaussian_mean_t_kl.pdf', bbox_inches='tight')
plt.show()

In [None]:

rhos = [0.50, .94, .99]
# rhos.reverse()
ds = np.concatenate([np.arange(2,10,100)]) # np.arange(2,11,2,dtype=int)
df = pd.DataFrame(columns=['corr', 'd', 'KL'])
inc_df = pd.DataFrame(columns=['corr', 'd', 'KL'])
n_iters = 10000
for rho in rhos:
    for d in ds:
        c2 = rho*np.ones((d,d))
        c2[np.diag_indices_from(c2)] = 1
        m2 = np.zeros(d)
        init_log_std = np.ones(d)*0.2   
        init_var_param1 = np.concatenate([m2, init_log_std])
        mf_t_var_family = mean_field_t_variational_family(d, df=6)
        lnpdf = lambda z: mvn.logpdf(z, m2, c2)
        #lnpdf_t = lambda z:
        
        klvi_objective_and_grad = black_box_klvi(mf_t_var_family, lnpdf, 2000)
        klvi_var_param,  klvi_param_history, value_history, grad_norm_history, oplog = \
        adagrad_optimize(2000, klvi_objective_and_grad, init_var_param1, learning_rate=.02, 
                                  learning_rate_end=0.001)
        if d == 2:
            plot_approx_and_exact_contours(lnpdf, mf_t_var_family, klvi_var_param, colors=[(0.,0.,0.)]+sns.color_palette())
        df = df.append(dict(corr=rho, dimension=d, KL=value_history[-1]), ignore_index=True)
        #inc_kl = gaussianKL(m2, c2, m2, np.diag(np.exp(klvi_var_param[d:])) )
        #inc_df = inc_df.append(dict(corr=rho, dimension=d, KL=inc_kl), ignore_index=True)

#### Now we plot inclusive KL-divergence at the KLVI solution for increasing correlation and increasing dimensions

In [None]:
sns.lineplot(data=inc_df, x='dimension', y='KL', hue='corr', legend='full')
#plt.legend(rhos, bbox_to_anchor=(1.04,1), loc="upper left")
plt.ylabel('Inclusive KL divergence')
plt.legend(rhos, loc='upper center', bbox_to_anchor=(0.5, 1.4),
           ncol=3, frameon=False)
sns.despine()
plt.savefig('../writing/variational-objectives/figures/inckl-vb-d.pdf', bbox_inches='tight')
plt.show()

In [None]:
### Now at inclusive KLVI solution

In [None]:
kl_inc_df = pd.DataFrame(columns=['corr', 'd', 'KL'])
inc_inc_df = pd.DataFrame(columns=['corr', 'd', 'KL'])
for rho in rhos:
    for d in ds:
        c2 = rho*np.ones((d,d))
        c2[np.diag_indices_from(c2)] = 1
        m2 = np.zeros(d)
        def objective(logc1):
            c1 = np.diag(np.exp(logc1))
            return gaussianKL(m2, c2, m2, c1)
        res = minimize(objective, np.ones(d)*0.4, method='BFGS', jac=grad(objective))
        if d == 2:
            plot_contours(means=[m2]*2, covs=[c2, np.diag(np.exp(res.x))], 
                          colors=[(0.,0.,0.)]+sns.color_palette(),
                          xlim=[-2.5,2.5], corr=rho)
        inc_inc_df = inc_inc_df.append(dict(corr=rho, d=d, KL=res.fun), ignore_index=True)
        kl = gaussianKL( m2, np.diag(np.exp(res.x)), m2, c2 )
        kl_inc_df = kl_inc_df.append(dict(corr=rho, d=d, KL=kl), ignore_index=True)

#### Now we plot inclusive KL-divergence at the inclusive KLVI solution for increasing correlation and increasing dimensions

In [None]:
sns.lineplot(data=inc_inc_df, x='d', y='KL', hue='corr', legend='full')
#plt.legend(rhos, bbox_to_anchor=(1.04,1), loc="upper left")
plt.ylabel('Inclusive KL divergence')
plt.legend(rhos, loc='upper center', bbox_to_anchor=(0.5, 1.4),
           ncol=3, frameon=False)
sns.despine()
plt.savefig('../writing/variational-objectives/figures/inc_kl_soln_inckl-vb-d.pdf', bbox_inches='tight')
plt.show()

#### Now we plot KL-divergence at the inclusive KLVI solution for increasing correlation and increasing dimensions

In [None]:
sns.lineplot(data=kl_inc_df, x='d', y='KL', hue='corr', legend='full')
#plt.legend(rhos, bbox_to_anchor=(1.04,1), loc="upper left")
plt.ylabel('KL divergence')
plt.legend(rhos, loc='upper center', bbox_to_anchor=(0.5, 1.4),
           ncol=3, frameon=False)
sns.despine()
plt.savefig('../writing/variational-objectives/figures/kl_at_inckl-vb-d.pdf', bbox_inches='tight')
plt.show()

In [None]:
print(inc_inc_df)

In [None]:
print(kl_inc_df)

In [None]:
print(kl_inc_df.d[125], kl_inc_df.KL[125])