# PML Final Project
## Demo of CEAVE package (Pyro)

In [1]:
import pyro
from pyro.contrib.cevae import CEVAE
import torch
import numpy as np

In [2]:
# TODO: optimize the package source code to make it cuda compatible

# # Step 0: Determine the device
# if torch.cuda.is_available():
#     device = torch.device("cuda")
#     print("CUDA is available! Training on GPU.")
# else:
#     device = torch.device("cpu")
#     print("CUDA not available. Training on CPU.")

In [3]:
def generate_synthetic_cevae_data(n=1000, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Latent variable z ~ N(0, 1)
    z = torch.randn(n, 1)

    # 3 continuous proxies and 2 binary ones
    x_cont = z + 0.1 * torch.randn(n, 3)
    x_binary = torch.bernoulli(torch.sigmoid(z.repeat(1, 2)))

    # Treatment t ~ Bernoulli(sigmoid(linear(z)))
    logits_t = 2.0 * z  # strong dependence on z
    t = torch.bernoulli(torch.sigmoid(logits_t)).unsqueeze(1)  # shape: (n, 1)
    t = t.float().flatten()

    # y ~ Normal(mu, 1) where mu depends on t and z
    y0 = (z - 3).squeeze() + 0.1 * torch.randn(n)
    y1 = (z + 3).squeeze() + 0.1 * torch.randn(n)
    y = torch.where(t == 1, y1, y0)

    # ATE = E[Y(1) - Y(0)]
    true_ate = (y1 - y0).mean().item()

    return {
        'x_cont': x_cont,
        'x_binary': x_binary,
        't': t,
        'y': y,
        'z': z,
        'true_ate': true_ate
    }


data = generate_synthetic_cevae_data()
print(f"True ATE: {data['true_ate']:.4f}")


True ATE: 5.9983


In [4]:
# Generate synthetic data
data = generate_synthetic_cevae_data()
x_cont = data['x_cont']
x_binary = data['x_binary']
t = data['t']
y = data['y']

x = torch.cat([x_binary, x_cont], dim=1)  # shape (n, 5)

In [5]:
cevae = CEVAE(
    feature_dim=x.shape[1],     # total number of covariates
    latent_dim=1,               # dimension of latent z
    outcome_dist='normal'       
)

In [6]:
# Fit the model
cevae.fit(x, t, y, num_epochs=500);

INFO 	 Training with 10 minibatches per epoch


**NOTE:** with epoch=100 wasn't working well.

In [7]:
# Estimate ITEs and ATE
ite = cevae.ite(x, num_samples=100)
ate = ite.mean().item()

# Compare
print(f"Estimated ATE: {ate:.4f}")
print(f"True ATE: {data['true_ate']:.4f}")


INFO 	 Evaluating 1 minibatches


Estimated ATE: 5.9980
True ATE: 5.9983
