# PML Final Project
## Demo of CEAVE package (Pyro)

In [15]:
import pyro
from pyro.contrib.cevae import CEVAE
import torch
import numpy as np

In [None]:
# TODO: optimize the package source code to make it cuda compatible

# # Step 0: Determine the device
# if torch.cuda.is_available():
#     device = torch.device("cuda")
#     print("CUDA is available! Training on GPU.")
# else:
#     device = torch.device("cpu")
#     print("CUDA not available. Training on CPU.")

In [None]:
def generate_synthetic_cevae_data(n=1000, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Latent variable z ~ N(0, 1)
    z = torch.randn(n, 1)

    # 3 continuous proxies and 2 binary ones
    x_cont = z + 0.1 * torch.randn(n, 3)
    x_binary = torch.bernoulli(torch.sigmoid(z.repeat(1, 2)))

    # Treatment t ~ Bernoulli(sigmoid(linear(z)))
    logits_t = 2.0 * z  # strong dependence on z
    t = torch.bernoulli(torch.sigmoid(logits_t)).unsqueeze(1)  # shape: (n, 1)
    t = t.float().flatten()

    # y ~ Normal(mu, 1) where mu depends on t and z
    y0 = (z - 3).squeeze() + 0.1 * torch.randn(n)
    y1 = (z + 3).squeeze() + 0.1 * torch.randn(n)
    y = torch.where(t == 1, y1, y0)

    # ATE = E[Y(1) - Y(0)]
    true_ate = (y1 - y0).mean().item()

    return {
        'x_cont': x_cont,
        'x_binary': x_binary,
        't': t,
        'y': y,
        'z': z,
        'true_ate': true_ate
    }


data = generate_synthetic_cevae_data()
print(f"True ATE: {data['true_ate']:.4f}")


True ATE: 5.9983


In [None]:
# Generate synthetic data
data = generate_synthetic_cevae_data()
x_cont = data['x_cont']
x_binary = data['x_binary']
t = data['t']
y = data['y']

x = torch.cat([x_binary, x_cont], dim=1)  # shape (n, 5)

In [None]:
cevae = CEVAE(
    feature_dim=x.shape[1],     # total number of covariates
    latent_dim=1,               # dimension of latent z
    outcome_dist='normal'       
)

In [None]:
# Fit the model
cevae.fit(x, t, y, num_epochs=500)

INFO 	 Training with 10 minibatches per epoch


[41.34023767089844,
 26.681786071777342,
 19.796100524902343,
 16.924492919921875,
 14.908769836425781,
 14.88294580078125,
 14.006267456054687,
 13.336826904296874,
 12.695631713867188,
 12.46577099609375,
 12.077224639892577,
 11.809527038574219,
 11.409260894775391,
 11.524563201904297,
 10.86501254272461,
 11.223307678222657,
 10.620431674957276,
 10.257586086273193,
 9.655898635864258,
 10.85896955871582,
 10.542980491638184,
 10.75814488220215,
 8.25803970336914,
 9.449584800720215,
 8.911754028320313,
 8.283021240234374,
 8.792870666503907,
 8.264956176757812,
 7.731203002929687,
 7.662085083007812,
 8.19069955444336,
 7.335792583465576,
 6.899392360687256,
 6.611456052780151,
 6.642475860595703,
 6.012014633178711,
 6.166726829528809,
 5.455362543106079,
 5.57081608581543,
 6.021696166992188,
 7.309531005859375,
 7.561002548217774,
 7.054480655670166,
 6.665626052856445,
 6.369782730102539,
 6.065120994567871,
 7.585767425537109,
 7.947341079711914,
 5.671857955932617,
 5.78393

**NOTE:** with epoch=100 wasn't working well.

In [14]:
# Estimate ITEs and ATE
ite = cevae.ite(x, num_samples=100)
ate = ite.mean().item()

# Compare
print(f"Estimated ATE: {ate:.4f}")
print(f"True ATE: {data['true_ate']:.4f}")


INFO 	 Evaluating 1 minibatches


Estimated ATE: 5.9980
True ATE: 5.9983
