# Conformal Prediction: Quantifying Uncertainty without Assumptions

## Introduction

Welcome to this in-depth exploration of **Conformal Prediction**, a powerful and increasingly popular framework in statistical learning for quantifying uncertainty in predictions. In traditional machine learning, models often output point predictions (e.g., a single value for regression, or a single class label for classification). While these predictions might be accurate on average, they rarely come with reliable guarantees about their uncertainty. Probabilistic models (like Gaussian Processes) provide uncertainty, but often rely on strong distributional assumptions (e.g., Gaussianity) or can be computationally intensive, especially for large datasets or complex models.

Conformal prediction offers a compelling alternative: it provides **finite-sample coverage guarantees** for prediction sets (or intervals) without making any assumptions about the underlying data distribution. This "distribution-free" property makes it incredibly versatile and robust, allowing it to be applied to any underlying prediction algorithm, from simple linear models to complex deep neural networks.

In this notebook, we will systematically delve into the core ideas, algorithms, and applications of conformal prediction, accompanied by practical Python code examples using `JAX` and `scikit-learn`.

## 1. The Lofty Goal: Distribution-Free Finite-Sample Coverage

The primary objective of conformal prediction is ambitious yet elegant. Let's formalize it:

Suppose we have $n$ i.i.d. (independent and identically distributed) feature-response pairs $(X_i, Y_i) \sim P$, for $i=1, \ldots, n$, drawn from an unknown distribution $P$ over $\mathcal{X} \times \mathcal{Y}$. Here, $\mathcal{X}$ is the feature space (e.g., $\mathbb{R}^d$) and $\mathcal{Y}$ is the response space (e.g., $\mathbb{R}$ for regression, or a set of discrete labels for classification).

Given a nominal error level $\alpha \in (0, 1)$ (e.g., $\alpha = 0.1$ for 90% coverage), our goal is to construct a **prediction set** $\hat{C}_n(X_{n+1}) \subseteq \mathcal{Y}$ such that for a new, unseen i.i.d. pair $(X_{n+1}, Y_{n+1}) \sim P$:

$$\mathbb{P}(Y_{n+1} \in \hat{C}_n(X_{n+1})) \ge 1 - \alpha$$

This probability is taken over all our data: the $n$ training points and the new $(n+1)$-th test point.

### Why is this "Lofty"?

1.  **Distribution-Free**: The guarantee holds *without any assumptions* about the underlying distribution $P$. This is a powerful statement, as most statistical methods rely on knowing (or assuming) the data-generating process.
2.  **Finite-Sample**: The guarantee is valid for *any finite sample size* $n$, not just asymptotically as $n \to \infty$. This is crucial for real-world applications where data might be limited.
3.  **Model-Agnostic**: Conformal prediction can wrap *any* arbitrary prediction algorithm (often called the "base predictor" or "black-box predictor"). This means you can use your favorite complex model (e.g., a deep neural network) and still get valid uncertainty quantification.

While a trivial prediction set (e.g., returning the entire response space $\mathcal{Y}$ with probability $1-\alpha$) could satisfy this, the true challenge is to achieve this guarantee with a **nontrivial** prediction set that adapts to the "hardness" of the problem. Ideally, the prediction set $\hat{C}_n(X_{n+1})$ should be smaller when $Y_{n+1}$ is easier to predict from $X_{n+1}$.

Remarkably, this goal is achievable, and the core ideas are surprisingly simple, rooted in the concept of **exchangeability**.

## 2. First Key Idea: Exchangeability and Ranks (No Features)

Let's start with the simplest scenario: predicting a single real-valued response $Y_{n+1}$ when we have $n$ i.i.d. observations $Y_1, \ldots, Y_n$, and no features $X$. Our goal is to find a one-sided prediction interval $\hat{C}_n = (-\infty, \hat{q}_n]$ such that $\mathbb{P}(Y_{n+1} \le \hat{q}_n) \ge 1 - \alpha$.

A naive approach might be to set $\hat{q}_n$ to be the $(1-\alpha)$-th sample quantile of $Y_1, \ldots, Y_n$. However, this only provides an *approximate* coverage that becomes exact asymptotically. Conformal prediction provides an exact finite-sample guarantee.

### The Power of Exchangeability

The key insight comes from the property of **exchangeability**. If $Y_1, \ldots, Y_n, Y_{n+1}$ are i.i.d., then their joint distribution is invariant under any permutation of their indices. A direct consequence of this is:

**The rank of $Y_{n+1}$ among the full set of $n+1$ observations $\{Y_1, \\ldots, Y_n, Y_{n+1}\}$ is uniformly distributed over $\{1, 2, \\ldots, n+1\}$.**

This means that $Y_{n+1}$ is equally likely to be the smallest, second smallest, ..., or $(n+1)$-th smallest value in the combined set.

From this, we can deduce:
$$\mathbb{P}(Y_{n+1} \le \text{the } k\text{-th smallest of } \{Y_1, \ldots, Y_n, Y_{n+1}\}) = \frac{k}{n+1}$$
where $k$ is the rank. To achieve at least $1-\alpha$ coverage, we choose $k = \lceil (1-\alpha)(n+1) \rceil$.

The crucial step is then realizing that:
$$\mathbb{P}(Y_{n+1} \le \text{the } \lceil (1-\alpha)(n+1) \rceil \text{ smallest of } \{Y_1, \ldots, Y_n\}) \ge 1 - \alpha$$
This equivalence is profound because the quantity on the right-hand side, $\hat{q}_n = \text{the } \lceil (1-\alpha)(n+1) \rceil \text{ smallest of } \{Y_1, \ldots, Y_n\}$, is **computable solely from the observed training data** $Y_1, \ldots, Y_n$.

Thus, by defining $\hat{q}_n$ this way, we obtain a prediction interval $(-\infty, \hat{q}_n]$ that guarantees at least $1-\alpha$ coverage for $Y_{n+1}$ in finite samples, without any distributional assumptions beyond i.i.d. (or the weaker condition of exchangeability).

An equivalent formulation using the empirical quantile function is:
$$\hat{q}_n = \text{Quantile}\left(\frac{\lceil (1-\alpha)(n+1) \rceil}{n}; \text{empirical distribution of } Y_1, \ldots, Y_n\right)$$
This means we compute the sample quantile at an "adjusted level" $\frac{\lceil (1-\alpha)(n+1) \rceil}{n}$ instead of the naive $1-\alpha$. This adjustment accounts for the finite sample size.

If there are no ties in the data, the coverage can be sharpened to $\mathbb{P}(Y_{n+1} \le \hat{q}_n) \in [1-\alpha, 1-\alpha + \frac{1}{n+1})$.

### Code Example: Rank-Based Quantile (No Features)

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np  # For sorting and quantile function
import matplotlib.pyplot as plt


def calculate_conformal_quantile_no_features(
    Y_train: jnp.ndarray, alpha: float
) -> float:
    """
    Calculates the conformal quantile for the no-features case.
    This quantile defines a one-sided prediction interval (-inf, q_n].

    Args:
        Y_train: 1D array of observed response values (training data).
        alpha: Nominal error level (e.g., 0.1 for 90% coverage).

    Returns:
        The conformal quantile q_n.
    """
    n = Y_train.shape[0]

    # Calculate the adjusted rank index.
    # This is the rank (1-indexed) that Y_{n+1} should be less than or equal to,
    # among the n+1 observations (Y_1, ..., Y_n, Y_{n+1}).
    # We then find the value at this rank from Y_1, ..., Y_n.
    adjusted_rank_index = jnp.ceil((1 - alpha) * (n + 1)).astype(int)

    # Sort the training data to find the order statistic
    Y_sorted = jnp.sort(Y_train)

    # Handle edge cases for the rank index.
    # If adjusted_rank_index is 0 or less, it means we want a very small quantile.
    # If adjusted_rank_index is greater than n, it means we want a very large quantile.
    if adjusted_rank_index <= 0:
        # If alpha is close to 1, adjusted_rank_index could be <= 0.
        # This implies the prediction set should be very small (e.g., -infinity for lower bound,
        # or +infinity for upper bound). For a (-inf, q_n] interval, q_n would be very low.
        print(
            f"Warning: Adjusted rank index ({adjusted_rank_index}) <= 0. Returning min observed value."
        )
        return Y_sorted[0]
    elif adjusted_rank_index > n:
        # If alpha is very small, adjusted_rank_index could be > n.
        # For a (-inf, q_n] interval, q_n should be very high (effectively infinity).
        print(
            f"Warning: Adjusted rank index ({adjusted_rank_index}) > n ({n}). Returning max observed value."
        )
        return Y_sorted[-1]
    else:
        # Return the (adjusted_rank_index)-th smallest value.
        # JAX/Python indexing is 0-based, so we subtract 1.
        return Y_sorted[adjusted_rank_index - 1]


