# Evaluating Generative Discrete Latent Variable Model

In [1]:
import torch
import pyro
import copy

from pyro.ops.indexing import Vindex
import pyro.distributions as dist
from pyro.distributions import transforms, constraints
from pyro.infer import config_enumerate
import pyro.poutine as poutine
from pyro.infer.autoguide import initialization as mcmc_inits
from pyro.infer import infer_discrete, MCMC, NUTS, HMC, Predictive
from pyro.distributions.util import broadcast_shape
from scipy.stats import pearsonr, spearmanr

import warnings
warnings.filterwarnings('ignore')

## Models

### G Model

In [2]:
def g_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        return Z, G

### Q Model

In [3]:
def gq_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    base_Q_c = dist.Gamma(torch.ones(n_c), torch.ones(n_c))
    lam_Q_c = pyro.sample("lam_Q_c", dist.TransformedDistribution(base_Q_c, [transforms.OrderedTransform()]))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))  

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        Q = pyro.sample('Q', dist.Poisson(Vindex(lam_Q_c)[...,Z.long()]).mask(data["mask"]["Q"]), obs=data["data"]["Q"])

        return Z, G, Q

## Generate Data

In [4]:
def generate_data(model, params, n_data = 2000):
    
    none_data = {"G": None,
               "Q": None}
    data_mask = {"G": torch.ones(n_data).bool(),
                "Q": torch.ones(n_data).bool()}
    none_data = {"data": none_data, "mask": data_mask}

    model = poutine.condition(gq_model, true_params)
    trace = poutine.trace(model).get_trace(none_data)
    
    data = dict()
    
    for n in trace.nodes.keys():
        if n in ["G", "Q"]:
            data[n] = trace.nodes[n]["value"]
            
    return {"data": data, "mask": data_mask}


true_params = {"lam_Q_c": torch.tensor([[3.5, 14.0, 22., 28.0]]), "lam_G_c": torch.tensor([[1.0, 5.5, 12.0, 15.0]]), "std_G_c": torch.tensor([[1.0, 0.5, 1.0, 2.0]]), "pi_Z_c": torch.tensor([[0.25, 0.45, 0.15, 0.15]])}
train_data = generate_data(gq_model, true_params, n_data = 2000)
test_data = generate_data(gq_model, true_params, n_data = 2000)

print("G - Q spearmanr", spearmanr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])
print("G - Q pearsonr", pearsonr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])

G - Q spearmanr 0.8441171014786846
G - Q pearsonr 0.88024442581624


## MCMC

In [5]:
p = Predictive(gq_model, num_samples = 2000)
init_params = p(train_data)
init_params = {k: torch.mean(v, dim = 0) for k,v in init_params.items() if k not in ["Z", "G", "Q"]}
init_params

{'lam_G_c': tensor([[ 0.9582,  5.6122,  9.9920, 14.3724]]),
 'std_G_c': tensor([[1.0299, 1.0035, 0.9714, 0.9720]]),
 'lam_Q_c': tensor([[ 1.0370, 49.8054, 57.0079, 65.8973]]),
 'pi_Z_c': tensor([[0.2490, 0.2442, 0.2578, 0.2490]])}

In [6]:
nuts_kernel = NUTS(gq_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
mcmc.run(train_data)
mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:33,  2.96it/s, step size=7.16e-02, acc. prob=0.926]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      1.00      0.03      1.00      0.97      1.04     16.10      1.03
lam_G_c[0,1]      5.52      0.01      5.52      5.50      5.53     37.27      0.99
lam_G_c[0,2]     11.95      0.08     11.94     11.85     12.01     56.38      0.98
lam_G_c[0,3]     15.19      0.17     15.21     15.00     15.37     37.83      1.01
lam_Q_c[0,0]      3.59      0.11      3.58      3.51      3.73     45.82      1.03
lam_Q_c[0,1]     13.96      0.15     13.95     13.87     14.22     42.03      1.00
lam_Q_c[0,2]     22.08      0.36     22.09     21.66     22.44     49.34      1.03
lam_Q_c[0,3]     27.97      0.39     27.97     27.70     28.53     75.57      1.00
 pi_Z_c[0,0]      0.25      0.01      0.25      0.24      0.26     62.20      0.99
 pi_Z_c[0,1]      0.46      0.01      0.46      0.45      0.47     64.91      1.01
 pi_Z_c[0,2]      0.15      0.01      0.15      0.14      0.16     33.47      1.01
 pi




