In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 12 2024

@author: ChatGPT
Editted by Yaning 
"""

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [2]:
# Generate synthetic data
torch.manual_seed(42)
num_groups = 3
num_samples_per_group = 100
true_slope = torch.tensor([2.0, 3.0, 4.0])
true_intercept = torch.tensor([-1.0, 0.0, 1.0])
x = torch.randn(num_groups, num_samples_per_group)
y = true_slope.unsqueeze(1) * x + true_intercept.unsqueeze(1) + torch.randn(num_groups, num_samples_per_group)

In [26]:
x

tensor([[ 1.9269e+00,  1.4873e+00,  9.0072e-01, -2.1055e+00,  6.7842e-01,
         -1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01,  1.6487e+00,
         -3.9248e-01, -1.4036e+00, -7.2788e-01, -5.5943e-01, -7.6884e-01,
          7.6245e-01,  1.6423e+00, -1.5960e-01, -4.9740e-01,  4.3959e-01,
         -7.5813e-01,  1.0783e+00,  8.0080e-01,  1.6806e+00,  1.2791e+00,
          1.2964e+00,  6.1047e-01,  1.3347e+00, -2.3162e-01,  4.1759e-02,
         -2.5158e-01,  8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01,
          1.7174e+00,  3.1888e-01, -4.2452e-01,  3.0572e-01, -7.7459e-01,
         -1.5576e+00,  9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00,
          2.1228e+00, -1.2347e+00, -4.8791e-01, -9.1382e-01, -6.5814e-01,
          7.8024e-02,  5.2581e-01, -4.8799e-01,  1.1914e+00, -8.1401e-01,
         -7.3599e-01, -1.4032e+00,  3.6004e-02, -6.3477e-02,  6.7561e-01,
         -9.7807e-02,  1.8446e+00, -1.1845e+00,  1.3835e+00,  1.4451e+00,
          8.5641e-01,  2.2181e+00,  5.

In [20]:
y.view(-1).shape

torch.Size([300])

In [21]:
# Define the hierarchical Bayesian model
def model(x, y):
    # Priors for group-level parameters
    slope_loc = pyro.param("slope_loc", torch.randn(num_groups))
    slope_scale = pyro.param("slope_scale", torch.ones(num_groups), constraint=dist.constraints.positive)
    intercept_loc = pyro.param("intercept_loc", torch.randn(num_groups))
    intercept_scale = pyro.param("intercept_scale", torch.ones(num_groups), constraint=dist.constraints.positive)

    slope_prior = dist.Normal(slope_loc, slope_scale)
    intercept_prior = dist.Normal(intercept_loc, intercept_scale)
    
    # Priors for observation-level parameters
    with pyro.plate("group", num_groups):
        slope = pyro.sample("slope", slope_prior)
        intercept = pyro.sample("intercept", intercept_prior)
    
    # Likelihood
    with pyro.plate("data", num_groups*num_samples_per_group):
        group_indices = torch.arange(num_groups).unsqueeze(1).repeat(1, num_samples_per_group).reshape(-1)
        y_hat = slope[group_indices] * x.view(-1) + intercept[group_indices]
        pyro.sample("obs", dist.Normal(y_hat, 1.0), obs=y.view(-1))

# Define the guide (variational distribution)
def guide(x, y):
    # Variational parameters for group-level parameters
    slope_loc = pyro.param("slope_loc", torch.randn(num_groups))
    slope_scale = pyro.param("slope_scale", torch.ones(num_groups), constraint=dist.constraints.positive)
    intercept_loc = pyro.param("intercept_loc", torch.randn(num_groups))
    intercept_scale = pyro.param("intercept_scale", torch.ones(num_groups), constraint=dist.constraints.positive)
    
    # Sample group-level parameters
    with pyro.plate("group", num_groups):
        slope = pyro.sample("slope", dist.Normal(slope_loc, slope_scale))
        intercept = pyro.sample("intercept", dist.Normal(intercept_loc, intercept_scale))


In [23]:
# Perform stochastic variational inference
pyro.clear_param_store()
svi = SVI(model, guide, Adam({"lr": 0.03}), loss=Trace_ELBO())
num_iterations = 1000
for i in range(num_iterations):
    loss = svi.step(x, y)
    if i % 100 == 0:
        print("Iteration {}: Loss = {}".format(i, loss))

Iteration 0: Loss = 2936.710205078125
Iteration 100: Loss = 1323.69091796875
Iteration 200: Loss = 575.1324462890625
Iteration 300: Loss = 458.24786376953125
Iteration 400: Loss = 434.82879638671875
Iteration 500: Loss = 456.35809326171875
Iteration 600: Loss = 428.23858642578125
Iteration 700: Loss = 426.2340087890625
Iteration 800: Loss = 420.5310974121094
Iteration 900: Loss = 422.7549743652344


In [None]:
# Get posterior samples for parameters
def get_posterior_samples(model, guide, x, y, num_samples=1000):
    posterior_samples = []
    for _ in range(num_samples):
        posterior_samples.append({k: v.item() for k, v in guide(x, y).items()})
    return posterior_samples

# posterior_samples = get_posterior_samples(model, guide, x, y)

In [27]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

slope_loc tensor([2.1813, 3.1389, 4.0655], requires_grad=True)
slope_scale tensor([0.0941, 0.1090, 0.1727], grad_fn=<AddBackward0>)
intercept_loc tensor([-0.7800, -0.0292,  1.0803], requires_grad=True)
intercept_scale tensor([0.0675, 0.0689, 0.0799], grad_fn=<AddBackward0>)


In [25]:
print(true_slope)
print(true_intercept)

tensor([2., 3., 4.])
tensor([-1.,  0.,  1.])
