# Introduction to `Gen.jl`

By: John Muchovej

<small>Though this notebook is heavily inspired by the "Introduction to Modeling in Gen" from Gen's tutorial series</small>

In [None]:
# Today's imports
using Gen
using Plots
plotly()

push!(LOAD_PATH, pwd())
# `Intro` is the `Intro.jl` file colocated with these notebooks. We abstract rendering
#   code away into this file to avoid getting sidetracked with lots of rendering code.
import Intro
;

`Gen` programs are a mix of probabilistic models built using the `Gen` modeling language
along with inference programs written in regular Julia code.

## Probabilistic models as generative functions

`Gen` represents probabilitics models as "generative functions". We can do this
by using Gen's [modeling DSL][modeling]. **This boils down to using writing
typical Julia code, then macroing (annotating) particular sections with Gen's
macros.**

We denote generative functions by prefixing the `@gen` macro to a regular
function definition. This function represents the data-generation process we're
modeling. Randoms choices can be thought of as random variables within a model.

The function below is a probabilistic model of a linear relationship in the 2D
coordinate space $(x, y)$. Given $x$ coordinates, this model randomly chooses a
line in the plane and generates corresponding $y$ values for that each $(x, y)$
is near the line.

Real-world analogies would be modeling the volume of a gas from it's observed
temperature or modeling housing prices as a function of square footage.

[modeling]: https://www.gen.dev/dev/ref/modeling/

In [None]:
@gen function linear_model(xs::Vector{Float64})
    n = length(xs)
    
    # We begin by sampling a slope and an intercept for the line. Before we have seen
    #   the data, we don't know the values of these parameters, so we treat them as
    #   random choices. The distributions they are drawn from represent our prior
    #   beliefs about the parameters: in this case, that neither the slope nor the
    #   intercept will be more than a couple of points away from 0.
    slope = @trace(normal(0, 1), :slope)
    intercept = @trace(normal(0, 2), :intercept)
    
    # Given the slope and intercept, we can sample $y$ coordinates for each of the $x$
    #   coordinatoes in our input vector.
    for (idx, x) in enumerate(xs)
        @trace(normal(slope * x + intercept, 0.1), (:y, idx))
    end
    
    # When using Gen, the return value of a model is often unimportant.
    # Here, though, we simply return `n` – the number of points.
    return n
end
;

Our function `linear_model` consumes a `Vector` of $x$-coordinates, so let's
create one below.

In [None]:
xs = collect(-5:1:5) .|> Float64
;

Given `xs`, `linear_model` samples a `:slope` from a
$\mathcal{N}(\mu=0, \sigma=1)$ (a Normal distribution with mean, $\mu=0$, and
standard deviation, $\sigma=1$) and an `:intercept` from a
$\mathcal{N}(\mu=0, \sigma=2)$.

These distributions from which we draw `:slope` and `:intercept` are our
_prior distributions_ over the `:slope` and `:intercept`, respectively.

We then sample $y$-coordinates from each of the $x$ coordinates with $\sigma=0.1$.

Finally, we return the `length(xs)` (or, the number of data points).

In [None]:
n = linear_model(xs)
display(n)

However, `n` isn't all that interesting, the values of random choices within
`linear_model` are quite interesting, though. You'll notice the `@trace` macros
attached to each of `:slope`, `:intercept`, and `(:y, idx)`. These are **addresses**
and **`Gen` requires** that they be **unique**.

Addresses can be any valid Julia value, however we will be using `Symbol`s
exclusively, today. In `linear_model`, you'll notice we use one of `Symbol` or
`(Symbol, Integer)`. Though we `@trace` in a `for` loop, because we increment
`idx`, we maintain unique addresses.

`@gen` models actually track each `@trace` (_execution trace_), so we may access
their values even though we don't explicitly return them.

We can run `linear_model` with [`Gen.simulate`][simulate] to obtain it's trace.

[simulate]: https://www.gen.dev/dev/ref/gfi/#Gen.simulate