# --- Example Usage and Coverage Simulation ---
key = random.PRNGKey(0)
num_training_samples = 50  # n
alpha_level = 0.1  # Desired coverage: 1 - alpha = 0.9 (90%)

# Simulate i.i.d. training data from a standard normal distribution
Y_train_data = random.normal(key, (num_training_samples,))
print(f"Training data (first 5 samples): {Y_train_data[:5]}\n")

# Calculate the conformal quantile
q_n_conformal = calculate_conformal_quantile_no_features(Y_train_data, alpha_level)
print(f"Calculated conformal quantile (q_n): {q_n_conformal:.4f}")

# For comparison, calculate the naive sample quantile (using 1-alpha directly)
q_n_naive = jnp.quantile(Y_train_data, 1 - alpha_level)
print(f"Naive sample quantile (1-alpha): {q_n_naive:.4f}\n")

# --- Simulate Coverage ---
num_simulations = 10000
coverage_count = 0

for i in range(num_simulations):
    # Generate a new i.i.d. test observation Y_{n+1}
    Y_n_plus_1 = random.normal(key, (1,))[0]  # Single new observation

    # Check if Y_{n+1} falls into the prediction interval (-inf, q_n]
    if Y_n_plus_1 <= q_n_conformal:
        coverage_count += 1

simulated_coverage = coverage_count / num_simulations
print(f"Simulated coverage for Y_n+1 <= q_n: {simulated_coverage:.4f}")
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(
    f"Theoretical upper bound (approx): {1 - alpha_level + 1 / (num_training_samples + 1):.4f}\n"
)

# --- Visualization ---
plt.figure(figsize=(10, 6))
plt.hist(
    Y_train_data,
    bins=15,
    density=True,
    alpha=0.6,
    color="skyblue",
    label="Training Data Distribution",
)
plt.axvline(
    q_n_conformal,
    color="red",
    linestyle="-",
    linewidth=2,
    label=f"Conformal Quantile (q_n={q_n_conformal:.2f})",
)
plt.axvline(
    q_n_naive,
    color="green",
    linestyle="--",
    linewidth=1,
    label=f"Naive Quantile (q_n={q_n_naive:.2f})",
)
plt.title(
    f"Conformal Prediction (No Features) for n={num_training_samples}, $\\alpha$={alpha_level}"
)
plt.xlabel("Y value")
plt.ylabel("Density")
plt.legend()
plt.grid(True)
plt.show()

## 3. Naive Attempt to Lift to Regression Problems (Why it Fails)

The success of the rank-based approach in the "no features" case naturally leads us to try and extend it to regression.

Suppose we have training data $(X_i, Y_i)$ for $i=1, \ldots, n$. We train a point predictor $\hat{f}_n$ (e.g., linear regression, random forest, neural network) on this data.

A naive attempt to form a prediction interval for a new point $(X_{n+1}, Y_{n+1})$ would be:

1.  Compute absolute residuals on the training set: $R_i = |Y_i - \hat{f}_n(X_i)|$ for $i=1, \ldots, n$.
2.  Find the conformal quantile $\hat{q}_n = \text{the } \lceil (1-\alpha)(n+1) \rceil \text{ smallest of } R_1, \ldots, R_n$.
3.  Form the prediction interval: $\hat{C}_n(x) = [\hat{f}_n(x) - \hat{q}_n, \hat{f}_n(x) + \hat{q}_n]$.

The hope is that $\mathbb{P}(Y_{n+1} \in \hat{C}_n(X_{n+1})) \ge 1-\alpha$. This is equivalent to $\mathbb{P}(R_{n+1} \le \hat{q}_n)$, where $R_{n+1} = |Y_{n+1} - \hat{f}_n(X_{n+1})|$ is the residual for the new test point.

### Why This Naive Approach Fails

This approach fails because the crucial property of **exchangeability is broken**. The test residual $R_{n+1}$ is **not exchangeable with the training residuals $R_1, \ldots, R_n$**.

The reason is that $\hat{f}_n$ was trained on $(X_1, Y_1), \ldots, (X_n, Y_n)$. It has "seen" this data. When a model is trained on data, it often fits the training data very well, leading to *smaller* training residuals. However, it has *not* seen $(X_{n+1}, Y_{n+1})$. Thus, the test residual $R_{n+1}$ will generally be stochastically *larger* than the training residuals $R_1, \ldots, R_n$.

Because $\hat{q}_n$ is computed from the (potentially artificially small) training residuals, it will be too small. Consequently, the prediction interval $\hat{C}_n(x)$ will be too narrow, and the actual coverage probability will be **less than the desired $1-\alpha$**. This is known as **undercoverage**.

This highlights the need for a more sophisticated strategy that ensures the scores used for calibration are treated symmetrically with respect to the test point.

## 4. Split Conformal Prediction: The Second Key Idea

To restore the crucial exchangeability property and achieve valid coverage in regression, we introduce the **second key idea**: constructing scores symmetrically. **Split Conformal Prediction (Split CP)** is the most common and computationally efficient way to do this.

The core idea is to divide the data into two parts: one for training the base predictor, and another for calibrating the prediction set.

### The Procedure:

1.  **Data Splitting**: Divide the original dataset $D = \{(X_i, Y_i)\}_{i=1}^n$ into two disjoint sets:
    * $D_1$: The **proper training set** (e.g., 50% of the data). Used *only* to train the base predictor $\hat{f}_{n_1}$.
    * $D_2$: The **calibration set** (e.g., the remaining 50% of the data). Used *only* to compute the conformal quantile.
    Let $n_1 = |D_1|$ and $n_2 = |D_2|$.

2.  **Train Base Predictor**: Fit your chosen point predictor $\hat{f}_{n_1}$ (e.g., `LinearRegression`, `RandomForestRegressor`, `NeuralNetwork`) exclusively on the data in $D_1$.

3.  **Compute Calibration Scores**: For each data point $(X_i, Y_i)$ in the calibration set $D_2$, compute a **non-conformity score** (often an absolute residual):
    $$R_i = |Y_i - \hat{f}_{n_1}(X_i)| \quad \text{for } i \in D_2$$
    Crucially, $\hat{f}_{n_1}$ has *not* seen these $D_2$ points during its training.

4.  **Compute Conformal Quantile**: Find the conformal quantile $\hat{q}_{n_2}$ from these calibration scores:
    $$\hat{q}_{n_2} = \text{the } \lceil (1-\alpha)(n_2+1) \rceil \text{ smallest of } \{R_i\}_{i \in D_2}$$

5.  **Form Conformal Prediction Set**: For a new test input $x$, the prediction set is:
    $$\hat{C}_n(x) = [\hat{f}_{n_1}(x) - \hat{q}_{n_2}, \hat{f}_{n_1}(x) + \hat{q}_{n_2}]$$

### The Guarantee

The key guarantee of split conformal prediction is:
$$\mathbb{P}(Y_{n+1} \in \hat{C}_n(X_{n+1}) | (X_i, Y_i)_{i \in D_1}) \in [1-\alpha, 1-\alpha + \frac{1}{n_2+1})$$
The lower bound always holds, and the upper bound holds if scores are distinct.

**Why it Works**: By conditioning on $D_1$ (making $\hat{f}_{n_1}$ fixed), the calibration scores $\{R_i\}_{i \in D_2}$ and the test score $R_{n+1} = |Y_{n+1} - \hat{f}_{n_1}(X_{n+1})|$ become conditionally i.i.d. This restores the exchangeability property, allowing the rank-based argument to apply.

### Remarks on Split Conformal Prediction

* **Protection Against Overfitting**: Split CP inherently protects against the overfitting issue that plagues the naive approach. The calibration residuals are unbiased with respect to the test point.
* **Computational Efficiency**: It's highly efficient. You train the base model once on $D_1$, compute residuals on $D_2$ once, and then prediction involves a single model inference and a lookup of the precomputed quantile.
* **Constant-Width Bands**: When using absolute residuals, the prediction bands are constant-width across the input space. This is a limitation if the noise or uncertainty varies with $X$. We'll address this later.
* **Quality of Base Predictor**: The better the base predictor $\hat{f}_{n_1}$ is (in terms of point prediction accuracy), the tighter (smaller) the conformal prediction bands will be, while still maintaining the coverage guarantee. This means you can use powerful ML models and still get valid uncertainty.

### Code Example: Split Conformal Prediction for Regression

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression  # Our base predictor

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Regression Data ---
num_samples = 200
# Generate data with some noise
X_data_np = np.linspace(-5, 5, num_samples)[:, None]  # Features (1D)
y_true = 2 * X_data_np.squeeze() + 3  # True linear relationship
noise_std = 5  # Constant noise standard deviation
y_data_np = y_true + np.random.normal(0, noise_std, num_samples)  # Add Gaussian noise

