# Conformal Prediction Under Distribution Shift: Beyond I.I.D. Assumptions

## Introduction

In the previous lecture, we explored the foundations of Conformal Prediction, which provides powerful finite-sample coverage guarantees for prediction sets under the assumption that data is Independent and Identically Distributed (I.I.D.). While the I.I.D. assumption simplifies theoretical analysis and is often a starting point in machine learning, real-world data rarely adheres to it perfectly. Data distributions can *shift* over time or across different collection environments, posing significant challenges for uncertainty quantification.

This notebook delves into advanced conformal methods designed to handle **distribution shifts**, focusing on scenarios where the underlying data generating process changes between training and testing. We will cover:

* **Likelihood-Weighted Conformal Prediction**: A primary method for handling *covariate shift*, where the distribution of features ($X$) changes, but the conditional distribution of the response given features ($Y|X$) remains constant.
* **Estimating Likelihood Ratios**: Practical techniques to estimate the necessary weights for likelihood-weighted conformal prediction using unlabeled data.
* **Conformal Prediction for Structured-X Settings**: A more general theoretical framework for complex feature dependencies.
* **Custom-Weighted Conformal Prediction**: A flexible approach that allows for arbitrary, fixed weights to prioritize certain data points.
* **Adaptive Conformal Inference (ACI)**: An online method for sequential prediction that continuously adjusts to distribution changes over time.

Understanding these extensions is crucial for deploying reliable machine learning models in dynamic environments. As before, we will use `JAX` and `scikit-learn` for practical code examples to illustrate these concepts.

## 1. Likelihood-Weighted Conformal Prediction

A common and important type of distribution shift is **covariate shift**. In this scenario, we assume:

* Training data: $(X_i, Y_i) \sim P = P_X \times P_{Y|X}$ independently, for $i=1, \ldots, n$.
* Test data: $(X_{n+1}, Y_{n+1}) \sim \tilde{P} = \tilde{P}_X \times P_{Y|X}$ independently.

Here, the conditional distribution of $Y|X$ (the relationship between features and response) is assumed to be the same for both training and test data, but the distribution of features $X$ is allowed to change, i.e., $\tilde{P}_X \ne P_X$.

### Why Standard Conformal Prediction Fails Under Covariate Shift

As empirically shown in Tibshirani et al. (2019) (Figure 1 in the original slides), if you apply standard split conformal prediction when covariate shift is present, the coverage guarantee often fails. The prediction intervals tend to under-cover. This happens because the calibration set, drawn from $P_X$, might not be representative of the test point, which is drawn from $\tilde{P}_X$. The notion of exchangeability, crucial for the rank-based quantile argument, is broken when the training and test sets come from different feature distributions.

To remedy this, we need to adapt the empirical distribution of conformity scores using **weights** that account for the shift. These weights are based on the likelihood ratio between the test and training feature distributions.

### Revisiting Exchangeability and Quantiles (with Weights)

Recall that the core of conformal prediction relies on the rank-based argument stemming from exchangeability. We need to generalize this to a "weighted exchangeability" setting. The key idea is that even if the random variables $R_1, \ldots, R_{n+1}$ are not perfectly exchangeable, their probabilities can be related to a permutation-invariant function $g$ and individual weight functions $w_i(r_i)$.

A sequence of random variables $R_1, \ldots, R_{n+1}$ is **weighted exchangeable** if their joint density (or mass function) can be written as:

$$f(r_1, \ldots, r_{n+1}) = \prod_{i=1}^{n+1} w_i(r_i) \cdot g(r_1, \ldots, r_{n+1})$$

where $g$ is any permutation-invariant function. In this case, the probability that $R_{n+1}$ takes a specific value $r_i$ from the observed set $\{r_1, \\ldots, r_{n+1}\}$ (conditioned on observing this set) is not $1/(n+1)$ anymore, but rather:

$$\mathbb{P}(R_{n+1}=r_i | \{R_1, \ldots, R_{n+1}\} = \{r_1, \ldots, r_{n+1}\}) = \frac{w_i(r_i)}{\sum_{j=1}^{n+1} w_j(r_j)}$$

This leads to **Lemma 1 (Quantile Lemma)** from the slides, which states that for weighted exchangeable random variables $Z_i$ with weights $w_1, \ldots, w_{n+1}$, and a symmetric score function $V$, the conformal prediction set for $R_{n+1}$ (test score) can be formed using a weighted empirical distribution of scores:

$$\mathbb{P}\left\{R_{n+1}\le \text{Quantile}\left(1-\alpha;\sum_{i=1}^{n}p_{i}^{w}(Z_{1},...,Z_{n+1})\delta_{R_{i}}+p_{n+1}^{w}(Z_{1},...,Z_{n+1})\delta_{\infty}\right)\right\}\ge1-\alpha.$$

where $p_i^w$ are the normalized weights that sum to 1.

### Application to Covariate Shift (Corollary 1)

For the covariate shift problem, the weights $w_i(r_i)$ simplify considerably. If $w(x) = d\tilde{P}_X(x) / dP_X(x)$ is the likelihood ratio (density ratio) between the test and training feature distributions, then the weights for the *training* points are effectively $1$ (as they come from $P_X$), and the weight for the *test* point is $w(X_{n+1})$.

This leads to **Corollary 1** (Tibshirani et al., 2019), which defines the weighted conformal set as:

$$\hat{C}_{n}^{w}(x)=\left\{y:R_{n+1}^{(x,y)}\le \text{Quantile}\left(1-\alpha;\sum_{i=1}^{n}\pi_{i}^{w}(x)\delta_{R_{i}^{(x,y)}}+\pi_{n+1}^{w}(x)\delta_{\infty}\right)\right\}$$

where $R_i^{(x,y)}$ are conformity scores (e.g., absolute residuals) based on a base predictor, and the normalized weights $\pi_i^w(x)$ are given by:

$$\pi_{i}^{w}(x)=\frac{w(X_{i})}{\sum_{j=1}^{n}w(X_{j})+w(x)},\quad i=1,\ldots,n$$
$$\pi_{n+1}^{w}(x)=\frac{w(x)}{\sum_{j=1}^{n}w(X_{j})+w(x)}$$

This weighted conformal set guarantees $\mathbb{P}(Y_{n+1}\in\hat{C}_{n}^{w}(X_{n+1}))\ge1-\alpha$. In a split conformal setup, this is conditional on the proper training set $D_1$.

### Impact on Coverage and Length