In [None]:
trace = Gen.simulate(linear_model, (xs, ))
;

`Gen.simulate` takes a function to be executed and a tuple of positional
arguments to the function. Then it returns a `trace`, which is a rather complex
data structure (see below).

In [None]:
display(trace)

An execution trace tracks many components of a particular run of our function.
For example, we can retrieve the arguments passed to the function using
[`Gen.get_args`][get_args].

[get_args]: https://www.gen.dev/dev/ref/gfi/#Gen.get_args

In [None]:
Gen.get_args(trace)

We can also retrieve the random choices made. These values are stored in a
`Gen.choicemap` (more on this later). You can use [`Gen.get_choices`][get_choices]
to inspect the random choices made.

[get_choices]: https://www.gen.dev/dev/ref/gfi/#Gen.get_choices

In [None]:
Gen.get_choices(trace)

We can also access individual choices much like a `Dictionary`.

In [None]:
choices = Gen.get_choices(trace)
choices[:slope]

However, each execution trace has appropriate functions to allow retrieving
choices directly from the `trace` without needing `Gen.get_choices`.

In [None]:
trace[:slope]

Additionally, we can inspect the return value of a trace using
[`Gen.get_retval`][get_retval].

[get_retval]: https://www.gen.dev/dev/ref/gfi/#Gen.get_retval

In [None]:
Gen.get_retval(trace)

### Aside: Visualizing outputs

It's often useful to visualize traces to better understand the behavior of our
generative functions. `render_line` (below) uses `Plots` to render a trace of
our `linear_model` with options to show $(x, y)$ data points alongside the line.

In [None]:
function render_line(trace, title; show_points=true, overlay=false)
    # Get `xs` from the trace
    (trace_xs, ) = Gen.get_args(trace)
    
    slope = trace[:slope]
    intercept = trace[:intercept]
    endx = [minimum(xs), maximum(xs)]
    endy = slope * endx .+ intercept
    
    info = Intro.PlotMetadata(
        title, trace, endx, endy;
        show_points=show_points, overlay=overlay
    )
    
    return Intro.plot_data(info)
end
;

In [None]:
render_line(trace, "linear")

However, we should also have a way of inspecting many runs to better understand
a given function's behavior. We'll use `render_grid` to visualize many runs.

In [None]:
traces = [Gen.simulate(linear_model, (xs, )) for _=1:12]
Intro.render_grid(render_line, traces; title="line")

### Exercise: Using the same `@trace` address twice

In [None]:
@gen function doublescale(xs::Vector)
    scale = @trace(normal(0, 1), :scale)
    xs = xs .* scale
    scale = @trace(normal(0, 1), :scale)
    return xs .* scale
end;

In [None]:
Gen.simulate(doublescale, (xs, ))

### Exercise: "Probabilistic" Sine Wave

Write a model to generate a sine wave, but with random phase, period, and
amplitude. Then, generate $y$-coordinates from a given vector of $x$-coordinates
by adding noise to the value of the wave at each $x$-coordinate.

Use the following priors (cf. [`Gen.gamma`][gamma] and [`Gen.uniform`][unif]):
- `gamma(5, 1)` for the `period`
- `gamma(1, 1)` for the `amplitude`
- `uniform(0, 2π)` for the `phase`

Also, write a function that renders the `trace` by showing the data points and
the computed sine wave. Then, visualize a grid of traces and discuss the
distribution. _Try tweaking the parameters of each our priors to see how the 
behavior changes._

[gamma]: https://www.gen.dev/dev/ref/distributions/#Gen.gamma
[gamma]: https://www.gen.dev/dev/ref/distributions/#Gen.uniform

In [None]:
@gen function sine_model(xs::Vector{Float64})::Number
    n = length(xs)
    
    phase ~ uniform(0, 2π)
    period ~ gamma(5, 1)
    amplitude ~ gamma(1, 1)
    
    for (idx, x) in enumerate(xs)
        μ = amplitude * sin(2π * x / period + phase)
        @trace(normal(μ, 0.1), (:y, idx))
    end
    
    return n
