# Implementation of an Archetype Restricted Boltzmann Machine (ARCH-RBM)
## Libraries

In [None]:
using LinearAlgebra
using Printf
using Random
using DelimitedFiles

## Definition of structs
We define two structs that will help keep the code more organized and easier to maintain. By using a struct for the architecture of the RBM as well as one for the hyperparameters, we can easily pass them as arguments to functions. Additionally, we could define default values for some of the fields, which would make it easier to modify and experiment with different hyperparameters.

In [None]:
Base.@kwdef mutable struct RBM{T<:AbstractFloat}
    num_visible   :: Int
    num_hidden    :: Int
    num_examples  :: Int
    W             :: Matrix{T} = 0.01*(randn(num_visible,num_hidden) .- 0.0)
    beta_parameter:: T
end;

Base.@kwdef mutable struct hyperparameters{T<:AbstractFloat}
    learning_rate :: T         = 0.001
    weight_decay  :: T         = 0.000001
    momentum      :: T         = 0.9
    batch_size    :: Int       = 100
    num_epochs    :: Int       = 1000
    CDK           :: Int       = 1
    skip          :: Int       = 1
end;

## Activation function

For neurons that take values $(\sigma,z) \in \{-1,1\}^{N \times K}$, the activation function becomes $\frac{1}{2}(1 + tanh(x))$ instead of the sigmoid.

In [None]:
function activation(x)
    return 0.5*(1.0 + tanh.(x));
end;

## Expectation function
In the case of binary neurons $\{0,1\}$ it is the sigmoid, but for binary neurons $\{-1,1\}$ it is the hyperbolic tangent.

In [None]:
function expectation(x)
    return tanh.(x)
end;

## Sampling procedures
When calculating the gradients, we need to use Gibbs sampling in order to approximate the sum over all the state space. For that, we include a function for sampling any state given its activation probabilities.

In [None]:
function sampleOne(px)
    m       = size(px,1);
    n       = size(px,2);
    ξx      = rand(m,n);
    return 2.0*(ξx .<= px) .- 1.0;
end;

## Calculating the gradients
We implement the proposed gradient approximation based on archetype selection with respect to the parameters, and then update them in the direction of the negative gradient.

