In [25]:
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
import math

In [None]:
p = 0.5

muX_0 = torch.ones(2)
muZ_0 = torch.ones(2)
mu_0 = torch.cat([muX_0, muZ_0])
print(mu_0)

muX_1 = -torch.ones(2)
muZ_1 = -torch.ones(2)
mu_1 = torch.cat([muX_1, muZ_1])
print(mu_1)

cov = 0.0
CXX_0 = torch.eye(2)
CZZ_0 = torch.eye(2)
CZX_0 = cov * torch.eye(2)
CXZ_0 = CZX_0.T
C_0 = torch.cat(
    [
        torch.cat([CXX_0, CXZ_0], dim=1),
        torch.cat([CZX_0, CZZ_0], dim=1)
    ], dim=0
)
print(C_0)


CXX_1 = torch.eye(2)
CZZ_1 = torch.eye(2)
CZX_1 = cov * torch.eye(2)
CXZ_1 = CZX_1.T
C_1 = torch.cat(
    [
        torch.cat([CXX_1, CXZ_1], dim=1),
        torch.cat([CZX_1, CZZ_1], dim=1)
    ], dim=0
)
print(C_1)

mvn_0 = MultivariateNormal(loc=mu_0, covariance_matrix=C_0)
mvn_1 = MultivariateNormal(loc=mu_1, covariance_matrix=C_1)

mvnX_0 = MultivariateNormal(loc=muX_0, covariance_matrix=CXX_0)
mvnX_1 = MultivariateNormal(loc=muX_1, covariance_matrix=CXX_1)

mvnZ_0 = MultivariateNormal(loc=muZ_0, covariance_matrix=CZZ_0)
mvnZ_1 = MultivariateNormal(loc=muZ_1, covariance_matrix=CZZ_1)



tensor([1., 1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])


In [50]:
def sample_Z_given(x, p, n_samples=10, seed=123):
    dist_0 = MultivariateNormal(loc=muZ_0 + CZX_0 @ torch.linalg.solve(CXX_0, x - muX_0), covariance_matrix=CZZ_0 - CZX_0 @ torch.linalg.solve(CXX_0, CXZ_0))
    dist_1 = MultivariateNormal(loc=muZ_1 + CZX_1 @ torch.linalg.solve(CXX_1, x - muX_1), covariance_matrix=CZZ_1 - CZX_1 @ torch.linalg.solve(CXX_1, CXZ_1))
    n1 = int(n_samples * p)
    n0 = n_samples - n1
    torch.manual_seed(seed)
    return torch.cat([dist_0.rsample((n0,)), dist_1.rsample((n1,))], dim=0)

def compute_pxz(x, z, p):
    v = torch.cat([x, z], axis=1)
    lp_1 = math.log(p)
    lp_0 = math.log(1 - p)
    numer = mvn_1.log_prob(v) + lp_1
    denom = (mvn_1.log_prob(v) + lp_1) * (mvn_0.log_prob(v) + lp_0)
    return torch.exp(numer - denom)

def compute_pz(z, p):
    lp_1 = math.log(p)
    lp_0 = math.log(1 - p)
    numer = mvnZ_1.log_prob(z) + lp_1
    denom = (mvnZ_1.log_prob(z) + lp_1) * (mvnZ_0.log_prob(z) + lp_0)
    return torch.exp(numer - denom)

def compute_px(x, p):
    lp_1 = math.log(p)
    lp_0 = math.log(1 - p)
    numer = mvnX_1.log_prob(x) + lp_1
    denom = (mvnX_1.log_prob(x) + lp_1) * (mvnX_0.log_prob(x) + lp_0)
    return torch.exp(numer - denom)

def compute_cmi(n_samples, p, tol=1e-10, seed=123):

    # compute test data
    torch.manual_seed(seed)
    n1 = int(n_samples * p)
    n0 = n_samples - n1
    v = torch.cat([mvn_0.rsample((n0,)), mvn_1.rsample((n1,))], dim=0)
    x = v[:, :len(muX_0)]
    z = v[:, len(muX_0):]
    v_ = torch.cat([mvn_0.rsample((n0,)), mvn_1.rsample((n1,))], dim=0)
    z_ =  v_[:, len(muX_0):]

    pxz = torch.clamp(compute_pxz(x, z, p), min=tol, max=1.0)
    pz = torch.clamp(compute_pz(z_, p), min=tol, max=1.0)

    # possible numerical issue from exp then log.
    H_XZ = (-(pxz) * torch.log(pxz) - (1 - pxz) * torch.log(1 - pxz)).mean()
    H_Z = (-(pz) * torch.log(pz) - (1 - pz) * torch.log(1 - pz)).mean()

    return H_Z - H_XZ

In [45]:
n_samples = 10
n1 = int(n_samples * p)
n0 = n_samples - n1
v = torch.cat([mvn_0.rsample((n0,)), mvn_1.rsample((n1,))], dim=0)
x = v[:, :len(muX_0)]
z = v[:, len(muX_0):]

In [46]:
n_samples = 100000
for i in range(5):
    print(compute_cmi(n_samples, p))

tensor(1.5365e-07)
tensor(1.5365e-07)
tensor(1.5365e-07)
tensor(1.5365e-07)
tensor(1.5365e-07)