The middle row of Figure 1 in the original slides demonstrates that using these oracle weights restores coverage in the presence of covariate shift. However, the dispersion (variability) in coverage can be larger, and the prediction intervals tend to be wider, compared to standard conformal prediction without any shift but with an equivalent effective sample size. This is because non-uniform weights effectively reduce the "effective" number of calibration samples.

### Code Example: Likelihood-Weighted Split Conformal Prediction

This example simulates a covariate shift scenario and demonstrates how likelihood weighting can restore coverage. We will simulate data from two different $P_X$ distributions (training and test) but with the same $P_{Y|X}$. We will then estimate the likelihood ratio $w(x)$ using a logistic regression classifier trained to distinguish between the two $X$ distributions.

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.preprocessing import StandardScaler

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Data with Covariate Shift ---
num_samples_train = 500
num_samples_test = 500
noise_std = 5.0

# Training data (P_X)
key, subkey = random.split(key)
X_train_P = random.normal(subkey, (num_samples_train, 1)) * 2 # Broader distribution
y_train_P = 2 * X_train_P.squeeze() + 3 + random.normal(key, (num_samples_train,)) * noise_std

# Test data (tilde_P_X) - Shifted distribution
key, subkey = random.split(key)
X_test_tildeP = random.normal(subkey, (num_samples_test, 1)) * 0.5 + 2 # Narrower, shifted distribution
y_test_tildeP = 2 * X_test_tildeP.squeeze() + 3 + random.normal(key, (num_samples_test,)) * noise_std

print(f"Training data shape: {X_train_P.shape}, {y_train_P.shape}")
print(f"Test data shape: {X_test_tildeP.shape}, {y_test_tildeP.shape}\n")

# --- 2. Split Training Data into Proper Training (D1) and Calibration (D2) ---
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_train_P, y_train_P, test_size=0.5, random_state=key.get(0).tolist()[0]
)
n1 = X_D1.shape[0]
n2 = X_D2.shape[0]
print(f"Proper training set size (n1): {n1}")
print(f"Calibration set size (n2): {n2}\n")

# --- 3. Train Base Predictor on Proper Training Set (D1) ---
base_predictor = LinearRegression()
base_predictor.fit(np.array(X_D1), np.array(y_D1))

# --- 4. Estimate Likelihood Ratio (w(x)) ---
# Combine X from P_X (training) and tilde_P_X (test) and assign labels 0 and 1
X_combined = jnp.vstack([X_D2, X_test_tildeP]) # Use D2 for training classifier, as D2 is from P_X
labels = jnp.concatenate([jnp.zeros(X_D2.shape[0]), jnp.ones(X_test_tildeP.shape[0])])

# Scale X for logistic regression stability
scaler = StandardScaler()
X_combined_scaled = scaler.fit_transform(np.array(X_combined))
X_test_tildeP_scaled = scaler.transform(np.array(X_test_tildeP))
X_D2_scaled = scaler.transform(np.array(X_D2))

# Train a logistic regression classifier to predict if X comes from P_X (0) or tilde_P_X (1)
prob_classifier = LogisticRegression(random_state=key.get(1).tolist()[0], solver='lbfgs')
prob_classifier.fit(X_combined_scaled, np.array(labels))

# Predict P(C=1|X=x) for calibration set (D2) and test set
p_D2 = prob_classifier.predict_proba(X_D2_scaled)[:, 1] # Probability of being from tilde_P_X
p_test = prob_classifier.predict_proba(X_test_tildeP_scaled)[:, 1]

# Estimate weights w(x) = p(C=1|x) / p(C=0|x) = p(C=1|x) / (1-p(C=1|x))
# w(x) = d(tilde_P_X) / d(P_X)
# Add small epsilon to avoid division by zero or inf
epsilon = 1e-6
weights_D2 = (p_D2 + epsilon) / (1 - p_D2 + epsilon)
weights_test = (p_test + epsilon) / (1 - p_test + epsilon)

print(f"Estimated weights for calibration set (first 5): {weights_D2[:5]}")
print(f"Estimated weights for test set (first 5): {weights_test[:5]}\n")

# --- 5. Compute Calibration Scores (Absolute Residuals) on D2 ---
y_D2_pred = base_predictor.predict(np.array(X_D2))
calibration_residuals = jnp.abs(y_D2 - jnp.array(y_D2_pred))

# --- 6. Compute Weighted Conformal Quantile ---
alpha_level = 0.1 # Desired coverage: 1 - alpha = 0.9 (90%)

# The weighted scores are $R_i / w(X_i)$ or equivalent ways to incorporate weights
# According to the slide, we use the quantile of a weighted empirical distribution.
# This means sorting by residual, and accumulating weights.

# Sort calibration residuals and get their corresponding weights
sorted_indices = jnp.argsort(calibration_residuals)
sorted_residuals = calibration_residuals[sorted_indices]
sorted_weights = weights_D2[sorted_indices]

# Compute the weighted empirical CDF
cumulative_weights = jnp.cumsum(sorted_weights)

# We need to find the residual `r` such that its cumulative weight is >= (1-alpha) * (total weights + weight of test point)
# The total sum of weights (including a hypothetical test point with weight 1 for the original formula)
# The formula uses pi_i^w which normalizes these weights.
# The core idea is to find the (1-alpha) quantile of the distribution formed by (residual, weight) pairs.
# We can use a trick: for each test point, combine its score with calibration scores,
# apply weights, and find its rank. This is essentially what the quantile form does.

# For split conformal, we use the calibration set. We are essentially finding a quantile of the
# distribution where each calibration point has weight w(X_i) / sum(w(X_j)).
# The simplest practical way is to find the (1-alpha) * (sum of weights + test weight) percentile.
# Let's find the threshold 'tau' such that the sum of weights for residuals <= tau is >= (1-alpha) * sum(weights)

total_calibration_weight = jnp.sum(weights_D2)
target_weight_quantile = (1 - alpha_level) * (total_calibration_weight + 1) # +1 for test point in original formula

# Find the smallest residual whose cumulative weight meets the target
weighted_conformal_quantile_idx = jnp.searchsorted(cumulative_weights, target_weight_quantile)

conformal_quantile_weighted = sorted_residuals[weighted_conformal_quantile_idx]

print(f"Weighted conformal quantile: {conformal_quantile_weighted:.4f}\n")

# --- 7. Form Conformal Prediction Band ---
# Prediction on the test set (tilde_P_X)
y_test_pred_mean = base_predictor.predict(np.array(X_test_tildeP))

# The prediction intervals are formed using the *same* quantile, but still centered at the mean prediction.
# [mean - quantile, mean + quantile]
lower_bound_weighted = y_test_pred_mean - conformal_quantile_weighted
upper_bound_weighted = y_test_pred_mean + conformal_quantile_weighted