In [None]:
function arc_gradients(rbm::RBM, V0::Matrix{T}, H0::Matrix{T}, parameters::hyperparameters) where T<:AbstractFloat
    batch_size  = size(V0,2);
    
    # Positive phase
    W_pos       = V0*H0';                             # (Nv,Nh)
    
    # k-step Contrastive Divergence (CD-k)
    VK          = V0;                                 # (Nv,b)
    HK          = H0;                                 # (Nh,b)
    for _ in 1:parameters.CDK
        PH_V    = activation.(rbm.beta_parameter*rbm.W'*VK);       # (Nh,b)
        PV_H    = activation.(rbm.beta_parameter*rbm.W*HK);        # (Nv,b)
        HK      = sampleOne(PH_V);                    # (Nh,b)
        VK      = sampleOne(PV_H);                    # (Nv,b)
    end
    
    # Negative phase
    W_neg       = VK*HK';                             # (Nv,Nh)
    
    # Compute the gradients
    dW          = rbm.beta_parameter*(W_pos .- W_neg)./batch_size; # (Nv,Nh)
    
    return dW;
end;

## Mini-batch Gradient Descent
With our function that estimates the gradients, we only need to implement the iterative process over all epochs in which each step looks at a batch of a certain size, computes the gradients with respect to that batch and updates the parameters.

In [None]:
function trainMyData(
    rbm::RBM,
    TSet::Matrix{T},
    Test::Matrix{T},
    OSet::Matrix{T},
    parameters::hyperparameters,
) where T<:AbstractFloat
    NTSet       = size(TSet,2);
    indices     = collect(1:NTSet);
    num_batches = Int(cld(NTSet,parameters.batch_size));
    
    vW          = zero(rbm.W);                                       #Matrix (Nv,Nh)

    for iep in 1:parameters.num_epochs
        shuffle!(indices)
        
        # Initialize the change in parameters
        δW      = zero(rbm.W);                                       #Matrix (Nv,Nh)           

        for batch in 1:num_batches
            # Get indices of batch
            batch_indices   = indices[(batch-1)*parameters.batch_size+1:min(batch*parameters.batch_size,NTSet)];
            
            # Get labels and batches of data
            V0              = TSet[:,batch_indices];
            H0              = OSet[:,batch_indices];
            
            δW              = arc_gradients(rbm,V0,H0,parameters);
            
            # Add L2 regularization to punish large weights
            δW             -= parameters.weight_decay*rbm.W;

            # Update velocities
            vW              = parameters.momentum*vW .+ (1.0 - parameters.momentum)*δW;
                
            # Update parameters
            rbm.W         .+= parameters.learning_rate*vW;
        end
    end
end;

## Useful functions for probability calculation
We define some numerically robust functions in order to calculate the marginal probability of a visible state with _prob, and a softmax which we will use for accuracy testing.

In [None]:
function log_sum_exp(u::AbstractVecOrMat, v::AbstractVecOrMat)
    maxim(a,b) = (a > b) ? a : b;
    max        = maxim.(u,v);
    return max + log.(exp.(u - max) + exp.(v - max));
end;

function _prob(rbm::RBM,x::AbstractVecOrMat)    
    return sum(log_sum_exp(rbm.beta_parameter*rbm.W'*x, -rbm.beta_parameter*rbm.W'*x),dims=1);
end;

function softmax(X::AbstractVecOrMat{T}, dim::Integer, theta::AbstractFloat=1.0)::AbstractVecOrMat where T <: AbstractFloat
    #abstract exponentiation function, subtract max for numerical stability and scale by theta
    _exp(x::AbstractVecOrMat, theta::AbstractFloat) = exp.((x .- maximum(x)) * theta);
    
    #softmax algorithm expects stablized eponentiated e
    _sftmax(e::AbstractVecOrMat, d::Integer) = (e ./ sum(e, dims = d));
    
    _sftmax(_exp(X,theta), dim)
end;

## Accuracy function
The accuracy function takes a matrix of data, that can be the validation set or the test set, and computes the ratio between correct predictions and total examples,as well as the mean probability of assigning the correct label given an example over all the input dataset.

In [None]:
function correctHiddenProbability(
    TSet::Matrix{Float64},
    Test::Matrix{Float64},
    rbm ::RBM
)
    N       = rbm.num_visible;
    K       = rbm.num_hidden;
    M       = rbm.num_examples;

    β       = rbm.beta_parameter;
    W       = rbm.W;
    
    Id      = Matrix(2I,K,K) .- 1;
    ID      = repeat(Id,inner=(M,1));
    
    log_num = zeros(K);
    log_den = zeros(K*M);
    
    for i in 1:K
        ξ              = Test[:,i];
        Z              = Id[:,i];
        sum_exps       = _prob(rbm,ξ);
        log_num[i]     = β*(ξ'*W*Z) - sum_exps[1];
        
        for j in 1:M
            k          = M*(i - 1) + j;
            η          = TSet[:,k];
            sum_exps   = _prob(rbm,η);
            log_den[k] = β*(η'*W*Z) - sum_exps[1];
        end
    end
    
    log_num = repeat(log_num, inner=(M,1));
    
    ratio   = log_num - log_den;
    ratio   = exp.(ratio);
    
    return exp.(log_num), exp.(log_den), sum(ratio)/length(ratio);
end;

## Overlap function
We need a function that calculates the overlap between the expected value of the visible layer when fixing a label and the original archetype related to that label. It also calculates the overlap between the same expected visible configuration and the distorted examples so that we can compare both values.

In [None]:
function calculateOverlap(
    TSet::Matrix{Float64},
    Test::Matrix{Float64},
    rbm ::RBM
)
    N       = rbm.num_visible;
    K       = rbm.num_hidden;
    M       = rbm.num_examples;
    Id      = Matrix(2I,K,K) .- 1;

    # 1: Start with Id{-1,1} of size K x K
    H       = Id;
    A       = rbm.W*H*rbm.beta_parameter;
    # 2: Calculate the expected value of σ over p(σ|z(μ)) (Same as sampling infinite times -> whole state space)
    V_M     = mapslices(x -> expectation(x),A,dims=1);
    # 3: Calculate avg of V_M and V_N
    V_N     = repeat(V_M, inner=(1,M));

    # 4: Calculate <m> and <n>
    m = 1/N*[dot(V_M[:,i],Test[:,i]) for i in 1:K];
    n = 1/N*[dot(V_N[:,i],TSet[:,i]) for i in 1:K*M];

    return sum(m)/K, sum(n)/(K*M);
end;

## Training dataset generation
We generate the dataset for training, validation and test sets inside a function for memory allocation issues.

In [None]:
function archetype_dataset(
        N::Int,
        K::Int,
        r::T,
        M::Int
) where T<:AbstractFloat
    
    # Calculate probability of keeping bit = 1.
    p                 = (r + 1.0)*0.5;
    
    # Test set corresponds to the random archetypes
    Test              = 2.0*(rand(N,K) .< 0.5) .- 1.0;
    Id                = Matrix(2I,K,K) .- 1.0;

    # Training set corresponds to the blurred examples
    TSet              = repeat(Test,inner=(1,M)).*(2.0*(rand(N,K*M) .< p) .- 1.0);
    OSet              = repeat(Id,inner=(1,M));
    
    return TSet, Test, OSet
end;

## The main code
The main code that will accept the problem variables, as well as the datasets containing training, validation and test sets. It calculates and prints both the classification accuracy and the overlaps for the training set (blurred examples) and the test set (archetypes).

In [None]:
function main(
        num_visible::Int, 
        num_hidden::Int, 
        num_examples::Int,
        quality_examples::T;
        beta_parameter::T=1.0,
        learning_rate::T=0.001,
        weight_decay::T=0.00001
) where T<:AbstractFloat

    # Initialize parameters
    parameters = hyperparameters(
        learning_rate  = learning_rate,
        weight_decay   = weight_decay,
        momentum       = 0.9,
        batch_size     = 50,
        num_epochs     = 5000,
        skip           = 10
    ) 

    # Initialize RBM
    rbm        = RBM(
        num_visible    = num_visible,
        num_hidden     = num_hidden,
        num_examples   = num_examples,
        beta_parameter = beta_parameter
    );
    
    # Generate training set
    TSet, Test, OSet = archetype_dataset(
        num_visible,
        num_hidden,
        quality_examples,
        num_examples
    );
    
    #display(Test)
    #display(TSet)

    # CD-k for all vectors in the training set

    #println("Now training RBM...")
    
    trainMyData(rbm,TSet,Test,OSet,parameters);
    
    #println("Done.")
    
    
    P_ARC, P_EX, PZ     = correctHiddenProbability(TSet,Test,rbm);
    avg_m, avg_n        = calculateOverlap(TSet,Test,rbm);
    return rbm, P_ARC, P_EX, avg_m, avg_n;
end;

## Define problem variables

Before calling the main function, we need to define the number of units that encode the training examples $N$, the number of hidden units $N_h$, the number of archetypes (or labels) $K$. Recall that the total number of visible units $N_v = N + K$.

In order to generate the training set, we need the quality of the examples $r$ and the number of examples per archetype $M$. $M_{validation}$ is used to create a smaller dataset for the hyperparameter tuning function. 

In [None]:
# Problem variables
N = 50;
K = 4;

# Training set variables
r = 0.66;
M = 4;

rbm, P_ARC, P_EX, avg_m, avg_n = main(N,K,M,r;beta_parameter=1.0,learning_rate=0.001,weight_decay=0.00001);