# PyPPL
An Python-embedded Probabilistic Programming Language

In [1]:
import pyppl

> Joshua Duquette, Christopher Liu, Rishi Upadhyay

# Choose your inference style

## 1. Exact Inference (e.g., Dice)
- Statically analyze the execution graph and compute the distribution.
- Fast inference for discrete programs.

In [2]:
# Dice N-Queens example from GitHub
N = 5
md = @dice begin
    row_to_col = [
        uniform(DistInt32, 1, 1 + N)
        for _ in 1:N
    ]

    for (i1, col1) in enumerate(row_to_col)
        for (i2, col2) in enumerate(row_to_col)
            i1 == i2 && continue
            # Convert rows to Dist
            row1, row2 = DistInt32(i1), DistInt32(i2)
            # Check column conflicts
            observe(!prob_equals(col1, col2))
            # Check conflicts along one diagonal
            observe(!prob_equals(row1 + col1, row2 + col2))
            # Check conflicts along other diagonal
            observe(!prob_equals(row1 - col1, row2 - col2))
        end
    end    
    row_to_col
end

SyntaxError: invalid syntax (4138560784.py, line 3)

## 2. Sampling Inference (e.g., WebPPL)
- Infer distribution via sampling techniques (e.g., rejection sampling, MCMC).
- More expressive $\rightarrow$ more program coverage.

In [3]:
# WebPPL example
var binomial = function() {
  var a = sample(Bernoulli({ p: 0.5 }))
  var b = sample(Bernoulli({ p: 0.5 }))
  var c = sample(Bernoulli({ p: 0.5 }))
  return a + b + c
}

var binomialDist = Infer({ model: binomial })

SyntaxError: invalid syntax (2426627952.py, line 2)

# Why not both?
Give the user the option to choose between exact inference and sampling inference.

## Design Goals
1. **Flexibility/Coverage:** user is can choose what inference techniques to use.
2. **Readability:** retain as much of Python's existing syntax as possible.

### Example: Keeping Python's Syntax

In [4]:
# Note: return_types is temporary until we implement a type analysis pass.
@pyppl.compile(return_types=pyppl.Flip)
def first_head_or_both_tails(
    prob_first: float,
    prob_second: float
):
    first = pyppl.Flip(prob_first)
    if first:
        return True
    else:
        second = pyppl.Flip(prob_second)
        return not second

In [5]:
with pyppl.RejectionSampling():
    success = first_head_or_both_tails(0.5, 0.5)

print(success)

Flip(DiscreteDistribution({
  False: 0.241
  True : 0.759
}))


In [6]:
# Note: Not implemented yet.
with pyppl.ExactInference():
    success = first_head_or_both_tails(0.5, 0.5)

print(success)

NotImplementedError: ExactInference is unsupported.

### Example: Partial Loop Support

In [7]:
@pyppl.compile(return_types=pyppl.Flip)
def up_to_n_heads_in_a_row(n):
    # Random number of flips from [1, n]
    random_n = pyppl.Integer(pyppl.UniformDistribution(1, n))
    
    n_heads_in_a_row = True
    for _ in range(int(random_n)):
        n_heads_in_a_row &= pyppl.Flip()
    return n_heads_in_a_row

In [8]:
with pyppl.RejectionSampling():
    success = up_to_n_heads_in_a_row(5)
print(success)

Flip(DiscreteDistribution({
  False: 0.810
  True : 0.190
}))


## Hybrid Inference
Resolve each context frame individually and aggregate distribution results in return.

In [9]:
# Ideal goa
@pyppl.compile(return_types=pyppl.Flip)
def n_plus_one_heads_in_a_row(n):
    # Unbound inference context.
    # User can choose what inference technique to use here.
    first_heads = pyppl.Flip()
    # Unbound inference context end.
    
    # Always evaluate this code with rejection sampling.
    with pyppl.RejectionSampling():
        n_heads_in_a_row = True
        for _ in range(n):
            n_heads_in_a_row &= pyppl.Flip()
    # Rejection sampling context end.
        
    return first_heads and n_heads_in_a_row

In [10]:
# Ideal goal
with pyppl.ExactInference():
    success = n_plus_one_heads_in_a_row(5)
    
with pyppl.RejectionSampling():
    success = n_plus_one_heads_in_a_row(5)

NotImplementedError: ExactInference is unsupported.