# --- Plotting Test Data and Prediction Band ---
plt.figure(figsize=(12, 8))
plt.scatter(X_train_P[:, 0], y_train_P, color='blue', label='Training Data (P_X)', alpha=0.5, s=20)
plt.scatter(X_test_tildeP[:, 0], y_test_tildeP, color='orange', label='Test Data ($\tilde{P}_X$)', alpha=0.5, s=20)
plt.plot(X_test_tildeP[jnp.argsort(X_test_tildeP.squeeze()), 0], y_test_pred_mean[jnp.argsort(X_test_tildeP.squeeze())], color='red', linewidth=2, label='Regression Mean')

# Plot the weighted conformal prediction band for the test data range
plt.fill_between(
    X_test_tildeP[jnp.argsort(X_test_tildeP.squeeze()), 0].squeeze(),
    lower_bound_weighted[jnp.argsort(X_test_tildeP.squeeze())],
    upper_bound_weighted[jnp.argsort(X_test_tildeP.squeeze())],
    color='red', alpha=0.2, label=f'Weighted Conformal Band (1-$\alpha$={1-alpha_level})'
)

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Likelihood-Weighted Conformal Prediction Under Covariate Shift')
plt.legend()
plt.grid(True)
plt.show()

# --- Simulate Test Coverage to Verify Guarantee ---
coverage_count_weighted = 0
for i in range(num_samples_test):
    if (y_test_tildeP[i] >= lower_bound_weighted[i]) and \
       (y_test_tildeP[i] <= upper_bound_weighted[i]):
        coverage_count_weighted += 1

simulated_coverage_weighted = coverage_count_weighted / num_samples_test
print(f"Simulated weighted conformal coverage on test set: {simulated_coverage_weighted:.4f}")
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(f"Theoretical upper bound (approx): {1 - alpha_level + 1/(n2 + 1):.4f}\n")

# --- Compare with Naive Conformal (without weighting) ---
# This simulates standard split conformal prediction under covariate shift
# It should show undercoverage.

naive_conformal_quantile_idx = jnp.ceil((1 - alpha_level) * (n2 + 1)).astype(int)
naive_conformal_quantile = jnp.sort(calibration_residuals)[naive_conformal_quantile_idx - 1]

lower_bound_naive = y_test_pred_mean - naive_conformal_quantile
upper_bound_naive = y_test_pred_mean + naive_conformal_quantile

coverage_count_naive = 0
for i in range(num_samples_test):
    if (y_test_tildeP[i] >= lower_bound_naive[i]) and \
       (y_test_tildeP[i] <= upper_bound_naive[i]):
        coverage_count_naive += 1

simulated_coverage_naive = coverage_count_naive / num_samples_test
print(f"Simulated NAIVE conformal coverage on test set (should undercover): {simulated_coverage_naive:.4f}")
print(f"Desired coverage: {1 - alpha_level:.4f}")

## 2. Estimating the Likelihood Ratio from Unlabeled Data

In real-world scenarios, the true likelihood ratio $w(x) = d\tilde{P}_X(x) / dP_X(x)$ is typically unknown. However, if we have access to unlabeled data $X_{n+1}, \ldots, X_{n+m}$ from the test distribution $\tilde{P}_X$ (along with the training data $X_1, \ldots, X_n$ from $P_X$), we can estimate this ratio.

The core idea is to frame the problem as a **binary classification task**:

1.  Create a combined dataset of features: $X_{\text{combined}} = (X_1, \ldots, X_n, X_{n+1}, \ldots, X_{n+m})$.
2.  Assign binary labels to these features:
    * $C_i = 0$ for $X_i$ coming from $P_X$ (i.e., for $i=1, \ldots, n$).
    * $C_i = 1$ for $X_i$ coming from $\tilde{P}_X$ (i.e., for $i=n+1, \ldots, n+m$).
3.  Train a probabilistic classifier (e.g., Logistic Regression, Random Forest Classifier) on $(X_{\text{combined}}, C_{\text{labels}})$ to estimate $P(C=1|X=x)$. Let this estimate be $\hat{p}(x)$.

The crucial insight is that the likelihood ratio $w(x)$ is directly related to the conditional odds ratio of this binary classification problem:

$$\frac{P(C=1|X=x)}{P(C=0|X=x)} = \frac{P(C=1)}{P(C=0)} \frac{d\tilde{P}_X(x)}{dP_X(x)}$$

Since we only need the likelihood ratio up to a proportionality constant for the weighted conformal procedure, we can use the following as our estimated weight function:

$$\hat{w}(x) = \frac{\hat{p}(x)}{1 - \hat{p}(x)}$$

The better calibrated the classifier $\hat{p}(x)$ is, the more accurate our estimated weights $\hat{w}(x)$ will be. The bottom row of Figure 1 in the original slides demonstrates the effectiveness of this estimation strategy, showing restored coverage when using weights derived from Logistic Regression or Random Forests.

The code example in the previous section (Likelihood-Weighted Split Conformal Prediction) already incorporates this estimation strategy.

## 3. Conformal Prediction for Structured-X Settings

Beyond simple covariate shift, we can consider even more general scenarios where the features themselves have a complex, dependent structure. This is captured by **Theorem 2** in the original slides, which considers data distributed according to:

* Features: $(X_1, \ldots, X_{n+1}) \sim \Lambda$, where $\Lambda$ can be an arbitrary joint distribution (not necessarily i.i.d.).
* Responses: $Y_i|X_i \sim P_{Y|X}$, independently, for $i=1, \ldots, n+1$.

In this setting, the conformal set is defined using a generalized form of weights $p_i^\lambda(x_1, \ldots, x_{n+1})$ that account for the complex dependencies within the feature vector:

$$\hat{C}_{n}^{\lambda}(x)=\left\{y:R_{n+1}^{(x,y)}\le \text{Quantile}\left(1-\alpha;\sum_{i=1}^{n}p_{i}^{\lambda}(X_{1},...,X_{n},x)\delta_{R_{i}^{(x,y)}}+p_{n+1}^{\lambda}(X_{1},...,X_{n},x)\delta_{\infty}\right)\right\}$$

where $R_i^{(x,y)}$ are conformity scores (as before), and $p_i^\lambda$ are probabilities derived from the density function of $\Lambda$ under permutations:

$$p_{i}^{\lambda}(x_{1},...,x_{n+1})=\frac{\sum_{\sigma:\sigma(n+1)=i}\lambda(x_{\sigma(1)},...,x_{\sigma(n+1)})}{\sum_{\sigma}\lambda(x_{\sigma(1)},...,x_{\sigma(n+1)})}, \quad i=1,\ldots,n+1$$

