# Restricted Boltzmann machine

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import BernoulliRBM
from scipy.special import expit

In [None]:
# set random seed
np.random.seed(1234567)

## Load data

In [None]:
# decide between binary and [0, 1]-valued data
IS_BINARY = True

In [None]:
# load data
x, _ = load_digits(return_X_y=True)

x = x.astype(np.float32)
x = x / 16.0  # scale data into interval [0, 1]

if IS_BINARY:
    x = np.where(x < 0.5, 0, 1)  # make binary data in {0, 1}

print(f'Data shape: {x.shape}')
print(
    f'Values in: {set(np.unique(x))}' if IS_BINARY
    else f'Values in: [{x.min()}, {x.max()}]'
)

In [None]:
# split into train and test set
x_train, x_test = train_test_split(x, test_size=0.2)

print(f'Number of train samples: {len(x_train)}')
print(f'Number of test samples: {len(x_test)}')

In [None]:
# plot train samples
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5.5))
random_train_ids = np.random.choice(len(x_train), size=axes.size, replace=False)
for idx, ax in enumerate(axes.ravel()):
    image = x_train[random_train_ids[idx]].reshape(8, 8)
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Train samples')
fig.tight_layout()

## Train RBM

In [None]:
# train RBM
rbm = BernoulliRBM(
    n_components=16,
    learning_rate=0.1,
    batch_size=10,
    n_iter=40,
    verbose=True
)

rbm.fit(x_train)

print(f'Weight matrix shape: {rbm.components_.shape}')  # (num_hidden, num_visible)
print(f'Bias shape (hidden): {rbm.intercept_hidden_.shape}')  # (num_hidden,)
print(f'Bias shape (visible): {rbm.intercept_visible_.shape}')  # (num_visible,)

In [None]:
# plot learned components
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(4, 4.5))
for idx, ax in enumerate(axes.ravel()):
    image = rbm.components_[idx].reshape(8, 8)
    ax.imshow(image, cmap='gray')
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Learned components')
fig.tight_layout()

## Compute hidden representations

In [None]:
# compute hidden unit activation probability p(h=1|v) = E[h|v]
z_train = rbm.transform(x_train)
z_test = rbm.transform(x_test)

print(f'Hidden shape (train): {z_train.shape}')
print(f'Hidden shape (test): {z_test.shape}')

In [None]:
# sample hidden units h_i
def sample_hidden(rbm, v):
    '''Sample hidden units according to p(h|v).'''
    p = rbm.transform(v)
    return np.random.uniform(size=p.shape) < p  # get binary sample

z_train_sample = sample_hidden(rbm, x_train)
z_test_sample = sample_hidden(rbm, x_test)

print(f'Hidden shape (train): {z_train_sample.shape}')
print(f'Hidden shape (test): {z_test_sample.shape}')

In [None]:
# plot hidden units
fig, axes = plt.subplots(nrows=10, ncols=2, figsize=(6, 4))
for idx, ax in enumerate(axes.ravel()):
    image = z_train_sample[random_train_ids[idx]].reshape(1, -1)
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Hidden representations')
fig.tight_layout()

## Compute reconstructions

In [None]:
# compute visible unit activation probability p(v=1|h) = E[v|h]
def rbm_inverse_transform(rbm, h):
    '''Compute p(v=1|h) for a given vector of hidden units.'''
    return expit(np.dot(h, rbm.components_) + rbm.intercept_visible_)  # see docstring of `BernoulliRBM._sample_visibles`

x_train_recon = rbm_inverse_transform(rbm, z_train)
x_test_recon = rbm_inverse_transform(rbm, z_test)

print(f'Reconstructed shape (train): {x_train_recon.shape}')
print(f'Reconstructed shape (test): {x_test_recon.shape}')

In [None]:
# plot visible units
fig, axes = plt.subplots(nrows=2, ncols=8, figsize=(9, 3.5))
random_test_ids = np.random.choice(len(x_test_recon), size=axes.size, replace=False)

for idx, ax in enumerate(axes[0]):
    image = x_train[random_test_ids[idx]].reshape(8, 8)
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set_title('$v^{{({})}}$'.format(idx + 1))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

for idx, ax in enumerate(axes[1]):
    image = x_train_recon[random_test_ids[idx]].reshape(8, 8)
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set_title('$\\hat{{v}}^{{({})}}$'.format(idx + 1))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')

fig.suptitle('Reconstructions')
fig.tight_layout()