end
;

In [None]:
function render_sine(trace, title; show_points=true, overlay=true)
    xs = Gen.get_args(trace)[1]
    xlims = [minimum(xs), maximum(xs)]
    
    phase = trace[:phase]
    period = trace[:period]
    amplitude = trace[:amplitude]
    
    plot_xs = collect(range(xlims...; length=1_000))
    plot_ys = amplitude * sin.(2π * plot_xs / period .+ phase)
    
    info = Intro.PlotMetadata(
        title, trace, plot_xs, plot_ys;
        show_points=show_points, overlay=overlay
    )
    return Intro.plot_data(info)
end
;

In [None]:
traces = [Gen.simulate(sine_model, (xs, )) for _=1:12]
Intro.render_grid(render_sine, traces; title="sin")

## Doing posterior inference

We now will provide a dataset of $y$-coordinates and try to draw inferences about the process that generated the data. We begin with the following dataset:

In [None]:
ys = [6.75003, 6.1568, 4.26414, 1.84894, 3.09686, 1.94026, 1.36411, -0.83959, -0.976, -1.93363, -2.91303]
scatter(xs, ys, color=:green, size=(600, 600))

Now, let's write an _inference program_ (`infer`) that takes a model we assume
generated our data, the dataset, and the number of steps we would like to
perform. `infer` will then return a trace of the function that is approximately
sampled from the _posterior distribution_ on traces of the function, given the
observed data.

That is, `infer` will try to find a trace that explains our created dataset
well. We can inspect the trace to find estimates of the `:slope` and
`:intercept` of a line that fits the data.

Functions like `Gen.importance_resampling` expect a `model` and a `choicemap`
representing our dataset and relating it to our `model`.
A `choicemap` will map the random choice addresses in `linear_model` to values
from our dataset `ys`. For example, the 4th value in our dataset (`y[4]`) will
be tied to `(:y, 4)`.

In [None]:
function infer(model, xs::Vector, ys::Vector, steps::Int64, args...)
    # Create a `choicemap` that models addresses `(:y, i)` to observed values `ys[i]`.
    #   We leave `:slope` and `:intercept` unconstracted because we want them to be
    #   inferred.
    observations = Gen.choicemap()
    for (idx, y) in enumerate(ys)
        observations[(:y, idx)] = y
    end
    
    # Call `importance_resampling` to obtain a likely trace consistent with our
    #   observations.
    (trace, _) = Gen.importance_resampling(model, (xs, args...,), observations, steps)
    return trace
end;

In [None]:
trace = infer(linear_model, xs, ys, 100)
render_line(trace, "importance-resampling")

We can see above that `importance_resampling` found a reasonable slope and
intercept to explain our data. Like we did in previous code blocks, we can
also generate many traces and visualize them in a grid.

In [None]:
traces = [infer(linear_model, xs, ys, 100) for _=1:12]
Intro.render_grid(render_line, traces, title="IR")

Inspecting the grid above, it's clear that we have enough uncertainty to not be
$100\%$ certain where the line truly is. However, we can refine our
understanding of the variability in the posterior by visualizing all the traces
atop one another, rather than individually.

Since each trace has the same observed data, we'll only plot it once, based on
of the first trace.

In [None]:
# n_steps = 100

### Exercise: Run `linear_model` with `1`, `10`, and `1000` steps

We achieve the results above with `steps = 100`. Try running our inference with
`steps = [1, 10, 1_000]`. **Which values seem like a good tradeoff between
accuracy and run time?**

In [None]:
# n_steps = 1

In [None]:
# n_steps = 10

In [None]:
# n_steps = 1_000

### Example: Consider the following dataset

Let's write an inference program to generate traces of `sine` that explain this
dataset. Visualize the distribution of traces.

Let's also change the prior distribution of the period to be `gamma(1, 1)`, too.