This theorem provides a very general guarantee of $\mathbb{P}(Y_{n+1}\in\hat{C}_{n}^{\lambda}(X_{n+1}))\ge1-\alpha$. However, computing these weights $p_i^\lambda$ can be extraordinarily difficult due to the combinatorial sums over permutations, making this approach computationally intractable for most real-world scenarios unless $\Lambda$ has a very specific, easily factorizable structure (e.g., Markov property for time series). This method is primarily of theoretical interest, demonstrating the extensibility of conformal principles.

### Code Example: Conformal Prediction for Structured-X Settings (Conceptual)

This example illustrates the theoretical concept of calculating $p_i^\lambda$ for a very small `n`. It defines a simple $\Lambda$ (joint distribution of features) where order matters. **Warning**: This implementation is purely conceptual and will not scale to larger `n` due to the factorial complexity of permutations. It serves to show the *logic* behind the weight calculation, not a practical implementation.

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import itertools # For generating permutations
from sklearn.linear_model import LinearRegression

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Very Small Structured Data ---
# For conceptual illustration, let's use n=2 training points.
# We'll have (X1, Y1), (X2, Y2) as training, and (x_test, y_test) as the query point.
# So total N+1 = 3 points.

num_training_points = 2 # n
noise_std = 0.5

# Define a simple sequential dependency for Lambda (e.g., Markov-like)
# Lambda will be P(X1) * P(X2|X1) * P(X3|X2)

def P_X_initial(x): # Probability of the first X
    return jnp.exp(-((x - 0.0)**2) / (2 * 1.0**2)) / jnp.sqrt(2 * jnp.pi * 1.0**2)

def P_X_conditional(x_curr, x_prev): # P(X_curr | X_prev)
    # Simple dependency: X_curr is drawn from a normal distribution centered at X_prev
    return jnp.exp(-((x_curr - x_prev)**2) / (2 * 0.5**2)) / jnp.sqrt(2 * jnp.pi * 0.5**2)

def Lambda_density(x_sequence): # Lambda(x1, x2, ..., x_N+1)
    # Assumes x_sequence is (x_1, x_2, ..., x_{n+1})
    if len(x_sequence) == 0: return 1.0
    if len(x_sequence) == 1: return P_X_initial(x_sequence[0])
    
    density = P_X_initial(x_sequence[0])
    for i in range(1, len(x_sequence)):
        density *= P_X_conditional(x_sequence[i], x_sequence[i-1])
    return density

# Simulate training X values (these are fixed for the calculation)
X_train_values_fixed = jnp.array([1.0, 2.0]) # Our X1, X2
y_train_values = 2 * X_train_values_fixed + random.normal(key, (num_training_points,)) * noise_std

X_test_point_fixed = jnp.array([1.5]) # Our x_test
y_test_point_true = 2 * X_test_point_fixed[0] + random.normal(key, (1,)) * noise_std # True y for test point

print(f"Training X values: {X_train_values_fixed}")
print(f"Test X point: {X_test_point_fixed.item()}\n")

# --- 2. Full Conformal Setup for a Query Point (x_test, y_candidate) ---
# For simplicity, we'll try to find if a specific y_candidate is in the set for x_test
y_candidate_example = jnp.array([4.5]) # A hypothetical y value for x_test

# Combined sequence of X values for the permutation sums
X_all_points = jnp.concatenate([X_train_values_fixed, X_test_point_fixed]).tolist()
n_plus_1 = len(X_all_points)

print(f"All X points (fixed): {X_all_points}")
print(f"Total points (n+1): {n_plus_1}\n")

def calculate_p_lambda(x_values_sequence_fixed, query_x_idx_in_sequence):
    """
    Calculates p_i^lambda weights for the structured-X setting.
    x_values_sequence_fixed: List of fixed X values (x1, ..., xn, x_test).
    query_x_idx_in_sequence: The 0-indexed position of the test point x_test in x_values_sequence_fixed.
    
    NOTE: This is computationally expensive, as it iterates through (n+1)! permutations.
    Only for very small 'n'.
    """
    n_plus_1 = len(x_values_sequence_fixed)
    all_indices = list(range(n_plus_1))
    
    denominator_sum = 0.0
    numerator_sums = jnp.zeros(n_plus_1) # For each p_i^lambda

    # Iterate through all (n+1)! permutations of the indices
    for perm_indices in itertools.permutations(all_indices):
        # Reconstruct the X sequence based on the permutation of fixed values
        perm_x_sequence = [x_values_sequence_fixed[idx] for idx in perm_indices]
        
        # Calculate lambda for this permuted sequence
        lambda_val = Lambda_density(perm_x_sequence)
        denominator_sum += lambda_val
        
        # The (n+1)-th element of the permuted sequence is perm_x_sequence[n_plus_1 - 1]
        # We need to find which original index maps to this position.
        original_idx_at_last_pos = perm_indices[n_plus_1 - 1]
        numerator_sums = numerator_sums.at[original_idx_at_last_pos].add(lambda_val)
    
    if denominator_sum == 0: # Avoid division by zero
        print("Warning: Denominator sum is zero. Check Lambda_density or input values.")
        return jnp.ones(n_plus_1) / n_plus_1 # Fallback to uniform weights
        
    p_lambda_weights = numerator_sums / denominator_sum
    return p_lambda_weights

# Calculate the weights p_i^lambda (for fixed X values)
p_lambda_for_fixed_X_values = calculate_p_lambda(X_all_points, num_training_points)

print(f"Calculated p_lambda weights: {p_lambda_for_fixed_X_values}\n")
print(f"Sum of p_lambda weights: {jnp.sum(p_lambda_for_fixed_X_values):.4f}\n")

# --- 3. Compute Scores R_i^(x,y) ---
# In the full conformal setting, a base predictor is trained on augmented data for each (x,y).
# Here, (X_train_values_fixed, y_train_values) + (X_test_point_fixed, y_candidate_example).

X_augmented = jnp.concatenate([X_train_values_fixed, X_test_point_fixed]).reshape(-1, 1)
y_augmented = jnp.concatenate([y_train_values, y_candidate_example])

base_predictor_for_scores = LinearRegression()
base_predictor_for_scores.fit(np.array(X_augmented), np.array(y_augmented))

all_residuals_for_scores = jnp.abs(y_augmented - jnp.array(base_predictor_for_scores.predict(np.array(X_augmented))))

R_n_plus_1_query = all_residuals_for_scores[-1] # Residual for the query point

