# Training a Restricted Boltzmann Machine

This notebook demonstrates how to create, train, and sample from a Restricted Boltzmann Machine (RBM) using THRML.

## What is an RBM?

A Restricted Boltzmann Machine is an energy-based model with two layers:
- **Visible layer**: Observed data
- **Hidden layer**: Latent features

The "restricted" aspect means there are no connections within a layer, only between layers (bipartite structure).

The energy function is:

$$E(v, h) = -\beta \left( \sum_i a_i v_i + \sum_j b_j h_j + \sum_{i,j} W_{ij} v_i h_j \right)$$

where:
- $v_i$ are visible units
- $h_j$ are hidden units
- $a_i, b_j$ are biases
- $W_{ij}$ are connection weights
- $\beta$ is the inverse temperature

In [None]:
import jax
import jax.numpy as jnp
from thrml import SpinNode, Block, SamplingSchedule
from thrml.models import (
    RBMEBM,
    RBMSamplingProgram,
    RBMTrainingSpec,
    rbm_init,
    estimate_rbm_grad
)

# Set random seed for reproducibility
key = jax.random.key(42)

## Creating a Simple RBM

We'll create a small RBM with 6 visible units and 3 hidden units.

In [None]:
# Define dimensions
n_visible = 6
n_hidden = 3

# Create nodes
visible_nodes = [SpinNode() for _ in range(n_visible)]
hidden_nodes = [SpinNode() for _ in range(n_hidden)]

# Initialize parameters with small random values
key, subkey = jax.random.split(key)
visible_biases = jax.random.normal(subkey, (n_visible,)) * 0.01

key, subkey = jax.random.split(key)
hidden_biases = jax.random.normal(subkey, (n_hidden,)) * 0.01

key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (n_visible, n_hidden)) * 0.01

beta = jnp.array(1.0)

# Create the RBM
rbm = RBMEBM(
    visible_nodes=visible_nodes,
    hidden_nodes=hidden_nodes,
    visible_biases=visible_biases,
    hidden_biases=hidden_biases,
    weights=weights,
    beta=beta
)

print(f"Created RBM with {n_visible} visible units and {n_hidden} hidden units")
print(f"Weight matrix shape: {weights.shape}")

## Sampling from the RBM

We can sample from the joint distribution over visible and hidden units using block Gibbs sampling.

In [None]:
from thrml.block_sampling import sample_states

# Create a sampling program that samples both visible and hidden units
program = RBMSamplingProgram(
    ebm=rbm,
    free_blocks=[Block(visible_nodes), Block(hidden_nodes)],
    clamped_blocks=[]
)

# Define sampling schedule
schedule = SamplingSchedule(
    n_warmup=100,      # Burn-in samples
    n_samples=500,     # Number of samples to collect
    steps_per_sample=2 # Gibbs steps between samples
)

# Initialize random state
key, subkey = jax.random.split(key)
init_state = rbm_init(
    subkey, 
    rbm, 
    [Block(visible_nodes), Block(hidden_nodes)], 
    ()
)

# Sample from the model
key, subkey = jax.random.split(key)
samples = sample_states(
    key=subkey,
    program=program,
    schedule=schedule,
    init_state_free=init_state,
    state_clamp=[],
    nodes_to_sample=[Block(visible_nodes), Block(hidden_nodes)]
)

visible_samples, hidden_samples = samples

print(f"Visible samples shape: {visible_samples.shape}")
print(f"Hidden samples shape: {hidden_samples.shape}")
print(f"\nVisible activation rate: {jnp.mean(visible_samples.astype(jnp.float32)):.3f}")
print(f"Hidden activation rate: {jnp.mean(hidden_samples.astype(jnp.float32)):.3f}")

## Conditional Sampling

We can also sample hidden units given fixed visible units, which is useful for inference.