## evaluation – predictive likelihood



In [10]:
def impute_data(data, sites=[]):
    
    impute_data = copy.deepcopy(data)
    for k in sites:
        impute_data['data'][k] = None ## remove data site
        impute_data['mask'][k] = torch.zeros(impute_data["mask"][k].shape[0]).bool() ## set data mask to False
    return impute_data



def compute_exp_pred_lik(post_loglik):
    
    ### computes pointwise expected log predictive density at each data point
    sample_mean_exp_n = torch.mean(torch.exp(post_loglik), 0)
    exp_log_lik = torch.exp(torch.mean(torch.log(sample_mean_exp_n), axis=0))
    #exp_log_density[k] = (post_loglik[k].logsumexp(0) - math.log(post_loglik[k].shape[0])).sum().item()
    return exp_log_lik.item()



def evaluate_pred_lik(mcmc, model, data, sites = ["G", "Q"]):

    ### computes predictive likelihood
    params = mcmc.get_samples()
    num_samples = list(params.values())[0].shape[0] 
    sample_plate = pyro.plate("samples", num_samples, dim=-2)
    
    for site in sites: ## loop through observed sites

        ## infer P(Z | Q, T, params)______
        infer_z_model = poutine.condition(model, params)
        infer_z_model = sample_plate(infer_z_model)
        infer_z_model = infer_discrete(infer_z_model, first_available_dim=-3, temperature=1)
        
        imputed_data = impute_data(data, sites=[site]) ## impute observed site
        trace = poutine.trace(infer_z_model).get_trace(imputed_data)
        Z = trace.nodes["Z"]["value"]
        Z_params = {"Z": Z, **params}

        ## infer P(G | Z, params)______
        infer_site_model = poutine.condition(model, Z_params)
        infer_site_model = sample_plate(infer_site_model)
        cond_model = infer_discrete(infer_site_model, first_available_dim=-3, temperature=1)
        trace = poutine.trace(infer_site_model).get_trace(test_data)
        trace.compute_log_prob()
        exp_pred_lik = compute_exp_pred_lik(trace.nodes[site]["log_prob"])
        
        print(f"{site} exp_pred_lik: {exp_pred_lik}")
    return exp_pred_lik


exp_pred_lik = evaluate_pred_lik(mcmc, gq_model, test_data, sites = ["G", "Q"])

G exp_pred_lik: 0.06633149832487106
Q exp_pred_lik: 0.008801787160336971


### ...now check model with G only

In [8]:
nuts_kernel = NUTS(g_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
mcmc.run(train_data)
mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:24,  4.16it/s, step size=1.62e-01, acc. prob=0.922]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      1.01      0.05      1.01      0.95      1.06     34.54      1.02
lam_G_c[0,1]      5.51      0.02      5.51      5.48      5.52     54.11      0.98
lam_G_c[0,2]     12.01      0.10     12.01     11.88     12.08     19.55      1.06
lam_G_c[0,3]     15.21      0.26     15.16     14.85     15.44     15.01      0.99
 pi_Z_c[0,0]      0.25      0.01      0.25      0.25      0.27     36.70      0.99
 pi_Z_c[0,1]      0.46      0.01      0.46      0.45      0.47     57.37      0.98
 pi_Z_c[0,2]      0.15      0.01      0.15      0.14      0.17     13.32      1.01
 pi_Z_c[0,3]      0.13      0.01      0.13      0.12      0.15     12.92      1.00
std_G_c[0,0]      0.99      0.04      0.99      0.95      1.04     80.17      0.98
std_G_c[0,1]      0.49      0.01      0.49      0.47      0.50     40.11      1.00
std_G_c[0,2]      1.00      0.07      1.00      0.91      1.06     11.11      1.09
std




In [9]:
exp_pred_lik = evaluate_pred_lik(mcmc, g_model, test_data, sites = ["G"])

G exp_pred_lik: 0.08961589634418488
