# Shape Semantics in JAX and NumPyro

Asked claude to help me better understand how JAX and numpyro handle the problems described in [Reasoning about Shapes and Probability Distributions](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/) where tensors by themselves can be overloaded. 

## Core Concepts

### Three Types of Shapes
1. **Event Shape**
   - Shape of a single draw from the distribution
   - Example: `()` for univariate normal, `(3,)` for 3D multivariate normal

2. **Batch Shape**
   - Shape of independent distributions
   - Determined by broadcasting of distribution parameters
   - Example: `(3,)` for batch of 3 independent normals

3. **Sample Shape**
   - Shape of multiple samples drawn from distribution(s)
   - Specified at sampling time
   - Example: `(1000,)` for 1000 samples

## JAX/NumPyro Implementation

### Distribution Properties
```python
normal = dist.Normal(loc=0., scale=1.)
normal.event_shape    # ()
normal.batch_shape    # ()

batch_normal = dist.Normal(loc=jnp.array([0., 1., 2.]), scale=1.)
batch_normal.event_shape    # ()
batch_normal.batch_shape    # (3,)
```

### Key Features

1. **Explicit Shape Attributes**
   - Every distribution has `event_shape` and `batch_shape`
   - Shapes are known at creation time
   - Sample shapes specified during sampling

2. **Shape Broadcasting**
   - Follows JAX's broadcasting rules
   - Parameter broadcasting determines batch shapes
   - Automatic vectorization through JAX transforms

3. **Log Probability Handling**
   - Respects batch vs event semantics
   - Batch shapes → independent log probs
   - Event shapes → joint log probs

4. **Shape Transformation**
   - `Independent` distribution wrapper
   - Converts batch dims to event dims
   - Useful for grouping independent components

## Common Patterns

### Basic Distribution Creation
```python
# Single distribution
normal = dist.Normal(0., 1.)

# Batch of distributions
batch_normal = dist.Normal(jnp.array([0., 1.]), 1.)

# Multivariate distribution
mvn = dist.MultivariateNormal(
    loc=jnp.zeros(3),
    covariance_matrix=jnp.eye(3)
)
```

### Sampling
```python
# Single sample
x = normal.sample(key)                  # shape: ()

# Multiple samples
x = normal.sample(key, sample_shape=(5,))  # shape: (5,)

# Batch sampling
x = batch_normal.sample(key)            # shape: (2,)
```

### Shape Transformation
```python
# Convert batch to event
independent = dist.Independent(batch_normal, 1)
independent.event_shape     # (2,)
independent.batch_shape     # ()
```

## Best Practices

1. **Shape Checking**
   - Verify shapes early in development
   - Use `assert` statements for shape validation
   - Print shapes when debugging

2. **Broadcasting**
   - Be explicit about intended shapes
   - Use `jnp.expand_dims` when needed
   - Remember broadcasting happens right-to-left

3. **Performance**
   - Use vectorized operations when possible
   - Leverage JAX's automatic vectorization
   - Consider batch shapes for parallel computation

4. **Common Gotchas**
   - Mixing batch and event dimensions
   - Incorrect broadcasting assumptions
   - Forgetting to specify sample shapes

Code:

In [4]:

import jax
import jax.numpy as jnp
import numpyro.distributions as dist

# NumPyro follows the same shape semantics as TensorFlow Probability
# Event shape: shape of a single draw
# Batch shape: shape of independent distributions
# Sample shape: shape of multiple samples

# Example 1: Single univariate normal
normal = dist.Normal(loc=0., scale=1.)
print("Univariate Normal:")
print(f"event_shape: {normal.event_shape}")  # ()
print(f"batch_shape: {normal.batch_shape}")  # ()

# Example 2: Batch of univariate normals
batch_normal = dist.Normal(
    loc=jnp.array([0., 1., 2.]),  # 3 different means
    scale=jnp.array([1., 0.5, 2.])  # 3 different scales
)
print("\nBatch of Univariate Normals:")
print(f"event_shape: {batch_normal.event_shape}")  # ()
print(f"batch_shape: {batch_normal.batch_shape}")  # (3,)

# Example 3: Single multivariate normal
mvn = dist.MultivariateNormal(
    loc=jnp.zeros(3),
    covariance_matrix=jnp.eye(3)
)
print("\nMultivariate Normal:")
print(f"event_shape: {mvn.event_shape}")  # (3,)
print(f"batch_shape: {mvn.batch_shape}")  # ()

# Example 4: Batch of multivariate normals
batch_mvn = dist.MultivariateNormal(
    loc=jnp.array([[0., 0.], [1., 1.]]),  # 2 different 2D means
    covariance_matrix=jnp.array([jnp.eye(2), jnp.eye(2)])  # 2 different covariance matrices
)
print("\nBatch of Multivariate Normals:")
print(f"event_shape: {batch_mvn.event_shape}")  # (2,)
print(f"batch_shape: {batch_mvn.batch_shape}")  # (2,)

