In [None]:
using MLDatasets
using PyPlot
using Random, Statistics
using Flux: onehotbatch

# Rapid intro to unsupervised learning with Restricted Boltzmann Machines

This notebook gives a rapid introduction to unsupervised learning with Restricted Boltzmann Machines (RBMs). We will again use the MNIST handwritten digits as example data, but now to goal is to teach the computer to write digits which look like those in the training data set. This is a very basic example of an application of generative modelling; click [here](https://thispersondoesnotexist.com/) to see some impressive results generated using a more sophisticated approach in the same spirit.

In contrast to supervised learning tasks, there are no labels associated with the data in unsupervised learning. Instead of predicting labels, the goal of unsupervised learning is to model the distribution $p_X(X)$ of data $X$, for example in order to identify characteristic features of the distribution.

In the following I will refer to the review article ["A high-bias, low-variance introduction to Machine Learning for physicists"](https://arxiv.org/abs/1803.08823) for the technical details of training a RBM.

## Restricted Boltzmann Machine

The RBM is an energy-based generative model defined by an energy function

$$
    E_{\theta}(\vec v, \vec h) = -\sum_{i=1}^{N_v}a_iv_i -\sum_{\mu=1}^{N_h}b_\mu h_\mu 
    -\sum_{i=1}^{N_v}\sum_{\mu=1}^{N_h}W_{i\mu}v_ih_\mu
$$

of $N_v$ *visible units* $\vec v$ and $N_h$ *hidden units* $\vec h$, which take binary values $v_i,h_\mu\in\{0,1\}$. The bias vectors $\vec a$ and $\vec b$ together with the weight matrix $W$ make up the variational parameters $\theta=(\vec a, \vec b, W)$.

The corresponding joint distribution of visible and hidden units is defined as $p_{\theta}(\vec v, \vec h)=e^{-E_{\theta}(\vec v, \vec h)}$ (this is the "Boltzmann" in the name RBM). With this setup the idea is that the visible units $\vec v$ correspond to the (high-dimensional) data and the hidden units $\vec h$ are auxiliary degrees of freedom, which mediate correlations between different components of $\vec v$. Hence, the goal of modeling a distribution of data means that we want to find the marginal distribution

$$p_\theta(\vec v) = \sum_{\vec b\in\{0,1\}^{N_h}}p_{\theta}(\vec v, \vec h)$$

that matches the given training data best.

A suited cost function that we can aim to **maximize** for this purpose is the **log-likelihood**

$$\mathcal L(\theta)=-\frac{1}{|\mathcal T_X|}\sum_{\vec x\in\mathcal T_X}\log\big(p_\theta(\vec x)\big)$$

where $\mathcal T_X$ denotes the training data set.

For details see chapter XV in ["A high-bias, low-variance introduction to Machine Learning for physicists"](https://arxiv.org/abs/1803.08823).

## Training data: MNIST

Like in the supervised learning example, we start by loading the training data.

In [None]:
# load full training set
trainDataAll, trainLabels = MNIST.traindata()
trainDataAll = convert.(Int32, floor.(255*trainDataAll));

Training the RBM directly on the ten different 8-bit grayscale images will be too ambitious for our purposes. Therefore, we transform the images from grayscale to binary black-and-white images and we group the images by their labels:

In [None]:
trainDataAll = div.(trainDataAll, 128)

trainData = Dict()

for n in 0:9
    trainData[string(n)] = trainDataAll[:,:,findall(trainLabels.==n)]
end

The `plot_images` function below plot `rows`x`cols` randomly selected examples from a "stack" of images, i.e., a 3-dimensional array, where the last two dimensions correspond to the image dimensions.

Let's look at some example digits:

In [None]:
function plot_images(data; rows=4, cols=4, figsize=(7,7))
    # For a stack of images `data` (3d-array), plot `rows`x`cols` randomly selected examples
    
    fig, axs = subplots(rows,cols,figsize=figsize)
    
    for i in 1:rows*cols
        idx = rand(1:size(data)[3])
        axs[i].imshow(transpose(data[:,:,idx]))
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    end
    tight_layout() 
    show()
end
            
# Plot some example digits
plot_images(trainData["8"])

## Gibbs sampling

The RBM distribution $p_{\theta}(\vec v, \vec h)$ has the useful property that the conditional distributions of hidden or visible units factorize as

$$
p(\vec v|\vec h)=\prod_ip(v_i|\vec h)\\
p(\vec h|\vec v)=\prod_\mu p(h_\mu|\vec v)
$$

with

$$
p(v_i=1|\vec h)=\sigma(a_i+\sum_\mu W_{i\mu}h_\mu)\\
p(h_\mu=1|\vec v)=\sigma(b_\mu+\sum_i W_{i\mu}v_i)
$$

This enables a Markov Chain Monte Carlo scheme called **Gibbs** sampling, where realizations of $\vec v$ and $\vec h$ are sampled *directly* using the conditional distributions above, see Fig. 62 in ["A high-bias, low-variance introduction to Machine Learning for physicists"](https://arxiv.org/abs/1803.08823).

**Let's implement this:**

*Hint:* We can draw random bernoulli outcomes according to a vector of probabilities `p` with `[rand() < q for q in p]`.

In [None]:
function sigmoid(x)
    return 1. / (1. + exp.(-x))
end

function p_h_given_v(v, W, b)
    # Compute the vector p(h_mu | v)
    
    return sigmoid.(b + transpose(W) * v)
end


function p_v_given_h(h, W, a)
    # Compute the vector p(v_i | h)
    
    return sigmoid.(a + W * h)
end


function gibbs_step(v, W, a, b)
    # This function performs one step of Gibbs sampling by sampling
    # a new hidden outcome followed by a new outcome of visible units
    #
    # Input arguments:  v - starting configuration of visible units
    #                   key - jax.random.PRNGKey

    # sample a realization from p(h_mu | v)
    
    p = p_h_given_v(v, W, b)
    h = [rand() < q for q in p]
    
    # sample a realization from p(v_i | h)
    
    p = p_v_given_h(h, W, a)
    v_new = [rand() < q for q in p]
    
    return v_new
end

function gibbs_sample(v, W, a, b, n)
    # Starting from a visible configuration `v` this function performs
    # `n` steps of Gibbs sampling and returns the new configuration
    sample = v
    
    for j in 1:n
        for k in 1:size(v)[2]
            sample[:,k] = gibbs_step(sample[:,k], W, a, b)
        end
    end
    
    return sample
end

## Gradients with Contrastive Divergence

Due to the particular form of the RBM, the gradients of our cost function (log-likelihood) have a simple form:

$$
\frac{\partial\mathcal L(W,\vec a, \vec b)}{\partial W_{i\mu}}
=
\langle v_ih_\mu\rangle_{\text{data}}-\langle v_ih_\mu\rangle_{\text{model}}
\\
\frac{\partial\mathcal L(W,\vec a, \vec b)}{\partial a_i}
=
\langle v_i\rangle_{\text{data}}-\langle v_i\rangle_{\text{model}}
\\
\frac{\partial\mathcal L(W,\vec a, \vec b)}{\partial b_\mu}
=
\langle h_\mu\rangle_{\text{data}}-\langle h_\mu\rangle_{\text{model}}
$$

Here, $\langle \cdot\rangle_{\text{data}}$ denotes a mean over the training data and $\langle \cdot\rangle_{\text{model}}$ denotes the mean over a sample drawn from our RBM distribution $p_{\theta}(\vec v)$. Since $p(\vec v, \vec h)=p(\vec v)p(\vec h|\vec v)=p(\vec v)\prod_{\mu}p(h_\mu|\vec v)$ and $h_\mu\in\{0,1\}$, 

$$
\langle v_ih_\mu\rangle
=\sum_{\vec v,\vec h}p(\vec v, \vec h) v_ih_\mu
=\sum_{\vec v}p(\vec v)\sum_{\vec h}p(\vec h|\vec v) v_ih_\mu
=\sum_{\vec v}p(\vec v)\sum_{h_\mu\in\{0,1\}}p(h_\mu|\vec v) v_ih_\mu
=\sum_{\vec v}p(\vec v)p(h_\mu=1|\vec v) v_i
$$

and the empirical means can be rewritten, e.g. as

$$
\langle v_ih_\mu\rangle_{\mathcal S}
=\frac{1}{|\mathcal S|}\sum_{\vec v\in\mathcal S} v_ip_\theta(h_\mu=1|\vec v)
$$

where $\mathcal S$ denotes the training data set or a sample drawn from $p_{\theta}(\vec v)$.

**Let's implement this:**

In [None]:
function pcd_gradients(batch, W, a, b; n=1, modelSample=nothing)
    
    batchSize = size(batch)[3]
    imgSize = size(batch)[1] * size(batch)[2]
    
    if modelSample == nothing
        modelSample = copy(batch)
    end
    
    # Sample outcomes from the RBM
    modelSample = gibbs_sample(reshape(modelSample, (imgSize, batchSize)), W, a, b, n)
    
    # Flatten the input batch
    flatBatch = reshape(batch, (imgSize, batchSize))
    
    # Compute W-gradients
    W_grad = zeros(size(W))
    for j in 1:batchSize
        W_grad .+= flatBatch[:,j] .* transpose(p_h_given_v(flatBatch[:,j], W, b))
        W_grad .-= modelSample[:,j] .* transpose(p_h_given_v(modelSample[:,j], W, b))
    end
    W_grad ./= batchSize
    
    # Compute a-gradients
    a_grad = mean(flatBatch, dims=2) - mean(modelSample, dims=2)
    a_grad = reshape(a_grad, :)
    
    # Compute b-gradients
    b_grad = zeros(size(b))
    for j in 1:batchSize
        b_grad .+= p_h_given_v(flatBatch[:,j], W, b) - p_h_given_v(modelSample[:,j], W, b)
    end
    b_grad ./= batchSize
    
    return W_grad, a_grad, b_grad, reshape(modelSample, size(batch))
end

## Training loop

The function below implements the training loop.

Input parameters are

- `W`: weight matrix
- `a`: visible bias
- `b`: hidden bias
- `trainData`: Training data. 3-dimensional array, where the last two dimensions are image dimensions.
- `learningRate`: learning rate
- `numEpochs`: number of epochs for training
- `batchSize`: batchSize
- `cg_n`: number of iterations between samples in the Gibbs MCMC sampling
- `persistent`: boolean indicating whether to perform persistent contrastive divergence or not
- `seed`: seed for random number generator

The function returns the RBM parameters obtained at the end of training.

In [None]:
function train(W, a, b, trainData; learningRate=0.01, numEpochs=10, batchSize=128, cg_n=2, persistent=false, seed=1234)

    Random.seed!(seed)
    
    modelSample = nothing
    
    batchNumber = div(size(trainData)[3], batchSize)
    
    # Training loop over epochs
    for n in 1:numEpochs

        println("Epoch $n")

        # Generate randomly shuffled batches
        order = shuffle(1:size(trainData)[3])
        batches = reshape(trainData[:,:,order][:,:,1:Int(batchNumber*batchSize)], 28,28,batchSize,:)

        for k in 1:batchNumber
            batch = batches[:,:,:,k]

            if !persistent
                modelSample = nothing
            end
            
            # Compute gradients
            Wg, ag, bg, modelSample = pcd_gradients(batch, W, a, b, n=cg_n, modelSample=modelSample)

            # Update parameters with gradients
            
            W .+= learningRate .* Wg
            a .+= learningRate .* ag
            b .+= learningRate .* bg
        end
        
        plot_images(modelSample) # Show some example images generated by the RBM
    end
    return W, a, b
end

Finally, we are set to train the RBM

In [None]:
numVisible=28*28
numHidden = 256

W = 0.01 * randn((numVisible, numHidden))
a = zeros(numVisible)
b = zeros(numHidden)

W, a, b = train(W, a, b, trainData["8"], numEpochs=20);

## Inspecting the features

Now we can inspect which features were learned in the weight matrix $W$ by plotting individual lines reshaped to the image dimensions:

In [None]:
plot_images(reshape(W,(28,28,:)),rows=10,cols=10,figsize=(10,10))

## Learning multiple digits

The RBM can not only learn to generate one single digit. Let's learn two at a time:

In [None]:
numVisible=28*28
numHidden = 256

W = 0.01 * randn((numVisible, numHidden))
a = zeros(numVisible)
b = zeros(numHidden)

examples = cat(trainData["8"],trainData["4"], dims=3)

W, a, b = train(W, a, b, examples, numEpochs=20);

In [None]:
plot_images(reshape(W,(28,28,:)),rows=10,cols=10,figsize=(10,10))