# Change of Measure Theorem: Visualization with JAX

This notebook demonstrates the change of measure theorem for probability densities under monotonic transformations, using JAX for computation and Plotly for visualization. We use a sum of sigmoid functions to create a strictly increasing function, transform a standard normal density, and visualize the effect of the change of variable formula.

## 1. Import Required Libraries

Import jax, jax.numpy, numpy, scipy.stats, and plotly for computation and visualization.

In [142]:
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.interpolate import interp1d
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "plotly_white"

## 2. Define Sigmoid Function and Its Derivative in JAX

Implement the sigmoid function and its derivative using JAX arrays and automatic differentiation (jax.grad).

In [127]:
def sigmoid(loc: float, gain: float):
    return jax.value_and_grad(lambda a: 1.0 / (1.0 + jnp.exp(-(a - loc) / gain)))

## 3. Set Parameters for Sigmoid Features

Define the locations and gains for three sigmoid functions as JAX arrays, and create the domain for x.

In [96]:
# Parameters for three sigmoid features
locs = jnp.array([-1.0, 0.0, 1.0])
gains = jnp.array([1.0, 1.0, 1.0])

# Domain for x
N = 400
x = jnp.linspace(-3, 3, N, dtype=jnp.float32)

## 4. Compute and Visualize the Monotonic Function

Sum the three sigmoid functions to create a strictly increasing function $f(x)$ and plot the individual sigmoids and their derivatives using Plotly.

In [128]:
Fs = jnp.asarray(
    [jax.vmap(sigmoid(l, g))(x) for (l, g) in zip(locs, gains)]
)  # [3, 2, N]
F = Fs.sum(axis=0)
# unpack
f = F[0, :]
df = F[1, :]

In [129]:
fig = go.Figure()

# Plot individual sigmoids and their derivatives
for i, (loc, gain) in enumerate(zip(locs, gains)):
    s = Fs[i, 0, :]
    ds = Fs[i, 1, :]
    fig.add_trace(
        go.Scatter(
            x=x,
            y=s,
            mode="lines",
            line=dict(color="gray", width=1),
            name="sigmoid" if i == 0 else None,
            showlegend=(i == 0),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x,
            y=ds,
            mode="lines",
            line=dict(color="gray", width=1, dash="dash"),
            name="sigmoid derivative" if i == 0 else None,
            showlegend=(i == 0),
        )
    )

# Plot sum of sigmoids and its derivative
fig.add_trace(
    go.Scatter(
        x=x,
        y=f,
        mode="lines",
        line=dict(color="red", width=2),
        name="$f(x)$ (sum of sigmoids)",
    )
)
fig.add_trace(
    go.Scatter(
        x=x,
        y=df,
        mode="lines",
        line=dict(color="red", width=2, dash="dash"),
        name="$df/dx$",
    )
)

fig.update_layout(
    title="Monotonic Function $f(x)$ and Its Derivative",
    xaxis_title="$x$",
    yaxis_title="Value",
    legend=dict(itemsizing="constant"),
    width=1200,
    height=500,
)
fig.show()

## 5. Compute and Plot the Original Density $p_X(x)$

Compute the standard normal density using scipy.stats and plot it over the domain using Plotly.

The change of measure formula for transforming a probability density from $x$ to $y = f(x)$ is:

$
p_Y(y) = p_X\big(v(y)\big) \left| \frac{dv}{dy} \right|
$

where $v(y) = f^{-1}(y)$ is the inverse function of $f$

Equivalently, the change of variable formula can be written as:

$
p_X(x) = p_Y\big(f(x)\big) \left| \frac{df}{dx} \right|
$

where $f(x)$ is the transformation and $\frac{df}{dx}$ is its derivative.

In [137]:
# Standard normal density
p_x = norm.pdf(x, loc=0, scale=1)
p_f = norm.pdf(f, loc=0, scale=1)
p_y = p_x / jnp.abs(df)
px_int = jnp.trapezoid(y=p_x, x=x)
f_int = jnp.trapezoid(y=p_f, x=f)
py_int = jnp.trapezoid(y=p_y, x=f)

In [141]:
fig_px = go.Figure()
fig_px.add_trace(
    go.Scatter(
        x=x,
        y=p_x,
        mode="lines",
        line=dict(color="black"),
        name="$p_X(x)$ (standard normal)",
    )
)
fig_px.update_layout(
    title="Original Density $p_X(x)$",
    xaxis_title="$x$",
    yaxis_title="Density",
    legend=dict(itemsizing="constant"),
    width=1300,
    height=300,
)
fig_px.show()

