# Evaluating Generative Discrete Latent Variable Model

In [9]:
import torch
import pyro
import copy
import numpy as np
from sklearn.metrics import mean_squared_error, f1_score
from scipy.stats import pearsonr, spearmanr

from pyro.ops.indexing import Vindex
import pyro.distributions as dist
from pyro.distributions import transforms, constraints
from pyro.infer import config_enumerate
import pyro.poutine as poutine
from pyro.infer.autoguide import initialization as mcmc_inits
from pyro.infer import infer_discrete, MCMC, NUTS, HMC, Predictive
from pyro.distributions.util import broadcast_shape

import warnings
warnings.filterwarnings('ignore')

## Models

### G Model

In [2]:
def g_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        return Z, G

### Q Model

In [3]:
def gq_model(data=None):
    
    n_c = 4 ## number of latent classes

    ## PRIOR
    base_G_c = dist.Normal(torch.ones(n_c), torch.ones(n_c))
    lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))
    std_G_c = pyro.sample("std_G_c", dist.Gamma(torch.ones(n_c), torch.ones(n_c)).to_event(1))

    base_Q_c = dist.Gamma(torch.ones(n_c), torch.ones(n_c))
    lam_Q_c = pyro.sample("lam_Q_c", dist.TransformedDistribution(base_Q_c, [transforms.OrderedTransform()]))

    pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(n_c) / n_c))  

    ## LIKELIHOOD
    with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

        Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})
        G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[...,Z.long()], Vindex(std_G_c)[...,Z.long()]).mask(data["mask"]["G"]),obs=data["data"]["G"])
        Q = pyro.sample('Q', dist.Poisson(Vindex(lam_Q_c)[...,Z.long()]).mask(data["mask"]["Q"]), obs=data["data"]["Q"])

        return Z, G, Q

## Generate Data

In [4]:
def generate_data(model, params, n_data = 2000):
    
    none_data = {"G": None,
               "Q": None}
    data_mask = {"G": torch.ones(n_data).bool(),
                "Q": torch.ones(n_data).bool()}
    none_data = {"data": none_data, "mask": data_mask}

    model = poutine.condition(gq_model, true_params)
    trace = poutine.trace(model).get_trace(none_data)
    
    data = dict()
    
    for n in trace.nodes.keys():
        if n in ["G", "Q"]:
            data[n] = trace.nodes[n]["value"]
            
    return {"data": data, "mask": data_mask}


true_params = {"lam_Q_c": torch.tensor([[3.5, 14.0, 22., 28.0]]), "lam_G_c": torch.tensor([[1.0, 5.5, 12.0, 15.0]]), "std_G_c": torch.tensor([[1.0, 0.5, 1.0, 2.0]]), "pi_Z_c": torch.tensor([[0.25, 0.45, 0.15, 0.15]])}
train_data = generate_data(gq_model, true_params, n_data = 2000)
test_data = generate_data(gq_model, true_params, n_data = 2000)

print("G - Q spearmanr", spearmanr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])
print("G - Q pearsonr", pearsonr(train_data["data"]["G"].numpy(), train_data["data"]["Q"].numpy())[0])

G - Q spearmanr 0.8427085529263129
G - Q pearsonr 0.869439653917436


## MCMC

In [5]:
p = Predictive(gq_model, num_samples = 2000)
init_params = p(train_data)
init_params = {k: torch.mean(v, dim = 0) for k,v in init_params.items() if k not in ["Z", "G", "Q"]}
init_params

{'lam_G_c': tensor([[ 1.0146,  5.5856, 10.1043, 14.6680]]),
 'std_G_c': tensor([[1.0145, 0.9698, 1.0560, 0.9983]]),
 'lam_Q_c': tensor([[ 0.9730, 11.0158, 53.3710, 62.4217]]),
 'pi_Z_c': tensor([[0.2501, 0.2467, 0.2575, 0.2457]])}

In [6]:
nuts_kernel = NUTS(gq_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
mcmc.run(train_data)
mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:44,  2.22it/s, step size=6.16e-02, acc. prob=0.952]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      0.91      0.04      0.91      0.86      0.94     57.52      1.00
lam_G_c[0,1]      5.52      0.02      5.52      5.50      5.55     72.11      0.99
lam_G_c[0,2]     11.88      0.09     11.87     11.75     11.94     44.90      1.03
lam_G_c[0,3]     14.85      0.21     14.86     14.68     15.14     70.81      0.99
lam_Q_c[0,0]      3.53      0.08      3.53      3.45      3.63     51.04      0.99
lam_Q_c[0,1]     14.04      0.14     14.05     13.91     14.15     49.28      1.05
lam_Q_c[0,2]     21.23      0.40     21.22     20.60     21.42     34.73      0.98
lam_Q_c[0,3]     28.52      0.41     28.49     28.16     29.13     32.13      0.98
 pi_Z_c[0,0]      0.25      0.01      0.26      0.24      0.26     40.93      1.02
 pi_Z_c[0,1]      0.45      0.01      0.45      0.44      0.47     45.72      1.01
 pi_Z_c[0,2]      0.13      0.01      0.13      0.12      0.14     61.17      1.01
 pi




## evaluation – predictive likelihood



