In [6]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax import random, grad, jit, vmap, lax
from datasets import load_dataset

# Helper functions for data loading and preprocessing
cache_dir="~/.cache/huggingface/datasets"

def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)



def load_mnist_data():
    """Load and preprocess MNIST dataset using Hugging Face datasets with JAX format."""
    print("Loading MNIST dataset from Hugging Face with JAX format...")
    
    # Load dataset from Hugging Face with JAX format
    mnist_dataset = load_dataset("mnist", cache_dir=cache_dir).with_format("jax")
    
    # Extract train and test sets directly as JAX arrays
    train_data = mnist_dataset["train"]
    test_data = mnist_dataset["test"]
    
    # Extract images and reshape to 28x28
    x_train = jnp.array([img.reshape(28, 28) for img in train_data["image"]])
    y_train = jnp.array(train_data["label"])
    
    x_test = jnp.array([img.reshape(28, 28) for img in test_data["image"]])
    y_test = jnp.array(test_data["label"])
    
    # Normalize data
    x_train = x_train.astype(jnp.float32) / 255.0
    x_test = x_test.astype(jnp.float32) / 255.0
    
    # Convert labels to one-hot encoding
    y_train = one_hot(y_train, 10)
    y_test = one_hot(y_test, 10)
    
    print("x_train shape:", x_train.shape)
    print("y_train shape:", y_train.shape)
    print("x_test shape:", x_test.shape)
    print("y_test shape:", y_test.shape)
    
    return x_train, y_train, x_test, y_test


def get_batch(key, x, y, batch_size):
    """Get a random batch from the dataset."""
    dataset_size = x.shape[0]
    key, subkey = random.split(key)
    indices = random.choice(subkey, dataset_size, (batch_size,), replace=False)
    return x[indices], y[indices], key

# Common utility functions
def sigmoid(x):
    """Compute sigmoid function with clipping for numerical stability."""
    return 1.0 / (1.0 + jnp.exp(-jnp.clip(x, -30.0, 30.0)))

def rmsnorm(x, gamma):
    """Root Mean Square Layer Normalization"""
    epsilon = 1e-5  # Increased epsilon for stability
    x_square = jnp.square(x)
    mean_square = jnp.mean(x_square, axis=-1, keepdims=True)
    # Add clipping to prevent division by very small numbers
    denom = jnp.maximum(jnp.sqrt(mean_square + epsilon), epsilon)
    x = x / denom
    return x * gamma


# Modify swiglu function - unchanged from original
def swiglu(x, params):
    """
    SwiGLU activation: SwiGLU(x) = Swish(xW_gate) ⊗ (xW_linear)
    where Swish(x) = x * sigmoid(beta * x), with beta commonly set to 1.0
    """
    gate = x @ params['W_gate']
    linear = x @ params['W_linear']
    
    # Swish activation (x * sigmoid(x))
    swish = gate * jax.nn.sigmoid(gate)
    
    # Element-wise product with linear projection
    intermediate = swish * linear
    
    # Final projection
    return intermediate @ params['W_out']


In [13]:
import jax
import time
import jax.numpy as jnp
import optax
from jax import random, lax
import numpy as np
import matplotlib.pyplot as plt



# Modified loss function for autoregressive prediction with closest match
def adaptive_loss_fn(params, forward_fn, batch_forward, x_batch):
    """
    Compute loss for autoregressive prediction using closest match approach.
    For each partially generated image, find the closest MNIST digit and
    optimize towards that target.
    """
    batch_size = x_batch.shape[0]
    
    # Function to find closest match for a partially generated image
    def find_closest_match(partial_generated, dataset):
        # We only compare the rows we've already generated
        num_rows = partial_generated.shape[0]
        
        # Calculate MSE between the partial image and all dataset images (only for the generated rows)
        partial_mse = jnp.mean((dataset[:, :num_rows, :] - partial_generated[jnp.newaxis, :, :]) ** 2, axis=(1, 2))
        
        # Find the index of the closest match
        closest_idx = jnp.argmin(partial_mse)
        return closest_idx
    
    # Function to compute loss for a single example
    def compute_single_example_loss(single_x):
        total_loss = 0.0
        generated_rows = []
        
        # Generate the first row
        # We need to initialize with a small random noise in the hidden state
        # This is a simplified approach - in practice we'd use the RNN's inner mechanisms
        if len(generated_rows) == 0:
            # For the first row, there's no prior context, so we'll predict based on a zero row
            # This is a simplification - ideally we'd use the model's actual initialization logic
            dummy_input = jnp.zeros((1, single_x.shape[1]))
            first_row_pred = forward_fn(params, dummy_input)[0]
            generated_rows.append(first_row_pred)
        
        # Generate rows 1 to 27 sequentially
        for i in range(1, 28):
            # Stack the generated rows so far
            partial_image = jnp.stack(generated_rows, axis=0)
            
            # Find the closest MNIST digit based on rows generated so far
            closest_idx = find_closest_match(partial_image, x_batch)
            
            # Get the target next row from the closest match
            target_next_row = x_batch[closest_idx, i, :]
            
            # Predict the next row based on generated rows so far
            predicted_next_row = forward_fn(params, partial_image)[-1]
            
            # Add to loss (MSE between prediction and target)
            row_loss = jnp.mean((predicted_next_row - target_next_row) ** 2)
            total_loss += row_loss
            
            # Add the predicted row to our generated image
            generated_rows.append(predicted_next_row)
        
        # Return average loss across all rows
        return total_loss / 27  # Divide by number of rows we predicted
    
    # Map over the batch
    example_losses = jnp.array([compute_single_example_loss(x_batch[i]) for i in range(batch_size)])
    
    # Return average loss across batch
    return jnp.mean(example_losses)

