# Simulate and estimate a very simple Bayesian causal model

In [55]:
import math
import os
import scipy.stats as stats
import numpy as np
import seaborn as sns
import logging

In [2]:
import torch
import torch.distributions.constraints as constraints
import pyro
from torch import nn
from pyro.nn import PyroModule, PyroSample
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

### Simulate a Bayesian latent variable model with perturbations

In [9]:
# Experiment parameters
n_perturb = 5
n_latent = 5
n_features = 10
n_obs = 1000

# Prior distributions
latent_mean = 0
latent_std = 1

# Define perturbation-latent space effect, should be sparse
beta = stats.norm.rvs(scale=2, size=(n_perturb, n_latent))
beta_mask = np.random.choice([0, 1], size=(n_perturb, n_latent)).astype(bool)
beta[beta_mask] = 0

# Define latent-feature mapping, should be sparse
W = stats.norm.rvs(scale=2, size=(n_latent, n_features))
b = stats.norm.rvs(size=(1, n_features))
W_mask = np.random.choice([0, 1], size=(n_latent, n_features)).astype(bool)
W[W_mask] = 0

In [10]:
# Generate the guide assignment matrix
P = np.zeros((n_obs,n_perturb), dtype=int)
for i in range(n_obs):
    P[i, np.random.choice(n_perturb)] = 1

# Generate the latent space variables
Z = stats.norm.rvs(loc=latent_mean, scale=latent_std, size=(n_obs, n_latent))
Z += P@beta

# Generate means for X's
X_means = Z@W+b
X = stats.norm.rvs(loc=X_means, scale=1.5)

In [11]:
X = torch.tensor(X).cuda()
P = torch.tensor(P).cuda()

In [21]:
device = X.device

### Create a Pyro model for the easy case - treat Z like X's

Basically a multiple multivariate regression

### Create a Pyro model

In [25]:
class CausalLDAE(PyroModule):
    
    def __init__(self, n_perturb, n_latent, n_features):
        
        super().__init__()
        
        # Declare beta as a module
        self.beta = PyroModule[nn.Linear](n_perturb, n_latent)
        self.beta.weight = PyroSample(dist.Normal(tfn(0.), tfn(1.)).expand([n_latent, n_perturb]).to_event(2))
        self.beta.bias = PyroSample(dist.Normal(tfn(0.), tfn(1.)).expand([n_latent]).to_event(1))
        
        # Define encoder
        self.encoder = nn.Linear(n_features, n_latent).to(device)
        
        # Define decoder
        self.decoder = nn.Linear(n_latent, n_features).to(device)

    def forward(self, x, p):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        effects = self.beta(p)
        
        # Sample the mask probabilities
        prior_a = torch.ones((n_perturb, n_latent), device=X.device)
        prior_b = torch.ones((n_perturb, n_latent), device=X.device)*n_perturb
        
        mask_prob = pyro.sample('mask_prob', dist.Beta(prior_a, prior_b))
        
        # Sample the mask
        mask = pyro.sample('mask', dist.Bernoulli(mask_prob))
        
        masked_effects = self.beta
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

In [26]:
CausalLDAE(n_perturb, n_latent, n_features)

CausalLDAE(
  (beta): PyroLinear(in_features=5, out_features=5, bias=True)
  (encoder): Encoder(
    (fc1): Linear(in_features=784, out_features=2, bias=True)
    (fc21): Linear(in_features=2, out_features=5, bias=True)
    (fc22): Linear(in_features=2, out_features=5, bias=True)
    (softplus): Softplus(beta=1, threshold=20)
  )
  (decoder): Linear(in_features=5, out_features=10, bias=True)
)

### SUPER simple SPARSE bayesian regression

In [163]:
beta = 1
n= 1000
labels = np.random.choice(2, size=n)
means = beta*labels
data = stats.norm.rvs(means, 1)

labels = torch.tensor(labels).cuda()
data = torch.tensor(data).cuda()
device = labels.device

def tfn(value, device=device):
    
    return torch.tensor(value, device=device)