In [None]:
# Create a sampling program with visible units clamped
program_conditional = RBMSamplingProgram(
    ebm=rbm,
    free_blocks=[Block(hidden_nodes)],
    clamped_blocks=[Block(visible_nodes)]
)

# Create some visible data
visible_data = jnp.array([True, False, True, False, True, False], dtype=jnp.bool_)

# Initialize hidden state
key, subkey = jax.random.split(key)
init_hidden = rbm_init(subkey, rbm, [Block(hidden_nodes)], ())

# Sample hidden units given visible data
schedule_short = SamplingSchedule(n_warmup=50, n_samples=100, steps_per_sample=1)

key, subkey = jax.random.split(key)
hidden_samples_conditional = sample_states(
    key=subkey,
    program=program_conditional,
    schedule=schedule_short,
    init_state_free=init_hidden,
    state_clamp=[visible_data],
    nodes_to_sample=[Block(hidden_nodes)]
)

print(f"Given visible pattern: {visible_data.astype(int)}")
print(f"Sampled hidden states shape: {hidden_samples_conditional[0].shape}")
print(f"Hidden activation probabilities: {jnp.mean(hidden_samples_conditional[0].astype(jnp.float32), axis=0)}")

## Training with Contrastive Divergence

RBMs are typically trained using contrastive divergence, which estimates the gradient of the log-likelihood.

The gradient update rules are:

$$\Delta W_{ij} = -\beta (\langle v_i h_j \rangle_{data} - \langle v_i h_j \rangle_{model})$$
$$\Delta a_i = -\beta (\langle v_i \rangle_{data} - \langle v_i \rangle_{model})$$
$$\Delta b_j = -\beta (\langle h_j \rangle_{data} - \langle h_j \rangle_{model})$$

In [None]:
# Create training specification
schedule_positive = SamplingSchedule(n_warmup=10, n_samples=50, steps_per_sample=1)
schedule_negative = SamplingSchedule(n_warmup=10, n_samples=50, steps_per_sample=1)

training_spec = RBMTrainingSpec(
    ebm=rbm,
    schedule_positive=schedule_positive,
    schedule_negative=schedule_negative
)

# Create a small training dataset
batch_size = 8
key, subkey = jax.random.split(key)
training_data = [jax.random.bernoulli(subkey, 0.5, shape=(batch_size, n_visible)).astype(jnp.bool_)]

# Initialize states for positive phase (hidden given visible)
n_chains_pos = 2
key, subkey = jax.random.split(key)
init_hidden_pos = rbm_init(
    subkey, 
    rbm, 
    [Block(hidden_nodes)], 
    (n_chains_pos, batch_size)
)

# Initialize states for negative phase (free sampling)
n_chains_neg = 2
key, subkey = jax.random.split(key)
init_neg = rbm_init(
    subkey, 
    rbm, 
    [Block(visible_nodes), Block(hidden_nodes)], 
    (n_chains_neg,)
)

# Estimate gradients
key, subkey = jax.random.split(key)
grad_weights, grad_visible_bias, grad_hidden_bias = estimate_rbm_grad(
    key=subkey,
    training_spec=training_spec,
    visible_data=training_data,
    init_state_positive=init_hidden_pos,
    init_state_negative=init_neg
)

print("Gradient statistics:")
print(f"Weight gradients - mean: {jnp.mean(grad_weights):.6f}, std: {jnp.std(grad_weights):.6f}")
print(f"Visible bias gradients - mean: {jnp.mean(grad_visible_bias):.6f}, std: {jnp.std(grad_visible_bias):.6f}")
print(f"Hidden bias gradients - mean: {jnp.mean(grad_hidden_bias):.6f}, std: {jnp.std(grad_hidden_bias):.6f}")

## Simple Training Loop

Let's implement a basic training loop with gradient descent.

In [None]:
import equinox as eqx

# Training hyperparameters
learning_rate = 0.01
n_epochs = 5

# Make a copy of the RBM for training
trained_rbm = rbm