X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np)

print(f"Total samples: {num_samples}")

# --- 2. Split Data into Proper Training (D1) and Calibration (D2) Sets ---
# We'll use a 50/50 split for simplicity
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_data, y_data, test_size=0.5, random_state=key.get(0).tolist()[0]
)
n1 = X_D1.shape[0]
n2 = X_D2.shape[0]
print(f"Proper training set size (n1): {n1}")
print(f"Calibration set size (n2): {n2}\n")

# --- 3. Fit Predictor on Proper Training Set (D1) ---
base_predictor = LinearRegression()
# Scikit-learn models typically expect numpy arrays
base_predictor.fit(np.array(X_D1), np.array(y_D1))

# --- 4. Compute Calibration Scores (Absolute Residuals) on D2 ---
y_D2_pred = base_predictor.predict(np.array(X_D2))
calibration_scores = jnp.abs(y_D2 - jnp.array(y_D2_pred))
print(f"Calibration scores (absolute residuals, first 5): {calibration_scores[:5]}\n")

# --- 5. Compute Conformal Quantile from Calibration Scores ---
alpha_level = 0.1  # Desired coverage: 1 - alpha = 0.9 (90%)

# Calculate the adjusted rank index for n2 calibration points
adjusted_rank_index = jnp.ceil((1 - alpha_level) * (n2 + 1)).astype(int)

# Sort calibration scores
sorted_calibration_scores = jnp.sort(calibration_scores)

# Get the conformal quantile (0-indexed)
# Handle edge cases for rank_index, though for alpha in (0,1) and n2 > 0, it's usually fine.
if adjusted_rank_index <= 0:
    conformal_quantile = sorted_calibration_scores[0]
elif adjusted_rank_index > n2:
    conformal_quantile = sorted_calibration_scores[-1]
else:
    conformal_quantile = sorted_calibration_scores[adjusted_rank_index - 1]

print(f"Conformal quantile (q_n2): {conformal_quantile:.4f}\n")

# --- 6. Form Conformal Prediction Band ---
# Generate a range of test points for plotting the prediction band
X_test_plot = jnp.linspace(X_data.min() - 1, X_data.max() + 1, 100)[:, None]
y_test_pred_mean = base_predictor.predict(np.array(X_test_plot))

# Construct the prediction band: [mean - quantile, mean + quantile]
lower_bound = y_test_pred_mean - conformal_quantile
upper_bound = y_test_pred_mean + conformal_quantile

# --- Plotting Results ---
plt.figure(figsize=(10, 7))
plt.scatter(
    X_D1[:, 0], y_D1, color="blue", label="Proper Training Data (D1)", alpha=0.7, s=30
)
plt.scatter(
    X_D2[:, 0], y_D2, color="green", label="Calibration Data (D2)", alpha=0.7, s=30
)
plt.plot(
    X_test_plot[:, 0],
    y_test_pred_mean,
    color="red",
    linewidth=2,
    label="Regression Mean ($\\hat{f}_{n_1}$)",
)
plt.fill_between(
    X_test_plot[:, 0],
    lower_bound,
    upper_bound,
    color="red",
    alpha=0.2,
    label=f"Conformal Prediction Band (1-$\\alpha$={1 - alpha_level})",
)

plt.xlabel("X")
plt.ylabel("Y")
plt.title("Split Conformal Prediction for Regression (Constant Width)")
plt.legend()
plt.grid(True)
plt.show()

# --- Simulate Test Coverage to Verify Guarantee ---
num_test_points_sim = 1000
# Generate new i.i.d. test data
X_test_sim_np = np.random.uniform(
    X_data.min(), X_data.max(), size=(num_test_points_sim, 1)
)
y_test_sim_true = (
    2 * X_test_sim_np.squeeze()
    + 3
    + np.random.normal(0, noise_std, num_test_points_sim)
)

X_test_sim = jnp.array(X_test_sim_np)
y_test_sim = jnp.array(y_test_sim_true)

y_test_sim_pred_mean = base_predictor.predict(np.array(X_test_sim))

coverage_count_sim = 0
for i in range(num_test_points_sim):
    # Check if the true Y_test_sim falls within the predicted interval
    if (y_test_sim[i] >= (y_test_sim_pred_mean[i] - conformal_quantile)) and (
        y_test_sim[i] <= (y_test_sim_pred_mean[i] + conformal_quantile)
    ):
        coverage_count_sim += 1

simulated_coverage = coverage_count_sim / num_test_points_sim
print(f"Simulated test coverage: {simulated_coverage:.4f}")
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(f"Theoretical upper bound (approx): {1 - alpha_level + 1 / (n2 + 1):.4f}")

# The simulated coverage should be close to or slightly above the desired 1-alpha,
# demonstrating the finite-sample guarantee.

## 5. Generalizing Conformity Scores

The power of conformal prediction lies in its flexibility regarding the choice of **conformity score function**. While we used absolute residuals $|Y_i - \hat{f}_{n_1}(X_i)|$ in the regression example, any score function that quantifies "non-conformity" (how unusual a point is with respect to the model) can be used.

Let $V(x, y)$ be a **negatively-oriented score function**, meaning smaller values of $V(x,y)$ indicate better conformity (e.g., smaller prediction errors). The procedure remains the same:

1.  Compute calibration scores $R_i = V(X_i, Y_i)$ for $i \in D_2$.
2.  Find $\hat{q}_{n_2} = \text{the } \lceil (1-\alpha)(n_2+1) \rceil \text{ smallest of } \{R_i\}_{i \in D_2}$.
3.  The prediction set is $\hat{C}_n(x) = \{y : V(x, y) \le \hat{q}_{n_2}\}$.

If you have a **positively-oriented score function** (where larger values are better, e.g., predicted probability of the true class in classification), you can either:
1.  Negate it: Use $-V(x,y)$ as the negatively-oriented score.
2.  Adjust the quantile rule: Find $\hat{q}_{n_2} = \text{the } \lfloor \alpha(n_2+1) \rfloor \text{ smallest of } \{R_i\}_{i \in D_2}$, and form the set $\hat{C}_n(x) = \{y : V(x, y) \ge \hat{q}_{n_2}\}$.

### Quantile and CDF Formulations

The prediction set can be expressed in equivalent ways:

* **Quantile Form**:
    $$\hat{C}_n(x) = \left\{y : V(x,y) \le \text{Quantile}\left(\frac{\lceil (1-\alpha)(n_2+1) \rceil}{n_2}; \text{empirical distribution of } \{R_i\}_{i \in D_2}\right)\right\}$$
    This highlights that we're comparing the test score $V(x,y)$ against an adjusted quantile of the calibration scores.

* **CDF Form**:
    $$\hat{C}_n(x) = \left\{y : \frac{1}{n_2}\sum_{i \in D_2}1\{R_i < V(x,y)\} \le \frac{\lceil (1-\alpha)(n_2+1) \rceil}{n_2}\right\}$$
    This uses the empirical cumulative distribution function (CDF) of the calibration scores.

### Auxiliary Randomization for Exact Coverage

While the guarantee is $[1-\alpha, 1-\alpha + 1/(n_2+1))$, we can achieve **exact coverage** of precisely $1-\alpha$ by introducing a small amount of auxiliary randomization. This is particularly useful when there are ties in the conformity scores.

The idea is to randomize the empirical CDF. For a score $R_{n+1}$ and calibration scores $\{R_i\}_{i \in D_2}$, we define a randomized p-value:
$$p^*(R_{n+1}) = \frac{1}{n_2+1} \left( \sum_{i \in D_2} 1\{R_i < R_{n+1}\} + U \cdot \left( \sum_{i \in D_2} 1\{R_i = R_{n+1}\} + 1 \right) \right)$$
where $U \sim \text{Unif}(0,1)$ is an independent random variable.
The prediction set is then $\{y : p^*(V(x,y)) \ge \alpha\}$. This guarantees $\mathbb{P}(Y_{n+1} \in \hat{C}_n^*(X_{n+1})) = 1-\alpha$.

### Code Example: Auxiliary Randomization (Conceptual)

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np  # For sorting
import matplotlib.pyplot as plt


def calculate_randomized_p_value(
    test_score: float, calibration_scores: jnp.ndarray, key: random.PRNGKey
) -> float:
    """
    Calculates a randomized p-value for a test score against calibration scores.
    This helps achieve exact coverage when ties are present.

    Args:
        test_score: The conformity score of the test point.
        calibration_scores: 1D array of conformity scores from the calibration set.
        key: JAX random key for generating uniform random number.

    Returns:
        The randomized p-value.
    """
    n2 = calibration_scores.shape[0]

    # Count scores strictly less than test_score
    num_less = jnp.sum(calibration_scores < test_score)

    # Count scores equal to test_score
    num_equal = jnp.sum(calibration_scores == test_score)

    # Generate a uniform random number
    U = random.uniform(key)

    # Calculate the randomized p-value
    # The +1 in the denominator accounts for the test score itself
    randomized_p_val = (num_less + U * (num_equal + 1)) / (n2 + 1)

    return randomized_p_val