# pyppl.compile
A JIT decorator that will
1. **For sampling inference:** perform the appropriate code transformations for `pyppl.observe`.
2. **For exact inference:** compute the binary decision diagram (BDD) via control flow analysis (via `pycfg`).
3. **Return type analysis:** determine the return type of the function (e.g., `pyppl.Flip`, `Tuple[...]`).

Workflow:
1. Retrieves the function source code with the built-in `inspect` library.
2. Parses the source with the built-in `ast` library.
3. Performs the appropriate analysis and transformations.
4. Creates a wrapper that'll return a context-dependent result.

## Supporting `pyppl.observe`
- We wanted `pyppl.observe` to resemble other probabilistic programming DSLs.

In [11]:
@pyppl.compile(return_types=pyppl.Flip)
def always_heads():
    f = pyppl.Flip()
    pyppl.observe(f)
    if f:
        return True
    else:
        return False

In [None]:
with pyppl.RejectionSampling():
    coin_flip = always_heads()
print(coin_flip)

`pyppl.compile` the `pyppl.observe` statement in `always_heads` into

In [None]:
def always_heads_transformed():
    f = pyppl.Flip()
    
    if not pyppl.observe(f):
        return pyppl.NotObservable
    
    ...

The sampler is able to detect when a sample is not observable.

## Supporting Inference Contexts
Replaces the compiled function with a runtime context check.

In [None]:
# Pseudo-code
def always_heads_compiled():
    if cur_context is SamplingInference:
        return sample(always_heads_transformed)
    elif cur_context is ExactInference:
        return exact_inference(always_heads_transformed)
    ...

# Probabilistic Primitives
We plan to support
1. `bool` for both sampling and exact inference.
2. `int`, `float` for sampling inference (possibly exact inference for select cases).

## Booleans
- `pyppl.ProbBool` (alias `pyppl.Flip`).

In [None]:
f = pyppl.Flip(prob_true=0.7)
print(f)

## Numerics
The number types we have are
1. `pyppl.Integer`
2. `pyppl.Real`

We have the following distributions
1. `pyppl.UniformDistribution`
2. `pyppl.GaussianDistribution`
3. `pyppl.DiscreteDistribution`

We initialize each number type with some distribution.

In [None]:
# range [0, 5]
rand_dist = pyppl.UniformDistribution(0, 5)
rand_int = pyppl.Integer(rand_dist)
print(rand_int)

In [None]:
rand_dist = pyppl.GaussianDistribution(mu=0, sigma=5)
rand_real = pyppl.Real(rand_dist)
print(rand_real)

For sampling inference, we also support common operations.

## Example: Rolling two dice

In [None]:
@pyppl.compile(return_types=pyppl.Integer)
def two_dice_rolls():
    six_sided_dice = pyppl.UniformDistribution(1, 6)
    roll1 = pyppl.Integer(six_sided_dice)
    roll2 = pyppl.Integer(six_sided_dice)
    return roll1 + roll2

In [None]:
with pyppl.RejectionSampling():
    my_roll = two_dice_rolls()
print(my_roll)

We can combine everything to express more complicated scenarios.

### Example: Rolling two dice BUT I'M LUCKY :)

In [None]:
@pyppl.compile(return_types=pyppl.Integer)
def two_dice_rolls_lucky():
    six_sided_dice = pyppl.UniformDistribution(1, 6)
    roll1 = pyppl.Integer(six_sided_dice)
    roll2 = pyppl.Integer(six_sided_dice)
    pyppl.observe(roll1 >= 3 and roll2 >= 3)
    return roll1 + roll2

In [None]:
with pyppl.RejectionSampling():
    lucky_roll = two_dice_rolls_lucky()
print(lucky_roll)

### Example: Rolling an `n`-sided dice `n` times

In [None]:
@pyppl.compile(return_types=pyppl.Integer)
def n_rolls_of_a_n_sided_dice(*, num_rolls, num_sides):
    dice = pyppl.UniformDistribution(1, num_sides)
    
    roll_sum = 0
    for _ in range(num_rolls):
        roll_sum += pyppl.Integer(dice)
        
    return roll_sum

In [None]:
with pyppl.RejectionSampling():
    my_roll = n_rolls_of_a_n_sided_dice(num_rolls=3, num_sides=5)
print(my_roll)