- Can you explain the difference in inference when using `period` priors of `gamma(1, 1)` vs `gamma(5, 1)`?
- How many `steps` were needed to achieve good results?

In [None]:
sine_ys = [2.89, 2.22, -0.612, -0.522, -2.65, -0.133, 2.70, 2.77, 0.425, -2.11, -2.76]
scatter(xs, sine_ys, color=:red, size=(600, 600))

In [None]:
@gen function sine(xs::Vector{Float64}, shape::Int=5)::Int
    # CODE HERE
end
;

In [None]:
sin1_traces = [infer(sine, xs, sine_ys, 500, 1) for _=1:12]
Intro.render_grid(render_sine, sin1_traces; title="γ1-sine")

In [None]:
sin5_traces = [infer(sine, xs, sine_ys, 500, 5) for _=1:12]
Intro.render_grid(render_sine, sin5_traces; title="γ5-sine")

## Predicting new data

We can use `Gen.generate` to generate a trace where certain random choices are
constrained to given values. This is done by way of a `Gen.choicemap`.

In [None]:
# constraints

# render line

Note that the points above are still generated randomly. _Run the cell a few
times to verify this._

We can use this ability to constrain executions to predict the values of the
$y$-coordinates at new $x$-coordinates by running new executions of the `model`.
Let's inspect the `predict` function below. It takes a `trace` and a `Vector` of
new $x$-coordinates, then returns a `Vector` of predicted $y$-coordinates
corresponding to the `new_xs`.

The `addresses` parameter allows us to use `predict` with an arbitrary model.

In [None]:
function predict(
        model::Gen.DynamicDSLFunction, trace::Gen.DynamicDSLTrace,
        new_xs::Vector{Float64}, addresses::Vector{Symbol}
    )
    # Copy parameter values from the inferred `trace` into a new `choicemap`.
    
    # Run the model with new x coordinates, and parameters fixed to inferred values
    
    # Extract y values and return them
end
;

The cell below defines a composite function that performs inference on an
observed dataset `(xs, ys)`, then runs `predict` to generate predicted
$y$-coordinates. This process generates `num_traces`, returning a `Matrix` of
predicted $y$-coordinates.

In [None]:
function infer_and_predict(
        model::Gen.DynamicDSLFunction,
        xs::Vector{Float64},
        ys::Vector{Float64},
        new_xs::Vector{Float64},
        n_traces::Int,
        n_steps::Int,
        addresses::Vector{Symbol},
    )
    
    # run inference and prediction, for `n_traces`
    pred_ys = []
    
    for _=1:n_traces
        trace = infer(model, xs, ys, n_steps)
        push!(pred_ys, predict(model, trace, new_xs, addresses))
    
    end
    return pred_ys
end
;

In [None]:
function plot_predictions(xs::Vector, ys::Vector, new_xs::Vector, pred_ys::Vector)
    p = plot(size=(600, 600))
    
    for pred_y in pred_ys
        p = scatter!(new_xs, pred_y, color=:black, alpha=0.3, legend=false)
    end
    
    p = scatter!(xs, ys, color=:red)
    return p
end
;

In [None]:
scatter(xs, ys, color=:red, size=(600, 600))

We can use the inferred values of the paramters to predict $y$ coordinates for
$x$ coordinates from $[5, 10]$ (from which data was not observed). We can also
predict data within a new $[-5, 5]$ to compare against our originally observed
data.

Predicting new data from inferred parameters, then comparing this new data to
the observed data is the core idea behind _posterior predictive check_. We won't
provide a rigorous overview behind techniques for checking the quality of a
model, but intent to provide a high-level intuition.

In [None]:
new_xs = collect(range(-5, 10; length=100))
pred_ys = infer_and_predict(linear_model, xs, ys, new_xs, 20, 1_000, [:slope, :intercept])
plot_predictions(xs, ys, new_xs, pred_ys)

The results we see above look quite reasonable – both in the observed data as
well as the extrapolated predictions on the right.

