# Geometric Intuition for Jensen's Inequality

## Introduction

Jensen's inequality is fundamental in many fields, including machine learning and statistics. For example, it is useful in the [diffusion models paper](https://maitbayev.github.io/posts/denoising-diffusion-probabilistic-models/) for understanding the variational lower bound. In this post, I will give a simple geometric intuition for Jensen's inequality.

Feel free to leave feedback on my [telegram channel](https://t.me/swemanml).

## Setup 

The post contains collapsed code sections that are used to produce the visualizations. They're optional, hence collapsed.

In [1]:
#| code-summary: code for fig_jensen_inequality
#| code-fold: true

import itertools

import numpy as np
import plotly.graph_objects as go


def alpha_profiles(n: int):
    if n == 2:
        space = np.linspace(0.01, 0.99, 100)
        return np.column_stack((space, 1.0 - space))
    space = np.linspace(0.01, 0.99, 15 - max(0, (n - 3) * 5))
    space_prod = itertools.product(*[space for _ in range(n - 1)])
    profiles = np.array(list(space_prod))
    profiles = profiles[np.sum(profiles, axis=1) < 1.0]
    return np.concatenate([profiles, 1 - np.sum(profiles, axis=1).reshape(-1, 1)], axis=1)


def fig_jensen_inequality(f, x_range: np.array, x: np.array, show_hull_point_legend: bool = True):
    points = np.column_stack([x, f(x)])
    n = len(points)
    steps = []
    hull_points = []
    titles = []
    for index, alphas in enumerate(alpha_profiles(n)):
        hp = np.average(points, weights=alphas, axis=0)
        hull_points.append(hp)
        title = ",".join(["\\lambda_" + f"{i + 1}={a:.2f}" for i, a in enumerate(alphas)])
        title = f"${title}$"
        titles.append(title)
        step = dict(name=index, label=index, method="update",
                    args=[{
                        "x": [[hp[0]], [hp[0], hp[0]]],
                        "y": [[hp[1]], [f(hp[0]), hp[1]]],
                    }, {"title": title}, [2, 3]])
        steps.append(step)
    active_index = len(steps) // 2
    sliders = [dict(active=len(steps) // 2, steps=steps)]
    return go.Figure(data=[
        go.Scatter(
            name="f", x=x_range, y=f(x_range), hoverinfo="none"
        ),
        go.Scatter(
            name="Convex Hull", x=np.append(points[:, 0], points[0][0]),
            y=np.append(points[:, 1], points[0][1]),
            fillcolor="rgba(239, 85, 59, 0.2)", fill="toself", mode="lines",
            line=dict(width=3), hoverinfo="none",
            showlegend=points.shape[0] > 2
        ),
        go.Scatter(
            name="$(\\sum \\lambda_i x_i, \\sum \\lambda_i f(x_i))$",
            x=[hull_points[active_index][0]],
            y=[hull_points[active_index][1]],
            mode=f"markers{'+text' if show_hull_point_legend else ''}",
            text=["$(\\sum \\lambda_i x_i, \\sum \\lambda_i f(x_i))$"],
            textposition="top center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
            marker={"size": 20, "color": "black"},
            legendrank=1001,
            showlegend=show_hull_point_legend,
        ),
        go.Scatter(
            x=[hull_points[active_index][0], hull_points[active_index][0]],
            y=[f(hull_points[active_index][0]), hull_points[active_index][1]],
            mode="lines",
            textposition="bottom center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
            line={"color": "black", "dash": "dot", "width": 1},
            showlegend=False
        ),
        go.Scatter(
            name="$(x_i, f(x_i))$",
            x=points[:, 0], y=points[:, 1],
            mode="markers+text",
            marker={"size": 20},
            text=[f"$(x_{i},f(x_{i}))$" for i in range(1, n + 1)],
            textposition="top center",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>"
        ),

    ], layout=go.Layout(
        title=titles[active_index],
        xaxis=dict(fixedrange=True),
        yaxis=dict(fixedrange=True, scaleanchor="x", scaleratio=1),
        sliders=sliders,
        legend=dict(
            yanchor="top",
            xanchor="right",
            x=1,
            y=1
        ),
        margin=dict(l=5, r=5, t=50, b=50)
    ))


def sample_parabola(x):
    return 0.15 * (x - 15) ** 2 + 15



## Convex Function

A function is a **convex function** when the line segment joining any two points on the function graph lies above or on the graph. In the simplest term, a convex function is shaped like $\cup$ and a **concave function** is shaped like $\cap$. If `f` is convex, then `-f` is concave.

A visualization from [Wikipedia](https://en.wikipedia.org/wiki/Convex_function):


In [2]:
#| code-summary: display image from Wikipedia 
#| code-fold: true

from IPython.display import Image
Image(url='https://upload.wikimedia.org/wikipedia/commons/c/c7/ConvexFunction.svg', width=400)

### Definition

A function is called **convex** if the following holds:

$$
f(\lambda x_1 + (1-\lambda) x_2) \le \lambda f(x_1) + (1-\lambda) f(x_2)
$$

and **concave** when:

$$
f(\lambda x_1 + (1-\lambda) x_2) \ge \lambda f(x_1) + (1-\lambda) f(x_2)
$$

### Geometric Intuition

In [3]:
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 22]),
)
fig.show(renderer="iframe")

An interactive visualization of the convex function: $f(x)=0.15(x - 15)^2 + 15$. We will use the same parabola during this post unless stated otherwise. You can use the slider to try different values of ($\lambda_1$, $\lambda_2)$, where $\lambda_2=1-\lambda_1$. 

We have a line segment that joins $(x_1, f(x_1))$ and $(x_2, f(x_2))$. We can sample any point along the line segment with $(\lambda_1 x_1 + \lambda_2 x_2, \lambda_1 f(x_1) + \lambda_2 f(x_2))$. For example:

- When $\lambda_1=1$, we get the first point
- When $\lambda_1=0$, we get the second point
- And when $\lambda_1=0.5$, we get the middle point of the line segment
- and so on...
 
This point is visualized with a black point above. Let's name it as **$A$**.

The point where the function graph intersects with the dotted line segment is described by: $(\lambda_1 x_1 + \lambda_2 x_2, f(\lambda_1 x_1 + \lambda_2 x_2))$. Let's name it as **$B$**. 

Then, the definition above is just asserting that $B_y \le A_y$ and we also have $A_x = B_x$. Note that we are just showing a single line segment, but this statement should be true for all similar line segments.

## Jensen's Inequality

Jensen's inequality is a generalization of the above convex function definition for more than 2 points.


### Definition 

Assume we have a **convex function** $f$ and $x_1, x_2, \cdots, x_n$ in $f$'s domain, and also positive weights $\lambda_1, \lambda_2, \cdots, \lambda_n$ where $\sum_{i=1}^n \lambda_i = 1$. Then Jensen's inequality can be stated as:

$$
f(\sum_{i=1}^n \lambda_i x_i) \le \sum_{i=1}^n \lambda_i f(x_i) 
$$

The equation is flipped for a **concave function** g: 

$$
g(\sum_{i=1}^n \lambda_i x_i) \ge \sum_{i=1}^n \lambda_i g(x_i)
$$

Note that we arrive at the same definition for convex function when $n=2$.

### Geometric Intuition

A numerous proofs are already available by other posts. I encourage you to checkout the following resources:

- [en.wikipedia.org/wiki/Jensen%27s_inequality#Proofs](https://en.wikipedia.org/wiki/Jensen%27s_inequality#Proofs)
- [brilliant.org/wiki/jensens-inequality](https://brilliant.org/wiki/jensens-inequality/)
- [artofproblemsolving.com/wiki/...](https://artofproblemsolving.com/wiki/index.php/Jensen%27s_Inequality)

Here I describe a geometric intuition, which resonates more with me.

#### Triangle 

Let's start with a triangle, i.e., $n=3$:

In [4]:
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 13, 25]),
)
fig.show(renderer="iframe")

As before, you can use the slider to try different values of $(\lambda_1, \lambda_2, \lambda_3)$ where $\lambda_1+\lambda_2+\lambda_3=1$.

We have a triangle that connects the points: $(x_1, f(x_1)), (x_2, f(x_2)), (x_3, f(x_3))$. 

In the $n=2$ case, we used $\lambda_1$ and $\lambda_2$ to sample a point along the line segment. 

In this $n=3$ case, it is similar, but we can sample any point inside or on the boundaries of the triangle with:

$$
\left(\lambda_1x_1+\lambda_2x_2+\lambda_3x_3, \lambda_1f(x_1)+\lambda_2f(x_2)+\lambda_3f(x_3)\right)
$$

For example:

- When $\lambda_i=1$ where $i \in \{1, 2, 3\}$, we get the point $(x_i, f(x_i))$
- When $\lambda_1=\lambda_2=\lambda_3=\frac{1}{3}$, we get the center of mass of the triangle

The black point in the visualization describes this point. Let's name it as **A**. 

Note that ($\lambda_1$, $\lambda_2$, $\lambda_3$) describes the [barycentric coordinate system](https://en.wikipedia.org/wiki/Barycentric_coordinate_system) w.r.t the triangle. Just in case you're already familiar with.

The point where the parabola meets the dotted line segment is described by:

$$
(\lambda_1x_1+\lambda_2x_2+\lambda_3x_3, f(\lambda_1x_1+\lambda_2x_2+\lambda_3x_3))
$$

If we name this point as **B**, then it is not difficult to see that Jensen's inequality is the same as $B_y \le A_y$.


#### Four Points or More

It is easy to generalize for $n>3$. I am adding it here for the sake of completeness:

In [5]:
fig = fig_jensen_inequality(
    f=sample_parabola,
    x_range=np.linspace(0, 30, 100),
    x=np.array([2, 13, 22, 25]),
)
fig.show(renderer="iframe")

In the general case, the point $(\sum_{i=1}^n \lambda_ix_i, \sum_{i=1}^n \lambda_if(x_i))$ describes a point inside or on the boundary of the convex hull enclosing the points: $(x_1, f(x_1)), (x_2, f(x_2)), \cdots, (x_n, f(x_n))$. The convex hull is always above or on the graph.


## Applications

### AM–GM inequality

TODO: see for now https://en.wikipedia.org/wiki/AM%E2%80%93GM_inequality

## The End

I hope you enjoyed this post. You can ask further questions on [my telegram channel](https://t.me/swemanml)