# Modified accuracy for image generation (using PSNR with closest match)
def adaptive_psnr_fn(params, forward_fn, batch_forward, x_batch):
    """
    Compute PSNR for autoregressively predicted images using closest match approach.
    For each generated image, find the closest MNIST digit in the training set and
    compute PSNR against that target.
    """
    batch_size = min(x_batch.shape[0], 100)  # Limit to 100 examples for efficiency
    x_sample = x_batch[:batch_size]
    
    # Generate images from scratch
    generated_images = jnp.zeros((batch_size, 28, 28))
    
    # Get hidden state dimensions from params
    hidden_size = params['diag_weights'].shape[0]
    
    # Define function to generate first row from a small random hidden state
    def generate_row_from_hidden(params, h1, h2):
        # Apply SwiGLU and output transformation
        h1_transformed = swiglu(h1, params)
        output = jnp.dot(h1_transformed, params['Wo']) + params['bo']
        return jnp.clip(output, 0.0, 1.0)
    
    # Generate first row for each image
    for j in range(batch_size):
        # Small random hidden state (deterministic seed for reproducibility)
        h1 = jnp.zeros(hidden_size) + 0.01 * jnp.sin(jnp.arange(hidden_size) + j)
        h2 = jnp.zeros(hidden_size) + 0.01 * jnp.cos(jnp.arange(hidden_size) + j)
        
        # Generate first row
        first_row = generate_row_from_hidden(params, h1, h2)
        generated_images = generated_images.at[j, 0, :].set(first_row)
        
        # Generate remaining rows
        for i in range(1, 28):
            prev_rows = generated_images[j, :i, :]
            next_row = forward_fn(params, prev_rows)[-1]
            generated_images = generated_images.at[j, i, :].set(next_row)
    
    # Find closest MNIST digit for each generated image
    def find_closest_image(generated_img):
        mse = jnp.mean((x_batch - generated_img[jnp.newaxis, :, :]) ** 2, axis=(1, 2))
        closest_idx = jnp.argmin(mse)
        return closest_idx, mse[closest_idx]
    
    # Compute MSE and PSNR against closest matches
    total_psnr = 0.0
    for j in range(batch_size):
        closest_idx, min_mse = find_closest_image(generated_images[j])
        psnr = 20 * jnp.log10(1.0) - 10 * jnp.log10(min_mse)
        total_psnr += psnr
    
    # Return average PSNR
    return total_psnr / batch_size



# Adapt the data loading for autoregressive task
def load_mnist_data_autoregressive():
    """Adapt the MNIST data for autoregressive prediction."""
    x_train, _, x_test, _ = load_mnist_data()
    
    # We only need the image data for autoregressive prediction
    return x_train, x_test

# Helper function to get batches
def get_batch(key, x, batch_size):
    """Get a random batch from x."""
    dataset_size = x.shape[0]
    key, subkey = random.split(key)
    indices = random.choice(subkey, dataset_size, (batch_size,), replace=False)
    return x[indices], key