print(f"All residuals for scores (n+1 points): {all_residuals_for_scores}")
print(f"Residual for query point (R_n+1^(x,y)): {R_n_plus_1_query:.4f}\n")

# --- 4. Form Weighted Quantile for a Specific y_candidate ---
alpha_level = 0.1 # 90% coverage

# Create a weighted empirical distribution for the residuals
# The delta_infinity term is handled by effectively treating the test point's weight separately
# Here we are trying to check if R_n+1^(x,y) <= Quantile(1-alpha; sum(pi_i^w * delta_R_i) + pi_n+1^w * delta_inf)

# For the sum term, we take the original training residuals, weighted by p_lambda
# The p_lambda are already calculated using the specific X_test_point.
training_residuals_for_quantile = all_residuals_for_scores[:-1]
training_p_lambda_weights = p_lambda_for_fixed_X_values[:-1]
p_n_plus_1_lambda = p_lambda_for_fixed_X_values[-1] # Weight for the test point

# Create pairs of (residual, weight) for the training points
weighted_training_score_pairs = sorted([(training_residuals_for_quantile[i], training_p_lambda_weights[i]) for i in range(num_training_points)])

cumulative_weighted_sum = 0.0
weighted_quantile = 0.0
target_cumulative_weight = 1 - alpha_level # We want 1-alpha quantile of the normalized distribution

# Iterate through sorted weighted training scores to find the quantile
# Note: This is an approximate way to find the quantile of a discrete weighted distribution.
for res, weight in weighted_training_score_pairs:
    cumulative_weighted_sum += weight
    if cumulative_weighted_sum >= target_cumulative_weight:
        weighted_quantile = res
        break
else:
    # If target not reached, it means quantile is larger than max training residual.
    # Given delta_infinity, it will be infinity unless the target is very low.
    # For simplicity, if target isn't met by sum of training weights, it means the quantile is effectively large.
    weighted_quantile = jnp.inf # If target is higher than sum of training weights

print(f"Weighted quantile for this query point (approx): {weighted_quantile:.4f}")

# Check if the query point's residual is in the set
is_in_set = R_n_plus_1_query <= weighted_quantile
print(f"Is (X_test_point={X_test_point_fixed.item()}, Y_candidate={y_candidate_example.item()}) in the conformal set? {is_in_set}")

print("\nNOTE: This code is a conceptual demonstration for very small N due to the factorial complexity of permutations.")
print("A practical implementation for structured-X settings would require specialized algorithms to compute p_lambda weights efficiently (e.g., for specific Markov structures)." )

## 4. Custom-Weighted Conformal Prediction

While likelihood-weighted conformal prediction targets specific distribution shifts like covariate shift, **custom-weighted conformal prediction** (Barber et al., 2023) takes a more general approach. Instead of inferring weights from a likelihood ratio, this method allows for fixed, arbitrary weights $w_i \in [0,1]$ for each training point. This can be useful when one has prior knowledge about the relevance or representativeness of certain training data points for the test distribution, without necessarily needing to model the exact shift.

### Fixed Arbitrary Weights (Theorem 3)

Given fixed weights $w_i \in [0,1]$ for $i=1, \ldots, n$, we define normalized weights $\tilde{w}_i$ including a unit weight for the test point:

$$\tilde{w}_{i}=\frac{w_{i}}{w_{1}+\cdots+w_{n}+1}, \quad i=1,\ldots,n$$
$$\tilde{w}_{n+1}=\frac{1}{w_{1}+\cdots+w_{n}+1}$$

The weighted conformal set is then formed as:

$$\hat{C}_{n}^{w}(x)=\left\{y:R_{n+1}^{(x,y)}\le \text{Quantile}\left(1-\alpha;\sum_{i=1}^{n}\tilde{w}_{i}\delta_{R_{i}^{(x,y)}}+\tilde{w}_{n+1}\delta_{\infty}\right)\right\}$$

where $R_i^{(x,y)}$ are conformity scores based on a base predictor. Importantly, **Theorem 3** states that *without any assumptions on the joint distribution of $Z_i=(X_i,Y_i)$*, this set satisfies:

$$\mathbb{P}(Y_{n+1}\in\hat{C}_{n}^{w}(X_{n+1}))\ge1-\alpha-\sum_{i=1}^{n}\tilde{w}_{i}\cdot TV(R(Z),R(Z^{i}))$$

Here, $TV(A,B)$ is the total variation distance between the distributions of random variables $A$ and $B$, and $R(Z^i)$ denotes the score vector if $Z_i$ and $Z_{n+1}$ were swapped. The term $\sum_{i=1}^{n}\tilde{w}_{i}\cdot TV(R(Z),R(Z^{i}))$ is called the **coverage gap**.

**Interpretation of the Coverage Gap**: This result implies that the guarantee is approximate, with a "gap" that depends on how much swapping $Z_i$ with $Z_{n+1}$ changes the distribution of scores. If the training data points $Z_i$ are highly representative of the test data $Z_{n+1}$ (i.e., small $TV$ distances), and/or if we assign large weights $\tilde{w}_i$ to such representative points, the coverage gap will be small.

**Special Cases**:
* **I.I.D. Setting**: If the data is truly I.I.D. (or exchangeable), then $TV(R(Z), R(Z^i))=0$ for all $i$. In this case, the coverage gap is zero, and the guarantee becomes exact $1-\alpha$. This means that even with arbitrary fixed weights, conformal prediction remains valid under I.I.D. data.
* **Split Version**: Similar to likelihood-weighted CP, a split conformal version exists where the base predictor is trained on an external dataset $Z_0$. The coverage gap then becomes conditional on $Z_0$.

### Nonsymmetric Score Functions (Theorem 4)

Traditional conformal prediction assumes the score function $V$ is symmetric with respect to its inputs. However, some models (e.g., autoregressive models in time series) are inherently non-symmetric as they depend on the order of data points. **Theorem 4** addresses this by introducing a "random swap" into the score computation:

It uses conformity scores $R_i^{(x,y), K}$ where the vector of augmented data $(Z_1, \ldots, Z_n, (x,y))$ has its components $K$ and $n+1$ swapped, and $K$ is randomly chosen from a multinomial distribution defined by the normalized weights $\tilde{w}_i$.

This extension broadens the applicability of custom-weighted conformal prediction to a wider range of algorithms and structured data settings (e.g., time series with decaying weights, as shown in Figure 2 of the original slides), while still providing a controlled coverage guarantee that depends on the coverage gap.

### Code Example: Custom-Weighted Split Conformal Prediction