# --- Example Usage ---
key = random.PRNGKey(0)
num_calibration_samples = 10
alpha_level = 0.1  # Desired coverage: 0.9

# Simulate calibration scores (introduce some ties for demonstration)
calibration_scores_data = jnp.array([1.2, 0.8, 1.5, 0.8, 1.0, 1.2, 0.9, 1.1, 1.0, 1.3])
print(f"Calibration scores: {calibration_scores_data}\n")

# Simulate a test score
test_score_data = 1.05
print(f"Test score: {test_score_data}\n")

# Calculate the randomized p-value
p_value = calculate_randomized_p_value(
    test_score_data, calibration_scores_data, key.get(0)
)
print(f"Randomized p-value for test score {test_score_data}: {p_value:.4f}")

# Determine if the test score is "in" the conformal set (for a negatively-oriented score)
# The set includes y if p_value(V(x,y)) >= alpha
# For a negatively-oriented score, we want V(x,y) <= quantile
# This means p-value should be >= alpha (if p-value is 1 - rank / (n+1) )
# Or, if p-value is rank / (n+1), we want p-value <= (1-alpha)
# Let's stick to the definition: Y_n+1 is in C_n(X_n+1) if p-value(R_n+1) <= alpha
# For a negatively-oriented score, smaller R is better, so smaller p-value is better.
# So we want p-value <= alpha.
# However, the slides use p(y) = (1/n) Sum 1{R_i >= R_n+1} >= b_alpha(n+1)c/n.
# This means higher p(y) is better. So p(y) >= alpha.


def calculate_randomized_p_value_slide_style(
    test_score: float, calibration_scores: jnp.ndarray, key: random.PRNGKey
) -> float:
    """
    Calculates a randomized p-value where higher values are better (slide style).
    This is 1 - (rank / (n+1)).
    """
    n2 = calibration_scores.shape[0]

    # Count scores strictly greater than test_score
    num_greater = jnp.sum(calibration_scores > test_score)

    # Count scores equal to test_score
    num_equal = jnp.sum(calibration_scores == test_score)

    U = random.uniform(key)

    randomized_p_val = (num_greater + U * num_equal) / (n2 + 1)

    return randomized_p_val


p_value_slide_style = calculate_randomized_p_value_slide_style(
    test_score_data, calibration_scores_data, key.get(1)
)
print(
    f"Randomized p-value (slide style) for test score {test_score_data}: {p_value_slide_style:.4f}"
)

# The conformal set in this style is {y : p(y) >= alpha}
is_in_set = p_value_slide_style >= alpha_level
print(f"Is test score {test_score_data} in the conformal set (p >= alpha)? {is_in_set}")

# Simulate coverage for this style
num_simulations_exact = 10000
coverage_count_exact = 0

for i in range(num_simulations_exact):
    # Generate new calibration scores (same as original for simplicity, but could be new)
    # For true simulation, you'd resample both calibration and test.
    # Here, we just simulate a new test score and check against fixed calibration.
    new_test_score = random.uniform(
        key.get(i + 2), minval=0.5, maxval=1.6
    )  # Assume scores are in this range

    p_val_for_new_test = calculate_randomized_p_value_slide_style(
        new_test_score, calibration_scores_data, key.get(i + 2)
    )

    if p_val_for_new_test >= alpha_level:
        coverage_count_exact += 1

simulated_coverage_exact = coverage_count_exact / num_simulations_exact
print(
    f"\nSimulated exact coverage (should be close to alpha): {simulated_coverage_exact:.4f}"
)
print(f"Desired coverage (alpha): {alpha_level:.4f}")

# Note: The simulated coverage should be very close to alpha_level (the threshold for p-value),
# demonstrating the exact coverage property for the p-value itself.
# For the full prediction set, the coverage is 1-alpha.

## 6. Improving Local Adaptivity

As noted, standard split conformal prediction with absolute residuals produces constant-width prediction bands. This is suboptimal when the uncertainty (e.g., noise level) varies across the input space. We want **local adaptivity**: narrower bands where prediction is easy, and wider bands where it's hard. This can be achieved by changing the conformity score function.

### 6.1 Studentized Residuals

A simple way to introduce local adaptivity is to use **studentized residuals**, which normalize the absolute residual by an estimate of the local spread.

**Procedure (Split Conformal with Studentized Residuals):**

1.  **Train Mean Predictor $\hat{f}_{n_1}$**: Fit your base predictor on $D_1$ to predict the mean response.
2.  **Train Spread Predictor $\hat{\sigma}_{n_1}$**: Also on $D_1$, train a *separate* model (e.g., another regression model) to predict the local spread of the residuals. This model $\hat{\sigma}_{n_1}(x)$ could be trained to predict $|Y_i - \hat{f}_{n_1}(X_i)|$ (the absolute residuals from $D_1$).
3.  **Compute Studentized Calibration Scores**: For $i \in D_2$, compute:
    $$R_i = \frac{|Y_i - \hat{f}_{n_1}(X_i)|}{\hat{\sigma}_{n_1}(X_i)}$$
4.  **Compute Conformal Quantile**: Find $\hat{q}_{n_2}$ from these studentized residuals.
5.  **Form Conformal Set**:
    $$\hat{C}_n(x) = [\hat{f}_{n_1}(x) - \hat{\sigma}_{n_1}(x)\hat{q}_{n_2}, \hat{f}_{n_1}(x) + \hat{\sigma}_{n_1}(x)\hat{q}_{n_2}]$$
The width of the band now scales with $\hat{\sigma}_{n_1}(x)$, adapting to local uncertainty. The coverage guarantee remains valid.

### 6.2 Conformalized Quantile Regression (CQR)

CQR (Romano et al., 2019) is a more principled approach to local adaptivity. Instead of predicting the mean and then the spread, it directly predicts quantiles.

**Procedure (Split Conformal with CQR):**

1.  **Train Quantile Predictors**: On $D_1$, fit *two* quantile regression models:
    * $\hat{f}^{\alpha/2}_{n_1}(x)$: Estimates the $\alpha/2$ quantile of $Y|X=x$.
    * $\hat{f}^{1-\alpha/2}_{n_1}(x)$: Estimates the $1-\alpha/2$ quantile of $Y|X=x$.
    These can be trained using specific quantile regression algorithms (e.g., LightGBM with quantile loss, or neural networks with pinball loss).
2.  **Compute CQR Calibration Scores**: For $i \in D_2$, compute:
    $$R_i = \max\left\{ \hat{f}^{\alpha/2}_{n_1}(X_i) - Y_i, Y_i - \hat{f}^{1-\alpha/2}_{n_1}(X_i) \right\}$$
    This score measures how far the true $Y_i$ is outside the predicted central quantile interval.
3.  **Compute Conformal Quantile**: Find $\hat{q}_{n_2}$ from these scores.
4.  **Form Conformal Set**:
    $$\hat{C}_n(x) = [\hat{f}^{\alpha/2}_{n_1}(x) - \hat{q}_{n_2}, \hat{f}^{1-\alpha/2}_{n_1}(x) + \hat{q}_{n_2}]$$
CQR often provides superior local adaptivity, especially for non-Gaussian conditional distributions.

### Code Example: Studentized Residuals for Regression

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor  # Used for spread predictor

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Heteroscedastic Regression Data ---
# Data where the noise level changes with X
num_samples = 300
X_data_np = np.linspace(-5, 5, num_samples)[:, None]
y_true = np.sin(X_data_np * 2) * 5  # A non-linear true function
noise_level = 0.5 + 0.5 * np.abs(X_data_np)  # Noise increases with |X|
y_data_np = y_true + noise_level * np.random.normal(0, 1, size=(num_samples, 1))

X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np).squeeze()  # Ensure y_data is 1D

print(f"Total samples: {num_samples}")

# --- 2. Split Data into Proper Training (D1) and Calibration (D2) Sets ---
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_data, y_data, test_size=0.5, random_state=key.get(0).tolist()[0]
)
n1 = X_D1.shape[0]
n2 = X_D2.shape[0]
print(f"Proper training set size (n1): {n1}")
print(f"Calibration set size (n2): {n2}\n")

# --- 3. Fit Point Predictor on D1 (e.g., Linear Regression or RandomForest) ---
# Using RandomForestRegressor for mean prediction to better capture non-linearity
mean_predictor = RandomForestRegressor(
    n_estimators=100, random_state=key.get(1).tolist()[0]
)
mean_predictor.fit(np.array(X_D1), np.array(y_D1))

