In [1]:
import torch
import numpy as np
from scipy.special import expit
from causememaybe.cevae_torch import PModel, QGuide, kl_loss_with_normal

In [2]:
N = 1000
sigma_0 = 3
sigma_1 = 5
Z = np.random.binomial(1, 0.5, size=(N))
X = np.random.normal(loc=Z, scale=(Z*sigma_1**2 + (1-Z)*sigma_0**2))
T = np.random.binomial(1, 0.75*Z + 0.25*(1-Z))
Y = np.random.binomial(1, expit( 3*(Z + 2*(2*T-1))))

def sample_batch(z, x, t, y):
    ind = np.random.randint(0, len(z), size=16)
    return (i[ind, None].astype(np.float32) for i in [z, x, t, y])

pmodel = PModel(1, 1, 1, False)
qguide = QGuide(1, 1, 1, False)

In [10]:
with torch.no_grad():
    z, x, t, y = sample_batch(Z, X, T, Y)
    x, t, y = torch.from_numpy(x), torch.from_numpy(t), torch.from_numpy(y)
    q_z_tyx, q_t_x, q_y_xt = qguide(x, t, y)
    p_x_z, p_t_z, p_y_zt = pmodel(q_z_tyx.rsample(), t)
    
    # Reconstruction loss
    l1 = p_x_z.log_prob(x) # p(x|z)
    l2 = p_t_z.log_prob(t) # p(t|z)
    l3 = p_y_zt.log_prob(y) # p(y|t,z)
    
    # REGULARIZATION LOSS
    # approximate KL
    l4 = kl_loss_with_normal(q_z_tyx)
    
    # AUXILIARY LOSS
    # q(t|x)
    l5 = q_t_x.log_prob(t)
    # q(y|x,t)
    l6 = q_y_xt.log_prob(y)
    
    # Negative sampling
    t_flipped = 1 - t
    q_z_t_flipped_yx, _, _ = qguide(x, t_flipped, y)
    l7 = kl_loss_with_normal(q_z_t_flipped_yx)