# Evaluating Generative Discrete Latent Variable Model

In [1]:
import torch
import pyro
import copy
import numpy as np
from sklearn.metrics import mean_squared_error, f1_score
from scipy.stats import pearsonr, spearmanr

from pyro.ops.indexing import Vindex
import pyro.distributions as dist
from pyro.distributions import transforms, constraints
from pyro.infer import config_enumerate
import pyro.poutine as poutine
from pyro.infer.autoguide import initialization as mcmc_inits
from pyro.infer import infer_discrete, MCMC, NUTS, HMC, Predictive
from pyro.distributions.util import broadcast_shape

import warnings
warnings.filterwarnings('ignore')

## Models

### GQ Model

In [118]:
def gq_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    base_Q_c = dist.Gamma(torch.ones(n_c), torch.ones(n_c))
    lam_Q_c = pyro.sample("lam_Q_c", dist.TransformedDistribution(base_Q_c, [transforms.OrderedTransform()]))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))  

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        Q = pyro.sample('Q', dist.Poisson(Vindex(lam_Q_c)[...,Z.long()]).mask(data["mask"]["Q"]), obs=data["data"]["Q"])

        return Z, G, Q

## Generate Data

In [4]:
def generate_data(model, params, n_data = 2000):
    
    none_data = {"G": None,
               "Q": None}
    data_mask = {"G": torch.ones(n_data).bool(),
                "Q": torch.ones(n_data).bool()}
    none_data = {"data": none_data, "mask": data_mask}

    model = poutine.condition(gq_model, true_params)
    trace = poutine.trace(model).get_trace(none_data)
    
    data = dict()
    
    for n in trace.nodes.keys():
        if n in ["G", "Q"]:
            data[n] = trace.nodes[n]["value"]
            
    return {"data": data, "mask": data_mask}


true_params = {"lam_Q_c": torch.tensor([[3.5, 14.0, 22., 28.0]]), "lam_G_c": torch.tensor([[1.0, 5.5, 12.0, 15.0]]), "std_G_c": torch.tensor([[1.0, 0.5, 1.0, 2.0]]), "pi_Z_c": torch.tensor([[0.25, 0.45, 0.15, 0.15]])}
train_data = generate_data(gq_model, true_params, n_data = 2000)
test_data = generate_data(gq_model, true_params, n_data = 2000)

print("G - Q spearmanr", spearmanr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])
print("G - Q pearsonr", pearsonr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])

G - Q spearmanr 0.8555847101426786
G - Q pearsonr 0.8752201882063496


## MCMC

In [5]:
p = Predictive(gq_model, num_samples = 2000)
init_params = p(train_data)
init_params = {k: torch.mean(v, dim = 0) for k,v in init_params.items() if k not in ["Z", "G", "Q"]}
init_params

{'lam_G_c': tensor([[ 1.0119,  5.6123, 10.1842, 14.5476]]),
 'std_G_c': tensor([[1.0086, 0.9856, 0.9943, 0.9906]]),
 'lam_Q_c': tensor([[ 1.0088, 14.6024, 21.3029, 29.6443]]),
 'pi_Z_c': tensor([[0.2396, 0.2465, 0.2570, 0.2569]])}

In [68]:
nuts_kernel = NUTS(gq_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
gq_mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
gq_mcmc.run(train_data)
gq_mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:34,  2.87it/s, step size=7.15e-02, acc. prob=0.963]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      1.06      0.05      1.06      1.02      1.11     20.05      0.98
lam_G_c[0,1]      5.51      0.02      5.51      5.49      5.53     91.62      0.98
lam_G_c[0,2]     11.93      0.09     11.93     11.84     12.04     26.69      0.99
lam_G_c[0,3]     14.90      0.16     14.88     14.75     15.04     32.41      1.03
lam_Q_c[0,0]      3.46      0.09      3.46      3.36      3.55     33.77      0.98
lam_Q_c[0,1]     13.81      0.12     13.81     13.67     13.90     70.58      0.98
lam_Q_c[0,2]     21.55      0.38     21.57     21.15     22.00     25.65      1.00
lam_Q_c[0,3]     28.18      0.30     28.12     27.81     28.35     38.77      0.99
 pi_Z_c[0,0]      0.26      0.01      0.26      0.25      0.27     48.79      0.98
 pi_Z_c[0,1]      0.43      0.01      0.43      0.42      0.44     43.14      0.98
 pi_Z_c[0,2]      0.14      0.01      0.14      0.12      0.15     34.13      1.01
 pi




## evaluation – predictive likelihood



In [114]:
def impute_data(data, sites=[]):
    
    impute_data = copy.deepcopy(data)
    for k in sites:
        impute_data['data'][k] = None ## remove data site
        impute_data['mask'][k] = torch.zeros(impute_data["mask"][k].shape[0]).bool() ## set data mask to False
    return impute_data


def evaluate_point_predictons(pred_sites, gt_data):
    pred_comp = dict()

    for k in gt_data["data"].keys():
        if k in pred_sites.keys():
            
            hat_data = torch.mean(pred_sites[k], dim = 0)
            
            mse = mean_squared_error(gt_data["data"][k].type(torch.float), hat_data.type(torch.float))
            f1 = f1_score(gt_data["data"][k].type(torch.int), hat_data.type(torch.int), average='weighted')
            print(f"{str(k)}: mse {mse}, weighted f1 {f1}")


def compute_exp_pred_lik(post_loglik):
    
    ### computes pointwise expected log predictive density at each data point
    sample_mean_exp_n = torch.mean(torch.exp(post_loglik), 0)
    exp_log_lik = torch.exp(torch.mean(torch.log(sample_mean_exp_n), axis=0))
    #exp_log_density[k] = (post_loglik[k].logsumexp(0) - math.log(post_loglik[k].shape[0])).sum().item()
    return exp_log_lik.item()



