In [1]:
%matplotlib widget

In [2]:
import os
import pickle
import copy

import pandas as pd
import seaborn as sns
import numpy
import torch
import scipy
import scipy.stats
import numpy as np

import pyro
import pyro.infer
import pyro.infer.mcmc
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from tqdm.auto import tqdm

import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import warnings
warnings.filterwarnings('ignore')
sns.set(style="whitegrid")
import glob
import matplotlib.pyplot as plt

from torch.distributions.normal import Normal
from torch.distributions.log_normal import LogNormal

from multi_virt_v2 import *

## Load Params

In [11]:

base_dir='/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt_syncfact0_dim1_2pl'
responses_dir='/Users/phumon/Documents/Research/nlu-test-sets/data_synthetic_0pl/sync_factors0_dim1_mean0_alpha-lognormal-0.00_theta-normal-0.00_irt_all_coded.csv'
file_name=f'lr-0.0001-steps7000_numfactors2_alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_1.00_0.40_particles32/params.p'
exp_dir=os.path.join(base_dir, file_name)


print(exp_dir)

/Users/phumon/Documents/Research/nlu-test-sets/params_mvirt_syncfact0_dim1_2pl/lr-0.0001-steps7000_numfactors2_alpha-lognormal-identity-dim*_theta-normal-identity_nosubsample_1.00_0.40_particles32/params.p


In [31]:
responses = torch.tensor(pd.read_csv(responses_dir, index_col=0).to_numpy()).float()
responses

tensor([[0., 0., 1.,  ..., 1., 1., 1.],
        [0., 1., 1.,  ..., 0., 1., 1.],
        [1., 1., 0.,  ..., 0., 1., 1.],
        ...,
        [0., 1., 1.,  ..., 1., 0., 1.],
        [1., 0., 0.,  ..., 0., 0., 1.],
        [1., 0., 1.,  ..., 1., 1., 1.]])

In [69]:
def get_posterior(exp_dir, dimension):
    posterior_stats = {}
    
    pyro.clear_param_store()
    pyro.get_param_store().load(exp_dir)
    with torch.no_grad():
        pyro_param_dict = dict(pyro.get_param_store().named_parameters())
    print("pyro dict: ", pyro_param_dict.keys())
    for k, v in pyro_param_dict.items():
        if k == "b mu":
            n_items = v.size(0)
        elif k == "t mu":
            n_models = v.size(0)
        posterior_stats[k] = v.mean().item()
    
    betas = Normal(posterior_stats['b mu']*torch.ones(n_items, dimension), torch.exp(torch.tensor(posterior_stats['b logstd'])))
    thetas = Normal(posterior_stats['t mu'] * torch.ones(n_models, dimension), torch.exp(torch.tensor(posterior_stats['t logstd'])))
    
    if "g mu" in posterior_stats.keys():
        log_gamma = Normal(posterior_stats['g mu']*torch.ones(n_items), torch.exp(torch.tensor(posterior_stats['g logstd'])))
    else:
        log_gamma = None
        
    if "a mu" in posterior_stats.keys():
        alphas = LogNormal(posterior_stats['a mu'] * torch.ones(n_items, dimension), torch.exp(torch.tensor(posterior_stats['a logstd'])))
    else:
        alphas = None
    probs = {
                "beta": betas,
                "log_gamma": log_gamma,
                "alpha": alphas,
                "theta": thetas
    }
    
    return posterior_stats, probs, n_items, n_models

In [70]:
def get_prior(item_param_std, alpha_std, dimension, n_items, n_models,num_factors=3):
    # Generate params
    betas = Normal(torch.zeros(n_items, dimension), torch.tensor(item_param_std))
    # Generate thetas
    thetas = Normal(0 * torch.ones(n_models, dimension), torch.tensor(item_param_std))
    
    if num_factors > 2:
        log_gamma = Normal(torch.zeros(n_items), torch.tensor(item_param_std))
    else:
        log_gamma = None

    if num_factors > 1:
        alphas = LogNormal(0 * torch.ones(n_items, dimension), torch.tensor(alpha_std))
    else:
        alphas = None
    
    probs = {
                "beta": betas,
                "log_gamma": log_gamma,
                "alpha": alphas,
                "theta": thetas
    }
    
    return probs