This example demonstrates Custom-Weighted Split Conformal Prediction with arbitrary fixed weights. We'll simulate a regression problem and assign weights to calibration points, for instance, to reflect their perceived relevance or quality.

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Regression Data ---
num_samples_train = 500
noise_std = 5.0

X_data_np = np.linspace(-5, 5, num_samples_train)[:, None]
y_true = 2 * X_data_np.squeeze() + 3
y_data_np = y_true + np.random.normal(0, noise_std, num_samples_train)

X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np)

print(f"Total samples: {num_samples_train}\n")

# --- 2. Split Data into Proper Training (D1) and Calibration (D2) ---
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_data, y_data, test_size=0.5, random_state=key.get(0).tolist()[0]
)
n1 = X_D1.shape[0]
n2 = X_D2.shape[0]
print(f"Proper training set size (n1): {n1}")
print(f"Calibration set size (n2): {n2}\n")

# --- 3. Assign Custom Weights to Calibration Points ---
# For demonstration, let's assign higher weights to points closer to 0 (more central/reliable)
# Weights should be in [0, 1]. A simple decaying function based on |X|.
custom_weights_D2 = 1.0 - jnp.abs(X_D2.squeeze()) / jnp.max(jnp.abs(X_D2.squeeze()))
custom_weights_D2 = jnp.clip(custom_weights_D2, 0.1, 1.0) # Ensure min weight of 0.1

print(f"Custom weights for calibration set (first 5): {custom_weights_D2[:5]}\n")

# --- 4. Train Base Predictor on Proper Training Set (D1) ---
base_predictor = LinearRegression()
base_predictor.fit(np.array(X_D1), np.array(y_D1))

# --- 5. Compute Calibration Scores (Absolute Residuals) on D2 ---
y_D2_pred = base_predictor.predict(np.array(X_D2))
calibration_residuals = jnp.abs(y_D2 - jnp.array(y_D2_pred))
print(f"Calibration residuals (first 5): {calibration_residuals[:5]}\n")

# --- 6. Compute Custom-Weighted Conformal Quantile ---
alpha_level = 0.1 # Desired coverage: 1 - alpha = 0.9 (90%)

# Normalize custom weights to sum to 1 + weight for the test point (which is 1)
sum_of_fixed_weights = jnp.sum(custom_weights_D2)
denominator = sum_of_fixed_weights + 1 # +1 for the implicit unit weight of the test point
normalized_weights_D2 = custom_weights_D2 / denominator

# Create pairs of (residual, normalized_weight) and sort by residual
sorted_weighted_pairs = sorted(zip(calibration_residuals, normalized_weights_D2))

cumulative_normalized_weights = 0.0
custom_weighted_conformal_quantile = 0.0

# Target for the (1-alpha) quantile, where the sum of weights is 1 (after normalization)
target_cumulative_normalized_weight = 1 - alpha_level

for res, norm_weight in sorted_weighted_pairs:
    cumulative_normalized_weights += norm_weight
    if cumulative_normalized_weights >= target_cumulative_normalized_weight:
        custom_weighted_conformal_quantile = res
        break
else:
    # Fallback if target not reached (e.g., all weights are very small, or alpha is too low)
    custom_weighted_conformal_quantile = sorted_weighted_pairs[-1][0] # Max residual

print(f"Custom-weighted conformal quantile: {custom_weighted_conformal_quantile:.4f}\n")

# --- 7. Form Conformal Prediction Band ---
num_test_points = 500
X_test_np = np.linspace(X_data.min() - 1, X_data.max() + 1, num_test_points)[:, None]
y_test_true = 2 * X_test_np.squeeze() + 3 + np.random.normal(0, noise_std, num_test_points)

X_test = jnp.array(X_test_np)
y_test = jnp.array(y_test_true)

y_test_pred_mean = base_predictor.predict(np.array(X_test))

lower_bound_custom_weighted = y_test_pred_mean - custom_weighted_conformal_quantile
upper_bound_custom_weighted = y_test_pred_mean + custom_weighted_conformal_quantile

# --- Plotting Results ---
plt.figure(figsize=(12, 8))
plt.scatter(X_D1[:, 0], y_D1, color='blue', label='Proper Training Data (D1)', alpha=0.5, s=20)
plt.scatter(X_D2[:, 0], y_D2, color='green', label='Calibration Data (D2)', alpha=0.5, s=20)
plt.scatter(X_D2[:, 0], y_D2, c=custom_weights_D2, cmap='viridis', s=custom_weights_D2 * 50, alpha=0.8, label='Calibration Weights (Size & Color)')
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap='viridis'), ax=plt.gca())
cbar.set_label('Custom Weight')

plt.plot(X_test[jnp.argsort(X_test.squeeze()), 0], y_test_pred_mean[jnp.argsort(X_test.squeeze())], color='red', linewidth=2, label='Regression Mean')

plt.fill_between(
    X_test[jnp.argsort(X_test.squeeze()), 0].squeeze(),
    lower_bound_custom_weighted[jnp.argsort(X_test.squeeze())],
    upper_bound_custom_weighted[jnp.argsort(X_test.squeeze())],
    color='red', alpha=0.2, label=f'Custom-Weighted Conformal Band (1-$\alpha$={1-alpha_level})'
)

plt.xlabel('X')
plt.ylabel('Y')
plt.title('Custom-Weighted Conformal Prediction')
plt.legend()
plt.grid(True)
plt.show()

# --- Simulate Test Coverage ---
coverage_count_custom_weighted = 0
for i in range(num_test_points):
    if (y_test[i] >= lower_bound_custom_weighted[i]) and \
       (y_test[i] <= upper_bound_custom_weighted[i]):
        coverage_count_custom_weighted += 1

simulated_coverage_custom_weighted = coverage_count_custom_weighted / num_test_points
print(f"Simulated custom-weighted conformal coverage on test set: {simulated_coverage_custom_weighted:.4f}")
print(f"Desired coverage (approx): {1 - alpha_level:.4f}")

print("\nNOTE: The coverage guarantee for custom weights is approximate (1-alpha - coverage_gap).")
print("The coverage_gap depends on the total variation distance, which is hard to compute.")
print("This example illustrates the mechanics, not a strict coverage verification for non-IID data.")

## 5. Adaptive Conformal Inference (ACI)

While the previous methods deal with a fixed, known shift, **Adaptive Conformal Inference (ACI)** (Gibbs and Candès, 2021) is designed for **sequential prediction problems** where the data distribution can shift arbitrarily and continuously over time. It's an online method that dynamically adjusts the confidence level of prediction sets to maintain coverage.

