In [1]:
%matplotlib widget

In [2]:
import os
import pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")
import glob
import matplotlib.pyplot as plt

from torch.distributions.normal import Normal
from torch.distributions.log_normal import LogNormal

from multi_virt_v2 import *

## Load Params

In [3]:

base_dir='/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt_sync'
file_name=f'lr-0.0001-steps-$4000-alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_2.00_0.40_particles8/params.p'
exp_dir=os.path.join(base_dir, file_name)


print(exp_dir)

/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt_sync/lr-0.0001-steps-$4000-alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_2.00_0.40_particles8/params.p


In [4]:
def get_posterior(exp_dir, dimension):
    posterior_stats = {}
    
    pyro.clear_param_store()
    pyro.get_param_store().load(exp_dir)
    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    
    for k, v in pyro_param_dict.items():
        if k == "g mu":
            n_items = v.size(0)
        elif k == "t mu":
            n_models = v.size(0)
        posterior_stats[k] = v.mean().item()
    
    betas = Normal(posterior_stats['b mu']*torch.ones(n_items, dimension), torch.exp(torch.tensor(posterior_stats['b logstd'])))
    log_gamma = Normal(posterior_stats['g mu']*torch.ones(n_items), torch.exp(torch.tensor(posterior_stats['g logstd'])))
    alphas = LogNormal(posterior_stats['a mu'] * torch.ones(n_items, dimension), torch.exp(torch.tensor(posterior_stats['a logstd'])))
    thetas = Normal(posterior_stats['t mu'] * torch.ones(n_models, dimension), torch.exp(torch.tensor(posterior_stats['t logstd'])))
    probs = {
                "beta": betas,
                "log_gamma": log_gamma,
                "alpha": alphas,
                "theta": thetas
    }
    
    return posterior_stats, probs, n_items, n_models

In [5]:
def get_prior(item_param_std, alpha_std, dimension, n_items, n_models):
    # Generate params
    betas = Normal(torch.zeros(n_items, dimension), torch.tensor(item_param_std))
    log_gamma = Normal(torch.zeros(n_items), torch.tensor(item_param_std))

    alphas = LogNormal(0 * torch.ones(n_items, dimension), torch.tensor(alpha_std))

    # Generate thetas
    thetas = Normal(0 * torch.ones(n_models, dimension), torch.tensor(item_param_std))
    probs = {
                "beta": betas,
                "log_gamma": log_gamma,
                "alpha": alphas,
                "theta": thetas
    }
    
    return probs

In [8]:
def get_marginal_loglik(prior_probs, posterior_probs, dimension, n_particles=10):
    log_weights = []
    for i in range(n_particles):
        #log p(x) = logsumexp(log p(x|z_k) + log p(z_k) - log q(z_k), dim=DIM) - log(K)
        log_weight = 0
        prior_values = {}
        for k, dist_param in prior_probs.items():
            prior_tmp = prior_probs[k].sample()
            prior_values[k] = prior_tmp
            posterior_tmp = posterior_probs[k].sample()
            log_weight += (prior_probs[k].log_prob(prior_tmp) - posterior_probs[k].log_prob(posterior_tmp)).sum()
        
        gamma = sigmoid(prior_values['log_gamma'])
        alphas = prior_values['alpha']
        thetas = prior_values['theta']
        betas = prior_values['beta']
        if dimension > 1:
            prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(torch.sum(alphas[None, :, :] * (thetas[:, None] - betas[None, :]).squeeze(), dim=-1))
        else:
            betas=betas.squeeze()
            gamma=gamma.squeeze()
            alphas=alphas.squeeze()
            thetas=thetas.squeeze()
            prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(alphas[None, :] * (thetas[:, None] - betas[None, :]))
        log_weight += torch.log(prob).sum()
        log_weights.append(log_weight)
    marginal = torch.logsumexp(log_weight, 0) - torch.log(torch.tensor(n_particles))
    return marginal
    #return lik, prob

In [9]:
marginal_logprob_list=[]
correct_list=[]
dim_list=[]
item_param_std=2; alpha_std=0.4

for filepath in glob.iglob(exp_dir):
    dimension = int(filepath.split('/')[-2].split('_')[0].rsplit('-',1)[1][3:])
    dim_list.append(dimension)
    
    posterior_stats, posterior_probs, n_items, n_models = get_posterior(filepath, dimension)
    prior_probs = get_prior(item_param_std, alpha_std, dimension, n_items, n_models)
    marginal = get_marginal_loglik(prior_probs, posterior_probs, dimension)
    
    marginal_logprob_list.append(marginal.item())
    
    
    
    

In [10]:
plot_data = list(zip(dim_list, marginal_logprob_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals = list(zip(*plot_data))

plot_data

[(1, -17705.234375),
 (2, -16971.255859375),
 (3, -18588.919921875),
 (4, -23208.138671875),
 (5, -27083.51953125),
 (6, -27869.66796875),
 (7, -23846.10546875),
 (8, -25149.951171875)]

In [11]:
plot_data = list(zip(dim_list, marginal_logprob_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals = list(zip(*plot_data))

plot_data

[(1, -17705.234375),
 (2, -16971.255859375),
 (3, -18588.919921875),
 (4, -23208.138671875),
 (5, -27083.51953125),
 (6, -27869.66796875),
 (7, -23846.10546875),
 (8, -25149.951171875)]

In [12]:

plt.figure()
plt.plot(dim[:-4], marginals[:-4])
plt.xlabel('Dim')
plt.ylabel('log p(x)') 

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:

plt.figure()
plt.plot(dim, acc)
plt.xlabel('Dim')
plt.ylabel('reconstruction acc') 

plt.show()