- Radial Basis Function (RBF) kernel

$$
K(\mathbf{x}_1, \mathbf{x}_2) = A\exp\left(-\frac{\|\mathbf{x}_1 - \mathbf{x}_2\|^2}{2l^2}\right)
$$

- $A$ : output_scale
- $l$ : length_scale

- mean-zero Gaussian Random Fields (GRFs)

$$
u(x) \sim \mathcal{GP}(0, K(\mathbf{x}_1, \mathbf{x}_2))
$$

In [3]:
import jax
import jax.numpy as jnp

In [4]:
jax.devices()

[cuda(id=0)]

# 
$$
|\mathbf{x}_1 - \mathbf{x}_2|^2
$$

In [5]:
N = 4
X = jnp.linspace(0, 1, N)[:, None]
X.shape

(4, 1)

In [6]:
X

Array([[0.        ],
       [0.33333334],
       [0.6666667 ],
       [1.        ]], dtype=float32)

In [7]:
jnp.expand_dims(X, 1).shape, jnp.expand_dims(X, 0).shape

((4, 1, 1), (1, 4, 1))

In [8]:
jnp.expand_dims(X, 1)

Array([[[0.        ]],

       [[0.33333334]],

       [[0.6666667 ]],

       [[1.        ]]], dtype=float32)

In [9]:
jnp.expand_dims(X, 0)

Array([[[0.        ],
        [0.33333334],
        [0.6666667 ],
        [1.        ]]], dtype=float32)

In [10]:
(jnp.expand_dims(X, 1) - jnp.expand_dims(X, 0)).shape

(4, 4, 1)

In [11]:
jnp.expand_dims(X, 1) - jnp.expand_dims(X, 0)

Array([[[ 0.        ],
        [-0.33333334],
        [-0.6666667 ],
        [-1.        ]],

       [[ 0.33333334],
        [ 0.        ],
        [-0.33333334],
        [-0.6666666 ]],

       [[ 0.6666667 ],
        [ 0.33333334],
        [ 0.        ],
        [-0.3333333 ]],

       [[ 1.        ],
        [ 0.6666666 ],
        [ 0.3333333 ],
        [ 0.        ]]], dtype=float32)

In [12]:
((jnp.expand_dims(X, 1) - jnp.expand_dims(X, 0))**2).squeeze()

Array([[0.        , 0.11111112, 0.44444448, 1.        ],
       [0.11111112, 0.        , 0.11111112, 0.4444444 ],
       [0.44444448, 0.11111112, 0.        , 0.1111111 ],
       [1.        , 0.4444444 , 0.1111111 , 0.        ]], dtype=float32)

In [13]:
jnp.sum((jnp.expand_dims(X, 1) - jnp.expand_dims(X, 0))**2, axis=2)

Array([[0.        , 0.11111112, 0.44444448, 1.        ],
       [0.11111112, 0.        , 0.11111112, 0.4444444 ],
       [0.44444448, 0.11111112, 0.        , 0.1111111 ],
       [1.        , 0.4444444 , 0.1111111 , 0.        ]], dtype=float32)

# Kernel

- Radial Basis Function (RBF) kernel

$$
K(\mathbf{x}_1, \mathbf{x}_2) = 
$$


$$
|\mathbf{x}_1 - \mathbf{x}_2|^2
$$

In [14]:
# Define RBF kernel
def RBF(x1, x2, params):
    output_scale, lengthscales = params
    diffs = jnp.expand_dims(x1 / lengthscales, 1) - \
            jnp.expand_dims(x2 / lengthscales, 0)
    r2 = jnp.sum(diffs**2, axis=2)
    return output_scale * jnp.exp(-0.5 * r2)

In [15]:
length_scale = 0.2
gp_params = (1.0, length_scale)
output_scale, length_scale = gp_params
diffs = jnp.expand_dims(X / length_scale, 1) - \
        jnp.expand_dims(X / length_scale, 0)
r2 = jnp.sum(diffs**2, axis=2)
r2.shape

(4, 4)

In [16]:
output_scale * jnp.exp(-0.5 * r2)

Array([[1.0000000e+00, 2.4935217e-01, 3.8659174e-03, 3.7266532e-06],
       [2.4935217e-01, 1.0000000e+00, 2.4935217e-01, 3.8659209e-03],
       [3.8659174e-03, 2.4935217e-01, 1.0000000e+00, 2.4935229e-01],
       [3.7266532e-06, 3.8659209e-03, 2.4935229e-01, 1.0000000e+00]],      dtype=float32)