# Implementation of a Supervised Restricted Boltzmann Machine (SRBM)

## Libraries

In [None]:
using LinearAlgebra
using Random
using DelimitedFiles
using Printf
using Plots

## Definition of structs

We define two structs that will help keep the code more organized and easier to maintain. By using a struct for the architecture of the RBM as well as one for the hyperparameters, we can easily pass them as arguments to functions.
Additionally, we could define default values for some of the fields, which would make it easier to modify and experiment with different hyperparameters.

In [None]:
Base.@kwdef mutable struct SRBM{T<:AbstractFloat}
    num_visible   :: Int
    num_hidden    :: Int
    num_labels    :: Int
    W             :: Matrix{T} = 0.01*(randn(num_visible+num_labels,num_hidden) .- 0.0)
    b             :: Vector{T} = zeros(num_visible+num_labels)
    c             :: Matrix{T} = zeros(num_hidden,num_labels)
end;

Base.@kwdef mutable struct hyperparameters{T<:AbstractFloat}
    learning_rate :: T         = 0.001
    weight_decay  :: T         = 0.000001
    momentum      :: T         = 0.9
    batch_size    :: Int       = 100
    num_epochs    :: Int       = 1000
    CDK           :: Int       = 1
    skip          :: Int       = 1
end;

## Activation function

For neurons that take values $(\sigma,z) \in \{-1,1\}^{N \times K}$, the activation function becomes $\frac{1}{2}(1 + tanh(x))$ instead of the sigmoid.

In [None]:
function activation(x)
    return 0.5*(1.0 + tanh.(x));
end;

## Expectation function
In the case of binary neurons $\{0,1\}$ it is the sigmoid, but for binary neurons $\{-1,1\}$ it is the hyperbolic tangent.

In [None]:
function expectation(x)
    return tanh.(x)
end;

## Sampling procedures
When calculating the gradients, we need to use Gibbs sampling in order to approximate the sum over all the state space. For that, we include 2 different functions: one for sampling hidden states, and the other for visible states with additional bits that encode the label.

In [None]:
function sample_hidden(p)
    h = p .> rand(size(p,1),size(p,2));
    return 2.0*h .- 1.0;
end;

function sample_visible(rbm::SRBM,p,y)
    v  = p .> rand(size(p,1),size(p,2));
    v[rbm.num_visible+1:end,:] = Id[:,y];
    return 2.0*v .- 1.0;
end;

## Calculating the gradients
One can use contrastive divergence to estimate the gradient of the log-likelihood of the model with respect to the parameters, and then update them in the direction of the negative gradient.