Let's run the same experiment, but with more noisy data.

In [None]:
noisy_ys = [5.092, 4.781, 2.46815, 1.23047, 0.903318, 1.11819, 2.10808, 1.09198, 0.0203789, -2.05068, 2.66031]

pred_ys = infer_and_predict(linear_model, xs, noisy_ys, new_xs, 20, 1_000, [:slope, :intercept])
plot_predictions(xs, noisy_ys, new_xs, pred_ys)

It seems like the data `linear_model` generates is less noisy that our data
actually is. This model seems overconfident and is a sign that our model is
mis-specified.

In our case, this is because we've assumed that the noise has a value of $0.1$.
However, it seems like the actual noise in our data is much larger. We can
correct this, though, by making noise a random choise as well and inferring its
value along with the other parameters.

This new version will sample a random choice from a `gamma(1, 1)` prior.

In [None]:
@gen function noisy_linear_model(xs::Vector{Float64})::Nothing    
    # reproduce linear_model, but with a `noise` sampled from `gamma(1, 1)`
end
;

Now let's compare the predictions using `infer_and_predict` on the original
`linear_model` and the `noisy_linear_model` on our `ys` data.

In [None]:
inf_args = [xs, ys, new_xs, 20, 1_000]
plt_args = inf_args[1:3]
pred_ys = infer_and_predict(linear_model, inf_args..., [:slope, :intercept])
linear_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Fixed noise level")

pred_ys = infer_and_predict(noisy_linear_model, inf_args..., [:slope, :intercept, :slope])
noisy_linear_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Inferred noise level")

plot(
    linear_model_plot, noisy_linear_model_plot,
    layout=(1, 2), size=(800, 400),
)

Notice that there's more uncertainty in the predictions made using our
`noisy_linear_model`.

Let's also compare these predictions against our `noisy_ys` dataset.

In [None]:
inf_args = [xs, noisy_ys, new_xs, 20, 1_000]
plt_args = inf_args[1:3]
addresses = [:slope, :intercept]

pred_ys = infer_and_predict(linear_model, inf_args..., addresses)
linear_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Fixed noise level")

pred_ys = infer_and_predict(noisy_linear_model, inf_args..., [addresses..., :slope])
noisy_linear_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Inferred noise level")

plot(
    linear_model_plot, noisy_linear_model_plot,
    layout=(1, 2), size=(800, 400),
)

Notice that our `linear_model` is overconfident (as demonstrated earlier), but
our `noisy_linear_model` contains much more uncertainty (as it should) while
still capturing the negative trend of our data.

### Exercise: Modify `sine` to make noisy predictions

In [None]:
@gen function noisy_sine_model(xs::Vector{Float64})
    # rewrite `sine_model`, but with `noise` sampled from `gamma(1, 1)`
end;

In [None]:
inf_args = [xs, sine_ys, new_xs, 20, 100]
plt_args = inf_args[1:3]
addresses = [:phase, :period, :amplitude]

pred_ys = infer_and_predict(sine_model, inf_args..., addresses)
sine_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Fixed noise level")

pred_ys = infer_and_predict(noisy_sine_model, inf_args..., [addresses..., :noise])
noisy_sine_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Inferred noise level")

plot(
    sine_model_plot, noisy_sine_model_plot;
    layout=(1, 2),size=(800, 400),
)

In [None]:
inf_args = [xs, noisy_ys, new_xs, 20, 100]
plt_args = inf_args[1:3]
addresses = [:phase, :period, :amplitude]

pred_ys = infer_and_predict(sine_model, inf_args..., addresses)
sine_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Fixed noise level")

pred_ys = infer_and_predict(noisy_sine_model, inf_args..., [addresses..., :noise])
noisy_sine_model_plot = plot_predictions(plt_args..., pred_ys)
title!("Inferred noise level")

plot(
    sine_model_plot, noisy_sine_model_plot;
    layout=(1, 2),size=(800, 400),
)