# --- 4. Fit Spread Predictor on D1 ---
# Compute absolute residuals on D1 for training the spread predictor
y_D1_pred_mean = mean_predictor.predict(np.array(X_D1))
abs_residuals_D1 = jnp.abs(y_D1 - jnp.array(y_D1_pred_mean))

# Train a model to predict these absolute residuals (as a proxy for local spread)
spread_predictor = RandomForestRegressor(
    n_estimators=100, random_state=key.get(2).tolist()[0]
)
spread_predictor.fit(np.array(X_D1), np.array(abs_residuals_D1))

# --- 5. Compute Studentized Calibration Residuals on D2 ---
y_D2_pred_mean = mean_predictor.predict(np.array(X_D2))
y_D2_pred_spread = spread_predictor.predict(np.array(X_D2))

# Add a small epsilon to spread predictions to avoid division by zero
# Ensure spread predictions are positive and non-zero
y_D2_pred_spread_stable = jnp.maximum(jnp.array(y_D2_pred_spread), 1e-6)

studentized_calibration_scores = (
    jnp.abs(y_D2 - jnp.array(y_D2_pred_mean)) / y_D2_pred_spread_stable
)
print(
    f"Studentized calibration scores (first 5): {studentized_calibration_scores[:5]}\n"
)

# --- 6. Compute Conformal Quantile from Studentized Calibration Scores ---
alpha_level = 0.1  # Desired coverage: 1 - alpha = 0.9 (90%)
adjusted_rank_index = jnp.ceil((1 - alpha_level) * (n2 + 1)).astype(int)
sorted_studentized_calibration_scores = jnp.sort(studentized_calibration_scores)

if adjusted_rank_index <= 0:
    conformal_quantile_studentized = sorted_studentized_calibration_scores[0]
elif adjusted_rank_index > n2:
    conformal_quantile_studentized = sorted_studentized_calibration_scores[-1]
else:
    conformal_quantile_studentized = sorted_studentized_calibration_scores[
        adjusted_rank_index - 1
    ]

print(
    f"Conformal quantile for studentized residuals: {conformal_quantile_studentized:.4f}\n"
)

# --- 7. Form Conformal Prediction Band ---
# Generate a range of test points for plotting the prediction band
X_test_plot = jnp.linspace(X_data.min() - 1, X_data.max() + 1, 200)[:, None]
y_test_pred_mean = mean_predictor.predict(np.array(X_test_plot))
y_test_pred_spread = spread_predictor.predict(np.array(X_test_plot))

# Ensure spread predictions are positive and non-zero for test points too
y_test_pred_spread_stable = jnp.maximum(jnp.array(y_test_pred_spread), 1e-6)

# Construct the prediction band: [mean - spread * quantile, mean + spread * quantile]
lower_bound_studentized = (
    y_test_pred_mean - y_test_pred_spread_stable * conformal_quantile_studentized
)
upper_bound_studentized = (
    y_test_pred_mean + y_test_pred_spread_stable * conformal_quantile_studentized
)

# --- Plotting Results ---
plt.figure(figsize=(10, 7))
plt.scatter(
    X_D1[:, 0], y_D1, color="blue", label="Proper Training Data (D1)", alpha=0.7, s=30
)
plt.scatter(
    X_D2[:, 0], y_D2, color="green", label="Calibration Data (D2)", alpha=0.7, s=30
)
plt.plot(
    X_test_plot[:, 0],
    y_test_pred_mean,
    color="red",
    linewidth=2,
    label="Regression Mean ($\\hat{f}_{n_1}$)",
)
plt.fill_between(
    X_test_plot[:, 0],
    lower_bound_studentized,
    upper_bound_studentized,
    color="red",
    alpha=0.2,
    label=f"Studentized Conformal Band (1-$\\alpha$={1 - alpha_level})",
)

plt.xlabel("X")
plt.ylabel("Y")
plt.title("Split Conformal Prediction with Studentized Residuals (Adaptive Width)")
plt.legend()
plt.grid(True)
plt.show()

# --- Simulate Test Coverage to Verify Guarantee ---
num_test_points_sim = 1000
# Generate new i.i.d. test data with heteroscedastic noise
X_test_sim_np = np.random.uniform(
    X_data.min(), X_data.max(), size=(num_test_points_sim, 1)
)
y_true_sim = np.sin(X_test_sim_np * 2) * 5
noise_level_sim = 0.5 + 0.5 * np.abs(X_test_sim_np)
y_test_sim_true = y_true_sim + noise_level_sim * np.random.normal(
    0, 1, size=(num_test_points_sim, 1)
)

X_test_sim = jnp.array(X_test_sim_np)
y_test_sim = jnp.array(y_test_sim_true).squeeze()

y_test_sim_pred_mean = mean_predictor.predict(np.array(X_test_sim))
y_test_sim_pred_spread = spread_predictor.predict(np.array(X_test_sim))
y_test_sim_pred_spread_stable = jnp.maximum(jnp.array(y_test_sim_pred_spread), 1e-6)

coverage_count_sim_studentized = 0
for i in range(num_test_points_sim):
    lower = (
        y_test_sim_pred_mean[i]
        - y_test_sim_pred_spread_stable[i] * conformal_quantile_studentized
    )
    upper = (
        y_test_sim_pred_mean[i]
        + y_test_sim_pred_spread_stable[i] * conformal_quantile_studentized
    )
    if (y_test_sim[i] >= lower) and (y_test_sim[i] <= upper):
        coverage_count_sim_studentized += 1

simulated_coverage_studentized = coverage_count_sim_studentized / num_test_points_sim
print(f"Simulated test coverage (Studentized): {simulated_coverage_studentized:.4f}")
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(f"Theoretical upper bound (approx): {1 - alpha_level + 1 / (n2 + 1):.4f}")

## 7. Full Conformal Prediction: Guaranteed Coverage Without Splitting

Full Conformal Prediction (often just "conformal prediction" in older literature) is a more general approach that achieves the coverage guarantee **without splitting the data**. This means all $n$ data points can be used for training the base predictor. However, this comes at a significant computational cost.

The core idea remains to treat all data symmetrically, but in a more intricate way.

### The Procedure:

1.  **Query Point Augmentation**: For any fixed test input $x \in \mathcal{X}$, we want to determine if a trial response value $y \in \mathcal{Y}$ should be included in the prediction set $\hat{C}_n(x)$. We consider $(x,y)$ as a hypothetical "query point."

2.  **Train on Augmented Data**: For each trial value $y$, we train our prediction algorithm on an **augmented dataset** consisting of the original $n$ training points *plus* the single query point $(x,y)$:
    $$D_{\text{aug}} = \{(X_1, Y_1), \ldots, (X_n, Y_n), (x,y)\}$$
    This yields a point predictor $\hat{f}_{n,(x,y)}$ that has been trained on $n+1$ points.

3.  **Define Residuals (Non-conformity Scores)**: Compute residuals for *all* $n+1$ points in the augmented dataset using this newly trained predictor $\hat{f}_{n,(x,y)}$:
    $$R_i^{(x,y)} = |Y_i - \hat{f}_{n,(x,y)}(X_i)| \quad \text{for } i=1, \ldots, n$$
    $$R_{n+1}^{(x,y)} = |y - \hat{f}_{n,(x,y)}(x)| \quad \text{(residual for the query point itself)}$$

4.  **Form Conformal Set**: The prediction set $\hat{C}_n(x)$ is the set of all trial values $y$ for which the residual of the query point $R_{n+1}^{(x,y)}$ is less than or equal to the $\lceil (1-\alpha)(n+1) \rceil$-th smallest of *all* $n+1$ residuals $\{R_1^{(x,y)}, \ldots, R_n^{(x,y)}, R_{n+1}^{(x,y)}\}$:
    $$\hat{C}_n(x) = \left\{y : R_{n+1}^{(x,y)} \le \text{the } \lceil (1-\alpha)(n+1) \rceil \text{ smallest of } \{R_j^{(x,y)}\}_{j=1}^{n+1} \right\}$$

### The Guarantee

The guarantee for full conformal prediction is:
$$\mathbb{P}(Y_{n+1} \in \hat{C}_n(X_{n+1})) \in [1-\alpha, 1-\alpha + \frac{1}{n+1})$$
This holds because, when the true test point $(X_{n+1}, Y_{n+1})$ is plugged in as the query point $(x,y)$, the resulting $n+1$ residuals become exchangeable (assuming the base predictor is a symmetric function of its training data).

### Remarks on Full Conformal Prediction

* **Computational Expense**: This is the major drawback. For each test input $x$, and for each candidate $y$ in a grid of possible responses, the base predictor must be *re-trained* on $n+1$ data points. This makes it extremely computationally intensive, especially for large $n$ or complex base predictors.
* **Use Cases**: Rarely used in practice for large-scale problems, except when the base predictor has a "shortcut" for efficient re-training (e.g., kernel methods).
* **Theoretical Elegance**: Despite its computational cost, it's theoretically very elegant as it avoids data splitting and uses all available data for model training.