In [28]:
def impute_data(data, sites=[]):
    
    impute_data = copy.deepcopy(data)
    for k in sites:
        impute_data['data'][k] = None ## remove data site
        impute_data['mask'][k] = torch.zeros(impute_data["mask"][k].shape[0]).bool() ## set data mask to False
    return impute_data


def evaluate_point_predictons(pred_sites, gt_data):
    pred_comp = dict()

    for k in gt_data["data"].keys():
        if k in pred_sites.keys():
            
            hat_data = torch.mean(pred_sites[k], dim = 0)
            
            mse = mean_squared_error(gt_data["data"][k].type(torch.float), hat_data.type(torch.float))
            f1 = f1_score(gt_data["data"][k].type(torch.int), hat_data.type(torch.int), average='weighted')
            print(f"{str(k)}: mse {mse}, weighted f1 {f1}")


def compute_exp_pred_lik(post_loglik):
    
    ### computes pointwise expected log predictive density at each data point
    sample_mean_exp_n = torch.mean(torch.exp(post_loglik), 0)
    exp_log_lik = torch.exp(torch.mean(torch.log(sample_mean_exp_n), axis=0))
    #exp_log_density[k] = (post_loglik[k].logsumexp(0) - math.log(post_loglik[k].shape[0])).sum().item()
    return exp_log_lik.item()



def evaluate_pred_lik(mcmc, model, data, sites = ["G", "Q"]):

    ### computes predictive likelihood
    params = mcmc.get_samples()
    num_samples = list(params.values())[0].shape[0] 
    sample_plate = pyro.plate("samples", num_samples, dim=-2)
    
    pred_sites = dict()
    
    for site in sites: ## loop through observed sites

        ## infer P(Z | Q, T, params)______
        infer_z_model = poutine.condition(model, params)
        infer_z_model = sample_plate(infer_z_model)
        infer_z_model = infer_discrete(infer_z_model, first_available_dim=-3, temperature=1)
        
        imputed_data = impute_data(data, sites=[site]) ## impute observed site
        impute_trace = poutine.trace(infer_z_model).get_trace(imputed_data)
        Z = impute_trace.nodes["Z"]["value"]
        Z_params = {"Z": Z, **params}

        ## infer P(G | Z, params)______
        infer_site_model = poutine.condition(model, Z_params)
        infer_site_model = sample_plate(infer_site_model)
        cond_model = infer_discrete(infer_site_model, first_available_dim=-3, temperature=1)
        trace = poutine.trace(infer_site_model).get_trace(test_data)
        trace.compute_log_prob()
        
        exp_pred_lik = compute_exp_pred_lik(trace.nodes[site]["log_prob"])
        print(f"{site}: exp_pred_lik {exp_pred_lik}")
            
        pred_sites[site] = impute_trace.nodes[site]["value"]
        
    return pred_sites


gq_pred_sites = evaluate_pred_lik(mcmc, gq_model, test_data, sites = ["G", "Q"])
evaluate_point_predictons(gq_pred_sites, test_data)

G: exp_pred_lik 0.0785723626613617
Q: exp_pred_lik 0.012601005844771862
G: mse 29.3521671295166, weighted f1 0.028663766659509187
Q: mse 729.9135131835938, weighted f1 0.0017907551164431897


### ...now check model with G only

In [24]:
nuts_kernel = NUTS(g_model, init_strategy = mcmc_inits.init_to_value(values = init_params))
mcmc = MCMC(nuts_kernel, num_samples= 50, warmup_steps= 50, num_chains=1)
mcmc.run(train_data)
mcmc.summary(prob=0.75)

Sample: 100%|██| 100/100 [00:18,  5.54it/s, step size=1.48e-01, acc. prob=0.922]


                  mean       std    median     12.5%     87.5%     n_eff     r_hat
lam_G_c[0,0]      0.91      0.05      0.91      0.86      0.96     26.16      0.99
lam_G_c[0,1]      5.52      0.02      5.52      5.51      5.54     47.66      1.02
lam_G_c[0,2]     11.90      0.12     11.87     11.76     12.04     30.41      1.08
lam_G_c[0,3]     14.71      0.37     14.65     14.31     15.07      3.50      1.85
 pi_Z_c[0,0]      0.25      0.01      0.25      0.24      0.26     33.48      0.98
 pi_Z_c[0,1]      0.46      0.01      0.45      0.44      0.46     41.71      0.98
 pi_Z_c[0,2]      0.12      0.02      0.11      0.09      0.14      3.12      1.95
 pi_Z_c[0,3]      0.17      0.02      0.17      0.15      0.20      3.29      1.85
std_G_c[0,0]      1.00      0.04      1.01      0.98      1.07     36.38      0.98
std_G_c[0,1]      0.52      0.01      0.52      0.50      0.53    172.69      0.98
std_G_c[0,2]      1.01      0.11      0.99      0.83      1.08      5.60      1.34
std




In [29]:
g_pred_sites = evaluate_pred_lik(mcmc, g_model, test_data, sites = ["G"])
evaluate_point_predictons(g_pred_sites, test_data)

G: exp_pred_lik 0.08768624067306519
G: mse 25.92630386352539, weighted f1 0.06497787063851514