In [None]:
function cd_gradients(rbm::SRBM, V0::Matrix{T}, y::Vector{Int}, parameters::hyperparameters) where T<:AbstractFloat
    # Positive phase
    PH_V0  = expectation.(rbm.W'*V0 + rbm.c[:,y]);             # (Nh,b)
    W_pos  = V0*PH_V0';                                        # (Nv+Nl,Nh)
    b_pos  = V0;                                               # (Nv+Nl,b)
    c_pos  = PH_V0;                                            # (Nh,b)
    
    # k-step Contrastive Divergence (CD-k)
    VK     = copy(V0);                                         # (Nv+Nl,b)
    for _ in 1:parameters.CDK
        PH_V   = activation.(rbm.W'*VK + rbm.c[:,y]);          # (Nh,b)
        HK     = sample_hidden(PH_V);                          # (Nh,b)
        PV_H   = activation.(rbm.W*HK .+ rbm.b);               # (Nv+Nl,b)
        VK     = sample_visible(rbm,PV_H,y);                   # (Nv+Nl,b)
    end
    
    # Negative phase
    PH_VK  = expectation.(rbm.W'*VK + rbm.c[:,y]);             # (Nh,b)
    W_neg  = VK*PH_VK';                                        # (Nv+Nl,Nh)
    b_neg  = VK;                                               # (Nv+Nl,b)
    c_neg  = PH_VK;                                            # (Nh,b)
    
    # Compute the gradients
    dW     = (W_pos - W_neg)/parameters.batch_size;            # (Nv+Nl,Nh)
    db     = (b_pos - b_neg)/parameters.batch_size;            # (Nv+Nl,b)
    dc     = (c_pos - c_neg)/parameters.batch_size;            # (Nh,b)
    
    return dW, db, dc;
end;

## Mini-batch Gradient Descent
With our function that estimates the gradients, we only need to implement the iterative process over all epochs in which each step looks at a batch of a certain size, computes the gradients with respect to that batch and updates the parameters.

In [None]:
function train_srbm(rbm::SRBM, TSet::Matrix{T}, labels::Vector{Int}, parameters::hyperparameters) where T<:AbstractFloat
    # Prepare values for creating the batches
    NTSet       = size(TSet,2);
    indices     = collect(1:NTSet);
    num_batches = Int(cld(NTSet,parameters.batch_size));
    
    # Define matrix with labels mapped to -1,1
    Identity    = 2.0*Id .- 1.0;
    
    # Initialize the velocities
    vW     = zero(rbm.W);                                       # (Nv+Nl,Nh)
    vb     = zero(rbm.b);                                       # (Nv,1)
    vc     = zero(rbm.c);                                       # (Nh,Nl)
    
    # Start the main loop
    for epoch in 1:parameters.num_epochs
        shuffle!(indices)
        
        # Initialize the change in parameters
        dW = zero(rbm.W);                                       # (Nv+Nl,Nh)
        db = zero(rbm.b);                                       # (Nv,1)
        dc = zero(rbm.c);                                       # (Nh,Nl)
        
        # Iterate through batches
        for batch in 1:num_batches
            # Get indices of batch
            batch_indices = indices[(batch-1)*parameters.batch_size+1:min(batch*parameters.batch_size,NTSet)];
            
            # Get labels and batches of data
            y  = labels[batch_indices];
            V0 = [TSet[:,batch_indices] ; Identity[:,y]];
            dW, db_i, dc_i = cd_gradients(rbm, V0, y, parameters);
            
            # Sum change in bias
            db = sum(db_i,dims=2);
            dc = dc_i*Id[:,y]';
            
            db = vec(db);
            
            # Add L2 regularization to punish large values
            dW             -= parameters.weight_decay*rbm.W;
            db             -= parameters.weight_decay*rbm.b;
            dc             -= parameters.weight_decay*rbm.c;
            
            # Update velocities
            vW              = parameters.momentum*vW .+ (1.0 - parameters.momentum)*dW;
            vb              = parameters.momentum*vb .+ (1.0 - parameters.momentum)*db;
            vc              = parameters.momentum*vc .+ (1.0 - parameters.momentum)*dc;
            
            # Update parameters
            rbm.W          += parameters.learning_rate*vW;
            rbm.b          += parameters.learning_rate*vb;
            rbm.c          += parameters.learning_rate*vc;
        end
    end
end;

function train_srbm(
        rbm::SRBM, 
        TSet::Matrix{T}, 
        labels::Vector{Int}, 
        parameters::hyperparameters, 
        Test::Matrix{T}, 
        test_labels::Vector{Int}, 
        P_ARC::Vector{T}, 
        P_EX::Vector{T},
        avg_m::Vector{T},
        avg_n::Vector{T}
) where T<:AbstractFloat
    
    # Prepare values for creating the batches
    NTSet       = size(TSet,2);
    indices     = collect(1:NTSet);
    num_batches = Int(cld(NTSet,parameters.batch_size));
    
    # Define matrix with labels mapped to -1,1
    Identity    = 2.0*Id .- 1.0;
    
    # Initialize the velocities
    vW     = zero(rbm.W);                                       # (Nv+Nl,Nh)
    vb     = zero(rbm.b);                                       # (Nv,1)
    vc     = zero(rbm.c);                                       # (Nh,Nl)
    
    iter = 1;
    #previous = rand([-1.0,1.0], rbm.num_visible, rbm.num_labels);
    # Start the main loop
    for epoch in 1:parameters.num_epochs
        shuffle!(indices)
        
        # Initialize the change in parameters
        dW = zero(rbm.W);                                       # (Nv+Nl,Nh)
        db = zero(rbm.b);                                       # (Nv,1)
        dc = zero(rbm.c);                                       # (Nh,Nl)
        
        # Iterate through batches
        for batch in 1:num_batches
            # Get indices of batch
            batch_indices = indices[(batch-1)*parameters.batch_size+1:min(batch*parameters.batch_size,NTSet)];
            
            # Get labels and batches of data
            y  = labels[batch_indices];
            V0 = [TSet[:,batch_indices] ; Identity[:,y]];
            dW, db_i, dc_i = cd_gradients(rbm, V0, y, parameters);
            
            # Sum change in bias
            db = sum(db_i,dims=2);
            dc = dc_i*Id[:,y]';
            
            db = vec(db);
            
            # Add L2 regularization to punish large values
            dW             -= parameters.weight_decay*rbm.W;
            db             -= parameters.weight_decay*rbm.b;
            dc             -= parameters.weight_decay*rbm.c;
            
            # Update velocities
            vW              = parameters.momentum*vW .+ (1.0 - parameters.momentum)*dW;
            vb              = parameters.momentum*vb .+ (1.0 - parameters.momentum)*db;
            vc              = parameters.momentum*vc .+ (1.0 - parameters.momentum)*dc;
            
            # Update parameters
            rbm.W          += parameters.learning_rate*vW;
            rbm.b          += parameters.learning_rate*vb;
            rbm.c          += parameters.learning_rate*vc;
        end
        
        if epoch % parameters.skip == 0
            iter += 1;
            # Calculate classification probability 
            dum, P_ARC[iter]                   = accuracy(rbm,Test,test_labels);
            dum, P_EX[iter]                    = accuracy(rbm,TSet,labels);
            # Calculate overlaps
            avg_m[iter], avg_n[iter]           = overlapping(rbm,Test,test_labels,TSet,labels);
            #previous = final;
        end
    end
end;

## Useful functions for probability calculation
We define some numerically robust functions in order to calculate the marginal probability of a visible state with _prob, and a softmax which we will use for accuracy testing.

In [None]:
function log_sum_exp(u::AbstractVecOrMat, v::AbstractVecOrMat)
    maxim(a,b) = (a > b) ? a : b;
    max        = maxim.(u,v);
    return max + log.(exp.(u - max) + exp.(v - max));
end;

function _prob(rbm::SRBM,x::AbstractVecOrMat, labels::Vector{Int})
    return rbm.b'*x + sum(log_sum_exp(-rbm.c[:,labels] .- rbm.W'*x, rbm.c[:,labels] .+ rbm.W'*x),dims=1);
end;

function softmax(X::AbstractVecOrMat{T}, dim::Integer, theta::AbstractFloat=1.0)::AbstractVecOrMat where T <: AbstractFloat
    #abstract exponentiation function, subtract max for numerical stability and scale by theta
    _exp(x::AbstractVecOrMat, θ::AbstractFloat) = exp.((x .- maximum(x)) * θ);
    
    #softmax algorithm expects stablized eponentiated e
    _sftmax(e::AbstractVecOrMat, d::Integer) = (e ./ sum(e, dims = d));
    
    _sftmax(_exp(X,theta), dim)
end;

## Accuracy function
The accuracy function takes a matrix of data, that can be the validation set or the test set, and computes the ratio between correct predictions and total examples,as well as the mean probability of assigning the correct label given an example over all the input dataset.

In [None]:
function accuracy(rbm::SRBM, data::Matrix{T}, labels::Vector{Int}) where T<:AbstractFloat
    num_examples = size(data,2);
    num_correct  = 0;
    mean         = 0.0;
    # Define matrix with labels
    Identity     = 2.0*Id .- 1.0;
    
    for i in 1:num_examples
        # Compute probability of each label given the data
        py_v            = _prob(rbm,[repeat(data[:,i],inner=(1,rbm.num_labels)) ; Identity],collect(1:rbm.num_labels));
        
        # Predict the label with the highest probability
        softy_v         = vec(softmax(py_v,2));
        predicted_label = argmax(softy_v);
        
        # Check if prediction is correct
        num_correct += 1*(predicted_label == labels[i]);
        
        # Calculate mean
        mean        += softy_v[labels[i]]/num_examples;
    end
    
    return num_correct/num_examples, mean
end;

## New state proposal function
We need a function that generates randomly a new state given an original one for our Metropolis sampling procedure. This flip proposal function takes the input and randomly flips a segment of random length starting at any point within the input.

In [None]:
function flip_proposal(x::AbstractVector{T}, min_length::Int=10) where T<:AbstractFloat    
    # Get dimensions of input
    n = length(x);
    # Random starting point
    k     = rand(1:n);
    # Random length, capped to min_length
    l     = rand(1:min(min_length, n .- k .+ 1));
    # Copy and flip elements
    x_new = copy(x);
    x_new[k:k+l-1] .= -x[k:k+l-1];
    return x_new
end;

function flip_proposal(x::AbstractMatrix{T}, min_length::Int=10) where T<:AbstractFloat
    return hcat([flip_proposal(x[:,j], min_length) for j in 1:size(x,2)]...);
end;

## Overlap function
We need a function that calculates the overlap between the expected value of the visible layer when fixing a label and the original archetype related to that label.

In [None]:
function generate_binary_matrix(N::Int)
    ncols = 2^N
    nrows = N
    binary_matrix = zeros(Int8, nrows, ncols)

    for j in 1:ncols
        i = j-1
        for k in 1:N
            binary_matrix[k, j] = (i >> (N-k)) & 1
        end
    end
    
    return binary_matrix
end

function overlapping(
        rbm::SRBM, 
        data::Matrix{T}, 
        labels::Vector{Int}, 
        train_data::Matrix{T}, 
        train_labels::Vector{Int}
) where T<:AbstractFloat
    # Check data and labels have same dimensions
    @assert size(data,2) == length(labels)
    @assert size(train_data,2) == length(train_labels)
    
    Nl        = rbm.num_labels;
    Nv        = rbm.num_visible;
    M         = Int(length(train_labels)/length(labels));
    
    # Define matrix with labels;
    Identity  = 2.0*Id .- 1.0;                                                      # (Nl,Nl)
    
    # Generate matrix with all states (only for N <= 20)
    @assert Nv <= 20
    XAll      = 2*generate_binary_matrix(Nv) .- 1;
    max_size  = size(XAll,2);
    
    # Initialize expected value of visible state given the label
    m_visible = zeros(Nv,Nl);                                                       # (Nv,Nl)
    
    for i in 1:Nl
        # Repeat labels for all possible states
        L_C             = fill(labels[i],max_size);                                 # (2ᴺᵛ,1)
        
        # Calculate joint probability of a state with i-th label
        P_C             = _prob(rbm, [XAll ; Identity[:,L_C]], L_C);                # (1,2ᴺᵛ)
        
        # Calculate the expected value of the visible state given the i-th label
        m_visible[:,i]  = sum(XAll.*P_C,dims=2)./sum(P_C);
    end
    
    n_visible = repeat(m_visible, inner=(1,M));
    
    m_prod = [dot(m_visible[:,i],data[:,i])       for i in 1:Nl  ]/Nv;
    n_prod = [dot(n_visible[:,i],train_data[:,i]) for i in 1:Nl*M]/Nv;
    
    return sum(m_prod)./length(m_prod), sum(n_prod)./length(n_prod);
end;

## Hyperparameter tuning
Hyperparameter tuning consists in exploring a grid of parameters and training our SRBM with them in order to find the combination that yields the best results.

In [None]:
using .Iterators

function grid_search(
        num_visible::Int,
        num_labels::Int,
        train_data::Matrix{T}, 
        train_labels::Vector{Int}, 
        validation_data::Matrix{T}, 
        validation_labels::Vector{Int}, 
        test_data::Matrix{T}, 
        test_labels::Vector{Int}
) where T<:AbstractFloat
    # Define hyperparameters to tune
    learning_rates           = [0.001, 0.005, 0.01, 0.05, 0.1];
    weight_decays            = [0.00001, 0.0001, 0.001, 0.01, 0.1];
    momentum_coefficients    = [0.0, 0.5, 0.9];
    num_hidden_units         = [10, 20, 50, 100];

    # Define grid of hyperparameters to search over
    hyperparameter_grid      = product(learning_rates, weight_decays, momentum_coefficients, num_hidden_units);

    # Train SRBM with each set of hyperparameters and evaluate on validation set
    best_validation_accuracy = -Inf;
    best_mean_probability    = -Inf;
    best_hyperparameters     = (0.0, 0.0, 0.0, 0);
    for hyperparameters_set in hyperparameter_grid
        learning_rate, weight_decay, momentum_coefficient, num_hidden = hyperparameters_set;
        
        rbm = SRBM(num_visible = num_visible, num_hidden = num_hidden, num_labels = num_labels); 
        parameters           = hyperparameters(
            learning_rate = learning_rate, 
            weight_decay = weight_decay, 
            momentum = momentum_coefficient, 
            batch_size = 10, 
            num_epochs = 50
        );
        
        train_srbm(rbm,train_data,train_labels,parameters);
        
        
        validation_accuracy, mean = accuracy(rbm, validation_data, validation_labels)
        #println("Hyperparameters: ", hyperparameters_set, ", Validation accuracy: ", validation_accuracy)
        if validation_accuracy >= best_validation_accuracy && mean > best_mean_probability
            best_validation_accuracy = validation_accuracy
            best_mean_probability    = mean
            best_hyperparameters = hyperparameters_set
            #println("Hyperparameters: ", hyperparameters_set, ", Validation accuracy: ", validation_accuracy)
        end
    end
    
    return best_hyperparameters;
end;

## Training dataset generation
We generate the dataset for training, validation and test sets inside a function for memory allocation issues.

In [None]:
function archetype_dataset(
        N::Int,
        K::Int,
        r::T,
        MT::Int,
        MV::Int
) where T<:AbstractFloat
    
    # Calculate probability of keeping bit = 1.
    p                 = (r + 1.0)*0.5;
    
    # Test set corresponds to the random archetypes
    test_data         = 2.0*(rand(N,K) .< 0.5) .- 1.0;
    test_labels       = collect(1:K);

    # Validation set helps tune the hyperparameters with blurred examples (can have different size as training set)
    validation_data   = repeat(test_data,inner=(1,MV)).*(2.0*(rand(N,K*MV) .< p) .- 1.0);
    validation_labels = repeat(test_labels,inner=MV);

    # Training set corresponds to the blurred examples
    train_data        = repeat(test_data,inner=(1,MT)).*(2.0*(rand(N,K*MT) .< p) .- 1.0);
    train_labels      = repeat(test_labels,inner=MT);
    
    return train_data, train_labels, validation_data, validation_labels, test_data, test_labels
end;

## The main code
Finally, the main code that will accept the problem variables, as well as the datasets containing training, validation and test sets.

In [None]:
function main(
        num_visible::Int, 
        num_hidden::Int, 
        num_labels::Int, 
        quality_examples::T,
        M_train::Int,
        M_validation::Int
) where T<:AbstractFloat
    
    # Generate training set
    train_data, train_labels, validation_data, validation_labels, test_data, test_labels = archetype_dataset(
        num_visible,
        num_labels,
        quality_examples,
        M_train,
        M_validation
    );
    
    # Tune hyperparameters
    #=
    learning_rate, weight_decay, momentum, hidden_units = grid_search(
        num_visible,
        num_labels,
        train_data, 
        train_labels, 
        validation_data, 
        validation_labels, 
        test_data, 
        test_labels
    )
    =#
    # Initialize the random weights (for the default values mean = 0.0, deviation = 0.01 we can skip this part)
    #mean      = 0;
    #deviation = 0.01;
    #W         = deviation*(randn(num_visible+num_labels, num_hidden) .- mean);
    #b         = zeros(num_visible+num_labels);
    #c         = zeros(num_hidden, num_labels);
    
    # Initialize RBM
    rbm        = SRBM(
        num_visible   = num_visible,
        num_hidden    = num_hidden,
        num_labels    = num_labels
    );
    
    # Initialize parameters
    parameters = hyperparameters(
        learning_rate = 0.01,
        weight_decay  = 0.00001,
        momentum      = 0.9,
        batch_size    = 50,
        num_epochs    = 1000,
        skip          = 10
    ) 
    
    #=
    # Train SRBM with best set of hyperparameters on entire training set
    num_elems         = cld(parameters.num_epochs,parameters.skip) + 1;
    probs_archetypes  = zeros(num_elems);
    probs_examples    = zeros(num_elems);
    avg_m             = zeros(num_elems);
    avg_n             = zeros(num_elems);
    train_srbm(rbm,train_data,train_labels,parameters,test_data,test_labels,probs_archetypes,probs_examples,avg_m,avg_n)
    =#
    train_srbm(rbm,train_data,train_labels,parameters)
    
    # Export classification probabilities
    # Calculate classification probability 
    ~, P_ARC          = accuracy(rbm,test_data,test_labels);
    ~, P_EX           = accuracy(rbm,train_data,train_labels);
    # Calculate overlaps
    avg_m, avg_n      = overlapping(rbm,test_data,test_labels,train_data,train_labels);
    return P_ARC, P_EX, avg_m, avg_n;  
end;

## Performance of the RBM

We define the number of visible units $N$ and the ranges in which the number of examples per archetype $M$ and their quality $r$ will vary.

Then, we nest the loop with a hierarchy with $r$ first and then $M$. This does not affect the result in any way, this is just to organize better the output results.

This code outputs four .txt files containing the classification probabilities and the overlaps $\langle m \rangle$ and $\langle n \rangle$ (averaged for 100 iterations). These files can be later used to plot the figures seen in the report.

In [None]:
num_visible       = 20;
num_hidden        = 50;
num_labels        = 4;

global const Id = Matrix{Float64}(I, num_labels, num_labels)

M_validation      = 4;

rvec              = [0.66 0.32 0.18 0.10 0.06 0.04];
Mvec              = [1    2    4    8    16   32  ];

probs_archetypes  = zeros(length(Mvec),length(rvec));
probs_examples    = zeros(length(Mvec),length(rvec));
overlap_m         = zeros(length(Mvec),length(rvec));
overlap_n         = zeros(length(Mvec),length(rvec));


maxit = 100;
for ir in 1:length(rvec)
    r = rvec[ir];
    for im in 1:length(Mvec)
        M_train = Mvec[im];
        
        P_ARC   = 0.0;
        P_EX    = 0.0;
        avg_m   = 0.0;
        avg_n   = 0.0;        
        
        for i in 1:maxit
            P_ARC_i, P_EX_i, avg_m_i, avg_n_i = main(num_visible,num_hidden,num_labels,r,M_train,M_validation);
            P_ARC += P_ARC_i/maxit;
            P_EX  += P_EX_i/maxit;
            avg_m += avg_m_i/maxit;
            avg_n += avg_n_i/maxit;
        end
        
        probs_archetypes[im,ir]  = P_ARC;
        probs_examples[im,ir]    = P_EX;
        overlap_m[im,ir]         = avg_m;
        overlap_n[im,ir]         = avg_n;
        println("M = ", M_train, ", r = ", r, ". Done.")
    end
end
str = @sprintf("SRBM-PARC-N%d-avg%d.txt",num_visible,maxit);
writedlm(str,probs_archetypes);
str = @sprintf("SRBM-PEX-N%d-avg%d.txt",num_visible,maxit);
writedlm(str,probs_examples);
str = @sprintf("SRBM-avgm-N%d-avg%d.txt",num_visible,maxit);
writedlm(str,overlap_m);
str = @sprintf("SRBM-avgn-N%d-avg%d.txt",num_visible,maxit);
writedlm(str,overlap_n);