In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pyjags


def main():
    # Parameters of the bivariate normal
    mu1, mu2 = 0.0, 0.0
    sigma1, sigma2 = 1.0, 1.0
    # Choose a correlation and convert to covariance
    r_corr = 0.9  # correlation in (-1, 1)
    rho = r_corr * sigma1 * sigma2  # covariance

    mu = np.array([mu1, mu2], dtype=float)
    Sigma = np.array([[sigma1**2, rho], [rho, sigma2**2]], dtype=float)

    # Sanity checks
    eigvals = np.linalg.eigvals(Sigma)
    if not np.all(eigvals > 0):
        raise ValueError(
            "Covariance matrix is not positive definite. Adjust parameters."
        )
    Tau = np.linalg.inv(Sigma)  # JAGS uses precision

    # JAGS model: draw x ~ N(mu, Sigma)
    model_code = """
    model {
      x[1:2] ~ dmnorm(mu[1:2], Tau[1:2,1:2])
    }
    """

    data = {
        "mu": mu.tolist(),
        "Tau": Tau.tolist(),
    }

    # Build and run the model
    chains = 4
    adapt_steps = 1000
    burnin_steps = 1000
    draws = 5000

    model = pyjags.Model(model_code, data=data, chains=chains, adapt=adapt_steps)
    model.update(burnin_steps)

    # Sample from the prior (the specified joint)
    samples = model.sample(draws, vars=["x"])  # shape: (draws, chains, 2)
    x_all = samples["x"].reshape(-1, 2)  # combine chains

    # Print empirical summaries
    print("Target mean:", mu)
    print("Target covariance:\n", Sigma)
    print("Empirical mean:", x_all.mean(axis=0))
    print("Empirical covariance:\n", np.cov(x_all.T))

    # Scatter of all draws
    plt.figure(figsize=(5, 5))
    plt.scatter(x_all[:, 0], x_all[:, 1], s=6, alpha=0.35)
    plt.axhline(0, color="gray", lw=0.5)
    plt.axvline(0, color="gray", lw=0.5)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.title("PyJAGS: Samples from bivariate normal")
    plt.xlabel("X1")
    plt.ylabel("X2")

    # Path plot for a single chain
    chain0 = samples["x"][:, 0, :]  # shape (draws, 2)
    plt.figure(figsize=(5, 5))
    plt.plot(chain0[:, 0], chain0[:, 1], "-o", ms=2, lw=0.5, alpha=0.7)
    plt.axhline(0, color="gray", lw=0.5)
    plt.axvline(0, color="gray", lw=0.5)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.title("Single-chain path")
    plt.xlabel("X1")
    plt.ylabel("X2")

    plt.show()

In [None]:
if __name__ == "__main__":
    main()