Assume we have a sequence of observations $(X_t, Y_t)$ indexed by time $t=1, 2, 3, \ldots$. At each time $t$, we have a method that can produce a prediction set $C_t^\beta$ for $Y_t$ at any nominal level $\beta \in \mathbb{R}$. All that's required is that $C_t^\beta = \emptyset$ for $\beta \le 0$ and $C_t^\beta = \mathcal{Y}$ for $\beta \ge 1$.

### The ACI Algorithm

ACI aims to maintain a realized coverage as close to $1-\alpha$ as possible, where $\alpha \in (0,1)$ is a prespecified error tolerance. It initializes $\alpha_0 = \alpha$ and performs updates for the working error level $\alpha_t$ according to:

$$\alpha_{t+1} = \alpha_t - \eta(\text{err}_t - \alpha), \quad t=0, 1, 2, \ldots$$
where:
* $\text{err}_t = 1\{Y_t \notin C_{1-\alpha_t}^t\}$ is an indicator variable that is 1 if the prediction set at time $t$ fails to cover $Y_t$, and 0 otherwise.
* $\eta > 0$ is a step size.

These updates are intuitive: if the set fails to cover ($err_t=1$), we decrease $\alpha_t$ (making future sets larger and more conservative). If it covers ($err_t=0$), we increase $\alpha_t$ (making future sets smaller and more efficient).

### Key Properties

* **Boundedness (Lemma 2)**: The ACI iterates $\alpha_t$ are always uniformly bounded within $[-\eta, 1+\eta]$. This self-correcting property prevents $\alpha_t$ from diverging.

* **Long-Run Coverage (Theorem 5)**: This is the most profound result. For any interval $[t_0+1, t_0+T]$, the average error satisfies:
    $$\left|\frac{1}{T}\sum_{t=t_0+1}^{t_0+T}\text{err}_t - \alpha\right| \le \frac{1+2\eta}{\eta T}$$
    In particular, as $T \to \infty$, the long-run average error converges to $\alpha$:
    $$\lim_{T\to\infty}\frac{1}{T}\sum_{t=1}^{T}\text{err}_t = \alpha$$
    This implies that ACI achieves long-run coverage of $1-\alpha$ *always*, regardless of the sequence $(X_t, Y_t)$ (even if chosen adversarially) and without any distributional assumptions beyond the existence of prediction sets $C_t^\beta$.

* **Connection to Online Gradient Descent**: ACI can be viewed as an instance of online gradient descent applied to a convex optimization problem, where the goal is to minimize cumulative errors over time.

Figure 3 from the original slides illustrates ACI's performance on financial time series data, showing that its local coverage frequencies remain stable even during periods when a non-adaptive method fails.

### Code Example: Adaptive Conformal Inference (ACI)

This example demonstrates Adaptive Conformal Inference in a simulated sequential setting with a gradual distribution shift. At each timestep, a prediction set is generated based on a sliding window of past data, and the `alpha` level is updated using the ACI algorithm to maintain long-run coverage.

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from collections import deque # For sliding window
from sklearn.linear_model import LinearRegression

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Sequential Data with Gradual Shift ---
total_timesteps = 500
initial_X_mean = 0.0
shift_magnitude = 0.01 # How much X_mean shifts per timestep
noise_std = 5.0
regression_slope = 2.0
regression_intercept = 3.0

X_stream = jnp.zeros(total_timesteps)
Y_stream = jnp.zeros(total_timesteps)
true_X_means = jnp.zeros(total_timesteps)

current_X_mean = initial_X_mean
for t in range(total_timesteps):
    key, subkey = random.split(key)
    X_t = random.normal(subkey, (1,))[0] + current_X_mean
    Y_t = regression_slope * X_t + regression_intercept + random.normal(key, (1,))[0] * noise_std
    
    X_stream = X_stream.at[t].set(X_t)
    Y_stream = Y_stream.at[t].set(Y_t)
    true_X_means = true_X_means.at[t].set(current_X_mean)
    
    # Gradual shift: mean of X increases over time
    current_X_mean += shift_magnitude

print(f"Simulated {total_timesteps} timesteps of data with gradual shift.\n")

# --- 2. ACI Algorithm Setup ---
alpha_target = 0.1 # Desired average error rate (1 - coverage)
eta = 0.1 # ACI step size
window_size = 100 # Size of the sliding window for base predictor training/calibration

alpha_t_history = [alpha_target] # History of alpha_t values, starting with alpha_0
err_t_history = [] # History of err_t (0 or 1)
coverage_history = [] # History of actual coverage at each step
prediction_set_widths = [] # History of prediction set widths

# Deques for sliding window (training and calibration combined for simplicity in this online context)
data_window_X = deque(maxlen=window_size)
data_window_Y = deque(maxlen=window_size)

base_predictor = LinearRegression()

print(f"ACI setup: target alpha={alpha_target}, eta={eta}, window_size={window_size}\n")