The model with the noise inference is more able to avoid making inaccurate
overconfident predictions on the dataset (`noisy_ys`) for which the model’s
assumptions are violated.

## Calling other generative functions

Along with making random choices, generative functions may call other functions.
Let's illustrate this by combining our `linear_model` and `sine_model`. This
new model will use either `linear_model` or `sine_model` to explain data, based
on a coin flip. This is called _model selection_.

In `Gen`, we may call generative functions in three ways:
1. using Julia's regular function call syntax
1. using `@trace` with an address: `@trace(<func>, <address>)`
1. using `@trace` without an address: `@trace(<func>)`

When invoking using Julia's regular function call syntax, **random choices made
by the called function _will not be traced_.** However, if we use `@trace`, we
will be able to trace random choices. Using `@trace(<func>)` (without an
address), random choices made in the called function are placed in the same
namespace as the caller's random choices. When using `@trace(<func>, <addr>)`,
the random choices are placed under the namespace given by `<addr>`.

In [None]:
@gen function callee()
    y ~ normal(0, 1)
end;

@gen function caller()
    x ~ bernoulli(0.5)
    @trace(callee())
end;

@gen function namespaced()
    x ~ bernoulli(0.5)
    namespaced ~ callee()
end;

We first show the addresses sampled by `caller`:

In [None]:
trace = Gen.simulate(caller, ())
display(Gen.get_choices(trace))

And the addresses sampled by `namespaced`:

In [None]:
trace = Gen.simulate(namespaced, ())
display(Gen.get_choices(trace))

Using `@trace` with a namespace can help avoid address collisions for complex models.

A hierarchical address is represented as a Julia `Pair`, where the first element of the pair is the first element of the address and the second element of the pair is the rest of the address:

In [None]:
trace[Pair(:namespaced, :y)]

Julia uses the `=>` operator as a shorthand for the `Pair` constructor, so we can access choices at hierarchical addresses like:

In [None]:
trace[:namespaced => :y]

If we have a hierarchical address with more than two elements, we can construct the address by chaining the `=>` operator:

In [None]:
@gen function namespaced_namespace()
    @trace(namespaced(), :parent)
end;

trace = Gen.simulate(namespaced_namespace, ())
display(trace[:parent => :namespaced => :y])

Note that the `=>` operator associated right, so this is equivalent to:

In [None]:
trace[Pair(:parent, Pair(:namespaced, :y))]

Now, we write a generative function that combies the line and sine models. It makes a Bernoulli random choice (e.g. a coin flip that returns true or false) that determines which of the two models will generate the data.

In [None]:
@gen function combined_model(xs::Vector{Float64})
    if @trace(bernoulli(0.5), :isline)
        @trace(noisy_linear_model(xs))
    else
        @trace(noisy_sine_model(xs))
    end
end;

We also write a visualization for a trace of this function:

In [None]:
function render_combined(trace, title=""; show_points=true, kwargs...)
    p = plot(size=(600, 600))
    if trace[:isline]
        p = render_line(trace, "isline"; show_points=show_points, kwargs...)
    else
        p = render_sine(trace, "issine"; show_points=show_points, kwargs...)
    end
    return p
end;

In [None]:
traces = map(_ -> Gen.simulate(combined_model, (xs,)), 1:12)
Intro.render_grid(render_combined, traces)

We run inference using this combined model on the `ys` data set and the `sine_ys` dataset.

In [None]:
traces = map(_ -> infer(combined_model, xs, ys, 10_000), 1:10)
p1 = Intro.render_overlay(render_combined, traces)

traces = map(_ -> infer(combined_model, xs, sine_ys, 10_000), 1:10)
p2 = Intro.render_overlay(render_combined, traces)

plot(
    p1, p2;
    layout=(1, 2), legend=false, size=(800, 400),
)

The results should show that the line model was inferred for the `ys` dataset, and the sine wave model was inferred for the `sine_ys` data set.