# Modify the forward function for autoregressive prediction
def two_layer_fused_scan_forward_fn(params, x_seq):
    """
    A more efficient implementation that fuses both layers into a single scan.
    Uses SwiGLU for the layer-to-layer transformation.
    Modified to return predictions for the next row at each step.
    """
    hidden_size = params['diag_weights'].shape[0]
    diag_weights = params['diag_weights']
    
    # Define a function that processes both layers for a single timestep
    def two_layer_step(carry, x):
        h1, h2 = carry
        
        # Layer 1
        input_proj1 = jnp.dot(x, params['Wxh'])
        recurrent_proj1 = h1 * diag_weights
        h1_new = input_proj1 + recurrent_proj1
        
        # Apply SwiGLU to connect layer 1 to layer 2
        h1_transformed = swiglu(h1_new, params)
        
        # Layer 2
        recurrent_proj2 = h2 * diag_weights
        h2_new = h1_transformed + recurrent_proj2
        
        # Generate output for this time step (predict next row)
        output = jnp.dot(h2_new, params['Wo']) + params['bo']
        
        return (h1_new, h2_new), output
    
    # Initialize both hidden states
    init_states = (jnp.zeros(hidden_size), jnp.zeros(hidden_size))
    
    # Process the entire sequence with a single scan
    _, outputs = lax.scan(
        f=two_layer_step,
        init=init_states,
        xs=x_seq
    )
    
    # Return the predicted next rows
    return outputs

# Modify parameter initialization for autoregressive task
def diagonal_init_params(key):
    """Initialize parameters for a 2-layer diagonal RNN with SwiGLU between layers."""
    hidden_size = 128
    features = 28   # MNIST image width
    output_size = 28  # Output is a full row of pixels
    intermediate_size = 128  # Size for the SwiGLU intermediate representation
    
    # Initialization scale
    init_scale = 0.01
    
    k1, k2, k3, k4, k5 = random.split(key, 5)
    
    return {
        # First layer parameters
        'Wxh': random.normal(k1, (features, hidden_size)) * init_scale,
        
        # Shared diagonal weights for recurrent connections in both layers
        'diag_weights': random.normal(k2, (hidden_size,)) * init_scale,
        
        # SwiGLU parameters for layer-to-layer transition
        'W_gate': random.normal(k3, (hidden_size, intermediate_size)) * init_scale,
        'W_linear': random.normal(k4, (hidden_size, intermediate_size)) * init_scale,
        'W_out': random.normal(k5, (intermediate_size, hidden_size)) * init_scale,
        
        # Output layer (predicts a row of pixels)
        'Wo': random.normal(k2, (hidden_size, output_size)) * init_scale,
        'bo': jnp.zeros(output_size)
    }