## 6. Compute and Plot the Transformed Density $p_Y(f(x))$

Express the transformed density as $p_Y(f(x)) = p_X(x) / \left| \frac{df}{dx} \right|$ and plot $p_Y(f(x))$ as a function of $f(x)$.

In [147]:
fig_vy = go.Figure()
fig_vy.add_trace(
    go.Scatter(
        x=f,
        y=p_f,
        mode="lines",
        line=dict(color="orange", dash="dashdot"),
        name="$p_X(v(y))$ (naive)",
    )
)
fig_vy.update_layout(
    title="Transformed Density $p_X(v(y))$ (without correction)",
    xaxis_title="$y$",
    yaxis_title="Density",
    legend=dict(itemsizing="constant"),
    width=900,
    height=350,
)
fig_vy.show()

## 7. Apply Change of Measure Correction and Plot $p_Y(y)$

Multiply $p_X(v(y))$ by the absolute value of the derivative of the inverse function to obtain the correct transformed density $p_Y(y)$ and plot it.

# Why Do We Need the Change of Measure Correction?

When transforming a random variable $X$ with a probability density $p_X(x)$ via a monotonic function $y = f(x)$, the resulting variable $Y = f(X)$ does **not** generally have a density $p_Y(y)$ that is simply $p_X$ evaluated at the inverse, i.e., $p_X(f^{-1}(y))$. 

This is because probability densities are not invariant under nonlinear transformations. The density must be adjusted to account for how the transformation stretches or compresses the space. This adjustment is given by the **change of measure correction**, which involves multiplying by the absolute value of the derivative of the inverse function:

$p_Y(y) = p_X(v(y)) \left| \frac{dv}{dy} \right|$

where $v(y) = f^{-1}(y)$.

Alternatively, when expressing everything in terms of $x$:

$p_X(x) = p_Y(f(x)) \left| \frac{df}{dx} \right|$

## Why Is This Correction Necessary?

- **Conservation of Probability:** The total probability must remain 1 after transformation. Without the correction, the transformed density may not integrate to 1.
- **Density Transformation:** The function $f(x)$ can locally stretch or compress intervals, changing the "density" of probability mass. The derivative term $\left| \frac{dv}{dy} \right|$ (or $\left| \frac{df}{dx} \right|$) quantifies this local change.
- **Normalization:** The naive approach $p_Y(y) = p_X(f^{-1}(y))$ ignores this stretching/compression, leading to incorrect probabilities and a density that is not properly normalized.

## Example

If $f(x)$ is a strictly increasing function, then for a small interval $dx$ around $x$, the corresponding interval in $y$ is $dy = f'(x) dx$. The probability in both intervals must be equal:  
$p_X(x) dx = p_Y(y) dy$

So,  
$p_Y(y) = p_X(x) \left| \frac{dx}{dy} \right| = p_X(f^{-1}(y)) \left| \frac{dv}{dy} \right|$

## Conclusion

The change of measure correction ensures that the transformed density $p_Y(y)$ is a valid probability density, properly normalized and reflecting the true distribution of $Y = f(X)$. Without this correction, the resulting density would be mathematically and probabilistically incorrect.


In [148]:
fig_py = go.Figure()
fig_py.add_trace(
    go.Scatter(
        x=f,
        y=p_f,
        mode="lines",
        line=dict(color="orange", dash="dashdot"),
        name="$p_X(v(y))$ (naive)",
    )
)
fig_py.add_trace(
    go.Scatter(
        x=f,
        y=p_y,
        mode="lines",
        line=dict(color="green"),
        name="$p_Y(y)$ (corrected)",
    )
)
fig_py.update_layout(
    title="Transformed Density with Change of Measure Correction",
    xaxis_title="$y$",
    yaxis_title="Density",
    legend=dict(itemsizing="constant"),
    width=900,
    height=350,
)
fig_py.show()

## 8. Compare Integrals of Densities

Numerically integrate $p_X(x)$, $p_X(v(y))$, and $p_Y(y)$ to verify normalization and demonstrate the effect of the correction.

In [152]:
print(f"Integral of $p_X(x)$: {px_int:.4f}")
print(f"Integral of $p_X(v(y))$: {f_int:.4f}")
print(f"Integral of $p_Y(y)$: {py_int:.4f}")

Integral of $p_X(x)$: 0.9973
Integral of $p_X(v(y))$: 0.4243
Integral of $p_Y(y)$: 0.9973


As expected, the integral of $p_X(x)$ and $p_Y(y)$ are both close to 1, while the naive transformed density $p_X(v(y))$ does not integrate to 1. This demonstrates the necessity of the change of measure correction.