print("Training RBM...")
print(f"Learning rate: {learning_rate}, Epochs: {n_epochs}\n")

for epoch in range(n_epochs):
    # Re-create training spec with updated parameters
    training_spec = RBMTrainingSpec(
        ebm=trained_rbm,
        schedule_positive=schedule_positive,
        schedule_negative=schedule_negative
    )
    
    # Compute gradients
    key, subkey = jax.random.split(key)
    grad_w, grad_vb, grad_hb = estimate_rbm_grad(
        key=subkey,
        training_spec=training_spec,
        visible_data=training_data,
        init_state_positive=init_hidden_pos,
        init_state_negative=init_neg
    )
    
    # Gradient descent update
    new_weights = trained_rbm.weights - learning_rate * grad_w
    new_visible_biases = trained_rbm.visible_biases - learning_rate * grad_vb
    new_hidden_biases = trained_rbm.hidden_biases - learning_rate * grad_hb
    
    # Create updated RBM
    trained_rbm = eqx.tree_at(
        lambda m: (m.weights, m.visible_biases, m.hidden_biases),
        trained_rbm,
        (new_weights, new_visible_biases, new_hidden_biases)
    )
    
    # Compute approximate reconstruction error
    grad_norm = jnp.sqrt(jnp.sum(grad_w**2) + jnp.sum(grad_vb**2) + jnp.sum(grad_hb**2))
    
    print(f"Epoch {epoch + 1}/{n_epochs} - Gradient norm: {grad_norm:.6f}")

print("\nTraining complete!")

## Comparing Samples Before and After Training

Let's see how the distribution changes after training.

In [None]:
# Sample from trained model
program_trained = RBMSamplingProgram(
    ebm=trained_rbm,
    free_blocks=[Block(visible_nodes), Block(hidden_nodes)],
    clamped_blocks=[]
)

key, subkey = jax.random.split(key)
init_state_trained = rbm_init(
    subkey, 
    trained_rbm, 
    [Block(visible_nodes), Block(hidden_nodes)], 
    ()
)

key, subkey = jax.random.split(key)
samples_trained = sample_states(
    key=subkey,
    program=program_trained,
    schedule=schedule,
    init_state_free=init_state_trained,
    state_clamp=[],
    nodes_to_sample=[Block(visible_nodes)]
)

print("Visible unit statistics:")
print(f"Original RBM: {jnp.mean(visible_samples.astype(jnp.float32), axis=0)}")
print(f"Trained RBM:  {jnp.mean(samples_trained[0].astype(jnp.float32), axis=0)}")
print(f"Training data: {jnp.mean(training_data[0].astype(jnp.float32), axis=0)}")

## Energy Computation

We can compute the energy of specific configurations.

In [None]:
# Create a specific configuration
visible_state = jnp.array([True, True, False, False, True, False], dtype=jnp.bool_)
hidden_state = jnp.array([True, False, True], dtype=jnp.bool_)

# Compute energy
state = [visible_state, hidden_state]
blocks = [Block(visible_nodes), Block(hidden_nodes)]

energy = trained_rbm.energy(state, blocks)

print(f"Visible configuration: {visible_state.astype(int)}")
print(f"Hidden configuration:  {hidden_state.astype(int)}")
print(f"Energy: {energy:.6f}")

## Summary

This notebook demonstrated:

1. Creating an RBM with specified dimensions
2. Sampling from the joint distribution (visible and hidden)
3. Conditional sampling (hidden given visible)
4. Computing gradients via contrastive divergence
5. Training the RBM with gradient descent
6. Computing energies of configurations

RBMs can be used for:
- Dimensionality reduction
- Feature learning
- Collaborative filtering
- Building blocks for deep belief networks

For larger-scale applications (e.g., MNIST), you would:
- Use larger dimensions (784 visible, 128-500 hidden)
- Train with mini-batches
- Use more sophisticated optimizers
- Implement proper validation and early stopping