In [71]:
def get_marginal_loglik(prior_probs, posterior_probs, dimension, observations, n_particles=100, num_factors=3):
    log_weights = []
    for i in range(n_particles):
        #log p(x) = logsumexp(log p(x|z_k) + log p(z_k) - log q(z_k), dim=DIM) - log(K)
        log_weight = 0
        posterior_values = {}
        for k, dist_param in prior_probs.items():
            if dist_param is None: continue
            posterior_tmp = posterior_probs[k].sample()
            posterior_values[k] = posterior_tmp
            log_weight += (prior_probs[k].log_prob(posterior_tmp) - posterior_probs[k].log_prob(posterior_tmp)).sum()
        
        thetas = posterior_values['theta']
        betas = posterior_values['beta']
        n_items = betas.size(0)
        
        if "log_gamma" in posterior_values.keys():
            gamma = sigmoid(posterior_values['log_gamma'])
        else:
            gamma = torch.ones(n_items) * 0.5
            
        if "alpha" in posterior_values:
            alphas = posterior_values['alpha']
        else:
            alphas = torch.ones(n_items, dimension)
        
        if dimension > 1:
            prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(torch.sum(alphas[None, :, :] * (thetas[:, None] - betas[None, :]).squeeze(), dim=-1))
        else:
            betas=betas.squeeze()
            gamma=gamma.squeeze()
            alphas=alphas.squeeze()
            thetas=thetas.squeeze()
            prob = gamma[None, :] + (1.0 - gamma[None, :]) * sigmoid(alphas[None, :] * (thetas[:, None] - betas[None, :]))
        lik_dist = torch.distributions.bernoulli.Bernoulli(prob)
        
        log_weight += lik_dist.log_prob(observations).sum()
        log_weights.append(log_weight.item())
    print(log_weights)
    marginal = torch.logsumexp(torch.tensor(log_weights), 0) - torch.log(torch.tensor(n_particles))
    return marginal
    #return lik, prob

In [90]:
lik_dist = torch.distributions.bernoulli.Bernoulli(0.6)
lik_dist.log_prob(torch.tensor(1.0))

tensor(-0.5108)

In [83]:
posterior_probs

{'beta': Normal(loc: torch.Size([1000, 8]), scale: torch.Size([1000, 8])),
 'log_gamma': None,
 'alpha': LogNormal(),
 'theta': Normal(loc: torch.Size([18, 8]), scale: torch.Size([18, 8]))}

In [72]:
marginal_logprob_list=[]
correct_list=[]
dim_list=[]
item_param_std=1; alpha_std=0.4
num_factors=2

for filepath in glob.iglob(exp_dir):
    dimension = int(filepath.split('/')[-2].split('dim')[1].split('_')[0])
    dim_list.append(dimension)
    
    posterior_stats, posterior_probs, n_items, n_models = get_posterior(filepath, dimension)
    prior_probs = get_prior(item_param_std, alpha_std, dimension, n_items, n_models, num_factors)
    marginal = get_marginal_loglik(prior_probs, posterior_probs, dimension, responses)
    
    marginal_logprob_list.append(marginal.item())
    
    
    
    

