In [1]:
import numpy as np
import torch
import pyro
import pyro.optim
from pyro.infer import *
from torch.distributions import constraints
from pyro import distributions as dst

In [2]:
from __future__ import absolute_import, division, print_function

import numbers
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict

from six import add_metaclass

import pyro.poutine as poutine
from pyro.distributions import Categorical, Empirical
from pyro.ops.stats import waic
from pyro.infer.util import site_is_subsample

In [3]:
def factorAnalysis(X):
    N, D = X.shape
    K = 2
    locloc = 0.
    locscale = 1.
    scaleloc = 0.
    scalescale = 1.
    cov_factor_loc = torch.zeros(K,D)
    cov_factor_scale = torch.ones(K,D)*10
    with pyro.plate('D', D):
        loc = pyro.sample('loc', dst.Normal(locloc, locscale))
        cov_diag = pyro.sample('scale', dst.LogNormal(scaleloc, scalescale))
        with pyro.plate('K', K):
            cov_factor = pyro.sample('cov_factor', dst.Normal(cov_factor_loc,cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
    with pyro.plate('N', N):
        X = pyro.sample('obs', dst.LowRankMultivariateNormal(loc, cov_factor=cov_factor, cov_diag=cov_diag))
    return X

In [4]:
N = 800
D = 5
data = factorAnalysis(np.ones((N,D)))

In [5]:
def guide(X):
    K, locloc, locscale, scaleloc, scalescale, cov_factor_loc, cov_factor_scale = 2,torch.zeros(D),torch.ones(D),torch.zeros(D),torch.ones(D),torch.zeros(1,D),torch.ones(1,D)*10
    with pyro.plate('D', D, dim=-1):
        loc_loc = pyro.param('loc_loc', locloc)
        loc_scale = pyro.param('loc_scale', locscale, constraint=constraints.positive)
        cov_diag_loc = pyro.param('scale_loc', scaleloc)
        cov_diag_scale = pyro.param('scale_scale', scalescale, constraint=constraints.positive)
        # sample variables
        loc = pyro.sample('loc', dst.Normal(loc_loc,loc_scale))
        with pyro.plate('K', K, dim=-2):
            cov_factor_loc = pyro.param('cov_factor_loc_{}'.format(K), cov_factor_loc)
            cov_factor_scale = pyro.param('cov_factor_scale_{}'.format(K), cov_factor_scale, constraint=constraints.positive)
            cov_factor = pyro.sample('cov_factor', dst.Normal(cov_factor_loc, cov_factor_scale))
        cov_factor = cov_factor.transpose(0,1)
        cov_diag = pyro.sample('scale', dst.LogNormal(cov_diag_loc, cov_diag_scale))
    return loc, cov_factor, cov_diag

In [6]:
def per_param_callable(module_name, param_name):
    return {"lr": 0.01, 'betas': [0.9, 0.99]}

In [7]:
conditioned_model = pyro.condition(factorAnalysis, data = {'obs': data})
optim = pyro.optim.Adam(per_param_callable)
elbo = Trace_ELBO()
svi = SVI(conditioned_model, guide, optim, loss=elbo,num_steps=1000, num_samples=100).run(data)

In [8]:
trace_pred = TracePredictive(factorAnalysis, svi, num_samples=1000).run(data)

AssertionError: 

In [9]:
class TracePredictive(TracePosterior):
    """
    Generates and holds traces from the posterior predictive distribution,
    given model execution traces from the approximate posterior. This is
    achieved by constraining latent sites to randomly sampled parameter
    values from the model execution traces and running the model forward
    to generate traces with new response ("_RETURN") sites.
    :param model: arbitrary Python callable containing Pyro primitives.
    :param TracePosterior posterior: trace posterior instance holding samples from the model's approximate posterior.
    :param int num_samples: number of samples to generate.
    :param keep_sites: The sites which should be sampled from posterior distribution (default: all)
    """
    def __init__(self, model, posterior, num_samples, keep_sites=None):
        self.model = model
        self.posterior = posterior
        self.num_samples = num_samples
        self.keep_sites = keep_sites
        super(TracePredictive, self).__init__()

    def _traces(self, *args, **kwargs):
        if not self.posterior.exec_traces:
            self.posterior.run(*args, **kwargs)
        data_trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        for _ in range(self.num_samples):
            model_trace = self.posterior().copy()
            self._remove_dropped_nodes(model_trace)
            self._adjust_to_data(model_trace, data_trace)
            resampled_trace = poutine.trace(poutine.replay(self.model, model_trace)).get_trace(*args, **kwargs)
            yield (resampled_trace, 0., 0)

    def _remove_dropped_nodes(self, trace):
        if self.keep_sites is None:
            return
        for name, site in list(trace.nodes.items()):
            if name not in self.keep_sites:
                trace.remove_node(name)
                continue

    def _adjust_to_data(self, trace, data_trace):
        subsampled_idxs = dict()
        for name, site in trace.iter_stochastic_nodes():
            # Adjust subsample sites
            if site_is_subsample(site):
                site["fn"] = data_trace.nodes[name]["fn"]
                site["value"] = data_trace.nodes[name]["value"]
            # Adjust sites under conditionally independent stacks
            orig_cis_stack = site["cond_indep_stack"]
            site["cond_indep_stack"] = data_trace.nodes[name]["cond_indep_stack"]
            assert len(orig_cis_stack) == len(site["cond_indep_stack"])
            site["fn"] = data_trace.nodes[name]["fn"]
            for ocis, cis in zip(orig_cis_stack, site["cond_indep_stack"]):
                # Select random sub-indices to replay values under conditionally independent stacks.
                # Otherwise, we assume there is an dependence of indexes between training data
                # and prediction data.
                assert ocis.name == cis.name
                if site_is_subsample(site):
                    batch_dim = cis.dim
                    subsampled_idxs[cis.name] = torch.randint(0, site['value'].size(batch_dim), (cis.size,),device=site["value"].device)
                else:    
                    batch_dim = cis.dim - site["fn"].event_dim
                    subsampled_idxs[cis.name] = subsampled_idxs.get(cis.name,
                                                                torch.randint(0, ocis.size, (cis.size,),
                                                                              device=site["value"].device))
                    site["value"] = site["value"].index_select(batch_dim, subsampled_idxs[cis.name])

In [10]:
trace_pred = TracePredictive(factorAnalysis, svi, num_samples=1000).run(data)

In [11]:
trace_pred.marginal().support()['_RETURN'].shape

torch.Size([1000, 800, 5])