### Code Example: Full Conformal Prediction (Conceptual)

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression  # Our base predictor

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Regression Data ---
num_samples = 20  # Small N for conceptual illustration
X_data_np = np.linspace(0, 10, num_samples)[:, None]
y_true = 0.5 * X_data_np.squeeze() + 1
noise_std = 1.0
y_data_np = y_true + np.random.normal(0, noise_std, num_samples)

X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np)

print(f"Original training samples (n): {num_samples}\n")


# --- Full Conformal Prediction Function ---
def full_conformal_predict(
    X_train: jnp.ndarray,
    y_train: jnp.ndarray,
    X_test_point: jnp.ndarray,  # Single test feature point x
    alpha: float,
    y_candidate_grid: jnp.ndarray,  # Grid of candidate y values for the prediction set
    base_predictor_class,  # e.g., LinearRegression
) -> jnp.ndarray:
    """
    Conceptual implementation of Full Conformal Prediction for a single test point.
    NOTE: This is computationally very expensive and not practical for large N.

    Args:
        X_train: Original training features.
        y_train: Original training responses.
        X_test_point: The feature value (x) for which to build the prediction set.
        alpha: Nominal error level.
        y_candidate_grid: A grid of y values to check for inclusion in the prediction set.
        base_predictor_class: The class of the base predictor (e.g., LinearRegression).

    Returns:
        A boolean array indicating which y_candidate_grid values are in the set.
    """
    n = X_train.shape[0]
    is_in_prediction_set = jnp.zeros(y_candidate_grid.shape[0], dtype=bool)

    print(
        f"  Checking {y_candidate_grid.shape[0]} candidate y values for X_test_point={X_test_point.item():.2f}..."
    )

    # For each candidate y in the grid
    for idx, y_candidate in enumerate(y_candidate_grid):
        # 1. Augment the training data with the query point (X_test_point, y_candidate)
        X_augmented = jnp.vstack([X_train, X_test_point])
        y_augmented = jnp.concatenate([y_train, jnp.array([y_candidate])])

        # 2. Train the base predictor on the augmented data
        current_predictor = base_predictor_class()
        current_predictor.fit(np.array(X_augmented), np.array(y_augmented))

        # 3. Define residuals for all n+1 points
        y_augmented_pred = current_predictor.predict(np.array(X_augmented))
        all_residuals = jnp.abs(y_augmented - jnp.array(y_augmented_pred))

        # The residual for the query point is the last one
        R_n_plus_1_query = all_residuals[-1]

        # 4. Form Conformal Set: Compare R_n_plus_1_query to the quantile of all residuals
        # The rank index is based on n+1 total residuals
        adjusted_rank_index = jnp.ceil((1 - alpha) * (n + 1)).astype(int)

        # Sort all residuals
        sorted_all_residuals = jnp.sort(all_residuals)

        # Get the threshold quantile (0-indexed)
        # Handle potential edge cases for rank_index
        if adjusted_rank_index <= 0:
            threshold_quantile = sorted_all_residuals[0]
        elif adjusted_rank_index > n + 1:  # Max index is n
            threshold_quantile = sorted_all_residuals[-1]
        else:
            threshold_quantile = sorted_all_residuals[adjusted_rank_index - 1]

        # Check if the query point's residual is less than or equal to the threshold
        if R_n_plus_1_query <= threshold_quantile:
            is_in_prediction_set = is_in_prediction_set.at[idx].set(True)

        if (idx + 1) % 50 == 0:
            print(f"    Processed {idx + 1}/{y_candidate_grid.shape[0]} candidates...")

    return is_in_prediction_set


# --- Example Usage for Full Conformal Prediction ---
alpha_level = 0.1  # 90% coverage
X_test_point = jnp.array(
    [[5.0]]
)  # A single test feature point for which to build the set

# Define a grid of candidate y values for the prediction set
y_candidate_grid = jnp.linspace(
    y_data.min() - 5, y_data.max() + 5, 200
)  # Increased grid density for better visualization

print("Starting Full Conformal Prediction for a single test point...")
is_in_set_flags = full_conformal_predict(
    X_data, y_data, X_test_point, alpha_level, y_candidate_grid, LinearRegression
)

# Extract the prediction set
full_conformal_set_y_values = y_candidate_grid[is_in_set_flags]

print(f"\nFull Conformal Prediction Set for X={X_test_point.item():.2f}:")
if full_conformal_set_y_values.shape[0] > 0:
    print(f"  Min Y: {full_conformal_set_y_values.min():.2f}")
    print(f"  Max Y: {full_conformal_set_y_values.max():.2f}")
    print(
        f"  Approximate Width: {full_conformal_set_y_values.max() - full_conformal_set_y_values.min():.2f}"
    )
else:
    print("  Prediction set is empty.")

# --- Plotting Results ---
plt.figure(figsize=(10, 7))
plt.scatter(X_data[:, 0], y_data, color="blue", label="Training Data", alpha=0.7, s=30)
plt.axvline(X_test_point.item(), color="purple", linestyle="--", label="Test Point X")

# Plot the prediction set as a vertical band at X_test_point
if full_conformal_set_y_values.shape[0] > 0:
    plt.vlines(
        X_test_point.item(),
        full_conformal_set_y_values.min(),
        full_conformal_set_y_values.max(),
        colors="red",
        linewidth=4,
        label=f"Full Conformal Set (1-$\\alpha$={1 - alpha_level})",
    )
    # Also plot the individual points that are in the set
    plt.scatter(
        jnp.full_like(full_conformal_set_y_values, X_test_point.item()),
        full_conformal_set_y_values,
        color="red",
        marker=".",
        s=10,
        alpha=0.5,
    )

plt.xlabel("X")
plt.ylabel("Y")
plt.title("Full Conformal Prediction (Conceptual)")
plt.legend()
plt.grid(True)
plt.show()

# Note: Simulating coverage for full conformal is extremely computationally demanding
# as it would require running this function for many (X_n+1, Y_n+1) pairs.
# The theoretical guarantee is the main assurance here.

## 8. Conformal Classification

Conformal prediction can also be applied to **classification problems**, where the output $Y$ is a discrete class label (e.g., $Y \in \{1, \ldots, K\}$). The core idea is the same, but the conformity score functions need to be adapted for discrete outputs.

We assume we have a probabilistic classifier $\hat{f}_{n_1}$ (trained on $D_1$) that estimates class probabilities: $\hat{f}_{n_1}(x; k)$ is the estimated probability of class $k$ given input $x$.

### 8.1 Likelihood Scores

A straightforward approach uses the predicted probability of the *true class* as the conformity score. This is a **positively-oriented score** (higher is better).

**Procedure (Split Conformal with Likelihood Scores):**

1.  **Train Probabilistic Classifier**: Fit $\hat{f}_{n_1}$ on $D_1$ (e.g., Logistic Regression, Softmax Neural Network).
2.  **Compute Calibration Scores**: For $(X_i, Y_i) \in D_2$, compute:
    $$R_i = \hat{f}_{n_1}(X_i; Y_i)$$
    (The predicted probability of the true class $Y_i$).
3.  **Compute Conformal Quantile**: Since $R_i$ is positively-oriented, we use the lower quantile:
    $$\hat{q}_{n_2} = \text{the } \lfloor \alpha(n_2+1) \rfloor \text{ smallest of } \{R_i\}_{i \in D_2}$$
4.  **Form Conformal Prediction Set**: For a new input $x$, the prediction set $\hat{C}_n(x)$ includes all classes $k$ whose predicted probability is greater than or equal to this quantile:
    $$\hat{C}_n(x) = \{k : \hat{f}_{n_1}(x; k) \ge \hat{q}_{n_2}\}$$
This set can contain one, multiple, or no classes.

### 8.2 Cumulative Likelihood (Adaptive Prediction Sets - APS / RAPS)

To achieve better local adaptivity and often smaller prediction sets, the **cumulative likelihood score** is used (Romano et al., 2020). This is a **negatively-oriented score**.

**Procedure (Split Conformal with Cumulative Likelihood Scores):**

1.  **Train Probabilistic Classifier**: Fit $\hat{f}_{n_1}$ on $D_1$.
2.  **Compute Calibration Scores**: For each $(X_i, Y_i) \in D_2$:
    * Sort the predicted probabilities $\hat{f}_{n_1}(X_i; k)$ in decreasing order to get a permutation $\pi_i$.
    * Find the rank $k_i$ of the true class $Y_i$ in this sorted list (i.e., $\pi_i(k_i) = Y_i$).
    * Compute the cumulative probability up to the true class's rank:
        $$R_i = \sum_{j=1}^{k_i} \hat{f}_{n_1}(X_i; \pi_i(j))$$
    This score represents the cumulative probability of all classes "at least as likely" as the true class.