pyro dict:  dict_keys(['b mu', 'b logstd', 'a mu', 'a logstd', 't mu', 't logstd'])
[-6027.7939453125, -5340.25146484375, -6324.2666015625, -6301.92578125, -6912.79052734375, -6228.2958984375, -5349.55078125, -6260.06396484375, -6212.23095703125, -5178.27392578125, -5889.50927734375, -6221.3271484375, -6678.814453125, -6923.94384765625, -5804.2734375, -5841.3076171875, -6220.97412109375, -6022.92236328125, -6754.3681640625, -6279.78564453125, -6474.1796875, -5181.23486328125, -5988.244140625, -7061.61962890625, -6286.7685546875, -7146.13037109375, -6700.92236328125, -6054.826171875, -6400.5322265625, -5113.79541015625, -7139.33251953125, -5785.388671875, -5228.27294921875, -6030.90771484375, -6572.8935546875, -5099.18115234375, -6517.0478515625, -7048.8466796875, -5792.2041015625, -6147.06494140625, -6175.896484375, -6277.08203125, -5735.07373046875, -7044.33349609375, -6173.88232421875, -5656.58447265625, -5792.52734375, -5701.27490234375, -5941.03466796875, -5920.94287109375, -7176.2

[-7527.92138671875, -6207.50390625, -7870.15966796875, -8574.470703125, -7071.02490234375, -8215.1611328125, -5961.79736328125, -6496.45654296875, -6547.27490234375, -9013.7607421875, -8351.0556640625, -6847.86279296875, -7852.0, -7353.28125, -6366.35302734375, -7790.9775390625, -8476.34765625, -6062.3193359375, -6623.41796875, -6558.126953125, -7660.78759765625, -7401.86279296875, -6909.8955078125, -7715.64013671875, -7154.76318359375, -7804.755859375, -6518.53466796875, -7504.58544921875, -7408.50390625, -6794.13720703125, -7138.4951171875, -7969.11865234375, -7507.8388671875, -6138.05078125, -8284.634765625, -7922.25244140625, -6836.30859375, -5845.91357421875, -7179.7001953125, -8525.4423828125, -7142.93994140625, -8376.9072265625, -6809.14697265625, -8363.04296875, -7386.43798828125, -7653.0029296875, -7113.291015625, -8021.8505859375, -6466.66015625, -5955.80029296875, -5915.07958984375, -6989.40380859375, -7109.2265625, -7444.79150390625, -6712.07421875, -6264.71923828125, -6367

In [73]:
plot_data = list(zip(dim_list, marginal_logprob_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals = list(zip(*plot_data))

plot_data

[(1, -5103.7861328125),
 (2, -5379.85400390625),
 (3, -5527.5537109375),
 (4, -5315.20947265625),
 (5, -5198.419921875),
 (6, -6356.189453125),
 (7, -5612.05126953125),
 (8, -6328.9091796875)]

In [74]:
plot_data = list(zip(dim_list, marginal_logprob_list))
plot_data.sort(key=lambda x: x[0])
dim, marginals = list(zip(*plot_data))

plot_data

[(1, -5103.7861328125),
 (2, -5379.85400390625),
 (3, -5527.5537109375),
 (4, -5315.20947265625),
 (5, -5198.419921875),
 (6, -6356.189453125),
 (7, -5612.05126953125),
 (8, -6328.9091796875)]

In [75]:

plt.figure()
plt.plot(dim, marginals)
plt.xlabel('Dim')
plt.ylabel('log p(x)') 

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Reconstruction

### Load posterior 

In [16]:

repo="/Users/phumon/Documents/Research/nlu-test-sets"
exp_dir = os.path.join(repo, 'params_mvirt_syncfact0_dim1_2pl', f'lr-0.0001-steps7000_numfactors2_alpha-lognormal-identity-dim1_theta-normal-identity_nosubsample_1.00_0.40_particles32')
p = 0.95

#combined_responses = pd.read_pickle(os.path.join(exp_dir, 'responses.p')).reset_index()

In [17]:
param_plot_stats = {}
keys = ['a', 'b', 'g', 't']
for key in keys:
    try:
        with open(os.path.join(exp_dir, 'plot_stats_pickles', f'{key}.p'), 'rb') as f:
            param_plot_stats[key] = pickle.load(f)
    except:
        continue

### Load prior

In [30]:
base_dir='/Users/phumon/Documents/Research/nlu-test-sets/data_synthetic_0pl'
path_prior_1=f'params_sync_factors0_dim1_mean0_alpha-lognormal-0.00_theta-normal-0.00_irt_all_coded.p'
exp_prior_1=os.path.join(base_dir, path_prior_1)

print(exp_prior_1)

with open(exp_prior_1, 'rb') as f:
    prior_dist = pickle.load(f)


/Users/phumon/Documents/Research/nlu-test-sets/data_synthetic_0pl/params_sync_factors0_dim1_mean0_alpha-lognormal-0.00_theta-normal-0.00_irt_all_coded.p


In [31]:
np.vstack(prior_dist['b']).squeeze()

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1.

In [32]:

mse_b = np.square(np.vstack(param_plot_stats['b']['mean']).squeeze() - np.vstack(prior_dist['b']).squeeze()).mean()
mse_b

0.6086030019368591

In [33]:

mse_t = np.square(np.vstack(param_plot_stats['t']['mean']).squeeze() - np.vstack(prior_dist['t']).squeeze()).mean()
mse_t

0.2898738094206891

In [34]:

mse_a = np.square(np.vstack(param_plot_stats['a']['mean']).squeeze() - np.vstack(prior_dist['a']).squeeze()).mean()
mse_a

0.008987708520150168

In [35]:
param_plot_stats['t']['mean']

0     [-0.032381560653448105]
1      [-0.03361646458506584]
2     [-0.027492281049489975]
3      [-0.03303256258368492]
4     [-0.045046307146549225]
5      [-0.04988527297973633]
6      [-0.03277480602264404]
7      [-0.03773510456085205]
8      [-0.04067172855138779]
9      [-0.05370273441076279]
10     [-0.03107420913875103]
11     [-0.03520526364445686]
12     [-0.03344722092151642]
13    [-0.036950379610061646]
14     [-0.04336337000131607]
15      [-0.0324644073843956]
16    [-0.045670296996831894]
17     [-0.04582581669092178]
Name: mean, dtype: object

In [41]:
betas = torch.tensor(np.vstack(param_plot_stats['b']['mean']).squeeze())
thetas =  torch.tensor(np.vstack(param_plot_stats['t']['mean']).squeeze())
alphas =  torch.tensor(np.vstack(param_plot_stats['a']['mean']).squeeze())
bminust_posterior = alphas * (thetas[:, None] - betas[None, :])
print(betas.size())
print(thetas.size())
bminust_posterior.size()

torch.Size([1000])
torch.Size([18])


torch.Size([18, 1000])

In [49]:
param_plot_stats['b']['var'].mean()

array([0.54663284])

In [38]:
bminust_posterior.size()

torch.Size([18, 1000])

In [39]:
thetas.size()

torch.Size([18])

In [42]:
dim=1
betas = torch.ones(betas.size(0),)
thetas =  torch.ones(thetas.size(0)) * 0.5
bminust_prior = thetas[:, None] - betas[None, :]
bminust_prior

tensor([[-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
        ...,
        [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]])

In [46]:
posterior_dist

NameError: name 'posterior_dist' is not defined

In [None]:
#dim>1

betas_prior = torch.tensor(np.vstack(prior_dist['b']).squeeze())
thetas_prior =  torch.tensor(np.vstack(prior_dist['t']).squeeze())
bminust_prior = thetas_prior[:, None] - betas_prior[None, :]
print(betas_prior.size())
print(thetas_prior.size())

bminust_prior.size()

In [None]:
#dim>1
betas_prior = torch.ones(betas.size(0),betas.size(1))
thetas_prior =  torch.ones(thetas.size(0), thetas.size(1)) * 0.5
bminust_prior = thetas_prior[:, None] - betas_prior[None, :]
print(betas_prior.size())
print(thetas_prior.size())

bminust_prior

In [None]:
bminust_posterior

In [44]:

mse_b = torch.square(bminust_prior - bminust_posterior).squeeze().mean()
mse_b

tensor(0.1354, dtype=torch.float64)