# Simplified Confidence and Prediction Intervals with Test-Set Dots

This notebook:
1. Simulates \(y \sim N(\mu=10, \sigma^2=81)\), \(n=300\).
2. Splits into 150 training and 150 test observations.
3. Computes a 95% confidence interval for the population mean using the training set.
4. Computes a 95% prediction interval for the population mean using the test set.
5. Computes a 95% prediction interval for individual observations using the training estimate.
6. Plots vertical panels of histograms of the test data, adding true mean, estimate, confidence interval, prediction interval, and overlays test-set points as dots.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# Set seed and parameters
np.random.seed(0)
mu_true = 10.0
sigma = 9.0

# Simulate data
y = np.random.normal(loc=mu_true, scale=sigma, size=300)
y_train = y[:150]
y_test = y[150:]

In [None]:
# Calculate z-value
z = norm.ppf(0.975)

# Confidence interval for mean using training set
n_train = len(y_train)
mu_hat_train = y_train.mean()
se_train = sigma / np.sqrt(n_train)
ci_lower = mu_hat_train - z * se_train
ci_upper = mu_hat_train + z * se_train

# "Prediction interval" for population mean using test set
n_test = len(y_test)
mu_hat_test = y_test.mean()
se_test = sigma / np.sqrt(n_test)
pi_mu_test_lower = mu_hat_test - z * se_test
pi_mu_test_upper = mu_hat_test + z * se_test

# Prediction interval for individual observations (train estimate)
pi_obs_lower = mu_hat_train - z * sigma
pi_obs_upper = mu_hat_train + z * sigma

# Display results
print(f"Training mean estimate: {mu_hat_train:.3f}")
print(f"95% CI for population mean (train): [{ci_lower:.3f}, {ci_upper:.3f}]")
print(f"Test mean estimate:     {mu_hat_test:.3f}")
print(f"95% PI for population mean (test):  [{pi_mu_test_lower:.3f}, {pi_mu_test_upper:.3f}]")
print(f"95% PI for individual obs:          [{pi_obs_lower:.3f}, {pi_obs_upper:.3f}]")

In [None]:
# Plotting vertical sequence of subplots with test-set dots
fig, axes = plt.subplots(4, 1, figsize=(8, 20), sharex=True)

# (a) Histogram + true mean + test points
axes[0].hist(y_test, bins=20, alpha=0.7)
axes[0].scatter(y_test, np.zeros_like(y_test), marker='o', color='black', alpha=0.6, label='Test Points')
axes[0].axvline(mu_true, color='red', linestyle='--', label='True Mean')
axes[0].set_title('(a) Histogram of Test Data with True Mean')
axes[0].legend()

# (b) + estimated mean + test points
axes[1].hist(y_test, bins=20, alpha=0.7)
axes[1].scatter(y_test, np.zeros_like(y_test), marker='o', color='black', alpha=0.6)
axes[1].axvline(mu_true, color='red', linestyle='--', label='True Mean')
axes[1].axvline(mu_hat_train, color='blue', linestyle='-', label='Estimated Mean')
axes[1].set_title('(b) + Estimated Mean')
axes[1].legend()

# (c) + 95% CI for mean + test points
axes[2].hist(y_test, bins=20, alpha=0.7)
axes[2].scatter(y_test, np.zeros_like(y_test), marker='o', color='black', alpha=0.6)
axes[2].axvline(mu_true, color='red', linestyle='--', label='True Mean')
axes[2].axvline(mu_hat_train, color='blue', linestyle='-', label='Estimated Mean')
axes[2].axvline(ci_lower, color='green', linestyle='-.', label='95% CI Lower')
axes[2].axvline(ci_upper, color='green', linestyle='-.', label='95% CI Upper')
axes[2].set_title('(c) + 95% Confidence Interval for Mean')
axes[2].legend()

# (d) + 95% prediction interval for obs + test points
axes[3].hist(y_test, bins=20, alpha=0.7)
axes[3].scatter(y_test, np.zeros_like(y_test), marker='o', color='black', alpha=0.6)
axes[3].axvline(mu_true, color='red', linestyle='--', label='True Mean')
axes[3].axvline(mu_hat_train, color='blue', linestyle='-', label='Estimated Mean')
axes[3].axvline(ci_lower, color='green', linestyle='-.', label='95% CI Lower')
axes[3].axvline(ci_upper, color='green', linestyle='-.', label='95% CI Upper')
axes[3].axvline(pi_obs_lower, color='purple', linestyle=':', label='95% PI Lower')
axes[3].axvline(pi_obs_upper, color='purple', linestyle=':', label='95% PI Upper')
axes[3].set_title('(d) + 95% Prediction Interval for Individual Observations')
axes[3].legend()

plt.tight_layout()
plt.show()