3.  **Compute Conformal Quantile**: Find $\hat{q}_{n_2}$ from these negatively-oriented scores:
    $$\hat{q}_{n_2} = \text{the } \lceil (1-\alpha)(n_2+1) \rceil \text{ smallest of } \{R_i\}_{i \in D_2}$$
4.  **Form Conformal Prediction Set**: For a new input $x$:
    * Sort the predicted probabilities $\hat{f}_{n_1}(x; k)$ in decreasing order to get $\pi_x$.
    * Find the smallest $k_x$ such that the cumulative probability of the top $k_x$ classes is less than or equal to $\hat{q}_{n_2}$:
        $$k_x = \min \left( k : \sum_{j=1}^k \hat{f}_{n_1}(x; \pi_x(j)) \le \hat{q}_{n_2} \right)$$
    * The prediction set is the top $k_x$ most likely classes:
        $$\hat{C}_n(x) = \{\pi_x(1), \ldots, \pi_x(k_x)\}$$
This method allows for more adaptive prediction set sizes.

### Code Example: Conformal Classification (Likelihood Scores)

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression  # A probabilistic classifier
from sklearn.datasets import make_classification

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Classification Data ---
num_samples = 200
num_classes = 3  # For multi-class classification
X_data_np, y_data_np = make_classification(
    n_samples=num_samples,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    n_clusters_per_class=1,
    n_classes=num_classes,
    random_state=key.get(0).tolist()[0],
)
X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np)

print(f"Total samples: {num_samples}")
print(f"Number of classes: {num_classes}")

# --- 2. Split Data into Proper Training (D1) and Calibration (D2) Sets ---
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_data, y_data, test_size=0.5, random_state=key.get(1).tolist()[0]
)
n1 = X_D1.shape[0]
n2 = X_D2.shape[0]
print(f"Proper training set size (n1): {n1}")
print(f"Calibration set size (n2): {n2}\n")

# --- 3. Fit Probabilistic Classifier on Proper Training Set (D1) ---
# LogisticRegression provides predict_proba() for class probabilities
classifier = LogisticRegression(
    random_state=key.get(2).tolist()[0], solver="lbfgs", multi_class="multinomial"
)
classifier.fit(np.array(X_D1), np.array(y_D1))

# --- 4. Compute Calibration Scores (Likelihood Scores) on D2 ---
# Get predicted probabilities for calibration set
y_D2_pred_proba = classifier.predict_proba(np.array(X_D2))  # Shape (n2, num_classes)

# Calibration score R_i = P(Y_i | X_i) for the true class Y_i
# We select the probability corresponding to the true class label for each calibration point
calibration_scores = y_D2_pred_proba[jnp.arange(n2), y_D2]
print(
    f"Calibration scores (predicted likelihoods for true class, first 5): {calibration_scores[:5]}\n"
)

# --- 5. Compute Conformal Quantile from Calibration Scores ---
alpha_level = 0.1  # Desired coverage: 1 - alpha = 0.9 (90%)

# For positively-oriented scores (higher is better), the conformal quantile is the
# floor(alpha * (n2 + 1))-th smallest score.
# This corresponds to the (alpha * (n2 + 1))-th percentile.
adjusted_rank_index = jnp.floor(alpha_level * (n2 + 1)).astype(int)

sorted_calibration_scores = jnp.sort(calibration_scores)

# Get the conformal quantile (0-indexed)
# Handle edge cases: if adjusted_rank_index is 0, use the smallest score.
# If adjusted_rank_index is n2 or more, use the largest score.
if adjusted_rank_index < 0:
    conformal_quantile_class = sorted_calibration_scores[0]
elif adjusted_rank_index >= n2:
    conformal_quantile_class = sorted_calibration_scores[-1]
else:
    conformal_quantile_class = sorted_calibration_scores[adjusted_rank_index]

print(f"Conformal quantile (q_n2) for classification: {conformal_quantile_class:.4f}\n")

# --- 6. Form Conformal Prediction Set for a New Test Point ---
# Let's pick a single test point for demonstration
X_test_single = jnp.array([[0.5, 0.5]])  # Example test point
# For checking coverage, assume true class is 1 for this point
y_test_single_true_class = 1

# Get predicted probabilities for the test point across all classes
prob_test_single = classifier.predict_proba(np.array(X_test_single))[
    0
]  # Shape (num_classes,)
print(f"Predicted probabilities for X_test_single: {prob_test_single}")

# Construct the prediction set: {k : P(Y=k|X) >= q_n2}
prediction_set = []
for k in range(num_classes):
    if prob_test_single[k] >= conformal_quantile_class:
        prediction_set.append(k)

print(f"Conformal Prediction Set for X_test_single: {prediction_set}")
if y_test_single_true_class in prediction_set:
    print(f"True class {y_test_single_true_class} IS in the prediction set.\n")
else:
    print(f"True class {y_test_single_true_class} IS NOT in the prediction set.\n")


# --- Simulate Test Coverage to Verify Guarantee ---
num_test_points_sim = 1000
X_test_sim_np, y_test_sim_np = make_classification(
    n_samples=num_test_points_sim,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    n_clusters_per_class=1,
    n_classes=num_classes,
    random_state=key.get(3).tolist()[0],
)
X_test_sim = jnp.array(X_test_sim_np)
y_test_sim = jnp.array(y_test_sim_np)

y_test_sim_pred_proba = classifier.predict_proba(np.array(X_test_sim))

coverage_count_sim_class = 0
for i in range(num_test_points_sim):
    predicted_probs_for_point = y_test_sim_pred_proba[i]
    true_class_for_point = y_test_sim[i]

    # Check if the true class's probability is above the conformal quantile
    if predicted_probs_for_point[true_class_for_point] >= conformal_quantile_class:
        coverage_count_sim_class += 1

simulated_coverage_class = coverage_count_sim_class / num_test_points_sim
print(f"Simulated test coverage (Classification): {simulated_coverage_class:.4f}")
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(f"Theoretical upper bound (approx): {1 - alpha_level + 1 / (n2 + 1):.4f}")

# --- Plotting Decision Regions (Conceptual) ---
# For classification, visualizing prediction sets (which can be multi-valued) is complex.
# Here, we'll visualize the decision boundary for the most likely class, and the conformal threshold.
x1_min, x1_max = X_data[:, 0].min() - 0.5, X_data[:, 0].max() + 0.5
x2_min, x2_max = X_data[:, 1].min() - 0.5, X_data[:, 1].max() + 0.5
xx1, xx2 = jnp.meshgrid(
    jnp.linspace(x1_min, x1_max, 100), jnp.linspace(x2_min, x2_max, 100)
)
X_grid = jnp.c_[xx1.ravel(), xx2.ravel()]

# Predict probabilities on the grid for all classes
prob_grid_all_classes = classifier.predict_proba(np.array(X_grid))
# Get the probability of the most likely class for visualization
max_prob_grid = jnp.max(prob_grid_all_classes, axis=1).reshape(xx1.shape)

plt.figure(figsize=(10, 8))
# Contour plot of the maximum predicted probability
plt.contourf(
    xx1, xx2, max_prob_grid, levels=jnp.linspace(0, 1, 11), cmap="viridis", alpha=0.6
)
plt.colorbar(label="Max Predicted Probability")

# Plot the training data points
colors = ["red", "blue", "green"]
markers = ["o", "x", "s"]
labels = [f"Class {k}" for k in range(num_classes)]
for k in range(num_classes):
    plt.scatter(
        X_data[y_data == k, 0],
        X_data[y_data == k, 1],
        color=colors[k],
        marker=markers[k],
        label=labels[k],
        edgecolor="black",
        s=50,
    )

# Add a contour line for the conformal quantile threshold (for the most likely class)
# This shows regions where even the most likely class might not meet the threshold.
plt.contour(
    xx1,
    xx2,
    max_prob_grid,
    levels=[conformal_quantile_class],
    colors="white",
    linewidths=2,
    linestyles="--",
    label="Conformal Threshold",
)

plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.title("Max Predicted Probability and Conformal Threshold")
plt.legend()
plt.grid(True)
plt.show()

### Code Example: Conformal Classification (Cumulative Likelihood - Conceptual)

In [None]:
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression  # A probabilistic classifier
from sklearn.datasets import make_classification

# Set a random seed for reproducibility
key = random.PRNGKey(0)

# --- 1. Simulate Classification Data ---
num_samples = 200
num_classes = 3
X_data_np, y_data_np = make_classification(
    n_samples=num_samples,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    n_clusters_per_class=1,
    n_classes=num_classes,
    random_state=key.get(0).tolist()[0],
)
X_data = jnp.array(X_data_np)
y_data = jnp.array(y_data_np)

