In [1]:
%matplotlib widget

In [2]:
import os
import pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")
import glob
import matplotlib.pyplot as plt

from multi_virt_v2 import *

## Load Params

In [4]:
base_dir='/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt'
file_name=f'alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_1.00_0.15/params.p'
exp_dir=os.path.join(base_dir, file_name)

print(exp_dir)

/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt/alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_1.00_0.15/params.p


In [23]:
def irt_likelihood(posterior_stats, dimension = 1, n_items=1000,n_models=18):
    # Generate betas
    betas = pyro.sample("b", dist.Normal(posterior_stats['b mu']*torch.ones(n_items, dimension), torch.exp(posterior_stats['b logstd'])))
    
    log_gamma = pyro.sample("log c", dist.Normal(posterior_stats['g mu']*torch.ones(n_items), torch.exp(posterior_stats['g logstd'])))
    gamma = sigmoid(log_gamma)

    # Generate alphas
    alphas = pyro.sample("a", dist.Normal(posterior_stats['a mu'] * torch.ones(n_items, dimension), torch.exp(posterior_stats['a logstd'])))
                            
    # Generate thetas
    thetas = pyro.sample("theta", dist.Normal(posterior_stats['t mu'] * torch.ones(n_models, dimension), torch.exp(posterior_stats['t logstd'])))
                            
                            

    #alphas = positive_transform(alphas)
    #thetas = positive_transform(thetas)
    if dimension > 1:
        prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(torch.sum(alphas[None, :, :] * (thetas[:, None] - betas[None, :]).squeeze(), dim=-1))
    else:
        betas=betas.squeeze()
        gamma=gamma.squeeze()
        alphas=alphas.squeeze()
        thetas=thetas.squeeze()
        prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(alphas[None, :] * (thetas[:, None] - betas[None, :]))

    try:
        lik = dist.Bernoulli(prob).sample()
    except:
        lik = None

    return lik, prob

In [24]:
def get_lik(exp_dir, dim):
    posterior_stats = {}
    
    pyro.clear_param_store()
    pyro.get_param_store().load(exp_dir)
    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    
    for k, v in pyro_param_dict.items():
        posterior_stats[k] = v.mean()
    lik, prob = irt_likelihood(posterior_stats, dim)
    return lik, prob

In [25]:
marginal_logprob_list=[]
correct_list=[]
dim_list=[]
for filepath in glob.iglob(exp_dir):
    dim = int(filepath.split('/')[-2].split('_')[0].rsplit('-',1)[1][3:])
    dim_list.append(dim)
    
    #responses = pd.read_pickle(os.path.join(filepath.rsplit('/', 1)[0], 'responses.p')).to_numpy()
    
    lik, prob = get_lik(filepath, dim)
    marginal_logprob_list.append(torch.log(prob).mean().item())
    
    #correct_list.append((torch.tensor(responses)==lik).float().mean().item())
    
    
    
    

In [26]:
plot_data = list(zip(dim_list, marginal_logprob_list, correct_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals, acc = list(zip(*plot_data))

plot_data

ValueError: not enough values to unpack (expected 3, got 0)

In [29]:
plot_data = list(zip(dim_list, marginal_logprob_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals = list(zip(*plot_data))

plot_data

[(2, -0.49537527561187744),
 (3, -0.46419763565063477),
 (4, -0.4202454686164856),
 (5, -0.3847504258155823),
 (6, -0.3547004163265228),
 (7, -0.3480048179626465),
 (8, -0.33148425817489624),
 (9, -0.3286436200141907),
 (10, -0.3239445090293884),
 (11, -0.32351961731910706),
 (12, -0.3271813690662384),
 (13, -0.32331356406211853),
 (14, -0.3271670341491699),
 (15, -0.320167601108551),
 (16, -0.32583698630332947),
 (17, -0.3237144351005554),
 (18, -0.3243173658847809),
 (36, nan),
 (54, nan),
 (72, nan),
 (90, nan)]

In [30]:

plt.figure()
plt.plot(dim[:-4], marginals[:-4])
plt.xlabel('Dim')
plt.ylabel('log p(x|z)') 

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:

plt.figure()
plt.plot(dim, acc)
plt.xlabel('Dim')
plt.ylabel('reconstruction acc') 

plt.show()