In [19]:
import torch

import pyro
import pyro.distributions as dist
from pyro.contrib.cevae import CEVAE
from torch.utils.data import DataLoader
from pyro import poutine


In [2]:
def generate_data(args):
    """
    This implements the generative process of [1], but using larger feature and
    latent spaces ([1] assumes ``feature_dim=1`` and ``latent_dim=5``).
    """
    z = dist.Bernoulli(0.5).sample([args["num_data"]])
    x = dist.Normal(z, 5 * z + 3 * (1 - z)).sample([args["feature_dim"]]).t()
    t = dist.Bernoulli(0.75 * z + 0.25 * (1 - z)).sample()
    y = dist.Bernoulli(logits=3 * (z + 2 * (2 * t - 2))).sample()

    # Compute true ite for evaluation (via Monte Carlo approximation).
    t0_t1 = torch.tensor([[0.], [1.]])
    y_t0, y_t1 = dist.Bernoulli(logits=3 * (z + 2 * (2 * t0_t1 - 2))).mean
    true_ite = y_t1 - y_t0
    return x, t, y, z, true_ite

In [3]:
args = {
    "num_data":1000,
    "feature_dim":5,
    "latent_dim":20,
    "hidden_dim":200,
    "num_layers":3,
    "num_epochs":50,
    "batch_size":100,
    "learning_rate":1e-3,
    "learning_rate_decay":0.1,
    "weight_decay":1e-4,
    "seed":1234567890}

In [4]:
pyro.set_rng_seed(args['seed'])
x_train, t_train, y_train, z_train, train_ite = generate_data(args)

In [5]:
pyro.clear_param_store()
cevae = CEVAE(feature_dim=args["feature_dim"],
              latent_dim=args["latent_dim"],
              hidden_dim=args["hidden_dim"],
              num_layers=args["num_layers"],
              num_samples=10)

In [6]:
loss = cevae.fit(x_train, t_train, y_train,
              num_epochs=args["num_epochs"],
              batch_size=args["batch_size"],
              learning_rate=args["learning_rate"],
              learning_rate_decay=args["learning_rate_decay"],
              weight_decay=args["weight_decay"])

INFO 	 Training with 10 minibatches per epoch


In [10]:
x_test, t_test, y_test, z_test, true_ite = generate_data(args)
true_ate = true_ite.mean()
print("true ATE = {:0.3g}".format(true_ate.item()))
naive_ate = y_test[t_test == 1].mean() - y_test[t_test == 0].mean()
print("naive ATE = {:0.3g}".format(naive_ate))

true ATE = 0.729
naive ATE = 0.815


In [17]:
batch_size = args["batch_size"]
num_samples = 2
dataloader = [x_test] if batch_size is None else DataLoader(x_test, batch_size=batch_size)

In [40]:
with torch.no_grad():
    for x in dataloader:
                x = cevae.whiten(x)
                with pyro.plate("num_particles", num_samples, dim=-2):
                    with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                        cevae.guide(x)
                    with poutine.do(data=dict(t=torch.ones(()))):
                        z1 = poutine.replay(cevae.guide.z_dist, tr.trace)(x, torch.ones(()))

TypeError: z_dist() missing 1 required positional argument: 'x'