# **Trainable Distributions**

## üìë Table of Contents

- **1. [Foundation - What Are Trainable Distributions?](#1--foundation---what-are-trainable-distributions)**
    - 1.1 [The Big Picture](#11-the-big-picture)

- **2. [Setup & Your First Trainable Distribution](#2--setup--your-first-trainable-distribution)**
    - 2.1 [Essential Imports & Setup](#21-essential-imports--setup)  
    - 2.2 [Creating Your First Trainable Distribution](#22-creating-your-first-trainable-distribution)

- **3. [Training Trainable Distributions](#3--training-trainable-distributions)**
    - 3.1 [Complete Training Example - Learning from Data](#31-complete-training-example---learning-from-data)  
    - 3.2 [Loss Function and Training Setup](#32-loss-function-and-training-setup)  
    - 3.3 [Training Loop with Progress Tracking](#33-training-loop-with-progress-tracking)

- **4. [Advanced Trainable Distribution Patterns](#4--advanced-trainable-distribution-patterns)**
    - 4.1 [Multiple Trainable Parameters](#41-multiple-trainable-parameters)  
    - 4.2 [Trainable Multivariate Distributions](#42-trainable-multivariate-distributions)

- **5. [Probabilistic Neural Network Integration](#5-Ô∏è-probabilistic-neural-network-integration)**
    - 5.1 [Trainable Distribution Layers](#51-trainable-distribution-layers)  
    - 5.2 [Custom Training with Uncertainty](#52-custom-training-with-uncertainty)

- **6. [Expert Applications](#6--expert-applications)**
    - 6.1 [Bayesian Neural Networks with Trainable Priors](#61-bayesian-neural-networks-with-trainable-priors)  
    - 6.2 [Generative Models with Trainable Distributions](#62-generative-models-with-trainable-distributions)

- **7. [Complete Reference Guide](#7--complete-reference-guide)**
    - 7.1 [Trainable Distribution Creation Patterns](#71-trainable-distribution-creation-patterns)  
    - 7.2 [Training Patterns Reference](#72-training-patterns-reference)  
    - 7.3 [Common Parameter Constraints](#73-common-parameter-constraints)

- **[Final Notes](#final-notes)**

## 1. **Foundation - What Are Trainable Distributions?**

### 1.1 **The Big Picture**

Trainable distributions are TensorFlow Probability distributions whose **parameters can be learned through gradient-based optimization**. Instead of fixed parameters, these distributions have parameters represented as `tf.Variable` objects that can be updated during training, enabling the model to learn the underlying data distribution automatically.

**Key Insight**: Trainable distributions transform static probability models into **dynamic, learnable components** that can be integrated seamlessly into neural networks and optimized using standard deep learning techniques like backpropagation.

**The Core Transformation**:
- **Static**: `tfd.Normal(loc=0., scale=1.)` ‚Üí Fixed parameters
- **Trainable**: `tfd.Normal(loc=tf.Variable(0.), scale=tf.Variable(1.))` ‚Üí Learnable parameters

**Why Revolutionary**: This enables **probabilistic deep learning** where uncertainty quantification, generative modeling, and Bayesian inference become end-to-end differentiable.

## 2. **Setup & Your First Trainable Distribution**

### 2.1 **Essential Imports & Setup**

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

# Standard alias for distributions
tfd = tfp.distributions

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

### 2.2 **Creating Your First Trainable Distribution**

In [3]:
# Create trainable Normal distribution
normal = tfd.Normal(loc=tf.Variable(0., name='loc'), scale=1.)

# Check trainable variables
print("Trainable variables:", normal.trainable_variables)
# Output: (<tf.Variable 'loc:0' shape=() dtype=float32, numpy=0.0>,)

print("Distribution:", normal)
print("Current location parameter:", normal.loc.numpy())  # 0.0
print("Fixed scale parameter:", normal.scale.numpy())     # 1.0


Trainable variables: (<tf.Variable 'loc:0' shape=() dtype=float32, numpy=0.0>,)
Distribution: tfp.distributions.Normal("Normal", batch_shape=[], event_shape=[], dtype=float32)
Current location parameter: 0.0
Fixed scale parameter: 1.0


**Key Components**:
- **`tf.Variable(0., name='loc')`**: Trainable location parameter, initialized to 0
- **`scale=1.`**: Fixed scale parameter (not trainable)
- **`normal.trainable_variables`**: Tuple containing all trainable parameters


## 3. **Training Trainable Distributions**

### 3.1 **Complete Training Example - Learning from Data**

In [4]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

tfd = tfp.distributions

# Create some sample data (you can replace this with your actual data)
np.random.seed(42)  # For reproducibility
x_samples = tf.constant(np.random.normal(2.0, 1.5, 1000), dtype=tf.float32)

print(f"Sample data statistics:")
print(f"  True mean: {tf.reduce_mean(x_samples).numpy():.4f}")
print(f"  True std: {tf.math.reduce_std(x_samples).numpy():.4f}")
print(f"  Sample size: {len(x_samples)}")

# Define number of training steps
num_steps = 1000

# Create trainable distribution
normal = tfd.Normal(loc=tf.Variable(0., name='loc'), scale=1.)
print(f"Initial location parameter: {normal.loc.numpy():.4f}")

Sample data statistics:
  True mean: 2.0290
  True std: 1.4681
  Sample size: 1000
Initial location parameter: 0.0000


### 3.2 **Loss Function and Training Setup**

In [5]:
# Define negative log likelihood (NLL) loss function
def nll(x_train):  # Defining the negative log likelihood(nll)
    """
    Compute negative log-likelihood of data under current distribution
    """
    log_likelihood = normal.log_prob(x_train)  # Shape: (1000,)
    mean_log_likelihood = tf.reduce_mean(log_likelihood)
    return -mean_log_likelihood  # Negative for minimization

"""
This function get_loss_and_grads takes a batch of training examples as an input
and computes the loss and gradients for our model.
"""
@tf.function  # Compile for performance
def get_loss_and_grads(x_train):
    with tf.GradientTape() as tape:
        tape.watch(normal.trainable_variables)
        loss = nll(x_train)
    grads = tape.gradient(loss, normal.trainable_variables)
    return loss, grads

# Initialize optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)
print("Training setup complete!")

Training setup complete!


### 3.3 **Training Loop with Progress Tracking**

In [6]:
# Training loop with progress tracking
print("\nStarting training...")
print("=" * 50)

for step in range(num_steps):
    loss, grads = get_loss_and_grads(x_samples)
    optimizer.apply_gradients(zip(grads, normal.trainable_variables))
    
    # Print progress every 100 steps
    if step % 100 == 0:
        current_loc = normal.loc.numpy()
        print(f"Step {step:4d}: Loss = {loss:.4f}, Estimated loc = {current_loc:.4f}")

# Final results
print("=" * 50)
print(f"\nTraining completed!")
print(f"Final estimated location parameter: {normal.loc.numpy():.4f}")
print(f"True mean of sample data: {tf.reduce_mean(x_samples).numpy():.4f}")
print(f"Estimation error: {abs(normal.loc.numpy() - tf.reduce_mean(x_samples).numpy()):.4f}")


Starting training...
Step    0: Loss = 4.0550, Estimated loc = 0.1014
Step  100: Loss = 1.9967, Estimated loc = 2.0176
Step  200: Loss = 1.9966, Estimated loc = 2.0289
Step  300: Loss = 1.9966, Estimated loc = 2.0290
Step  400: Loss = 1.9966, Estimated loc = 2.0290
Step  500: Loss = 1.9966, Estimated loc = 2.0290
Step  600: Loss = 1.9966, Estimated loc = 2.0290
Step  700: Loss = 1.9966, Estimated loc = 2.0290
Step  800: Loss = 1.9966, Estimated loc = 2.0290
Step  900: Loss = 1.9966, Estimated loc = 2.0290

Training completed!
Final estimated location parameter: 2.0290
True mean of sample data: 2.0290
Estimation error: 0.0000


## 4. **Advanced Trainable Distribution Patterns**

### 4.1 **Multiple Trainable Parameters**

In [7]:
class FullyTrainableNormal:
    """
    Normal distribution with both loc and scale trainable
    """
    def __init__(self, initial_loc=0., initial_scale=1.):
        self.raw_loc = tf.Variable(initial_loc, name='raw_loc')
        self.raw_scale = tf.Variable(
            tf.math.log(tf.math.expm1(initial_scale)), 
            name='raw_scale'
        )
    
    @property
    def distribution(self):
        # Ensure scale is always positive using softplus
        loc = self.raw_loc
        scale = tf.nn.softplus(self.raw_scale) + 1e-6
        return tfd.Normal(loc=loc, scale=scale)
    
    @property
    def trainable_variables(self):
        return [self.raw_loc, self.raw_scale]
    
    def log_prob(self, x):
        return self.distribution.log_prob(x)
    
    def sample(self, sample_shape=()):
        return self.distribution.sample(sample_shape)

# Usage example
fully_trainable = FullyTrainableNormal(initial_loc=0., initial_scale=1.)
print("Trainable variables:", len(fully_trainable.trainable_variables))

# Training both parameters
def train_both_parameters(data, model, steps=500):
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
    
    @tf.function
    def train_step():
        with tf.GradientTape() as tape:
            loss = -tf.reduce_mean(model.log_prob(data))
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss
    
    for step in range(steps):
        loss = train_step()
        if step % 100 == 0:
            current_dist = model.distribution
            print(f"Step {step}: Loss = {loss:.4f}, "
                  f"loc = {current_dist.loc.numpy():.4f}, "
                  f"scale = {current_dist.scale.numpy():.4f}")
    
    return model

# Train on the same sample data
trained_model = train_both_parameters(x_samples, fully_trainable)
final_dist = trained_model.distribution
print(f"\nFinal parameters:")
print(f"  Learned loc: {final_dist.loc.numpy():.4f}")
print(f"  Learned scale: {final_dist.scale.numpy():.4f}")
print(f"  True std: {tf.math.reduce_std(x_samples).numpy():.4f}")

Trainable variables: 2
Step 0: Loss = 4.0550, loc = 0.0100, scale = 1.0063
Step 100: Loss = 2.1725, loc = 0.7820, scale = 1.4595
Step 200: Loss = 1.9307, loc = 1.2481, scale = 1.5905
Step 300: Loss = 1.8498, loc = 1.5779, scale = 1.5931
Step 400: Loss = 1.8167, loc = 1.8013, scale = 1.5502

Final parameters:
  Learned loc: 1.9312
  Learned scale: 1.5085
  True std: 1.4681


### 4.2 **Trainable Multivariate Distributions**

In [8]:
def create_trainable_multivariate_normal(event_size, use_full_covariance=False):
    """
    Create trainable multivariate normal distribution
    """
    # Trainable location parameter
    loc = tf.Variable(
        tf.zeros(event_size), 
        name='mvn_loc'
    )
    
    if use_full_covariance:
        # Full covariance matrix using Cholesky decomposition
        scale_tril_size = event_size * (event_size + 1) // 2
        raw_scale_tril = tf.Variable(
            tf.zeros(scale_tril_size),
            name='raw_scale_tril'
        )
        
        # Transform to valid lower triangular matrix
        scale_tril = tfp.bijectors.FillScaleTriL()(raw_scale_tril)
        distribution = tfd.MultivariateNormalTriL(
            loc=loc, 
            scale_tril=scale_tril
        )
    else:
        # Diagonal covariance (simpler)
        raw_scale_diag = tf.Variable(
            tf.zeros(event_size),
            name='raw_scale_diag'
        )
        
        scale_diag = tf.nn.softplus(raw_scale_diag) + 1e-6
        distribution = tfd.MultivariateNormalDiag(
            loc=loc,
            scale_diag=scale_diag
        )
    
    return distribution

# Example usage
mv_dist = create_trainable_multivariate_normal(event_size=3, use_full_covariance=False)
print("Multivariate trainable variables:", len(mv_dist.trainable_variables))
print("Variable shapes:", [var.shape for var in mv_dist.trainable_variables])

# Generate multivariate training data
true_mean = [1., -0.5, 2.]
true_cov = [[1., 0.3, 0.1], [0.3, 0.8, -0.2], [0.1, -0.2, 1.2]]
mv_data = tf.random.normal([500, 3]) @ tf.linalg.cholesky(true_cov) + true_mean

# Train multivariate distribution
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

@tf.function
def mv_train_step():
    with tf.GradientTape() as tape:
        loss = -tf.reduce_mean(mv_dist.log_prob(mv_data))
    grads = tape.gradient(loss, mv_dist.trainable_variables)
    optimizer.apply_gradients(zip(grads, mv_dist.trainable_variables))
    return loss

print("\nTraining multivariate distribution...")
for step in range(300):
    loss = mv_train_step()
    if step % 50 == 0:
        print(f"Step {step}: Loss = {loss:.4f}")

print(f"\nLearned mean: {mv_dist.mean().numpy()}")
print(f"True mean: {true_mean}")

Multivariate trainable variables: 1
Variable shapes: [TensorShape([3])]

Training multivariate distribution...
Step 0: Loss = 10.0395
Step 50: Loss = 7.3503
Step 100: Loss = 5.9957
Step 150: Loss = 5.3188
Step 200: Loss = 4.9917
Step 250: Loss = 4.8454

Learned mean: [ 0.9956153  -0.46135303  1.8014729 ]
True mean: [1.0, -0.5, 2.0]


## 5. **Probabilistic Neural Network Integration**

### 5.1 **Trainable Distribution Layers**

In [9]:
class ProbabilisticDense(tf.keras.layers.Layer):
    """
    Dense layer that outputs a trainable distribution
    """
    def __init__(self, units, distribution_type='normal', **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.distribution_type = distribution_type
        
    def build(self, input_shape):
        if self.distribution_type == 'normal':
            # Parameters for Normal distribution
            self.loc_layer = tf.keras.layers.Dense(self.units, name='loc_layer')
            self.scale_layer = tf.keras.layers.Dense(self.units, name='scale_layer')
        elif self.distribution_type == 'bernoulli':
            # Parameters for Bernoulli distribution
            self.logits_layer = tf.keras.layers.Dense(self.units, name='logits_layer')
        
        super().build(input_shape)
    
    def call(self, inputs):
        if self.distribution_type == 'normal':
            loc = self.loc_layer(inputs)
            raw_scale = self.scale_layer(inputs)
            scale = tf.nn.softplus(raw_scale) + 1e-6
            
            # Create batched normal distribution
            batched_normal = tfd.Normal(loc=loc, scale=scale)
            # Convert to multivariate for joint operations
            return tfd.Independent(batched_normal, reinterpreted_batch_ndims=1)
            
        elif self.distribution_type == 'bernoulli':
            logits = self.logits_layer(inputs)
            batched_bernoulli = tfd.Bernoulli(logits=logits)
            return tfd.Independent(batched_bernoulli, reinterpreted_batch_ndims=1)

# Example probabilistic model
class ProbabilisticRegressor(tf.keras.Model):
    def __init__(self, hidden_units=64, output_units=3):
        super().__init__()
        self.hidden1 = tf.keras.layers.Dense(hidden_units, activation='relu')
        self.hidden2 = tf.keras.layers.Dense(hidden_units, activation='relu')
        self.output_layer = ProbabilisticDense(output_units, distribution_type='normal')
    
    def call(self, inputs):
        x = self.hidden1(inputs)
        x = self.hidden2(x)
        return self.output_layer(x)  # Returns distribution, not point estimate

# Create and use probabilistic model
model = ProbabilisticRegressor(hidden_units=32, output_units=2)

### 5.2 **Custom Training with Uncertainty**

In [10]:
def train_probabilistic_model(model, x_train, y_train, epochs=50):
    """
    Train probabilistic model with negative log-likelihood loss
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    
    @tf.function
    def train_step(x_batch, y_batch):
        with tf.GradientTape() as tape:
            # Model outputs distribution
            pred_dist = model(x_batch, training=True)
            
            # Negative log-likelihood loss
            nll_loss = -tf.reduce_mean(pred_dist.log_prob(y_batch))
            
            # Add regularization
            reg_loss = sum(model.losses) if model.losses else 0.
            total_loss = nll_loss + 0.01 * reg_loss
        
        gradients = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        return total_loss, nll_loss
    
    # Training loop
    for epoch in range(epochs):
        total_loss, nll_loss = train_step(x_train, y_train)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch:2d}: Total Loss = {total_loss:.4f}, "
                  f"NLL = {nll_loss:.4f}")
    
    return model

# Generate synthetic regression data
n_samples = 1000
x_train = tf.random.normal([n_samples, 5])
true_weights = tf.constant([[1., -0.5], [0.8, 1.2], [0., 0.5], [-0.3, 0.], [0.7, -0.8]])
y_train = x_train @ true_weights + tf.random.normal([n_samples, 2]) * 0.1

# Train the model
print("Training probabilistic regressor...")
trained_model = train_probabilistic_model(model, x_train, y_train, epochs=100)

# Make probabilistic predictions
x_test = tf.random.normal([10, 5])
pred_dist = trained_model(x_test)

print(f"\nPredictions with uncertainty:")
print(f"Mean predictions shape: {pred_dist.mean().shape}")
print(f"Std predictions shape: {pred_dist.stddev().shape}")
print(f"Sample predictions:\n{pred_dist.sample(5).numpy()}")

Training probabilistic regressor...
Epoch  0: Total Loss = 9.2536, NLL = 9.2536
Epoch 10: Total Loss = 5.9630, NLL = 5.9630
Epoch 20: Total Loss = 4.5082, NLL = 4.5082
Epoch 30: Total Loss = 3.7724, NLL = 3.7724
Epoch 40: Total Loss = 3.3635, NLL = 3.3635
Epoch 50: Total Loss = 3.1161, NLL = 3.1161
Epoch 60: Total Loss = 2.9438, NLL = 2.9438
Epoch 70: Total Loss = 2.8014, NLL = 2.8014
Epoch 80: Total Loss = 2.6649, NLL = 2.6649
Epoch 90: Total Loss = 2.5198, NLL = 2.5198

Predictions with uncertainty:
Mean predictions shape: (10, 2)
Std predictions shape: (10, 2)
Sample predictions:
[[[ 0.82389325 -1.4043255 ]
  [-0.96353745  0.6602427 ]
  [ 1.7169242  -0.34885058]
  [ 1.7561765  -0.45557353]
  [-0.3937552   2.03724   ]
  [ 0.3635568  -0.01658844]
  [-0.39817008 -0.04280734]
  [-0.99177986 -0.29655802]
  [ 1.1118047  -1.3519115 ]
  [ 2.4875593  -1.9801378 ]]

 [[ 2.3026707   0.6278113 ]
  [ 0.03576642 -1.5997378 ]
  [ 0.31301594  0.61786246]
  [-0.17076135 -0.12737668]
  [ 1.9682679   

## 6. **Expert Applications**

### 6.1 **Bayesian Neural Networks with Trainable Priors**

In [11]:
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions

class BayesianLayer(tf.keras.layers.Layer):
    """
    Bayesian neural network layer with trainable weight distributions
    """
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        input_dim = input_shape[-1]
        
        # Prior parameters (trainable)
        self.weight_loc = tf.Variable(
            tf.random.normal([input_dim, self.units], stddev=0.1),
            name='weight_loc'
        )
        self.weight_raw_scale = tf.Variable(
            tf.random.normal([input_dim, self.units], stddev=0.1) - 1,
            name='weight_raw_scale'
        )
        
        self.bias_loc = tf.Variable(
            tf.zeros([self.units]),
            name='bias_loc'
        )
        self.bias_raw_scale = tf.Variable(
            tf.random.normal([self.units], stddev=0.1) - 1,
            name='bias_raw_scale'
        )
        
        super().build(input_shape)
    
    def call(self, inputs, training=None):
        # Create weight and bias distributions
        weight_scale = tf.nn.softplus(self.weight_raw_scale) + 1e-6
        bias_scale = tf.nn.softplus(self.bias_raw_scale) + 1e-6
        
        weight_dist = tfd.Normal(loc=self.weight_loc, scale=weight_scale)
        bias_dist = tfd.Normal(loc=self.bias_loc, scale=bias_scale)
        
        # Sample weights and biases
        weights = weight_dist.sample()
        biases = bias_dist.sample()
        
        # Standard linear transformation
        output = tf.matmul(inputs, weights) + biases
        
        if training:
            # Add KL divergence to losses for regularization
            prior_weight = tfd.Normal(loc=0., scale=1.)
            prior_bias = tfd.Normal(loc=0., scale=1.)
            
            weight_kl = tf.reduce_sum(tfd.kl_divergence(weight_dist, prior_weight))
            bias_kl = tf.reduce_sum(tfd.kl_divergence(bias_dist, prior_bias))
            
            self.add_loss(weight_kl + bias_kl)
        
        return output


# Bayesian neural network model
class BayesianNN(tf.keras.Model):
    def __init__(self, hidden_units=50, output_units=1):
        super().__init__()
        self.hidden = BayesianLayer(hidden_units)
        self.output_layer = BayesianLayer(output_units)
    
    def call(self, inputs, training=None):
        x = tf.nn.relu(self.hidden(inputs, training=training))
        return self.output_layer(x, training=training)


# Create and demonstrate Bayesian NN
bayesian_model = BayesianNN(hidden_units=20, output_units=1)

# Generate data
x_data = tf.linspace(-3., 3., 100)[:, None]
y_data = 0.5 * x_data**3 + tf.random.normal([100, 1]) * 0.3

# BUILD THE MODEL - This is crucial!
_ = bayesian_model(x_data[:1], training=True)

print("Training Bayesian Neural Network...")
optimizer = tf.keras.optimizers.Adam(0.01)

for step in range(500):
    with tf.GradientTape() as tape:
        # Clear accumulated losses from previous iterations
        bayesian_model.losses.clear()
        
        pred = bayesian_model(x_data, training=True)
        mse_loss = tf.reduce_mean((pred - y_data)**2)
        kl_loss = sum(bayesian_model.losses)
        total_loss = mse_loss + 0.01 * kl_loss
    
    grads = tape.gradient(total_loss, bayesian_model.trainable_variables)
    
    # Filter out None gradients and create valid gradient-variable pairs
    grads_and_vars = [
        (grad, var) for grad, var in zip(grads, bayesian_model.trainable_variables) 
        if grad is not None
    ]
    
    # Only apply gradients if we have valid gradients
    if grads_and_vars:
        optimizer.apply_gradients(grads_and_vars)
    
    if step % 100 == 0:
        print(f"Step {step}: MSE = {mse_loss:.4f}, KL = {kl_loss:.4f}")

# Uncertainty quantification
predictions = []
for _ in range(100):
    pred = bayesian_model(x_data, training=False)
    predictions.append(pred)

predictions = tf.stack(predictions)
mean_pred = tf.reduce_mean(predictions, axis=0)
std_pred = tf.math.reduce_std(predictions, axis=0)

print(f"\nPredictive uncertainty shape: {std_pred.shape}")
print(f"Average uncertainty: {tf.reduce_mean(std_pred):.4f}")

Training Bayesian Neural Network...
Step 0: MSE = 26.2266, KL = 42.2583
Step 100: MSE = 24.5402, KL = 42.2583
Step 200: MSE = 30.4602, KL = 42.2583
Step 300: MSE = 30.4573, KL = 42.2583
Step 400: MSE = 25.6658, KL = 42.2583

Predictive uncertainty shape: (100, 1)
Average uncertainty: 0.6459


### 6.2 **Generative Models with Trainable Distributions**

In [12]:
class SimpleVAE(tf.keras.Model):
    """
    Variational Autoencoder with trainable encoder and decoder distributions
    """
    def __init__(self, latent_dim=10, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder_hidden = tf.keras.layers.Dense(50, activation='relu')
        self.encoder_loc = tf.keras.layers.Dense(latent_dim)
        self.encoder_raw_scale = tf.keras.layers.Dense(latent_dim)
        
        # Decoder  
        self.decoder_hidden = tf.keras.layers.Dense(50, activation='relu')
        self.decoder_logits = tf.keras.layers.Dense(784)  # For MNIST-like data
    
    def encode(self, x):
        """Encode input to latent distribution parameters"""
        h = self.encoder_hidden(x)
        loc = self.encoder_loc(h)
        raw_scale = self.encoder_raw_scale(h)
        scale = tf.nn.softplus(raw_scale) + 1e-6
        
        return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale)
    
    def decode(self, z):
        """Decode latent samples to reconstruction distribution"""
        h = self.decoder_hidden(z)
        logits = self.decoder_logits(h)
        
        return tfd.Independent(
            tfd.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=1
        )
    
    def call(self, inputs):
        # Encode
        posterior = self.encode(inputs)
        
        # Sample from posterior
        z = posterior.sample()
        
        # Decode
        reconstruction_dist = self.decode(z)
        
        return posterior, reconstruction_dist
    
    def compute_loss(self, x):
        """Compute VAE loss (ELBO)"""
        posterior, reconstruction_dist = self(x)
        
        # Sample from posterior for reconstruction
        z = posterior.sample()
        
        # Reconstruction loss
        reconstruction_loss = -tf.reduce_mean(reconstruction_dist.log_prob(x))
        
        # KL divergence
        prior = tfd.MultivariateNormalDiag(
            loc=tf.zeros(self.latent_dim),
            scale_diag=tf.ones(self.latent_dim)
        )
        kl_loss = tf.reduce_mean(tfd.kl_divergence(posterior, prior))
        
        # ELBO = -reconstruction_loss - kl_loss (we minimize negative ELBO)
        return reconstruction_loss + kl_loss, reconstruction_loss, kl_loss

# Example usage (with dummy data)
vae = SimpleVAE(latent_dim=5)
dummy_data = tf.random.uniform([64, 784], maxval=1.0)  # Batch of binary images

total_loss, recon_loss, kl_loss = vae.compute_loss(dummy_data)
print(f"VAE Loss components:")
print(f"  Total: {total_loss:.4f}")
print(f"  Reconstruction: {recon_loss:.4f}")
print(f"  KL Divergence: {kl_loss:.4f}")

VAE Loss components:
  Total: 547.0978
  Reconstruction: 544.7975
  KL Divergence: 2.3003


## 7. **Complete Reference Guide**

### 7.1 **Trainable Distribution Creation Patterns**

In [13]:
# === BASIC TRAINABLE DISTRIBUTIONS ===
# Single trainable parameter
tfd.Normal(loc=tf.Variable(0.), scale=1.)
tfd.Exponential(rate=tf.Variable(1.))
tfd.Beta(concentration1=tf.Variable(1.), concentration0=tf.Variable(1.))

# Multiple trainable parameters with constraints
def safe_normal():
    return tfd.Normal(
        loc=tf.Variable(0.),
        scale=tf.nn.softplus(tf.Variable(0.)) + 1e-6
    )

# === ADVANCED TRAINABLE PATTERNS ===
# Using TransformedVariable for constraints
tfd.Normal(
    loc=tf.Variable(0.),
    scale=tfp.util.TransformedVariable(
        1., bijector=tfp.bijectors.Exp(), name='scale'
    )
)

# Multivariate with trainable parameters
def trainable_multivariate_normal(event_size):
    return tfd.MultivariateNormalDiag(
        loc=tf.Variable(tf.zeros(event_size)),
        scale_diag=tf.nn.softplus(tf.Variable(tf.zeros(event_size))) + 1e-6
    )

# === NEURAL NETWORK INTEGRATION ===
# Probabilistic layer outputs
def probabilistic_output_layer(inputs, output_size):
    loc = tf.keras.layers.Dense(output_size)(inputs)
    raw_scale = tf.keras.layers.Dense(output_size)(inputs)
    scale = tf.nn.softplus(raw_scale) + 1e-6
    
    return tfd.Independent(
        tfd.Normal(loc=loc, scale=scale),
        reinterpreted_batch_ndims=1
    )

### 7.2 **Training Patterns Reference**

In [14]:
# === BASIC TRAINING SETUP ===
def setup_basic_training(distribution, data):
    optimizer = tf.keras.optimizers.Adam(0.01)
    
    @tf.function
    def train_step():
        with tf.GradientTape() as tape:
            loss = -tf.reduce_mean(distribution.log_prob(data))
        grads = tape.gradient(loss, distribution.trainable_variables)
        optimizer.apply_gradients(zip(grads, distribution.trainable_variables))
        return loss
    
    return train_step

# === ADVANCED TRAINING WITH REGULARIZATION ===
def setup_regularized_training(model, data, kl_weight=0.01):
    optimizer = tf.keras.optimizers.Adam(0.001)
    
    @tf.function
    def train_step():
        with tf.GradientTape() as tape:
            dist = model(data, training=True)
            nll_loss = -tf.reduce_mean(dist.log_prob(data))
            reg_loss = sum(model.losses)
            total_loss = nll_loss + kl_weight * reg_loss
        
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return total_loss, nll_loss, reg_loss
    
    return train_step

# === CUSTOM LOSS FUNCTIONS ===
class NegativeLogLikelihood(tf.keras.losses.Loss):
    def call(self, y_true, y_pred_dist):
        return -tf.reduce_mean(y_pred_dist.log_prob(y_true))

class ELBOLoss(tf.keras.losses.Loss):
    def __init__(self, kl_weight=1.0, **kwargs):
        super().__init__(**kwargs)
        self.kl_weight = kl_weight
    
    def call(self, y_true, outputs):
        reconstruction_dist, posterior, prior = outputs
        
        # Reconstruction term
        recon_loss = -tf.reduce_mean(reconstruction_dist.log_prob(y_true))
        
        # KL divergence
        kl_loss = tf.reduce_mean(tfd.kl_divergence(posterior, prior))
        
        return recon_loss + self.kl_weight * kl_loss

### 7.3 **Common Parameter Constraints**

| **Parameter Type** | **Constraint** | **Implementation** | **Use Case** |
|-------------------|----------------|-------------------|--------------|
| **Scale/Std** | Positive | `tf.nn.softplus(raw) + 1e-6` | Normal, Exponential |
| **Probability** | [4] | `tf.nn.sigmoid(raw)` | Bernoulli, Beta |
| **Concentration** | Positive | `tf.nn.softplus(raw) + 1e-6` | Gamma, Dirichlet |
| **Rate** | Positive | `tf.exp(raw)` | Exponential, Poisson |
| **Correlation** | [-1,1] | `tf.nn.tanh(raw)` | Correlation matrices |

## üí° **Final Notes**

- **The Trainable Revolution**: Trainable distributions represent the **convergence of classical statistics and modern deep learning**. They transform static probability models into dynamic, learnable components that can discover patterns in data automatically.

- **Parameter Constraints Are Critical**: Always ensure parameters satisfy their mathematical constraints (e.g., scales > 0, probabilities in ). Use appropriate transformations like `softplus`, `sigmoid`, or `TransformedVariable`.

- **Gradient Flow**: Trainable distributions enable **end-to-end differentiable probabilistic models**. The gradient flows through sampling operations (with reparameterization) and log_prob evaluations seamlessly.

- **Loss Function Design**: Negative log-likelihood is your fundamental loss. Add regularization (KL divergence, priors) to prevent overfitting and encourage meaningful parameter values.

- **Numerical Stability**: Always add small epsilon values (`1e-6`) to transformed parameters to avoid numerical issues. Monitor gradients for NaN or explosion.

- **The Bayesian Advantage**: Trainable distributions naturally incorporate uncertainty. Unlike deterministic models that give point estimates, trainable distributions provide **full uncertainty quantification**.

- **Start Simple, Scale Complex**: Begin with single-parameter distributions, then progress to multivariate, then neural integration, then Bayesian networks. Each level builds on the previous.

- **Memory Considerations**: Trainable distributions can be memory-intensive during training due to gradient computation. Use techniques like gradient accumulation for large models.

>Trainable distributions are the **gateway to modern probabilistic deep learning**. They enable models that not only make predictions but also quantify their confidence‚Äîessential for high-stakes applications where uncertainty matters.
