# **Multivariate Distributions**

## 📑 Table of Contents

- **1. [Foundation - What Are Multivariate Distributions?](#1-foundation---what-are-multivariate-distributions)**
    - 1.1 [The Big Picture](#11-the-big-picture)

- **2. [Setup & First Steps](#2-setup--first-steps)**
    - 2.1 [Essential Imports & Setup](#21-essential-imports--setup)  
    - 2.2 [Your First Multivariate Distribution - The Multivariate Normal](#22-your-first-multivariate-distribution---the-multivariate-normal)

- **3. [Multivariate vs Univariate Batching](#3-multivariate-vs-univariate-batching)**
    - 3.1 [The Critical Distinction: Multivariate vs Batched Univariate](#31-the-critical-distinction-multivariate-vs-batched-univariate)

- **4. [Log Probability Computations](#4-log-probability-computations)**
    - 4.1 [Understanding Joint vs Independent Probabilities](#41-understanding-joint-vs-independent-probabilities)

- **5. [Batch Multivariate Distributions](#5-batch-multivariate-distributions)**
    - 5.1 [Multiple Multivariate Distributions](#51-multiple-multivariate-distributions)  
    - 5.2 [Sampling from Batched Multivariate](#52-sampling-from-batched-multivariate)

- **6. [Advanced Shape Analysis](#6-advanced-shape-analysis)**
    - 6.1 [Complex Batch-Event Shape Interactions](#61-complex-batch-event-shape-interactions)  
    - 6.2 [Shape Rules Summary](#62-shape-rules-summary)

- **7. [Types of Multivariate Distributions](#7-types-of-multivariate-distributions)**
    - 7.1 [MultivariateNormal Family](#71-multivariatenormal-family)  
    - 7.2 [Other Multivariate Distributions](#72-other-multivariate-distributions)

- **8. [Professional Patterns & Best Practices](#8-professional-patterns--best-practices)**
    - 8.1 [Efficient Covariance Parameterization](#81-efficient-covariance-parameterization)  
    - 8.2 [Covariance Matrix Validation](#82-covariance-matrix-validation)  
    - 8.3 [Loss Functions for Multivariate Outputs](#83-loss-functions-for-multivariate-outputs)

- **9. [Expert Applications](#9-expert-applications)**
    - 9.1 [Uncertainty Quantification for Multivariate Outputs](#91-uncertainty-quantification-for-multivariate-outputs)  
    - 9.2 [Custom Multivariate Distribution](#92-custom-multivariate-distribution)  
    - 9.3 [Correlation Analysis](#93-correlation-analysis)

- **10. [Complete Reference Guide](#10-complete-reference-guide)**
    - 10.1 [Multivariate Distribution Constructor Cheat Sheet](#101-multivariate-distribution-constructor-cheat-sheet)  
    - 10.2 [Essential Multivariate Operations Reference](#102-essential-multivariate-operations-reference)

- **[Final Notes](#final-notes)**

## 1. **Foundation - What Are Multivariate Distributions?**

### 1.1 **The Big Picture**
Multivariate distributions are **multi-dimensional random variables** that model relationships between multiple correlated features simultaneously. Unlike univariate distributions that output scalars, multivariate distributions produce **vectors** where each component can be dependent on others.[1][2]

**Key Insight**: Multivariate distributions capture **joint probability relationships** between multiple variables, enabling modeling of correlation, covariance, and complex dependencies that univariate distributions cannot represent.

**Critical Difference**: 
- **Univariate**: `event_shape = []` (scalar outcomes)
- **Multivariate**: `event_shape = [k]` (k-dimensional vector outcomes)



## 2. **Setup & First Steps**

### 2.1 **Essential Imports & Setup**

In [53]:
# Importing TensorFlow Probability library
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

# Standard alias for distributions
tfd = tfp.distributions

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

### 2.2 **Your First Multivariate Distribution - The Multivariate Normal**

In [54]:
# Defining our first multivariate distribution object
"""
2D Multivariate Normal distribution with diagonal covariance
Mean: [-1.0, 0.5], Standard deviations: [1.0, 1.5]
"""
mv_normal = tfd.MultivariateNormalDiag(loc=[-1., 0.5], scale_diag=[1., 1.5])
print(mv_normal)
# Output: <tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[2] dtype=float32>

# Inspect the distribution properties
print("Batch shape:", mv_normal.batch_shape)  # []
print("Event shape:", mv_normal.event_shape)  # [2]
print("Data type:", mv_normal.dtype)          # float32

tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[], event_shape=[2], dtype=float32)
Batch shape: ()
Event shape: (2,)
Data type: <dtype: 'float32'>


**Key Properties Explained:**
- **`batch_shape=[]`**: Single distribution (no batching)
- **`event_shape=`**: 2-dimensional vectors (bivariate)[2]
- **`dtype=float32`**: Numerical precision

## 3. **Multivariate vs Univariate Batching**

### 3.1 **The Critical Distinction: Multivariate vs Batched Univariate**

In [55]:
# MULTIVARIATE DISTRIBUTION - Single 2D distribution
mv_normal = tfd.MultivariateNormalDiag(loc=[-1., 0.5], scale_diag=[1., 1.5])
print("Multivariate Normal:", mv_normal)
# Output: batch_shape=[], event_shape=[2]

# Sample from multivariate (returns 2D vectors)
mv_samples = mv_normal.sample(3)
print("Multivariate samples shape:", mv_samples.shape)  # (3, 2)
print("Multivariate samples:\n", mv_samples)

# Output: Each row is a 2D sample from the joint distribution

Multivariate Normal: tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[], event_shape=[2], dtype=float32)
Multivariate samples shape: (3, 2)
Multivariate samples:
 tf.Tensor(
[[-0.6725315  -0.76393867]
 [-0.6805663  -1.6113279 ]
 [-3.3880599  -1.0588717 ]], shape=(3, 2), dtype=float32)


In [56]:
# BATCHED UNIVARIATE DISTRIBUTION - Two independent 1D distributions  
batched_normal = tfd.Normal(loc=[-1., 0.5], scale=[1., 1.5])
print("Batched Normal:", batched_normal)
# Output: batch_shape=[2], event_shape=[]

# Sample from batched univariate (returns independent scalars)
batch_samples = batched_normal.sample(3)
print("Batched samples shape:", batch_samples.shape)  # (3, 2)
print("Batched samples:\n", batch_samples)


# Output: Each column is independent samples from separate distributions

Batched Normal: tfp.distributions.Normal("Normal", batch_shape=[2], event_shape=[], dtype=float32)
Batched samples shape: (3, 2)
Batched samples:
 tf.Tensor(
[[-0.9157754  -0.7913556 ]
 [-0.62187696  0.4922056 ]
 [-1.494532    1.4267287 ]], shape=(3, 2), dtype=float32)


**🎯 Key Difference:**
- **Multivariate**: Models **joint relationships** between variables
- **Batched Univariate**: Models **independent** variables processed together

## 4. **Log Probability Computations**

### 4.1 **Understanding Joint vs Independent Probabilities**

In [57]:
# MULTIVARIATE LOG PROBABILITY - Joint probability
mv_normal = tfd.MultivariateNormalDiag(loc=[-1., 0.5], scale_diag=[1., 1.5])
mv_log_prob = mv_normal.log_prob([0.2, -1.8])
print("Multivariate log prob:", mv_log_prob)
print("Shape:", mv_log_prob.shape)  # ()
# Output: Single scalar - joint log probability of the vector [0.2, -1.8]

# BATCHED UNIVARIATE LOG PROBABILITY - Independent probabilities  
batched_normal = tfd.Normal(loc=[-1., 0.5], scale=[1., 1.5])
batch_log_prob = batched_normal.log_prob([0.2, -1.8])
print("Batched log prob:", batch_log_prob)
print("Shape:", batch_log_prob.shape)  # (2,)
# Output: [log P(0.2), log P(-1.8)] - separate log probabilities

Multivariate log prob: tf.Tensor(-4.1388974, shape=(), dtype=float32)
Shape: ()
Batched log prob: tf.Tensor([-1.6389385 -2.499959 ], shape=(2,), dtype=float32)
Shape: (2,)


**Mathematical Insight:**
- **Multivariate**: `log P(x₁, x₂)` - joint probability
- **Batched**: `[log P(x₁), log P(x₂)]` - independent probabilities

## 5. **Batch Multivariate Distributions**

### 5.1 **Multiple Multivariate Distributions**

In [58]:
# Create batch of 3 different 2D multivariate normal distributions
batched_mv_normal = tfd.MultivariateNormalDiag(
    loc=[[-1., 0.5], [2., 0], [-0.5, 1.5]],        # 3 different mean vectors
    scale_diag=[[1., 1.5], [2., 0.5], [1., 1.]]     # 3 different scale vectors
)

print(batched_mv_normal)
# Output: <tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[3] event_shape=[2] dtype=float32>

print("Batch shape:", batched_mv_normal.batch_shape)  # [3] - 3 distributions
print("Event shape:", batched_mv_normal.event_shape)  # [2] - 2D vectors each

tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[3], event_shape=[2], dtype=float32)
Batch shape: (3,)
Event shape: (2,)


### 5.2 **Sampling from Batched Multivariate**

In [59]:
# Sample from batch of multivariate distributions
batched_mv_normal = tfd.MultivariateNormalDiag(
    loc=[[-1., 0.5], [2., 0], [-0.5, 1.5]],
    scale_diag=[[1., 1.5], [2., 0.5], [1., 1.]]
)

batch_mv_samples = batched_mv_normal.sample(2)
print("Batched multivariate samples shape:", batch_mv_samples.shape)  # (2, 3, 2)
print("Batched multivariate samples:\n", batch_mv_samples)

"""
Shape interpretation: (sample_shape, batch_shape, event_shape)
- 2 samples
- 3 different distributions  
- 2D vectors each

Output structure:
[[[sample1_dist1_x1, sample1_dist1_x2],   # Sample 1 from distribution 1
  [sample1_dist2_x1, sample1_dist2_x2],   # Sample 1 from distribution 2  
  [sample1_dist3_x1, sample1_dist3_x2]],  # Sample 1 from distribution 3
 
 [[sample2_dist1_x1, sample2_dist1_x2],   # Sample 2 from distribution 1
  [sample2_dist2_x1, sample2_dist2_x2],   # Sample 2 from distribution 2
  [sample2_dist3_x1, sample2_dist3_x2]]]  # Sample 2 from distribution 3
"""

Batched multivariate samples shape: (2, 3, 2)
Batched multivariate samples:
 tf.Tensor(
[[[-1.5590973  -1.8588896 ]
  [ 0.93055725  0.4027528 ]
  [ 1.8730333   0.66612303]]

 [[-0.69388777 -1.8304234 ]
  [ 6.532099    0.18987766]
  [-0.2143586   2.2664626 ]]], shape=(2, 3, 2), dtype=float32)


'\nShape interpretation: (sample_shape, batch_shape, event_shape)\n- 2 samples\n- 3 different distributions  \n- 2D vectors each\n\nOutput structure:\n[[[sample1_dist1_x1, sample1_dist1_x2],   # Sample 1 from distribution 1\n  [sample1_dist2_x1, sample1_dist2_x2],   # Sample 1 from distribution 2  \n  [sample1_dist3_x1, sample1_dist3_x2]],  # Sample 1 from distribution 3\n\n [[sample2_dist1_x1, sample2_dist1_x2],   # Sample 2 from distribution 1\n  [sample2_dist2_x1, sample2_dist2_x2],   # Sample 2 from distribution 2\n  [sample2_dist3_x1, sample2_dist3_x2]]]  # Sample 2 from distribution 3\n'

## 6. **Advanced Shape Analysis**

### 6.1 **Complex Batch-Event Shape Interactions**

In [60]:
# Create complex batched multivariate distribution
batched_mv_normal = tfd.MultivariateNormalDiag(
    loc=[[0.3, 0.8, 1.1], [2.3, -0.3, -1.]],           # 2 distributions
    scale_diag=[[1.5, 1., 0.4], [2.5, 1.5, 0.5]]       # 3D vectors each
)

print("Complex batch-multivariate:", batched_mv_normal)
print("Batch shape:", batched_mv_normal.batch_shape)  # [2]
print("Event shape:", batched_mv_normal.event_shape)  # [3]

# **Question: What is the shape of the Tensor returned by the following?**
result = batched_mv_normal.log_prob([0., -1., 1.])
print("Result shape:", result.shape)  # (2,)
print("Result values:", result)

"""
SHAPE ANALYSIS:
- Input: [0., -1., 1.] has shape (3,) - matches event_shape=[3] ✓
- batch_shape=[2] means we have 2 different distributions
- Each distribution computes log_prob for the same input vector
- Output shape: (2,) - one log probability per distribution

RESULT INTERPRETATION:
[log_prob_dist1, log_prob_dist2] where:
- log_prob_dist1: log P([0., -1., 1.]) under distribution 1
- log_prob_dist2: log P([0., -1., 1.]) under distribution 2
"""

Complex batch-multivariate: tfp.distributions.MultivariateNormalDiag("MultivariateNormalDiag", batch_shape=[2], event_shape=[3], dtype=float32)
Batch shape: (2,)
Event shape: (3,)
Result shape: (2,)
Result values: tf.Tensor([ -3.9172401 -11.917513 ], shape=(2,), dtype=float32)


'\nSHAPE ANALYSIS:\n- Input: [0., -1., 1.] has shape (3,) - matches event_shape=[3] ✓\n- batch_shape=[2] means we have 2 different distributions\n- Each distribution computes log_prob for the same input vector\n- Output shape: (2,) - one log probability per distribution\n\nRESULT INTERPRETATION:\n[log_prob_dist1, log_prob_dist2] where:\n- log_prob_dist1: log P([0., -1., 1.]) under distribution 1\n- log_prob_dist2: log P([0., -1., 1.]) under distribution 2\n'

### 6.2 **Shape Rules Summary**

| **Distribution Type** | **Batch Shape** | **Event Shape** | **Sample Shape** | **Final Output** |
|----------------------|----------------|-----------------|------------------|------------------|
| **Univariate** | `[]` | `[]` | `[n]` | `(n,)` |
| **Batched Univariate** | `[k]` | `[]` | `[n]` | `(n, k)` |
| **Multivariate** | `[]` | `[d]` | `[n]` | `(n, d)` |
| **Batched Multivariate** | `[k]` | `[d]` | `[n]` | `(n, k, d)` |

## 7. **Types of Multivariate Distributions**

### 7.1 **MultivariateNormal Family**

In [61]:
# === DIAGONAL COVARIANCE (Independent components) ===
mv_diag = tfd.MultivariateNormalDiag(
    loc=[0., 1.],
    scale_diag=[2., 0.5]  # Independent variances
)
print("Diagonal covariance shape:", mv_diag.covariance().shape)  

Diagonal covariance shape: (2, 2)


In [62]:
# === FULL COVARIANCE (Correlated components) ===
covariance_matrix = [[4., 1.2],   # σ₁²=4, σ₁σ₂ρ=1.2  
                     [1.2, 0.25]] # σ₂²=0.25
mv_full = tfd.MultivariateNormalFullCovariance(
    loc=[0., 1.],
    covariance_matrix=covariance_matrix
)
print("Full covariance:\n", mv_full.covariance())

Full covariance:
 tf.Tensor(
[[nan nan]
 [nan nan]], shape=(2, 2), dtype=float32)


In [63]:
# === TRIANGULAR FORM (Efficient parameterization) ===
scale_tril = [[2., 0.],     # Lower triangular Cholesky factor
              [0.6, 0.5]]   # Σ = L L^T
mv_tril = tfd.MultivariateNormalTriL(
    loc=[0., 1.],
    scale_tril=scale_tril
)
print("TriL covariance:\n", mv_tril.covariance())

TriL covariance:
 tf.Tensor(
[[4.   1.2 ]
 [1.2  0.61]], shape=(2, 2), dtype=float32)


### 7.2 **Other Multivariate Distributions**

In [64]:
# === DIRICHLET DISTRIBUTION (Probability Simplex) ===
# Perfect for categorical probabilities that sum to 1
dirichlet = tfd.Dirichlet(concentration=[1., 2., 3.])
dirichlet_sample = dirichlet.sample(3)
print("Dirichlet samples (sum to 1):\n", dirichlet_sample)
print("Row sums:", tf.reduce_sum(dirichlet_sample, axis=1))  # All ≈ 1.0

Dirichlet samples (sum to 1):
 tf.Tensor(
[[0.11105072 0.36103418 0.5279151 ]
 [0.1413823  0.20639223 0.65222555]
 [0.04453953 0.54176676 0.41369376]], shape=(3, 3), dtype=float32)
Row sums: tf.Tensor([1.        1.0000001 1.       ], shape=(3,), dtype=float32)


In [65]:
# === MULTIVARIATE STUDENT-T (Heavy-tailed alternative) ===
mv_student_t = tfd.MultivariateStudentTLinearOperator(
    df=3.,  # Degrees of freedom
    loc=[0., 0.],
    scale=tf.linalg.LinearOperatorIdentity(2)
)
print("Student-T samples:\n", mv_student_t.sample(3))

Student-T samples:
 tf.Tensor(
[[ 0.59606385 -1.8027201 ]
 [-3.6440275  -1.564675  ]
 [-0.29609033 -0.02172232]], shape=(3, 2), dtype=float32)


In [66]:
# === VON MISES-FISHER (Directional data on sphere) ===
# For modeling directions, angles, or unit vectors
von_mises_fisher = tfd.VonMisesFisher(
    mean_direction=[1., 0., 0.],  # 3D unit vector
    concentration=2.0
)
vmf_samples = von_mises_fisher.sample(3)
print("Unit vector samples:\n", vmf_samples)
print("Norms (should be 1):", tf.norm(vmf_samples, axis=1))

Unit vector samples:
 tf.Tensor(
[[-0.08670211 -0.27101254  0.95866317]
 [ 0.77641785  0.31677374  0.5448207 ]
 [ 0.58069456 -0.32908437  0.7446458 ]], shape=(3, 3), dtype=float32)
Norms (should be 1): tf.Tensor([1.         0.99999994 1.        ], shape=(3,), dtype=float32)


## 8. **Professional Patterns & Best Practices**

### 8.1 **Efficient Covariance Parameterization**

In [67]:
def create_stable_multivariate_normal(raw_loc, raw_scale_diag, raw_scale_tril=None):
    """
    Create numerically stable multivariate normal from raw parameters
    """
    loc = raw_loc  # Location parameter needs no transformation
    
    if raw_scale_tril is None:
        # Diagonal covariance case
        scale_diag = tf.nn.softplus(raw_scale_diag) + 1e-6
        return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
    else:
        # Full covariance case using triangular parameterization
        # Ensure positive diagonal elements
        diag_part = tf.nn.softplus(tf.linalg.diag_part(raw_scale_tril)) + 1e-6
        scale_tril = tf.linalg.set_diag(raw_scale_tril, diag_part)
        return tfd.MultivariateNormalTriL(loc=loc, scale_tril=scale_tril)

# Example usage in neural network
class MultivariateOutput(tf.keras.layers.Layer):
    def __init__(self, event_size, use_full_covariance=False):
        super().__init__()
        self.event_size = event_size
        self.use_full_covariance = use_full_covariance
        
        # Location parameters
        self.loc_layer = tf.keras.layers.Dense(event_size)
        
        if use_full_covariance:
            # Full covariance matrix (triangular parameterization)
            tril_size = event_size * (event_size + 1) // 2
            self.scale_layer = tf.keras.layers.Dense(tril_size)
        else:
            # Diagonal covariance
            self.scale_layer = tf.keras.layers.Dense(event_size)
    
    def call(self, inputs):
        raw_loc = self.loc_layer(inputs)
        raw_scale = self.scale_layer(inputs)
        
        if self.use_full_covariance:
            # Construct lower triangular matrix
            scale_tril = tfp.bijectors.FillScaleTriL()(raw_scale)
            return tfd.MultivariateNormalTriL(loc=raw_loc, scale_tril=scale_tril)
        else:
            scale_diag = tf.nn.softplus(raw_scale) + 1e-6
            return tfd.MultivariateNormalDiag(loc=raw_loc, scale_diag=scale_diag)

### 8.2 **Covariance Matrix Validation**

In [68]:
import tensorflow as tf

def validate_covariance_matrix(cov_matrix):
    """
    Validate that a matrix is a valid covariance matrix
    """
    # Convert input to TensorFlow tensor
    cov_matrix = tf.convert_to_tensor(cov_matrix, dtype=tf.float32)
    
    # Check symmetry
    is_symmetric = tf.reduce_all(
        tf.abs(cov_matrix - tf.transpose(cov_matrix)) < 1e-6
    )
    
    # Check positive definiteness via eigenvalues
    eigenvals = tf.linalg.eigvals(cov_matrix)
    # Eigenvalues might be complex, so take the real part
    is_positive_definite = tf.reduce_all(tf.math.real(eigenvals) > 1e-6)
    
    return is_symmetric, is_positive_definite


# Example validation
cov_matrix = [[4.0, 1.2], [1.2, 0.25]]
is_sym, is_pd = validate_covariance_matrix(cov_matrix)
print(f"Symmetric: {is_sym.numpy()}, Positive Definite: {is_pd.numpy()}")


Symmetric: True, Positive Definite: False


In [69]:
def validate_covariance_matrix_cholesky(cov_matrix):
    """
    Validate covariance matrix using Cholesky decomposition
    """
    cov_matrix = tf.convert_to_tensor(cov_matrix, dtype=tf.float32)
    
    # Check symmetry
    is_symmetric = tf.reduce_all(
        tf.abs(cov_matrix - tf.transpose(cov_matrix)) < 1e-6
    )
    
    # Check positive definiteness using Cholesky decomposition
    try:
        tf.linalg.cholesky(cov_matrix)
        is_positive_definite = tf.constant(True)
    except tf.errors.InvalidArgumentError:
        is_positive_definite = tf.constant(False)
    
    return is_symmetric, is_positive_definite


# Example
cov_matrix = [[4.0, 1.2], [1.2, 0.25]]
is_sym, is_pd = validate_covariance_matrix_cholesky(cov_matrix)
print(f"Symmetric: {is_sym.numpy()}, Positive Definite: {is_pd.numpy()}")


Symmetric: True, Positive Definite: True


### 8.3 **Loss Functions for Multivariate Outputs**

In [70]:
def multivariate_regression_loss(y_true, distribution):
    """
    Negative log-likelihood loss for multivariate regression
    """
    return -tf.reduce_mean(distribution.log_prob(y_true))

def multivariate_classification_loss(y_true_one_hot, dirichlet_dist):
    """
    Loss for categorical outcomes using Dirichlet distribution
    """
    return -tf.reduce_mean(dirichlet_dist.log_prob(y_true_one_hot))

# Example training step for multivariate regression
@tf.function
def multivariate_train_step(features, targets, model, optimizer):
    with tf.GradientTape() as tape:
        # Model outputs multivariate distribution
        pred_distribution = model(features)
        
        # Compute negative log-likelihood
        nll_loss = multivariate_regression_loss(targets, pred_distribution)
        
        # Add regularization
        reg_loss = sum(model.losses)
        total_loss = nll_loss + 0.01 * reg_loss
    
    gradients = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    return total_loss, nll_loss

## 9. **Expert Applications**

### 9.1 **Uncertainty Quantification for Multivariate Outputs**

In [71]:
def multivariate_uncertainty_prediction(model, x_test, n_samples=100):
    """
    Predict multivariate outputs with uncertainty quantification
    """
    predictions = []
    
    for _ in range(n_samples):
        dist = model(x_test, training=True)
        pred_sample = dist.sample()
        predictions.append(pred_sample)
    
    predictions = tf.stack(predictions)  # (n_samples, batch_size, event_size)
    
    # Compute statistics across samples
    mean_pred = tf.reduce_mean(predictions, axis=0)
    std_pred = tf.math.reduce_std(predictions, axis=0)
    
    # Compute covariance for each test point
    centered_preds = predictions - tf.expand_dims(mean_pred, 0)
    covariance_pred = tf.reduce_mean(
        tf.matmul(centered_preds[..., :, None], centered_preds[..., None, :]), 
        axis=0
    )
    
    return {
        'mean': mean_pred,
        'std': std_pred,
        'covariance': covariance_pred,
        'samples': predictions
    }

### 9.2 **Custom Multivariate Distribution**

In [72]:
class TruncatedMultivariateNormal:
    """
    Custom truncated multivariate normal distribution
    """
    def __init__(self, loc, covariance_matrix, low, high):
        self.base_dist = tfd.MultivariateNormalFullCovariance(
            loc=loc, covariance_matrix=covariance_matrix
        )
        self.low = tf.convert_to_tensor(low)
        self.high = tf.convert_to_tensor(high)
        
    def sample(self, sample_shape=()):
        """Rejection sampling for truncated multivariate normal"""
        # Simple clipping approach (not mathematically correct)
        samples = self.base_dist.sample(sample_shape)
        return tf.clip_by_value(samples, self.low, self.high)
    
    def log_prob(self, value):
        """Log probability with boundary conditions"""
        base_log_prob = self.base_dist.log_prob(value)
        
        # Check if all components are within bounds
        in_bounds = tf.logical_and(
            tf.reduce_all(value >= self.low, axis=-1),
            tf.reduce_all(value <= self.high, axis=-1)
        )
        
        return tf.where(in_bounds, base_log_prob, -np.inf)

### 9.3 **Correlation Analysis**

In [73]:
def analyze_correlation_structure(samples):
    """
    Analyze correlation structure of multivariate samples
    """
    # Compute sample correlation matrix
    centered_samples = samples - tf.reduce_mean(samples, axis=0)
    # Fix: Cast the divisor to float32 to match the tensor dtype
    divisor = tf.cast(tf.shape(samples)[0] - 1, tf.float32)
    covariance = tf.matmul(centered_samples, centered_samples, transpose_a=True) / divisor
    
    # Convert to correlation matrix
    std_devs = tf.sqrt(tf.linalg.diag_part(covariance))
    correlation = covariance / tf.matmul(std_devs[:, None], std_devs[None, :])
    
    return {
        'covariance': covariance,
        'correlation': correlation,
        'std_devs': std_devs
    }


# Example usage
mv_dist = tfd.MultivariateNormalFullCovariance(
    loc=[0., 0.], 
    covariance_matrix=[[1., 0.8], [0.8, 1.]]
)
samples = mv_dist.sample(1000)
analysis = analyze_correlation_structure(samples)
print("Sample correlation matrix:\n", analysis['correlation'])

Sample correlation matrix:
 tf.Tensor(
[[1.         0.79317343]
 [0.79317343 1.        ]], shape=(2, 2), dtype=float32)


Using TensorFlow's Built-in Functions

In [74]:
def analyze_correlation_structure_robust(samples):
    """
    Robust correlation analysis with comprehensive error handling
    """
    try:
        samples = tf.convert_to_tensor(samples, dtype=tf.float32)
        
        # Validate input shape
        if len(samples.shape) != 2:
            raise ValueError("Samples must be a 2D tensor (n_samples, n_features)")
        
        n_samples, n_features = tf.shape(samples)[0], tf.shape(samples)[1]
        
        if n_samples < 2:
            raise ValueError("Need at least 2 samples for correlation analysis")
        
        # Compute statistics
        n_samples_float = tf.cast(n_samples, tf.float32)
        mean_vals = tf.reduce_mean(samples, axis=0)
        centered_samples = samples - mean_vals
        
        # Covariance matrix
        covariance = tf.matmul(centered_samples, centered_samples, transpose_a=True) / (n_samples_float - 1)
        
        # Standard deviations
        variances = tf.linalg.diag_part(covariance)
        std_devs = tf.sqrt(tf.maximum(variances, 1e-8))  # Avoid sqrt of negative numbers
        
        # Correlation matrix
        outer_std = tf.matmul(std_devs[:, None], std_devs[None, :])
        correlation = covariance / tf.maximum(outer_std, 1e-8)  # Avoid division by zero
        
        return {
            'covariance': covariance,
            'correlation': correlation,
            'std_devs': std_devs,
            'mean': mean_vals,
            'n_samples': n_samples,
            'n_features': n_features
        }
        
    except Exception as e:
        print(f"Error in correlation analysis: {e}")
        return None


# Example usage
mv_dist = tfd.MultivariateNormalFullCovariance(
    loc=[0., 0.], 
    covariance_matrix=[[1., 0.8], [0.8, 1.]]
)
samples = mv_dist.sample(1000)
analysis = analyze_correlation_structure_robust(samples)

if analysis:
    print("Sample correlation matrix:")
    print(analysis['correlation'].numpy())
    print("\nCovariance matrix:")
    print(analysis['covariance'].numpy())
    print(f"\nNumber of samples: {analysis['n_samples'].numpy()}")

Sample correlation matrix:
[[1.        0.8004023]
 [0.8004023 1.0000001]]

Covariance matrix:
[[1.0191175 0.828036 ]
 [0.828036  1.0501652]]

Number of samples: 1000


## 10. **Complete Reference Guide**

### 10.1 **Multivariate Distribution Constructor Cheat Sheet**

```python
# === MULTIVARIATE NORMAL FAMILY ===
# Diagonal covariance (independent components)
tfd.MultivariateNormalDiag(loc=[0., 1.], scale_diag=[1., 2.])

# Full covariance (correlated components) 
tfd.MultivariateNormalFullCovariance(loc=[0., 1.], covariance_matrix=[[1., 0.5], [0.5, 2.]])

# Triangular parameterization (efficient)
tfd.MultivariateNormalTriL(loc=[0., 1.], scale_tril=[[1., 0.], [0.5, 1.4]])

# Linear operator (advanced)
tfd.MultivariateNormalLinearOperator(loc=[0., 1.], scale=tfp.bijectors.ScaleLinearOperator(...))

# === OTHER MULTIVARIATE DISTRIBUTIONS ===
# Probability simplex
tfd.Dirichlet(concentration=[1., 2., 3.])

# Heavy-tailed multivariate
tfd.MultivariateStudentTLinearOperator(df=3., loc=[0., 0.], scale=...)

# Directional data (unit sphere)
tfd.VonMisesFisher(mean_direction=[1., 0., 0.], concentration=2.)

# Matrix-valued distributions  
tfd.MatrixNormalLinearOperator(loc=matrix_loc, scale_row=..., scale_column=...)

# Correlation matrices
tfd.LKJ(dimension=3, concentration=2.)
```

### 10.2 **Essential Multivariate Operations Reference**

```python
# Creating multivariate distribution
mv_dist = tfd.MultivariateNormalDiag(loc=[0., 1.], scale_diag=[1., 2.])

# === SAMPLING ===
single_vector = mv_dist.sample()                    # Shape: [2]
batch_vectors = mv_dist.sample(100)                 # Shape: [100, 2]
seeded_sample = mv_dist.sample(100, seed=42)        # Reproducible

# === PROBABILITY EVALUATION ===
joint_prob = mv_dist.prob([0.5, 1.5])             # Joint PDF value
joint_log_prob = mv_dist.log_prob([0.5, 1.5])     # Joint log PDF (preferred)
batch_log_prob = mv_dist.log_prob([[0., 1.], [1., 0.]])  # Batch evaluation

# === STATISTICS ===
mean_vector = mv_dist.mean()                        # E[X] - vector
cov_matrix = mv_dist.covariance()                  # Cov(X) - matrix
var_vector = mv_dist.variance()                    # Var(X) - diagonal of covariance
std_vector = mv_dist.stddev()                      # σ(X) - element-wise std dev

# === SPECIAL OPERATIONS ===
marginal_dist = mv_dist.marginal([0])              # Marginal distribution of X₁
entropy_val = mv_dist.entropy()                    # Differential entropy
```

## **Final Notes**

- **The Multivariate Mindset**: While univariate distributions model individual uncertainties, multivariate distributions model **joint uncertainties and relationships**—capturing how variables covary and influence each other.

- **Start with Diagonal**: Begin with `MultivariateNormalDiag`. It's simpler and often sufficient. Move to full covariance only when you need to model correlations.

- **Shape is Everything**: Master the `(sample_shape, batch_shape, event_shape)` hierarchy. This understanding transfers to all advanced TFP concepts.

- **Correlation vs Independence**: The key power of multivariate distributions is modeling dependence. Use them when variables are naturally related.

- **Numerical Stability**: Always use triangular parameterization (`MultivariateNormalTriL`) for learnable full covariance matrices.