# Kernelized Stein Gradient

This notebook accompanies the original blog post [here](https://www.sanyamkapoor.com/machine-learning/stein-gradient).

## Install Dependencies

We use [PyTorch](https://pytorch.org/) for all our differentiation needs and [Altair](https://altair-viz.github.io/) for plotting.

In [1]:
# Uncomment this if the imports throw an error
# ! pip install altair>=2.4 numpy>=1.16 torch>=1.0

In [1]:
import math
import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.optim as optim
import altair as alt

alt.data_transformers.enable('default', max_rows=None)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  return torch._C._cuda_getDeviceCount() > 0


### Drawing Utilities

In [2]:
def get_density_chart(P, d=7.0, step=0.1):
  xv, yv = torch.meshgrid([
      torch.arange(-d, d, step), 
      torch.arange(-d, d, step)
  ])
  pos_xy = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), dim=-1)
  p_xy = P.log_prob(pos_xy.to(device)).exp().unsqueeze(-1).cpu()
  
  df = torch.cat([pos_xy, p_xy], dim=-1).numpy()
  df = pd.DataFrame({
      'x': df[:, :, 0].ravel(),
      'y': df[:, :, 1].ravel(),
      'p': df[:, :, 2].ravel(),
  })
  
  chart = alt.Chart(df).mark_point().encode(
    x='x:Q',
    y='y:Q',
    color=alt.Color('p:Q', scale=alt.Scale(scheme='viridis')),
    tooltip=['x','y','p']
  )
  
  return chart


def get_particles_chart(X):
  df = pd.DataFrame({
      'x': X[:, 0],
      'y': X[:, 1],
  })

  chart = alt.Chart(df).mark_circle(color='red').encode(
    x='x:Q',
    y='y:Q'
  )
  
  return chart

## RBF Kernel

In these experiments, we will use the *rbf* kernel. The kernel is defined as the squared exponential distance between the two vectors, parametrized by a bandwidth argument $\sigma$.

$$
k_{rbf}(\mathbf{x}, \mathbf{x}^\prime) = \exp{-\frac{1}{2\sigma^2}||\mathbf{x}-\mathbf{x}^\prime||^2}
$$

A vectorized version of the  kernel is given below. A few notes on the implementation follow.

In [3]:
class RBF(torch.nn.Module):
  def __init__(self, sigma=None):
    super(RBF, self).__init__()

    self.sigma = sigma

  def forward(self, X, Y):
    XX = X.matmul(X.t())
    XY = X.matmul(Y.t())
    YY = Y.matmul(Y.t())

    dnorm2 = -2 * XY + XX.diag().unsqueeze(1) + YY.diag().unsqueeze(0)

    # Apply the median heuristic (PyTorch does not give true median)
    if self.sigma is None:
      np_dnorm2 = dnorm2.detach().cpu().numpy()
      h = np.median(np_dnorm2) / (2 * np.log(X.size(0) + 1))
      sigma = np.sqrt(h).item()
    else:
      sigma = self.sigma

    gamma = 1.0 / (1e-8 + 2 * sigma ** 2)
    K_XY = (-gamma * dnorm2).exp()

    return K_XY
  
# Let us initialize a reusable instance right away.
K = RBF()

Selecting the bandwidth parameter $\sigma$ may be a painful task in itself. A popular heuristic chosen in literature is the *median* heuristic where we choose the bandwidth to be

$$
\sigma^2 = \frac{median^2}{2 \log{(n + 1)}}
$$

the median among distance of all pairs. This allows for gradient contribution from all the pairs when computing the gradient of the kernel during simulation of the ODE. Note that we use the `numpy` median function because the PyTorch median function does not behave as expected when the number of elements are even (and does not return the mean of the two central elements).

## Stein Variational Gradient Descent

We now simulate the following ODE for each particle $x_j$ in the system.

$$
\dot{x}_j = \frac{1}{n} \sum_{j = 1}^n \left[ k(x_j, x) \nabla_{x_j} \log{p(x_j)} + \nabla_{x_j} k(x_j, x)  \right]
$$

For stability reasons, we use Adam to allow for adaptive step size during the simulation. In fact, this can be replaced by any of the adaptive gradient descent techniques.

This is encapsulated in the `step` function below.

In [4]:
class SVGD:
  def __init__(self, P, K, optimizer):
    self.P = P
    self.K = K
    self.optim = optimizer

  def phi(self, X):
    X = X.detach().requires_grad_(True)

    log_prob = self.P.log_prob(X)
    score_func = autograd.grad(log_prob.sum(), X)[0]

    K_XX = self.K(X, X.detach())
    grad_K = -autograd.grad(K_XX.sum(), X)[0]

    phi = (K_XX.detach().matmul(score_func) + grad_K) / X.size(0)

    return phi

  def step(self, X):
    self.optim.zero_grad()
    X.grad = -self.phi(X)
    self.optim.step()

# Experiments

## Unimodal Gaussian

We will first run this on a Unimodal Gaussian. We initialize the particles in an overdispersed manner and see how they converge around the typical set of the distribution.

**NOTE**: Try increasing the number of particles $n$ and different initializations to see how the particles distribute themselves.

