# Stein Variational Gradient Descent
Step-by-step through a general SVGD implementation in `tensorflow 2`

In [6]:
import tensorflow as tf
import tensorflow_probability as tfp

For the actual algorithm we need an update loop, a reasonable optimizer (e.g. Adam) and a couple of gradients. Also we need a kernel, but we will 
take of that afterwards. Let us assume the algorithm is already initialized with a kernel function that we can call with the current samples that 
returns a differentiable kernel matrix. In each iteration, the main method `update` first computes `kernelMatrix, kernelGrad` and `logprobGradient` and then 
simply computes the SVGD-gradient $\frac{1}{N} (K \nabla \log p(x) - \nabla K)$. Check out [this discussion](https://github.com/activatedgeek/svgd/issues/1)
to see why the vectorized kernel gradient computation needs a **minus sign**.

Also note how SVGD is defined for pertrubations which are **added** to the particles, not subtracted, as is customary for gradient descent and thus we counter the optimizers standard behaviour by another **minus sign** in front of the gradient.

In [None]:
class SVGD:
    def __init__(self, kernel, targetDistribution, learningRate):
        self.kernel = kernel
        self.targetDistribution = targetDistribution
        self.optimizer = tf.keras.optimizers.Adam(learningRate)

    def update(self, x, nIterations):
        for _ in range(nIterations):
            kernelMatrix, kernelGrad = self.computeKernel(x)
            logprobGradient = self.logprobGradient(x)

            # minus to cancel out the negative descent of Adam
            completeGrad = -(kernelMatrix @ logprobGradient + kernelGrad) / x.shape[0]
            self.optimizer.apply_gradients([(completeGrad, x)])

        return x

    def computeKernel(self, x):
        with tf.GradientTape() as tape:
            kernelMatrix = self.kernel(x)

        # why minus? see https://github.com/activatedgeek/svgd/issues/1 (right at the end)
        return kernelMatrix, -tape.gradient(kernelMatrix, [x])[0]

    def logprobGradient(self, x):
        with tf.GradientTape() as tape:
            logprob = tf.math.log(self.targetDistribution(x))
        return tape.gradient(logprob, [x])[0]

Now for the kernel we implement the RBF-kernel but any other proper kernel can be used as well. We start by defining the eculidean pairwise difference generalized to a matrix of particles (where particles can be in any $\mathbb{R}^k$), which is the matrix version of 

$\left\lVert x - x' \right\rVert^2_2$.

We stop the gradient of $x'$ since we want to take the gradient only w.r.t $x$:

In [5]:
@tf.function
def euclideanPairwiseDistance(x):
    distance = tf.expand_dims(x, 1) - tf.expand_dims(tf.stop_gradient(x), 0)
    return tf.einsum('ijk,kji->ij', distance, tf.transpose(distance))

We use `tf.einsum` ([Einstein summation](https://www.tensorflow.org/api_docs/python/tf/einsum)) to do matrix computations in three dimensions.

Now we define the RBF-kernel $\exp -\dfrac{1}{2h^2} \left\lVert x - x' \right\rVert^2_2$, while estimating bandwidth $h = \text{med}^2 / \log(N + 1)$, see [the SVGD paper, Section 5](https://arxiv.org/pdf/1608.04471.pdf). This allows for a dynamic adjustment according to the particles.

We also stop the gradient for the bandwidth to improve the kernel gradient (we treat $h$ as a constant). Since `tensorflow` has no specific function to 
calculate the median, we use the one-liner from [this answer on SO](https://stackoverflow.com/a/47657076).

In [8]:
class RbfKernel:
    @tf.function
    def __call__(self, x):
        normedDist = euclideanPairwiseDistance(x)
        bandwidth = tf.stop_gradient(self.computeBandWidth(normedDist))
        return tf.exp(-0.5 * normedDist / bandwidth**2)

    @tf.function
    def computeBandWidth(self, euclideanPwDistances):
        pwDistanceMedian = tfp.stats.percentile(
            euclideanPwDistances, 50.0, interpolation='midpoint')

        n = tf.Scalar(euclideanPwDistances.shape[0])
        return pwDistanceMedian / tf.math.log(n + 1)

You can find this implementation also in the repo and we will import it from there for the demo notebooks.