# --- 3. Run ACI over the Sequential Data Stream ---
for t in range(total_timesteps):
    # Add current data point to window
    data_window_X.append(X_stream[t])
    data_window_Y.append(Y_stream[t])

    if len(data_window_X) < 20: # Need enough data to train a simple model
        # Not enough data for prediction set, skip this step but record current alpha
        alpha_t_history.append(alpha_t_history[-1])
        err_t_history.append(0) # Assume no error for warm-up
        coverage_history.append(1) # Assume 100% coverage for warm-up
        prediction_set_widths.append(np.nan) # No valid width yet
        continue

    # Current alpha level from previous step's update
    current_alpha = alpha_t_history[-1]

    # --- Generate Prediction Set C_{1-alpha_t}^t (using simple split conformal on current window) ---
    # For simplicity, we'll use the whole window as calibration data
    # In a true split setting, you'd split the window into training/calibration
    
    # Here, we'll train on the first half of the window and calibrate on the second half
    # This mimics a split conformal approach within the sliding window
    window_X_np = np.array(data_window_X).reshape(-1, 1)
    window_Y_np = np.array(data_window_Y)

    # Simple split of current window (e.g., first half for training, second for calibration)
    split_idx = len(window_X_np) // 2
    X_train_window, y_train_window = window_X_np[:split_idx], window_Y_np[:split_idx]
    X_cal_window, y_cal_window = window_X_np[split_idx:], window_Y_np[split_idx:]

    if len(X_cal_window) == 0: # Ensure calibration set is not empty
        alpha_t_history.append(alpha_t_history[-1])
        err_t_history.append(0)
        coverage_history.append(1)
        prediction_set_widths.append(np.nan)
        continue
    
    # Train base predictor on training part of the window
    base_predictor.fit(X_train_window, y_train_window)

    # Compute calibration scores (absolute residuals) on calibration part of the window
    y_cal_pred = base_predictor.predict(X_cal_window)
    calibration_scores_window = jnp.abs(y_cal_window - jnp.array(y_cal_pred))

    n_cal_window = len(X_cal_window)
    # Calculate quantile for current_alpha (which is the error rate for C_{1-alpha_t}^t)
    # We need 1 - current_alpha coverage.
    adjusted_rank_index_window = jnp.ceil((1 - current_alpha) * (n_cal_window + 1)).astype(int)
    sorted_scores_window = jnp.sort(calibration_scores_window)
    
    if adjusted_rank_index_window <= 0:
        q_t = sorted_scores_window[0]
    elif adjusted_rank_index_window > n_cal_window:
        q_t = sorted_scores_window[-1]
    else:
        q_t = sorted_scores_window[adjusted_rank_index_window - 1]

    # Predict for the current X_t (the current point in the stream)
    X_t_reshaped = X_stream[t].reshape(-1, 1)
    y_t_pred_mean = base_predictor.predict(np.array(X_t_reshaped))[0]
    
    # Form the prediction set for Y_t using current_alpha
    lower_bound_t = y_t_pred_mean - q_t
    upper_bound_t = y_t_pred_mean + q_t

    # Check if Y_t is covered by C_{1-alpha_t}^t
    is_covered = (Y_stream[t] >= lower_bound_t) and (Y_stream[t] <= upper_bound_t)
    err_t = 1 - is_covered # err_t is 1 if not covered, 0 if covered
    
    # Store histories
    err_t_history.append(err_t)
    coverage_history.append(is_covered)
    prediction_set_widths.append(upper_bound_t - lower_bound_t)

    # --- ACI Update Rule ---
    alpha_next = current_alpha - eta * (err_t - alpha_target)
    
    # Clip alpha to a reasonable range, e.g., (0, 1) or [-eta, 1+eta] as per Lemma 2.
    # Clipping to (0,1) for practical purposes often helps keep sets non-empty/non-trivial.
    alpha_next = jnp.clip(alpha_next, 0.001, 0.999) 

    alpha_t_history.append(alpha_next)

    if (t + 1) % 50 == 0:
        print(f"Timestep {t+1}: Current alpha={current_alpha:.4f}, err_t={err_t}, new alpha={alpha_next:.4f}")
        print(f"  Coverage at t={t+1}: {is_covered}, Prediction Set Width: {prediction_set_widths[-1]:.2f}")

# Remove initial warm-up entries for plotting if necessary
alpha_t_history = alpha_t_history[1:] # First entry was alpha_0

print("\n--- Simulation Complete ---")

# --- 4. Plotting Results --- 
fig, axes = plt.subplots(3, 1, figsize=(14, 15), sharex=True)

timesteps = jnp.arange(total_timesteps)

# Plot 1: True X Mean Shift
axes[0].plot(timesteps, true_X_means, label='True Mean of X', color='purple', linestyle='--')
axes[0].set_ylabel('X Mean')
axes[0].set_title('Simulated Distribution Shift Over Time (Mean of X)')
axes[0].legend()
axes[0].grid(True)

# Plot 2: Adaptive Alpha
axes[1].plot(timesteps, alpha_t_history, label='Adaptive $\alpha_t$', color='blue')
axes[1].axhline(alpha_target, color='red', linestyle=':', label=f'Target $\alpha$={alpha_target}')
axes[1].set_ylabel('$\alpha_t$ (Error Level)')
axes[1].set_title('Adaptive Error Level ($\\alpha_t$) by ACI')
axes[1].legend()
axes[1].grid(True)

# Plot 3: Long-Run Coverage
cumulative_errors = jnp.cumsum(jnp.array(err_t_history))
average_errors = cumulative_errors / (jnp.arange(len(err_t_history)) + 1)

axes[2].plot(timesteps, 1 - average_errors, label='Average Coverage', color='green')
axes[2].axhline(1 - alpha_target, color='red', linestyle=':', label=f'Target Coverage (1-$\alpha$)={1-alpha_target}')

# Moving average for local coverage (e.g., over a window)
window_for_local_coverage = 50
local_coverage = jnp.convolve(jnp.array(coverage_history), jnp.ones(window_for_local_coverage)/window_for_local_coverage, mode='valid')
axes[2].plot(timesteps[window_for_local_coverage-1:], local_coverage, label=f'Local Coverage ({window_for_local_coverage}-step avg)', color='cyan', alpha=0.7)

axes[2].set_xlabel('Timestep (t)')
axes[2].set_ylabel('Coverage')
axes[2].set_title('Actual Coverage Over Time (ACI)')
axes[2].set_ylim(0.5, 1.05) # Keep y-axis reasonable
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.show()

print(f"Final average error rate: {average_errors[-1]:.4f}")
print(f"Target error rate (alpha): {alpha_target:.4f}")
print("\nNOTE: ACI aims for long-run average coverage. Local coverage may fluctuate, especially with strong shifts or small windows.")
print("The simple linear regression model inside ACI might not adapt perfectly to non-linear relationships or complex shifts.")

## Conclusion: Conformal Prediction in a Dynamic World

This notebook has extended our understanding of Conformal Prediction beyond the ideal I.I.D. setting to more realistic scenarios involving distribution shifts. The advancements in likelihood-weighted, custom-weighted, and adaptive conformal prediction techniques are vital for building robust and reliable machine learning systems in dynamic and unpredictable environments.

Key insights:

* **Covariate Shift**: Standard conformal prediction fails under covariate shift due to broken exchangeability. Likelihood-weighted conformal prediction restores validity by re-weighting calibration scores based on the density ratio between test and training feature distributions.
* **Weight Estimation**: The necessary likelihood ratios can be estimated using a binary classifier trained to distinguish between features from the source and target distributions.
* **Generalization**: The theoretical framework extends to highly structured feature dependencies (Theorem 2) and allows for arbitrary custom weights (Theorem 3), albeit with potential computational challenges or approximate guarantees.
* **Adaptive Online Learning**: ACI provides a powerful solution for truly online, sequential prediction where distributions shift arbitrarily, guaranteeing long-run average coverage without strong assumptions.

These sophisticated conformal methods underscore the flexibility and adaptability of the framework, enabling practitioners to provide statistically valid uncertainty quantification even when faced with the inherent complexities of real-world data and distribution changes. As machine learning models are deployed in increasingly complex and critical applications, the ability to provide reliable uncertainty estimates under distribution shift becomes ever more important.