In [5]:
gauss = torch.distributions.MultivariateNormal(torch.Tensor([-0.6871,0.8010]).to(device),
        covariance_matrix=5 * torch.Tensor([[0.2260,0.1652],[0.1652,0.6779]]).to(device))

n = 10
X_init = (3 * torch.randn(n, *gauss.event_shape)).to(device)

In [6]:
gauss_chart = get_density_chart(gauss, d=7.0, step=0.1)

Let us see how this overdispersed initialization looks like. Note that initializations much farther away from the typical set of the distributions may take longer to converge.

In [7]:
gauss_chart + get_particles_chart(X_init.cpu().numpy())

In [8]:
X = X_init.clone()
svgd = SVGD(gauss, K, optim.Adam([X], lr=1e-1))
for _ in range(1000):
    svgd.step(X)

In [9]:
gauss_chart + get_particles_chart(X.cpu().numpy())

## Mixture of Gaussians

The exact same simulation without any manual fine tuning works even for a multimodal Gaussian. We will first create a generic PyTorch distribution which can help us build multiple kinds of Mixture of Gaussians.

In [11]:
class MoG(torch.distributions.Distribution):
  def __init__(self, loc, covariance_matrix):
    self.num_components = loc.size(0)
    self.loc = loc
    self.covariance_matrix = covariance_matrix

    self.dists = [
      torch.distributions.MultivariateNormal(mu, covariance_matrix=sigma)
      for mu, sigma in zip(loc, covariance_matrix)
    ]
    
    super(MoG, self).__init__(torch.Size([]), torch.Size([loc.size(-1)]))

  @property
  def arg_constraints(self):
    return self.dists[0].arg_constraints

  @property
  def support(self):
    return self.dists[0].support

  @property
  def has_rsample(self):
    return False

  def log_prob(self, value):
    return torch.cat(
      [p.log_prob(value).unsqueeze(-1) for p in self.dists], dim=-1).logsumexp(dim=-1)

  def enumerate_support(self):
    return self.dists[0].enumerate_support()

### Mixture of Two Gaussians

Here we create a mixture of two Gaussians where the means are symmetrically placed at $x=|5|$ and the covariance matrix is given by $\begin{pmatrix}0.5 & 0 \\ 0 & 0.5\end{pmatrix}$.

In [12]:
class MoG2(MoG):
  def __init__(self, device=None):
    loc = torch.Tensor([[-5.0, 0.0], [5.0, 0.0]]).to(device)
    cov = torch.Tensor([0.5, 0.5]).diag().unsqueeze(0).repeat(2, 1, 1).to(device)
  
    super(MoG2, self).__init__(loc, cov)
    
mog2 = MoG2(device=device)

In [13]:
n = 100
X_init = (5 * torch.randn(n, *mog2.event_shape)).to(device)

In [14]:
X = X_init.clone()
svgd = SVGD(mog2, K, optim.Adam([X], lr=1e-1))
for _ in range(1000):
    svgd.step(X)

In [15]:
mog2_chart = get_density_chart(mog2, d=7.0, step=0.1)

(mog2_chart + get_particles_chart(X_init.cpu().numpy())) | (mog2_chart + get_particles_chart(X.cpu().numpy()))

### Mixture of Six Gaussians

Here we create a mixture of six Gaussians where the means are spread around a circle of radius $5$and the covariance matrix is given by $\begin{pmatrix}0.5 & 0 \\ 0 & 0.5\end{pmatrix}$.

In [29]:
class MoG6(MoG):
  def __init__(self, device=None):
    def _compute_mu(i):
      return 5.0 * torch.Tensor([[
        torch.tensor(i * math.pi / 3.0).sin(),
        torch.tensor(i * math.pi / 3.0).cos()]])

    loc = torch.cat([_compute_mu(i) for i in range(1, 7)], dim=0).to(device)
    cov = torch.Tensor([0.5, 0.5]).diag().unsqueeze(0).to(device).repeat(6, 1, 1)

    super(MoG6, self).__init__(loc, cov)

mog6 = MoG6(device=device)

In [30]:
n = 100
X_init = (5 * torch.randn(n, *mog6.event_shape)).to(device)

In [31]:
X = X_init.clone()
svgd = SVGD(mog6, K, optim.Adam([X], lr=1e-1))
for _ in range(1000):
    svgd.step(X)

In [32]:
mog6_chart = get_density_chart(mog6, d=7.0, step=0.1)

(mog6_chart + get_particles_chart(X_init.cpu().numpy())) | (mog6_chart + get_particles_chart(X.cpu().numpy()))

## Mixture of Six Gaussians with One Particle

As we've noted in the blog post, using the one particle and a kernel where $\nabla_xk(x,x) = 0$, we achieve the classic MAP estimator. *rbf* kernel satisfies the gradient property. Let us see whether we get a MAP estimator. We should expect, just like in the classic mode-finding setting that the particle should end up in different modes for different runs.

In [33]:
X_init = (5 * torch.randn(1, *mog6.event_shape)).to(device)
X = X_init.clone()
svgd = SVGD(mog6, K, optim.Adam([X], lr=1e-1))

for _ in range(1000):
    svgd.step(X)

In [34]:
(mog6_chart + get_particles_chart(X_init.cpu().numpy())) | (mog6_chart + get_particles_chart(X.cpu().numpy()))