# Sampling demonstrates how these shapes interact
sample_normal = normal.sample(key=jax.random.PRNGKey(0), sample_shape=(4,))
print("\nShape of 4 samples from univariate normal:")
print(f"sample shape: {sample_normal.shape}")  # (4,)

sample_batch = batch_normal.sample(key=jax.random.PRNGKey(0), sample_shape=(4,))
print("\nShape of 4 samples from batch of normals:")
print(f"sample shape: {sample_batch.shape}")  # (4, 3)

# Log probability handling
x = jnp.array([0., 1., 2.])
log_prob_batch = batch_normal.log_prob(x)
print("\nLog prob shape for batch evaluation:")
print(f"log_prob shape: {log_prob_batch.shape}")  # (3,)

# Independent distribution wrapper for changing event shapes
independent_normal = dist.Independent(batch_normal, reinterpreted_batch_ndims=1)
print("\nIndependent distribution (reinterpreted batch):")
print(f"event_shape: {independent_normal.event_shape}")  # (3,)
print(f"batch_shape: {independent_normal.batch_shape}")  # ()

Univariate Normal:
event_shape: ()
batch_shape: ()

Batch of Univariate Normals:
event_shape: ()
batch_shape: (3,)

Multivariate Normal:
event_shape: (3,)
batch_shape: ()

Batch of Multivariate Normals:
event_shape: (2,)
batch_shape: (2,)

Shape of 4 samples from univariate normal:
sample shape: (4,)

Shape of 4 samples from batch of normals:
sample shape: (4, 3)

Log prob shape for batch evaluation:
log_prob shape: (3,)

Independent distribution (reinterpreted batch):
event_shape: (3,)
batch_shape: ()


## Changing Distribution shapes and effect handlers

The key concepts to remember are:

* `to_event(n)` converts batch dimensions to event dimensions from right to left
* `expand(shape)` adds batch dimensions through broadcasting
* plate contexts create batch dimensions for vectorized sampling
* The order of transformations matters
* Always validate shapes explicitly when debugging


# Distribution Shape Transformers and Effect Handlers

## Core Shape Transformers

### 1. `to_event(n)`
Converts the rightmost `n` batch dimensions into event dimensions.

```python
import numpyro.distributions as dist
import jax.numpy as jnp

# Create a batch of 2x3 normal distributions
batch_normal = dist.Normal(
    loc=jnp.zeros((2, 3)),
    scale=jnp.ones((2, 3))
)
print(f"Original - batch: {batch_normal.batch_shape}, event: {batch_normal.event_shape}")
# Original - batch: (2, 3), event: ()

# Convert both batch dimensions to event dimensions
transformed = batch_normal.to_event(2)
print(f"Transformed - batch: {transformed.batch_shape}, event: {transformed.event_shape}")
# Transformed - batch: (), event: (2, 3)
```

### 2. `expand(batch_shape)`
Adds batch dimensions through broadcasting.

```python
normal = dist.Normal(0., 1.)
# Add batch dimensions of size 2 and 3
expanded = normal.expand([2, 3])
print(f"Expanded - batch: {expanded.batch_shape}, event: {expanded.event_shape}")
# Expanded - batch: (2, 3), event: ()
```

### 3. `Independent`
Similar to `to_event()` but as a distribution wrapper.

```python
# Create a batch of independent normals
batch_normal = dist.Normal(jnp.zeros(3), 1.)
independent = dist.Independent(batch_normal, 1)
print(f"Independent - batch: {independent.batch_shape}, event: {independent.event_shape}")
# Independent - batch: (), event: (3,)
```

## Common Use Cases

### 1. Creating Multivariate Distributions

```python
# Method 1: Using to_event()
mv_normal_1 = dist.Normal(jnp.zeros(3), 1.).to_event(1)

# Method 2: Using Independent
mv_normal_2 = dist.Independent(dist.Normal(jnp.zeros(3), 1.), 1)

# Both create distributions with:
# batch_shape: ()
# event_shape: (3,)
```

### 2. Hierarchical Models

```python
def hierarchical_model():
    # Batch of 10 group-level parameters
    group_means = dist.Normal(0., 1.).expand([10])
    
    # 5 observations per group (10x5)
    with numpyro.plate('groups', 10):
        with numpyro.plate('obs', 5):
            numpyro.sample('y', dist.Normal(group_means, 1.))
```

### 3. Matrix-Variate Distributions

