In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from typing import Callable, Tuple, Any
import numpy as np
from functools import partial
import pdb
import tensorflow as tf

import matplotlib.pyplot as plt
import seaborn

: 

In [14]:
# Load the MNIST dataset
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
X_train, Y_train, X_test, Y_test = jnp.expand_dims(jnp.array(X_train/255), axis=-1), jnp.expand_dims(jnp.array(Y_train), axis=-1), jnp.expand_dims(jnp.array(X_test/255), axis=-1), jnp.expand_dims(jnp.array(Y_test), axis=-1)

X_train = (X_train > 0.5).astype(jnp.float32)
X_test = (X_test > 0.5).astype(jnp.float32)

In [9]:
# --- 1. The Energy-Based Model (RBM) ---
class RBM:
    def __init__(self, n_visible, n_hidden, learning_rate=0.1):
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        self.lr = learning_rate

        # Initialize weights with small random values
        # W shape: (n_visible, n_hidden)
        self.W = jax.random.normal(jax.random.PRNGKey(45), (n_visible, n_hidden)) #(0, 0.01, (n_visible, n_hidden))
        self.v_bias = jnp.zeros(n_visible)
        self.h_bias = jnp.zeros(n_hidden)

        # Monitoring
        self.errors = []

    def sigmoid(self, x):
        return 1.0 / (1.0 + jnp.exp(-x))

    def sample_hidden(self, v):
        """
        Block Gibbs Step: P(h|v)
        """
        # Calculate activation energy
        print(v.shape, self.W.shape)
        activation = jnp.dot(v, self.W) + self.h_bias
        prob_h = self.sigmoid(activation)
        # Stochastic sampling (Bernoulli)
        return prob_h, jax.random.binomial(jax.random.PRNGKey(397), 1, prob_h)

    def sample_visible(self, h):
        """
        Block Gibbs Step: P(v|h)
        """
        # Calculate activation energy
        activation = jnp.dot(h, self.W.T) + self.v_bias
        prob_v = self.sigmoid(activation)
        # Stochastic sampling (Bernoulli)
        return prob_v, jax.random.binomial(jax.random.PRNGKey(16), 1, prob_v)

    def contrastive_divergence(self, v0):
        """
        The training algorithm (CD-1).
        Approximates the gradient of the Energy function.
        """
        # --- Positive Phase (Data Driven) ---
        # 1. Pass data up to hidden layer
        prob_h0, h0 = self.sample_hidden(v0)

        # --- Negative Phase (Model Daydreaming / Block Gibbs) ---
        # 2. Reconstruct visible (Block Gibbs down)
        prob_v1, v1 = self.sample_visible(h0)
        # 3. Sample hidden again (Block Gibbs up)
        prob_h1, h1 = self.sample_hidden(v1)

        # --- Update Weights (Gradient Descent on Energy) ---
        # Contrastive Divergence: <v0*h0> - <v1*h1>
        batch_size = v0.shape[0]

        positive_grad = jnp.dot(v0.T, prob_h0)
        negative_grad = jnp.dot(v1.T, prob_h1)

        # Update W, biases
        self.W += self.lr * (positive_grad - negative_grad) / batch_size
        self.v_bias += self.lr * jnp.mean(v0 - v1, axis=0)
        self.h_bias += self.lr * jnp.mean(prob_h0 - prob_h1, axis=0)

        # Track reconstruction error (MSE) for visualization
        error = jnp.mean((v0 - v1) ** 2)
        return error

    def transform(self, v):
        """
        Transform raw input into hidden feature representation.
        Returns the probability of hidden units being active.
        """
        prob_h, _ = self.sample_hidden(v)
        return prob_h


In [11]:
# --- 3. Training the PGM (Unsupervised) ---
print("Training RBM (Generative Phase)...")
n_hidden_units = 128
rbm = RBM(n_visible=784, n_hidden=n_hidden_units, learning_rate=0.1)

epochs = 10
batch_size = 64
n_batches = X_train.shape[0] // batch_size

for epoch in range(epochs):
    epoch_error = 0
    # Shuffle data
    indices = np.random.permutation(X_train.shape[0])

    for i in range(0, X_train.shape[0], batch_size):
        batch_idx = indices[i:i+batch_size]
        batch = X_train[batch_idx]

        # Skip incomplete batches
        if len(batch) < batch_size: continue

        err = rbm.contrastive_divergence(batch)
        epoch_error += err

    avg_error = epoch_error / n_batches
    print(f"Epoch {epoch+1}/{epochs} - Reconstruction Error: {avg_error:.4f}")

Training RBM (Generative Phase)...


TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (784,).

: 

: 

In [16]:
batch.shape

(64, 28, 28, 1)

In [15]:
X_train.shape

(60000, 28, 28, 1)