In [5]:
import sys
!{sys.executable} -m pip install pymc





[notice] A new release of pip is available: 24.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import pymc as pm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Simulated data
np.random.seed(42)
n = 100
intervention = 70

# Simulate control and treated series
time = np.arange(n)
x = np.sin(time / 10) + np.random.normal(0, 0.3, size=n)
y = 2 * x + np.random.normal(0, 0.2, size=n)
y[intervention:] += 1  # Simulate causal effect

pre_period = slice(0, intervention)
post_period = slice(intervention, n)

with pm.Model() as model:
    # Local linear trend
    level = pm.GaussianRandomWalk("level", sigma=0.1, shape=n)
    
    # Regression coefficient
    beta = pm.Normal("beta", mu=0, sigma=1)
    
    # Observation noise
    sigma_obs = pm.Exponential("sigma_obs", 1.0)

    # Expected value
    mu = level + beta * x

    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu[pre_period], sigma=sigma_obs, observed=y[pre_period])

    # Posterior predictive for counterfactual
    y_pred = pm.Normal("y_pred", mu=mu[post_period], sigma=sigma_obs, shape=mu[post_period].shape)

    trace = pm.sample(1000, tune=1000, target_accept=0.95)

# Posterior predictive checks
import arviz as az
az.plot_posterior(trace, var_names=["beta", "sigma_obs"])
plt.show()

# Counterfactual
y_post_pred = trace.posterior["y_pred"].stack(draws=("chain", "draw")).values
y_post_mean = y_post_pred.mean(axis=1)
y_post_hpd = az.hdi(y_post_pred.T, hdi_prob=0.95)

# Plot
plt.figure(figsize=(10, 6))
plt.plot(time, y, label="Observed")
plt.plot(time[post_period], y_post_mean, label="Counterfactual", color='orange')
plt.fill_between(time[post_period], y_post_hpd[:, 0], y_post_hpd[:, 1], color='orange', alpha=0.3, label='95% CI')
plt.axvline(intervention, color='gray', linestyle='--')
plt.legend()
plt.title("Causal Impact via Bayesian Structural Time-Series (PyMC)")
plt.xlabel("Time")
plt.ylabel("y")
plt.show()


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [level, beta, sigma_obs, y_pred]


In [None]:
import scipy.stats as stats

# Gibbs sampler config
n_iter = 1000
burn_in = 200
n_total = len(y[pre_period])
X = x[pre_period]
Y = y[pre_period]

# Initialize
mu_samples = np.zeros((n_iter, n_total))
beta_samples = np.zeros(n_iter)
sigma2_samples = np.zeros(n_iter)

# Hyperparameters
sigma2_mu = 1.0    # prior variance for local level innovations
a0 = 2.0           # Inverse-Gamma prior for observation variance
b0 = 1.0
beta_prior_mean = 0
beta_prior_var = 10.0

# Initial values
mu = np.zeros(n_total)
beta = 0.0
sigma2 = 1.0

for i in range(n_iter):
    # 1. Sample mu_t via FFBS (Forward Filtering Backward Sampling)
    # Observation model: y_t = mu_t + beta * x_t + epsilon_t
    # State model: mu_t = mu_{t-1} + eta_t
    
    # Kalman filtering (forward pass)
    mu_filt = np.zeros(n_total)
    P = np.zeros(n_total)
    mu_pred = 0
    P_pred = 1.0
    
    for t in range(n_total):
        # Predict step
        mu_pred = mu_pred
        P_pred = P_pred + sigma2_mu

        # Update step
        y_tilde = Y[t] - beta * X[t]
        K = P_pred / (P_pred + sigma2)
        mu_filt[t] = mu_pred + K * (y_tilde - mu_pred)
        P[t] = (1 - K) * P_pred

        mu_pred = mu_filt[t]
        P_pred = P[t]

    # Backward sampling
    mu[n_total-1] = np.random.normal(mu_filt[n_total-1], np.sqrt(P[n_total-1]))
    for t in reversed(range(n_total-1)):
        mean = mu_filt[t] + P[t] / (P[t] + sigma2_mu) * (mu[t+1] - mu_filt[t])
        var = P[t] - P[t]**2 / (P[t] + sigma2_mu)
        mu[t] = np.random.normal(mean, np.sqrt(var))

    # 2. Sample beta | mu, sigma^2
    Xt = X
    Yt = Y - mu
    V_beta = 1 / (np.sum(Xt**2) / sigma2 + 1 / beta_prior_var)
    M_beta = V_beta * (np.sum(Xt * Yt) / sigma2 + beta_prior_mean / beta_prior_var)
    beta = np.random.normal(M_beta, np.sqrt(V_beta))

    # 3. Sample sigma^2 | mu, beta
    residuals = Y - mu - beta * X
    a_post = a0 + n_total / 2
    b_post = b0 + 0.5 * np.sum(residuals**2)
    sigma2 = stats.invgamma.rvs(a=a_post, scale=b_post)

    # Store samples
    mu_samples[i] = mu
    beta_samples[i] = beta
    sigma2_samples[i] = sigma2