# --- 2. Split Data ---
X_D1, X_D2, y_D1, y_D2 = train_test_split(
    X_data, y_data, test_size=0.5, random_state=key.get(1).tolist()[0]
)
n2 = X_D2.shape[0]

# --- 3. Fit Probabilistic Classifier ---
classifier = LogisticRegression(
    random_state=key.get(2).tolist()[0], solver="lbfgs", multi_class="multinomial"
)
classifier.fit(np.array(X_D1), np.array(y_D1))

# --- 4. Compute Calibration Scores (Cumulative Likelihood) on D2 ---
y_D2_pred_proba = classifier.predict_proba(np.array(X_D2))  # Shape (n2, num_classes)

cumulative_calibration_scores = jnp.zeros(n2)

for i in range(n2):
    true_class = y_D2[i]
    predicted_probs = y_D2_pred_proba[i]

    # Get the permutation that sorts probabilities in decreasing order
    # argsort returns indices that would sort an array. We want descending order.
    sorted_indices = jnp.argsort(predicted_probs)[::-1]

    cumulative_sum = 0.0
    for j, class_idx in enumerate(sorted_indices):
        cumulative_sum += predicted_probs[class_idx]
        if class_idx == true_class:
            # This is the cumulative probability of all classes "at least as likely" as the true one
            cumulative_calibration_scores = cumulative_calibration_scores.at[i].set(
                cumulative_sum
            )
            break
print(
    f"Cumulative likelihood calibration scores (first 5): {cumulative_calibration_scores[:5]}\n"
)

# --- 5. Compute Conformal Quantile ---
alpha_level = 0.1  # Desired coverage: 1 - alpha = 0.9 (90%)

# For negatively-oriented scores (higher is worse), we use ceil((1-alpha)*(n2+1)) smallest.
adjusted_rank_index = jnp.ceil((1 - alpha_level) * (n2 + 1)).astype(int)
sorted_cumulative_scores = jnp.sort(cumulative_calibration_scores)

if adjusted_rank_index <= 0:
    conformal_quantile_cumulative = sorted_cumulative_scores[0]
elif adjusted_rank_index > n2:
    conformal_quantile_cumulative = sorted_cumulative_scores[-1]
else:
    conformal_quantile_cumulative = sorted_cumulative_scores[adjusted_rank_index - 1]

print(
    f"Conformal quantile (cumulative likelihood): {conformal_quantile_cumulative:.4f}\n"
)


# --- 6. Form Conformal Prediction Set for a New Test Point ---
def get_conformal_set_cumulative(
    X_point: jnp.ndarray, classifier_model, conformal_q: float, num_classes: int
) -> list[int]:
    """
    Forms the conformal prediction set using cumulative likelihood.
    """
    predicted_probs = classifier_model.predict_proba(np.array(X_point))[0]

    # Get the permutation that sorts probabilities in decreasing order
    sorted_indices = jnp.argsort(predicted_probs)[::-1]

    prediction_set_classes = []
    cumulative_prob_sum = 0.0

    for class_idx in sorted_indices:
        cumulative_prob_sum += predicted_probs[class_idx]
        prediction_set_classes.append(class_idx.item())  # Add to set

        # Stop when cumulative probability exceeds the conformal quantile
        if cumulative_prob_sum > conformal_q:
            break

    return prediction_set_classes


# Example test point
X_test_single = jnp.array([[0.5, 0.5]])
y_test_single_true_class = 1  # For checking

prediction_set_cumulative = get_conformal_set_cumulative(
    X_test_single, classifier, conformal_quantile_cumulative, num_classes
)
print(
    f"Conformal Prediction Set (Cumulative Likelihood) for X_test_single: {prediction_set_cumulative}"
)
if y_test_single_true_class in prediction_set_cumulative:
    print(f"True class {y_test_single_true_class} IS in the prediction set.\n")
else:
    print(f"True class {y_test_single_true_class} IS NOT in the prediction set.\n")

# --- Simulate Test Coverage ---
num_test_points_sim = 1000
X_test_sim_np, y_test_sim_np = make_classification(
    n_samples=num_test_points_sim,
    n_features=2,
    n_redundant=0,
    n_informative=2,
    n_clusters_per_class=1,
    n_classes=num_classes,
    random_state=key.get(3).tolist()[0],
)
X_test_sim = jnp.array(X_test_sim_np)
y_test_sim = jnp.array(y_test_sim_np)

coverage_count_sim_cumulative = 0
for i in range(num_test_points_sim):
    X_point = X_test_sim[i : i + 1]  # Slice to keep 2D shape
    true_class = y_test_sim[i]

    current_prediction_set = get_conformal_set_cumulative(
        X_point, classifier, conformal_quantile_cumulative, num_classes
    )

    if true_class in current_prediction_set:
        coverage_count_sim_cumulative += 1

simulated_coverage_cumulative = coverage_count_sim_cumulative / num_test_points_sim
print(
    f"Simulated test coverage (Cumulative Likelihood): {simulated_coverage_cumulative:.4f}"
)
print(f"Desired coverage (at least): {1 - alpha_level:.4f}")
print(f"Theoretical upper bound (approx): {1 - alpha_level + 1 / (n2 + 1):.4f}")

# --- Plotting (Conceptual) ---
# Visualizing adaptive set sizes is harder for multi-class.
# We can plot the decision boundary for the most likely class, similar to before.
x1_min, x1_max = X_data[:, 0].min() - 0.5, X_data[:, 0].max() + 0.5
x2_min, x2_max = X_data[:, 1].min() - 0.5, X_data[:, 1].max() + 0.5
xx1, xx2 = jnp.meshgrid(
    jnp.linspace(x1_min, x1_max, 100), jnp.linspace(x2_min, x2_max, 100)
)
X_grid = jnp.c_[xx1.ravel(), xx2.ravel()]

prob_grid_all_classes = classifier.predict_proba(np.array(X_grid))
# For visualization, let's plot the probability of the most likely class
max_prob_grid = jnp.max(prob_grid_all_classes, axis=1).reshape(xx1.shape)

plt.figure(figsize=(10, 8))
plt.contourf(
    xx1, xx2, max_prob_grid, levels=jnp.linspace(0, 1, 11), cmap="viridis", alpha=0.6
)
plt.colorbar(label="Max Predicted Probability")

colors = ["red", "blue", "green"]
markers = ["o", "x", "s"]
labels = [f"Class {k}" for k in range(num_classes)]
for k in range(num_classes):
    plt.scatter(
        X_data[y_data == k, 0],
        X_data[y_data == k, 1],
        color=colors[k],
        marker=markers[k],
        label=labels[k],
        edgecolor="black",
        s=50,
    )

# Add a contour line for the conformal quantile threshold (for the most likely class)
# This is a simplification, as the actual set depends on cumulative probabilities.
plt.contour(
    xx1,
    xx2,
    max_prob_grid,
    levels=[conformal_quantile_cumulative],
    colors="white",
    linewidths=2,
    linestyles="--",
    label="Conformal Threshold",
)

plt.xlabel("$x_1$")
plt.ylabel("Y")
plt.title("Max Predicted Probability and Cumulative Conformal Threshold")
plt.legend()
plt.grid(True)
plt.show()

## Conclusion: The Power of Conformal Prediction

Conformal prediction offers a unique and powerful approach to uncertainty quantification in machine learning. Its ability to provide **distribution-free, finite-sample coverage guarantees** for prediction sets makes it an invaluable tool, especially in high-stakes applications where reliability is paramount.

Key takeaways from this notebook:

* **Rigorous Guarantees**: Unlike many methods that rely on asymptotic approximations or strong distributional assumptions, conformal prediction provides a hard guarantee on the coverage of its prediction sets.
* **Model Agnostic**: It can be wrapped around any existing machine learning model, allowing you to leverage the predictive power of complex algorithms while adding a layer of statistical validity.
* **Exchangeability is Key**: The core principle relies on the exchangeability of conformity scores, which is achieved through clever data splitting (Split Conformal) or symmetric re-training (Full Conformal).
* **Computational Efficiency (Split CP)**: Split conformal prediction is highly practical and computationally efficient, making it suitable for large datasets.
* **Local Adaptivity**: By choosing appropriate conformity scores (e.g., studentized residuals, CQR, cumulative likelihood), prediction sets can adapt their size to the local difficulty of prediction, providing more informative uncertainty estimates.
* **Versatile Applications**: Applicable to both regression (continuous outputs) and classification (discrete outputs), yielding prediction intervals and prediction sets, respectively.

While challenges like the impossibility of exact X-conditional coverage exist, conformal prediction provides a robust and theoretically sound framework for building trustworthy machine learning systems.