```python
# Creating a matrix-variate normal (2x3 matrix)
matrix_normal = dist.Normal(
    loc=jnp.zeros((2, 3)),
    scale=1.
).to_event(2)

# batch_shape: ()
# event_shape: (2, 3)
```

## Effect Handlers

### 1. `plate`
Manages batch dimensions for vectorized sampling.

```python
def plate_example():
    with numpyro.plate('batch', 10):
        # Samples 10 independent values
        x = numpyro.sample('x', dist.Normal(0., 1.))
        # x.shape is (10,)
```

### 2. `vmap`
JAX's vectorization transformer.

```python
from jax import vmap

def single_sample():
    return dist.Normal(0., 1.).sample(jax.random.PRNGKey(0))

# Vectorize the sampling
vectorized_sample = vmap(single_sample, in_axes=(None,))
```

## Shape Transformation Rules

1. **Right-to-Left Processing**
   - Shape transformations process dimensions from right to left
   - Important for `to_event()` and `Independent`

2. **Broadcasting Rules**
   - Follow standard NumPy broadcasting rules
   - Smaller shapes broadcast to larger shapes
   - Single dimensions broadcast to any size

3. **Effect Handler Interactions**
   - `plate` contexts create batch dimensions
   - `to_event()` can convert plate dimensions to event dimensions
   - Multiple plates stack from outside to inside

## Best Practices

1. **Explicit Shape Transformations**
```python
# Prefer explicit transformations
normal = dist.Normal(0., 1.).expand([2, 3]).to_event(2)

# Over implicit reshaping
normal = dist.Normal(jnp.zeros((2, 3)), 1.).to_event(2)
```

2. **Shape Validation**
```python
def validate_shapes(dist):
    assert len(dist.batch_shape) == 2, "Expected 2 batch dimensions"
    assert len(dist.event_shape) == 1, "Expected 1 event dimension"
```

3. **Document Shape Transformations**
```python
def model():
    # Shape: batch=(10,), event=()
    x = numpyro.sample('x', dist.Normal(0., 1.).expand([10]))
    
    # Transform to: batch=(), event=(10,)
    x_transformed = x.to_event(1)
```

## Common Pitfalls

1. **Unintended Broadcasting**
```python
# This might not do what you expect
dist.Normal(jnp.zeros(3), jnp.ones(2))  # Broadcasting creates 3x2 batch

# Be explicit about intended shapes
dist.Normal(jnp.zeros((3, 1)), jnp.ones((1, 2)))  # Clear 3x2 batch
```

2. **Order of Operations**
```python
# These are different:
dist.Normal(0., 1.).expand([2, 3]).to_event(2)
dist.Normal(0., 1.).to_event(1).expand([2, 3])
```

3. **Missing Shape Transformations**
```python
# This might fail:
log_prob = dist.Normal(0., 1.).log_prob(jnp.zeros(3))

# Need to match shapes:
log_prob = dist.Normal(0., 1.).expand([3]).log_prob(jnp.zeros(3))
```

# Concrete Example

What would happen if you drew one sample from a distribution of shape (2,3)

In [8]:
import jax.numpy as jnp
import numpyro.distributions as dist
import jax.random as random

# Create the batch normal distribution
batch_normal = dist.Normal(
    loc=jnp.zeros((2, 3)),    # shape: (2, 3)
    scale=jnp.ones((2, 3))    # shape: (2, 3)
)

# Distribution properties
print(f"Batch shape: {batch_normal.batch_shape}")    # (2, 3)
print(f"Event shape: {batch_normal.event_shape}")    # ()

# Draw one sample
key = random.PRNGKey(0)
sample = batch_normal.sample(key)
print(f"Sample shape: {sample.shape}")    # (2, 3)

Batch shape: (2, 3)
Event shape: ()
Sample shape: (2, 3)


When you draw 1 sample, the result will have shape (2, 3). Here's why:

1. The final shape is: `sample_shape + batch_shape + event_shape`
2. In this case:

    * `sample_shape = ()` (drawing 1 sample)
    * `batch_shape = (2, 3)` (from your parameters)
    * `event_shape = ()` (it's a univariate normal)


So: `() + (2, 3) + () = (2, 3)`

If you drew multiple samples, say 4, it would look like:

In [9]:
samples = batch_normal.sample(key, sample_shape=(4,))
print(f"Multiple samples shape: {samples.shape}")    # (4, 2, 3)

Multiple samples shape: (4, 2, 3)


Each element in the resulting array represents a draw from the corresponding normal distribution in your batch. So in this case:

* You have 2 × 3 = 6 independent normal distributions
* Each sample gives you one value from each of these distributions
* The values at `sample[i,j]` are drawn from `Normal(loc=zeros[i,j], scale=ones[i,j])`