def evaluate_pred_lik(mcmc, model, data, sites = ["G", "Q"]):

    ### computes predictive likelihood
    params = mcmc.get_samples()    
    num_samples = list(params.values())[0].shape[0] 
    sample_plate = pyro.plate("samples", num_samples, dim=-2)
    
    pred_sites = dict()
    
    for site in sites: ## loop through observed sites

        ## infer P(Z | Q, T, params)______
        infer_z_model = poutine.condition(model, params)
        infer_z_model = sample_plate(infer_z_model)
        infer_z_model = infer_discrete(infer_z_model, first_available_dim=-3, temperature=1)
        
        imputed_data = impute_data(data, sites=[site]) ## impute observed site
        impute_trace = poutine.trace(infer_z_model).get_trace(imputed_data)
        Z = impute_trace.nodes["Z"]["value"]
        Z_params = {"Z": Z, **params}

        ## infer P(G | Z, params)______
        infer_site_model = poutine.condition(model, Z_params)
        infer_site_model = sample_plate(infer_site_model)
        #infer_site_model = infer_discrete(infer_site_model, first_available_dim=-3, temperature=1)
        trace = poutine.trace(infer_site_model).get_trace(test_data)
        trace.compute_log_prob()
        
        exp_pred_lik = compute_exp_pred_lik(trace.nodes[site]["log_prob"])
        print(f"{site}: exp_pred_lik {exp_pred_lik}")
            
        pred_sites[site] = impute_trace.nodes[site]["value"]
        
    return pred_sites

In [115]:
gq_pred_sites = evaluate_pred_lik(gq_mcmc, gq_model, test_data, sites = ["G", "Q"])
evaluate_point_predictons(gq_pred_sites, test_data)

G: exp_pred_lik 0.19960589706897736
Q: exp_pred_lik 0.06829415261745453
G: mse 5.509785175323486, weighted f1 0.2778919213301624
Q: mse 16.895063400268555, weighted f1 0.05578961422286238


## G Model

In [119]:
def g_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        return Z, G

In [120]:
nuts_kernel = NUTS(g_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
g_mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
g_mcmc.run(train_data)
g_mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:18,  5.28it/s, step size=2.30e-01, acc. prob=0.789]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      1.05      0.04      1.05      1.01      1.10     37.00      0.98
lam_G_c[0,1]      5.50      0.02      5.50      5.47      5.52     71.65      0.98
lam_G_c[0,2]     12.02      0.11     12.01     11.87     12.10     18.15      0.99
lam_G_c[0,3]     15.27      0.43     15.30     14.59     15.59     11.46      0.99
 pi_Z_c[0,0]      0.26      0.01      0.26      0.24      0.27     22.56      1.12
 pi_Z_c[0,1]      0.43      0.01      0.43      0.42      0.45     32.18      1.00
 pi_Z_c[0,2]      0.16      0.03      0.16      0.13      0.19     10.78      0.98
 pi_Z_c[0,3]      0.15      0.03      0.14      0.12      0.18     10.80      1.01
std_G_c[0,0]      1.02      0.03      1.03      1.01      1.06     30.98      1.03
std_G_c[0,1]      0.51      0.01      0.50      0.49      0.52     48.92      1.00
std_G_c[0,2]      1.04      0.10      1.07      0.93      1.16     13.66      0.99
std




In [121]:
g_pred_sites = evaluate_pred_lik(g_mcmc, g_model, test_data, sites = ["G"])
evaluate_point_predictons(g_pred_sites, test_data)

G: exp_pred_lik 0.08619570732116699
G: mse 26.321210861206055, weighted f1 0.0561665563468427


## Q Model

In [122]:
def q_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_Q_c = dist.Gamma(torch.ones(n_c), torch.ones(n_c))
    lam_Q_c = pyro.sample("lam_Q_c", dist.TransformedDistribution(base_Q_c, [transforms.OrderedTransform()]))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))  

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        Q = pyro.sample('Q', dist.Poisson(Vindex(lam_Q_c)[...,Z.long()]).mask(data["mask"]["Q"]), obs=data["data"]["Q"])

        return Z, Q

In [123]:
nuts_kernel = NUTS(q_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
q_mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
q_mcmc.run(train_data)
q_mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:14,  7.09it/s, step size=2.44e-02, acc. prob=0.982]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_Q_c[0,0]      3.42      0.09      3.42      3.32      3.51     42.68      0.98
lam_Q_c[0,1]     13.55      0.43     13.63     13.13     13.93      4.68      1.45
lam_Q_c[0,2]     15.11      0.28     15.09     14.73     15.38     48.34      0.98
lam_Q_c[0,3]     26.72      0.39     26.65     26.27     27.09     24.26      1.01
 pi_Z_c[0,0]      0.26      0.01      0.26      0.26      0.28     17.08      1.04
 pi_Z_c[0,1]      0.25      0.06      0.25      0.22      0.32      3.55      1.80
 pi_Z_c[0,2]      0.24      0.06      0.23      0.16      0.29      3.32      1.83
 pi_Z_c[0,3]      0.25      0.01      0.25      0.24      0.27     34.03      0.98

Number of divergences: 29





In [124]:
q_pred_sites = evaluate_pred_lik(q_mcmc, q_model, test_data, sites = ["Q"])
evaluate_point_predictons(q_pred_sites, test_data)

Q: exp_pred_lik 0.02917259931564331
Q: mse 87.25928497314453, weighted f1 0.02256849137366103
