In [None]:
import scipy.io
import numpy as np
from matplotlib import pyplot as plt
from rbm import RBM


In [None]:
mat = scipy.io.loadmat('binaryalphadigs.mat')
data = np.array(mat['dat'])

In [None]:
# lire_alpha_digits
def read_alpha_digits(data, digits):
    """Read digits from the AlphaDigits dataset

    Parameters
    ----------
    data : np.array
    digits : List[int]
        Indexes of the digit classes to read

    Returns
    -------
    np.array
        Array of all the samples from requested digits, in binary array form 
    """    
    read_data = np.take(data.copy(), digits, axis=0)
    read_data = np.reshape(read_data, (len(digits) * read_data.shape[1]))
    return np.array(list(read_data))

# read_alpha_digits(data, [10, 11, 12])

In [None]:
def display_samples(X, n_rows = 10, n_cols = 10, fig_x=10, fig_y=10):
    """
    Display a grid of samples.
    Samples are chosen randomly if there are more samples than n_rows X n_cols.
    """

    # Configs
    plt.figure(figsize=(fig_x,fig_y)) ## Create figure
    n = X.shape[0] # Obtain number of samples to display

    if n <= n_rows * n_cols: 
        # If number of samples fits in the columns/rows
        samples_idx = np.arange(n)
        if n < n_rows * n_cols:
            n_rows = n // n_cols + 1
    else: 
        # If not, choose randomly which images to show
        samples_idx = np.random.choice(n, size=n_rows * n_cols, replace=False)

    for k, idx in enumerate(samples_idx):
        plt.subplot(n_rows, n_cols, k+1)
        plt.imshow(X[idx], cmap=plt.get_cmap('gray'))

        # Hide ticks
        # plt.axis('off')
        plt.xticks([], [])
        plt.yticks([], [])

    plt.show()


digits = read_alpha_digits(data, [10, 11, 12])
# display_samples(digits)

## Test bench

In [None]:
input_digit_idx = [10,11,12]
digits = read_alpha_digits(data, input_digit_idx)

In [None]:
# display_samples(digits)
print(digits.shape)

In [None]:
digits_vecs = digits.reshape(len(input_digit_idx)*39,20*16)

print(digits_vecs.shape)
print(digits.shape)

In [None]:
# ---------
# Constants
# ---------
# Data dample dim
p = 20*16 # For Binary AlphaDigits
# Latent vector dim
q = 100

rbm = RBM(p,q)
rbm.train_RBM(digits_vecs, batch_size=4, n_epoch=100, verbose=True)

In [None]:
out = rbm.generate_data(100, nb_iter_gibbs=100, random_init=True)
display_samples(out)

In [None]:
out = rbm.generate_data(100, nb_iter_gibbs=100, random_init=False)
display_samples(out)

# Analyse de l'erreur de reconstruction

In [None]:
import seaborn as sns

In [None]:
digits = read_alpha_digits(data, [10])
digits_vecs = digits.reshape(39,20*16)

### Dimensionality of hidden layer (q)

In [None]:
p = 20 * 16
errors = []
for q in range(20, 400, 20):
    rbm = RBM(p, q)
    rbm.train_RBM(digits_vecs, batch_size=4, n_epoch=100, verbose=False)
    
    errors.append(np.mean(np.power(digits_vecs - rbm.decode(rbm.encode(digits_vecs)),2)))

sns.lineplot(x=range(20, 400, 20), y=errors)
# y: reconstruction error
# x: number of hidden units

### Batch size

In [None]:
q = 100
errors_batch_size = []
for batch_size in range(1, 40, 1):
    rbm = RBM(p, q)
    rbm.train_RBM(digits_vecs, batch_size=batch_size, n_epoch=100, verbose=False)
    
    errors_batch_size.append(np.mean(np.power(digits_vecs - rbm.decode(rbm.encode(digits_vecs)),2)))

sns.lineplot(x=range(1, 40, 1), y=errors_batch_size)
# y: reconstruction error
# x: batch size


## Analyse en fonction du nombre de caractères à apprendre

In [None]:
digits = []
p = 20*16
q = 100
errors = []
outs = []
for new_digit in range(10, len(data)):
    digits.append(new_digit)
    digits_vecs = read_alpha_digits(data, digits).reshape((-1, 20*16))

    rbm = RBM(p, q)
    rbm.train_RBM(digits_vecs, batch_size=4, n_epoch=100, verbose=False)
    errors.append(np.mean(np.power(digits_vecs - rbm.decode(rbm.encode(digits_vecs)),2)))
    #rbm.generate_data(4, nb_iter_gibbs=5)

sns.lineplot(x=range(1, len(data) - 9), y=errors)

# Analyse avec MNIST

In [None]:
from sklearn import datasets

mnist = datasets.fetch_openml('mnist_784')
# Make mnist binary
# digits = (digits > 8) * 1
# display_samples(digits.reshape((-1, 8, 8)))

In [None]:
mnist_digits = mnist.data.to_numpy()
display_samples(mnist_digits.reshape((-1, 28, 28)))

In [None]:
#Make digits binary
mnist_digits_bin = (mnist_digits > 128) * 1
display_samples(mnist_digits_bin.reshape((-1, 28, 28)))

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test = train_test_split(mnist_digits_bin, test_size=0.3)

In [None]:
X_train.shape, X_test.shape

In [None]:
# Train a RBM in MNIST
p = 28*28
q = 400
rbm = RBM(p, q)
rbm.train_RBM(X_train, batch_size=256, n_epoch=100, verbose=True)
rbm.generate_data(10, nb_iter_gibbs=100, height=28, width=28)