In [164]:
def model(data, label):
    
    # Effect size 
    prior_es_mean = tfn(0)
    prior_es_variance = tfn(10)
    effect_size = pyro.sample('es', dist.Normal(prior_es_mean, prior_es_variance))
    
    # Mask probability
    
    
    means = label*effect_size
    
    with pyro.plate("data", len(data)):
        pyro.sample('x', dist.Normal(means, tfn(1.)), obs=data)
    
def guide(x, label):
    
    mean_q = pyro.param('mean_q', tfn(0.))
    var_q = pyro.param('var_q', tfn(1.), constraint=constraints.positive)
    
    pyro.sample('es', dist.Normal(mean_q, var_q))

In [165]:
adam_params = {"lr": 0.001, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

In [166]:
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())


In [168]:
stats.linregress(labels.to('cpu'), data.to('cpu'))

LinregressResult(slope=1.1017099392393317, intercept=-0.06815267501502076, rvalue=0.48943287007876385, pvalue=2.3282197303676225e-61, stderr=0.062136352612193496, intercept_stderr=0.04371679913252673)

In [162]:
stats.linregress(labels.to('cpu'), data.to('cpu'))

LinregressResult(slope=0.955888673984559, intercept=-0.03551017655501176, rvalue=0.3862891932709075, pvalue=7.200627997075127e-05, stderr=0.23056348270902785, intercept_stderr=0.1563757621791224)

In [150]:
print(n, pyro.param('mean_q').item(), pyro.param('var_q').item())

100 2.9425899982452393 0.4835539758205414


In [129]:
print(n, pyro.param('mean_q').item(), pyro.param('var_q').item())

10000 2.9410760402679443 0.5252532362937927


In [124]:
# print(n, pyro.param('mean_q').item(), pyro.param('var_q').item())

1000 2.9311881065368652 0.5552800297737122


In [116]:
print('1000', pyro.param('mean_q').item(), pyro.param('var_q').item())

1000 2.922926902770996 0.5885440111160278


In [None]:
 def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

In [None]:
def model(X):
    
    _beta = pyro.sample("beta", dist.Normal(0, 1))
    
    
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

In [159]:
def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

In [87]:
P@beta

ValueError: matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)

In [81]:
Z

array([[-1.91002742, -0.92186503, -1.08265123,  0.10128784,  1.13183064],
       [-0.27033453,  0.09442541, -0.21412823, -1.0213405 , -0.04262896],
       [ 0.25084065,  0.22240159,  0.17268779,  0.25823634, -0.04667609],
       ...,
       [ 0.72570834,  0.65298513, -0.28390392, -1.10068273, -0.46834919],
       [ 0.21971457,  1.48248596, -0.59271296,  1.06924575,  0.57892723],
       [-0.24633521, -1.06654928, -0.4405091 ,  0.97873046,  0.77464404]])

In [76]:
P

array([[1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [0, 0, 0, 1, 0],
       [0, 0, 1, 0, 0]])

In [69]:
Z.shape

(1000,)

In [49]:
beta

array([[ 0.        , -1.18208369,  0.        ,  0.        ,  0.        ],
       [ 0.80622373, -0.25971326,  1.73093022,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        , -0.61205513,  2.68592717,  0.24531006, -2.3916667 ],
       [ 0.        ,  1.86658371,  0.        ,  0.02836729,  0.        ]])

In [20]:
# def model(data):

#     # sample f from the Beta prior
#     f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    
#     # loop over the observed data
#     for i in range(len(data)):
#         # observe datapoint i using the Bernoulli
#         # likelihood Bernoulli(f)
#         pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
        
def model(data):
    
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
    for i in pyro.plate("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

In [21]:
def guide(data):
    
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

In [22]:
# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

assert pyro.__version__.startswith('1.8.5')

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

In [23]:
# set up the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 5000
# do gradient steps
for step in range(n_steps):
    svi.step(data)

In [24]:
# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)

# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nBased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


Based on the data and our prior belief, the fairness of the coin is 0.531 +- 0.090
