In [1]:
from pathlib import Path

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import Image, display


def l2_norm(x):
    return torch.linalg.vector_norm(x, dim=-1, keepdim=True)


x = torch.tensor(
    [
        [1.68, 0.64],
        [1.17, 1.45],
        [-1.98, -0.93],
    ]
)

p = Path("11-NormalizationFunctions")
p.mkdir(exist_ok=True)

In [2]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(x.shape[0]):
    ax.plot([x[i, 0], 0], [x[i, 1], 0], color=f"C{i}", marker="o", markevery=2)

ax.add_patch(mpatches.Circle((0, 0), 1, fill=False))
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.axvline(0, lw=0.5, color="black")
ax.axhline(0, lw=0.5, color="black")
ax.set_aspect("equal")
ax.set_title("Inputs")

fig.set_facecolor("white")
fig.savefig(p / "inputs.svg")
plt.close(fig)
display(Image(url=p / "inputs.svg"))

Norm

$$y = \frac{x}{||x||}$$

All points end up on the unit circle.

In [3]:
y = x / l2_norm(x)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(x.shape[0]):
    ax.plot(
        [x[i, 0], 0], [x[i, 1], 0], color=f"C{i}", marker="o", markevery=2, alpha=0.5
    )
    ax.plot([y[i, 0], 0], [y[i, 1], 0], color=f"C{i}", marker="o", markevery=2)

ax.add_patch(mpatches.Circle((0, 0), 1, fill=False))
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.axvline(0, lw=0.5, color="black")
ax.axhline(0, lw=0.5, color="black")
ax.set_aspect("equal")
ax.set_title("L2 Norm")

fig.set_facecolor("white")
fig.savefig(p / "l2_norm.svg")
plt.close(fig)
display(Image(url=p / "l2_norm.svg"))

Squash (original)

$$y = \frac{||x||^2}{1 + ||x||^2} \frac{x}{||x||}$$

All points end up inside the unit circle, with a radius proportional to their original norm.

In [4]:
y = (l2_norm(x) ** 2 / (1 + l2_norm(x)) ** 2) * (x / l2_norm(x))

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(x.shape[0]):
    ax.plot(
        [x[i, 0], 0], [x[i, 1], 0], color=f"C{i}", marker="o", markevery=2, alpha=0.5
    )
    ax.plot([y[i, 0], 0], [y[i, 1], 0], color=f"C{i}", marker="o", markevery=2)

ax.add_patch(mpatches.Circle((0, 0), 1, fill=False))
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.axvline(0, lw=0.5, color="black")
ax.axhline(0, lw=0.5, color="black")
ax.set_aspect("equal")
ax.set_title("Squash")

fig.set_facecolor("white")
fig.savefig(p / "squash.svg")
plt.close(fig)
display(Image(url=p / "squash.svg"))

Squash (improved)

$$y = \left(1 - \frac{1}{\exp ||x||}\right) \frac{x}{||x||}$$

All points end up inside the unit circle, with a radius proportional to their original norm.
However, the proportionality function is different from the original squash.

In [5]:
y = (1 - 1 / l2_norm(x).exp()) * (x / l2_norm(x))

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(x.shape[0]):
    ax.plot(
        [x[i, 0], 0], [x[i, 1], 0], color=f"C{i}", marker="o", markevery=2, alpha=0.5
    )
    ax.plot([y[i, 0], 0], [y[i, 1], 0], color=f"C{i}", marker="o", markevery=2)

ax.add_patch(mpatches.Circle((0, 0), 1, fill=False))
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.axvline(0, lw=0.5, color="black")
ax.axhline(0, lw=0.5, color="black")
ax.set_aspect("equal")
ax.set_title("Squash Improved")

fig.set_facecolor("white")
fig.savefig(p / "squash_improved.svg")
plt.close(fig)
display(Image(url=p / "squash_improved.svg"))

Layer normalization

$$y = \frac{x - \text{mean}(x)}{\text{std}(x)}$$

All points are projected to a $D-1$ subspace that is the intersection between the unit hypersphere (a circle) and the plane perpendicular to $(1, 1, \ldots, 1) \in \mathbb{R}^D$ (the bisector of the 2nd and 4th quadrants).

In [6]:
y = (x - x.mean(dim=-1, keepdim=True)) / x.std(dim=-1, keepdim=True)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for i in range(x.shape[0]):
    ax.plot(
        [x[i, 0], 0], [x[i, 1], 0], color=f"C{i}", marker="o", markevery=2, alpha=0.5
    )
    ax.plot([y[i, 0], 0], [y[i, 1], 0], color=f"C{i}", marker="o", markevery=2)

ax.add_patch(mpatches.Circle((0, 0), 1, fill=False))
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.axvline(0, lw=0.5, color="black")
ax.axhline(0, lw=0.5, color="black")
ax.set_aspect("equal")
ax.set_title("Layer Normalization")

fig.set_facecolor("white")
fig.savefig(p / "layer_norm.svg")
plt.close(fig)
display(Image(url=p / "layer_norm.svg"))