# Core evaluation function modified for adaptive autoregressive task
def evaluate_model(init_params_fn, forward_fn, epochs=5, batch_size=64, learning_rate=0.001):
    """
    Evaluate an RNN architecture for autoregressive image generation with adaptive closest match approach.
    
    Args:
        init_params_fn: Function that initializes model parameters
        forward_fn: Function that performs forward pass on a batch of data
        epochs: Number of training epochs
        batch_size: Training batch size
        learning_rate: Learning rate for optimizer
        
    Returns:
        trained_params: The trained model parameters
        metrics: Dictionary containing training history and evaluation metrics
    """
    
    x_train, x_test = load_mnist_data_autoregressive()
    
    # Get PRNGKey for reproducibility
    key = random.PRNGKey(42)
    
    # Initialize parameters
    print("Initializing model parameters...")
    params = init_params_fn(key)
    
    # Create optimizer
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)
    
    # JIT-compile the forward function
    batch_forward = jax.jit(jax.vmap(forward_fn, in_axes=(None, 0)))
    
    # Define update step with adaptive loss
    @jax.jit
    def update_step(params, opt_state, x_batch):
        """Perform a single update step using adaptive closest match loss."""
        def batch_loss(p):
            return adaptive_loss_fn(p, forward_fn, batch_forward, x_batch)
            
        loss_value, grads = jax.value_and_grad(batch_loss)(params)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss_value
    
    # Training loop
    print(f"Starting training for {epochs} epochs...")
    start_time = time.time()
    
    train_losses = []
    train_psnrs = []
    val_psnrs = []
    
    # Use a smaller validation set
    val_size = 1000
    val_x = x_test[:val_size]
    
    for epoch in range(epochs):
        epoch_losses = []
        
        # Number of batches per epoch
        num_batches = x_train.shape[0] // batch_size
        
        for batch in range(num_batches):
            # Get batch
            x_batch, key = get_batch(key, x_train, batch_size)
            
            # Update parameters
            params, opt_state, loss_value = update_step(params, opt_state, x_batch)
            epoch_losses.append(loss_value)
            
            # Print progress occasionally
            if batch % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {batch}/{num_batches}, Loss: {loss_value:.4f}")
        
        # Compute epoch metrics
        avg_loss = jnp.mean(jnp.array(epoch_losses))
        train_losses.append(float(avg_loss))
        
        # Compute validation PSNR with adaptive metric
        val_psnr = adaptive_psnr_fn(params, forward_fn, batch_forward, val_x)
        val_psnrs.append(float(val_psnr))
        
        print(f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {avg_loss:.4f}, Val PSNR: {val_psnr:.2f} dB")
    
    # Compute final test PSNR using adaptive metric
    test_psnr = adaptive_psnr_fn(params, forward_fn, batch_forward, x_test)
    print(f"Final test PSNR: {test_psnr:.2f} dB")
    
    # Compute training time
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds.")
    
    # Return the trained parameters and metrics
    metrics = {
        'train_losses': train_losses,
        'val_psnrs': val_psnrs,
        'test_psnr': float(test_psnr),
        'training_time': training_time
    }
    
    return params, metrics

# Generate MNIST images from scratch
def generate_images_from_scratch(params, forward_fn, num_images=5, seed=42):
    """
    Generate MNIST images completely from scratch using the trained RNN model.
    Uses a small initial hidden state and generates all rows including the first.
    
    Args:
        params: Trained model parameters
        forward_fn: Model forward function
        num_images: Number of images to generate
        seed: Random seed
        
    Returns:
        Array of generated images
    """
    # Initialize random key
    key = random.PRNGKey(seed)
    
    # Initialize images with zeros
    generated_images = jnp.zeros((num_images, 28, 28))
    
    # Get hidden state dimensions from params
    hidden_size = params['diag_weights'].shape[0]
    
    # Define a function that generates a row directly from the hidden state
    def generate_row_from_hidden(params, hidden_state):
        """Generate a row directly from a hidden state."""
        # Apply SwiGLU to transform hidden state
        hidden_transformed = swiglu(hidden_state[0], params)
        
        # Generate output
        output = jnp.dot(hidden_transformed, params['Wo']) + params['bo']
        return jnp.clip(output, 0.0, 1.0)  # Clip to valid pixel range
    
    # For each image
    for j in range(num_images):
        # Create a small initial hidden state
        key, subkey = random.split(key)
        h1 = random.normal(subkey, (hidden_size,)) * 0.01
        
        key, subkey = random.split(key)
        h2 = random.normal(subkey, (hidden_size,)) * 0.01
        
        # Generate the first row directly from hidden state
        first_row = generate_row_from_hidden(params, (h1, h2))
        generated_images = generated_images.at[j, 0, :].set(first_row)
        
        # Now use the RNN to generate subsequent rows
        # Create a single-sample forward function
        def forward_one_row(params, prev_rows):
            """Forward pass to generate a single row based on previous rows."""
            prediction = forward_fn(params, prev_rows)
            return prediction[-1]
        
        # Generate rows 1-27 autoregressively
        for i in range(1, 28):
            # Get previously generated rows
            prev_rows = generated_images[j, :i, :]
            
            # Predict the next row
            next_row = forward_one_row(params, prev_rows)
            
            # Update the image with the new row
            generated_images = generated_images.at[j, i, :].set(next_row)
    
    return generated_images

# Visualize generated images with their closest MNIST matches
def visualize_generations_with_closest_matches(generated_images, mnist_images):
    """
    Visualize comparison between generated images and their closest MNIST matches.
    
    Args:
        generated_images: Images generated from scratch
        mnist_images: Dataset of real MNIST images to find matches in
    """
    num_to_show = len(generated_images)
    
    # Function to find closest match for a generated image
    def find_closest_match(generated_img, real_imgs):
        # Calculate mean squared error between the generated image and all real images
        mse = np.mean((real_imgs - generated_img[np.newaxis, :, :]) ** 2, axis=(1, 2))
        # Find the index of the closest match
        closest_idx = np.argmin(mse)
        return closest_idx, mse[closest_idx]
    
    plt.figure(figsize=(12, 2*num_to_show))
    
    for i in range(num_to_show):
        # Find the closest match in the MNIST dataset
        closest_idx, mse_score = find_closest_match(np.array(generated_images[i]), np.array(mnist_images))
        
        # Generated image
        plt.subplot(num_to_show, 2, 2*i + 1)
        plt.imshow(generated_images[i], cmap='gray')
        plt.title(f"Generated")
        plt.axis('off')
        
        # Closest match
        plt.subplot(num_to_show, 2, 2*i + 2)
        plt.imshow(mnist_images[closest_idx], cmap='gray')
        plt.title(f"Closest MNIST Match\nMSE: {mse_score:.4f}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize row-by-row generation
def visualize_generation_process(params, forward_fn, seed=42):
    """
    Visualize the row-by-row generation process of a single image from scratch.
    """
    # Generate a single image and show the progression
    key = random.PRNGKey(seed)
    
    # Initialize image with zeros
    generated_image = jnp.zeros((28, 28))
    
    # Get hidden state dimensions from params
    hidden_size = params['diag_weights'].shape[0]
    
    # Create a small initial hidden state
    key, subkey = random.split(key)
    h1 = random.normal(subkey, (hidden_size,)) * 0.01
    
    key, subkey = random.split(key)
    h2 = random.normal(subkey, (hidden_size,)) * 0.01
    
    # Define a function that generates a row directly from the hidden state
    def generate_row_from_hidden(params, hidden_state):
        """Generate a row directly from a hidden state."""
        # Apply SwiGLU to transform hidden state
        hidden_transformed = swiglu(hidden_state[0], params)
        
        # Generate output
        output = jnp.dot(hidden_transformed, params['Wo']) + params['bo']
        return jnp.clip(output, 0.0, 1.0)  # Clip to valid pixel range
    
    # Generate the first row directly from hidden state
    first_row = generate_row_from_hidden(params, (h1, h2))
    generated_image = generated_image.at[0, :].set(first_row)
    
    # Create a function to generate a single row
    def forward_one_row(params, prev_rows):
        """Forward pass to generate a single row based on previous rows."""
        prediction = forward_fn(params, prev_rows)
        return prediction[-1]
    
    # Generate each row autoregressively and save intermediate states
    intermediate_images = [generated_image.copy()]
    
    for i in range(1, 28):
        # Get previously generated rows
        prev_rows = generated_image[:i, :]
        
        # Predict the next row
        next_row = forward_one_row(params, prev_rows)
        
        # Update the image with the new row
        generated_image = generated_image.at[i, :].set(next_row)
        
        # Save the intermediate state
        intermediate_images.append(generated_image.copy())
    
    # Plot selected intermediate states
    steps_to_show = [0, 4, 9, 14, 19, 27]  # First row, 25%, 50%, 75%, 100%
    
    plt.figure(figsize=(15, 3))
    for i, step in enumerate(steps_to_show):
        plt.subplot(1, len(steps_to_show), i+1)
        plt.imshow(intermediate_images[step], cmap='gray')
        plt.title(f"After {step+1} rows")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage
if __name__ == "__main__":
    print("Two-Layer Diagonal RNN for Autoregressive MNIST Generation")
    print("--------------------------------------------------------")
    
    forward_fn = two_layer_fused_scan_forward_fn
    
    # Evaluate model
    trained_params, metrics = evaluate_model(
        init_params_fn=diagonal_init_params,
        forward_fn=forward_fn,
        epochs=40,
        batch_size=64,
        learning_rate=0.0005
    )
    
    print("Evaluation complete!")
    print(f"Test PSNR: {metrics['test_psnr']:.2f} dB")
    print(f"Training time: {metrics['training_time']:.2f} seconds")
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(metrics['train_losses'])
    plt.title('Training Loss (MSE)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(metrics['val_psnrs'])
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.tight_layout()
    plt.show()
    
    # Generate images completely from scratch
    print("\nGenerating new images completely from scratch...")
    generated_images = generate_images_from_scratch(
        params=trained_params, 
        forward_fn=forward_fn,
        num_images=5
    )
    
    # Get full MNIST dataset for finding closest matches
    _, _, x_test, _ = load_mnist_data()
    
    # Visualize generated vs closest MNIST images
    visualize_generations_with_closest_matches(generated_images, x_test)
    
    # Visualize the row-by-row generation process
    print("\nVisualizing the row-by-row generation process...")
    visualize_generation_process(trained_params, forward_fn)

Two-Layer Diagonal RNN for Autoregressive MNIST Generation
--------------------------------------------------------
Loading MNIST dataset from Hugging Face with JAX format...


x_train shape: (60000, 28, 28)
y_train shape: (60000, 10)
x_test shape: (10000, 28, 28)
y_test shape: (10000, 10)
Initializing model parameters...
Starting training for 40 epochs...


: 