# Problem Set 2 -- Sampling-based algorithms in Gen

<div class="alert alert-info">
    
**Instructions:** Finish the code following the instructions. Execute all blocks of your Jupyter notebook, write analyses in English in the corresponding Markdown blocks ("marked as YOUR ANSWER HERE"), save the final notebook as a pdf file, and submit it on Canvas.

Acknowledge the contributions if you collaborated with someone else. You must write your own code and analyses.
</div>

First, let's install neccessary packages.

In [None]:
import Pkg
Pkg.activate("CogAI")
Pkg.add(["CSV", "DataFrames"])
# load necessary packages for this problem set
# Note that running this for the first time might take a good 15 mins &ndash; plan ahead
using Gen
using Plots
using CSV
using DataFrames

include("utils/draw.jl")

## Question 1

Nimet has been inspecting 3 colonies of leaf bugs; the average size of the members of each colony is different and denoted as $\theta_1, \theta_2, \theta_3$. 

Nimet knows that these 3 colonies are related, sprung off from a common colony last year: That is, she believes the average size of leafbugs in each colony should come from a shared underlying distribution. 

She measures the size of the 40 members of each colony, denoted $\vec{x}_1, \vec{x}_2, \vec{x}_3$. 

Nimet intuitively arrives at an inference about the average size of the members of each colony, as well as the shared parameter determining the average size of  leafbugs in each colony, based on these observations from each colony.

Your task is to formalize Nimet's inference process. You can assume that the average size of leafbugs in a colony (i.e., $\theta$s) can be modeled as coming from a Gamma distribution, $Gamma(\alpha, scale=2)$, shared across the three colonies. Furthermore, you can assume a prior over the `shape` parameter of the Gamma distribution, denoted $\alpha$, to be a uniform over the range of 1.0 to 15.0. (We assume that we know the scale parameter of the Gamma distribution `scale=2`).

Finally assume that the size measurements come from a normal distribution with known standard deviation $\vec{x}_{i,k} \sim Normal(\theta_i, \sigma=0.1)$, where $i \in [1, 3]$ indexes the colony and $k \in [1, 40]$ indexes members in that colony.

The following shows a graphical model of the generative model you can use to formalize Nimet's thought process. 

<img src="./images/leafbug-gm.png" alt="" width="400"/>

### Q 1A [3 pts]

Fill in the following to write a Gen generative function of the prior illustrated in the graphical model based on the parametric distributions you decided appropriate.

In [None]:
@gen function leafbug_colonies()
    scale = 2
    σ = 0.1

    # What's the prior of α?
    # your code here
    throw(Exception("Not Implemented."))

    # average size of leafbugs in each colony
    # What's distribution for each of θ1, θ2, θ3?
    # your code here
    throw(Exception("Not Implemented."))
    
    # make observation for each colony's leafbugs
    # the address structure in trace should be 
    # {:data => colony_id => leafbug_id => :x} ~ ...
    # e.g., {:data => 1 => 3 => :x} ~ ...
    # NOTE: You can perform this in a single loop. It's not required, but definitely a worthwhile challenge!
    # your code here
    throw(Exception("Not Implemented."))
end

Now simulate your generative function to ensure that its outputs make sense

In [None]:
# Run your model `leafbug_colonies` forward
# your code here
throw(Exception("Not Implemented."))
get_choices(trace)

Fill in the following code to load your data from `./data/leafbug_sizes.csv`. Each column is measurements in a colony.

In [None]:
function make_observations(df::DataFrame)
    # Setup your observation constraints
    # your code here
    throw(Exception("Not Implemented."))

    # Apply your constraints
    # To index DataFrames, you need to use the column name as a symbol.
    #   This means you'll need to use `Symbol("colony<N>")` to get the
    #   correct column name to index the DataFrame.
    # your code here
    throw(Exception("Not Implemented."))
    
    constraints
end 
# call the funciton to get the observations
df = DataFrame(CSV.File("./data/leafbug_sizes.csv"))
observations = make_observations(df)

### Q 1B [1 pt]

Given your generative model and observations, we will perform importance resampling for posterior inference. Fill in the code to obtain 100 posterior samples each using 1000 resampling steps.

In [None]:
# Note: this codeblock might take a minute to run

# Using 1000 samples in `importance_resampling`, do posterior inference.
# Ensure that you keep trace of all traces (that is, you should have a Vector of 100 traces from your runs of `importance_resampling`)
# your code here
throw(Exception("Not Implemented."))

### Q 1C [1 pts]

Fill in the following codeblock to visualize your posterior samples. 

First, show the mean log probability of all traces

Then make a plot with 2 subplots

* Subplot 1: Make a histogram plot showing $\alpha$.
* Subplot 2: Make another histogram plot showing the three $\theta$s on the same plot.

In [None]:
# function for compute the average log probs of all traces
function logmeanexp(scores)
    logsumexp(scores) - log(length(scores))
end;

# function for display all traces
function display_results(traces)
    # compute the average log prob of all traces
    log_probs = [get_score(t) for t in traces]
    println("Average log probability: $(logmeanexp(log_probs))")

    # collect the inferred αs and θs across the chains and plot them.
    αs = [t[:α] for t in traces]
    # your code here
    throw(Exception("Not Implemented."))

    # plot the movements
    α_plot = histogram(αs, thickness_scaling=3.5, size=(1000, 800), label="α", xtickfontsize=5, ytickfontsize=5)

    # Plot the θs on the same plot, use the same `thickness_scaling` and `size` as in α_plot. Ensure your plots are labeled correctly.
    # your code here
    throw(Exception("Not Implemented."))
    
    plot(α_plot, θ_plot, layout=(2, 1), size=(1600, 1200))
end;

In [None]:
# run the display_results function to show the average log prob and inferred variables
# your code here
throw(Exception("Not Implemented."))

### Q 1D [3 pts]

Now let's implement a MCMC algorithm with a cutomized proposal. In particular,

1. Random work proposal for α 
2. Then block resimulation for θs

Expected result: with 1000 updates (similar amount of compute compared to importance resampling), the mean log prob should be at least around 100

You need to tune the max step size in random walk proposal to reach good inference results.


In [None]:
@gen function random_work_proposal(current_trace)
    # Random work proposal on alpha
    # NOTE: the new α must be within the range of [1, 15]
    # adjust max step size to improve the inference results 
    # a good implementation means that the mean log prob should be at least around 100
    # your code here
    throw(Exception("Not Implemented."))
end;

In [None]:
function MCMC_inference(observations, num_updates)
    # Conduct Metropolis-Hastings update for 
    # num_updates times using your own proposal function
    
    # Generate the initial trace
    # your code here
    throw(Exception("Not Implemented."))
    
    for iter_id = 1:num_updates
        # random walk MH update for α
        # your code here
        throw(Exception("Not Implemented."))
            
        # block resimulation MN update for θs
        # NOTE: use Gen.select() instead of select() to select variables 
        # this is to avoid confusion with select() in DataFrames
        # your code here
        throw(Exception("Not Implemented."))
    end
    tr
end;

### Q 1E [1 pt]

Fill in the code to obtain 100 posterior samples each using 1000 MH updates; then display the results using the display_results() function.

In [None]:
# Note: this codeblock might take a while to run

# run MCMC inference with 1000 updates, repeat the inference 100 times, and save all traces
# your code here
throw(Exception("Not Implemented."))

# run the display_results function to show the average log prob and inferred variables
# your code here
throw(Exception("Not Implemented."))

### Q 1F [1 pt]
Compare the results of MCMC and importance resampling. Briefly explain the difference. Write 1 - 2 sentences in the Markdown block below.

YOUR ANSWER HERE

## Question 2

The iron core of our planet conducts electricity, which creates a magnetic field around Earth. This magnetic field not only provides a protective shield against sun’s unwanted rays, but also creates a kind of a map that remains relatively constant across time. Many animal species are thought to rely on this magnetic field for wayfinding, from migratory birds across the open skies to fish, reptiles and crustaceans in the deep ocean. For example, sea turtles can use the magnetic field to make plans between points A and B on the globe, and to orient themselves, i.e., to know where they are on the Earth. 

Let's assume that the planet's magnetic field can be described using a two-dimensional grid. The sea turtle moves in this gridworld one step at a time, in one of the four cardinal direction: north (n), south (s), east (e), and west (w). The following struct encodes these mechanics of grid movements and it will come handy as we go along.



In [None]:
struct Movement
    dx::Real
    dy::Real
end

const N = Movement( 0,  1)
const E = Movement( 1,  0)
const S = Movement( 0, -1)
const W = Movement(-1,  0)
const DIRECTIONS = [N, S, E, W]

Each cell in this grid emits the *intensity* and *direction* of the magnetic field at that cell. The `intensity` and `direction` of a coordinate at `x` and `y` are defined as:

<img src="./images/magnetic-field.png" alt="" width="400"/>

Your overall goal in this question will be to infer a posterior over the sequence of movements of the sea turtle from a sequence of intensity and direction observations. We will build up to that.

### Q 2A [4 pts]

Imagine that suddently, this sea turtle finds itself in the middle of an oceanic storm. The storm is such that:

* The sea turtle knows where it is at the beginning of the storm (time step `k=0`), including its `x` and `y` 
* At each time step, its movement &ndash; a single step in a cardinal direction: north `n`, south `s`, east `e` or west `w` &ndash; are dictated by the waves and turbulence of the ocean (not controlled by the sea turtle). Assume that these dynamics are random &ndash; a multinomial distribution with equal weight on each direction:

$$p(m_{t} | m_{t-1}) = p(m_{t}) = Multinomial([n, s, e, w])$$

* At each time step, the sea turtle observes noisy magnetic field measurements (because of the storm, it cannot additinally rely on vision or smell) 
$$p(intensity_{t}) \sim Normal(x + y, \sigma)$$
$$p(direction_{t}) \sim Normal(|x-y|, \sigma)$$
where $\sigma=0.1$

(You might be able to relate to the experience of our sea turtle friend if you can remember the last time you were on a Ferris Wheel with your eyes closed. In such a scenario, when you rely just on your vestibular system to tell your pose in space, you are likely to experience all sorts of hallucinated backward flips.)

Your task is write a generative model of this process using Gen's generative functions and the generative function combinator `Unfold`. You will write a generative function (a temporal kernel) called `ferriswheel_kernel`, and input it to the `Unfold` combinator to create a temporal generative model called `ferriswheel`. 

Assume that the initial state of the sea turtle is provided to `ferriswheel` as a global variable (not modeled as a random variable). 

Start with implementing a Julia `struct` to represent and modify the state information at each time step; the struct should include the following three entities. (For each variable, you must indicate its type.)

- `movement`: current movement
- `x`: current coordinate in the east-west axis
- `y`: current coordinate in the north-south axis

In [None]:
struct Field
    # your code here
    throw(Exception("Not Implemented."))
end

Using the definitions above, fill in the following `intensity` function to compute the intensity of a `Field`.

In [None]:
function intensity(field::Field)
    # your code here
    throw(Exception("Not Implemented."))
end

Using the definitions above, fill in the following `direction` function to compute the direction of a `Field`.

In [None]:
function direction(field::Field)
    # your code here
    throw(Exception("Not Implemented."))
end

We need a way to update the `x` and `y` entries of a `Field` based on a `Movement`. Using Julia's support for multiple dispatch and overriding functions and primitives, we provide a redefined addition, `Base.:+`, which applies a `Movement` to the `x` and `y` entries of a Field, returning a new `Field` with the addition as well as the input `Movement`. This will come in handy as we go along.

In [None]:
function Base.:+(field::Field, movement::Movement)
    return Field(
        movement,
        field.x + movement.dx,
        field.y + movement.dy,
    )
end

Fill in the following code block to complete the definition of ferrishweel_kernel and create a function called chain using Gen's Unfold combinator and this kernel.

In [None]:
@gen function ferriswheel_kernel(k::Int, curr_field::Field)
    # observation noise of the mangetic field (intensity and direction)
    σ = 0.1
    
    # Draw a movement (according to the probabilistic specification above)
    # your code here
    throw(Exception("Not Implemented."))

    # As implemented below, the following code will break.
    # Use Julia's "multiple dispatch" to ensure that you can add
    #   a `Field` to a `Movement`
    next_field = curr_field + movement

    # observe noisy intensity/direction measurements
    # The following variable should be named `obs_intensity`
    # your code here
    throw(Exception("Not Implemented."))
    # The following variable should be named `obs_direction`
    # your code here
    throw(Exception("Not Implemented."))
    
    # Return the updated field
    # your code here
    throw(Exception("Not Implemented."))
end

# Create a function `chain` using Gen's `Unfold` combinator for use in Particle Filtering
# your code here
throw(Exception("Not Implemented."))

Fill in the following codeblock to create the temporal generative model `ferrishweel`

In [None]:
@gen function ferriswheel(K::Int)
    # this line allows us to use the init_field definition from the main scope – keep it there!
    global init_field
    # Sample from the `ferriswheel_kernel` K times, store it at the address `trajectory`
    # your code here
    throw(Exception("Not Implemented."))
end

### Q 1B [1 pt]

Draw a sequence of 10 movements from your generative model. The initial state is provided (`init_field`). 

Use the `get_choices` and `get_retval` functions to display the random choices and return values associated with the trace you simulated of the generative function. 

In [None]:
# Start at coordinates (N, 3, 3)
init_field = Field(N, 3, 3)

# Run your model forward with 10 movements, name your variable `trace`
# your code here
throw(Exception("Not Implemented."))

# Display the choices from `trace`
# your code here
throw(Exception("Not Implemented."))

In [None]:
# Get the return values for your `trace
# your code here
throw(Exception("Not Implemented."))

Execute the next code block to visualize your sample. (Review this visualization code, but nothing to fill in.) You will see that it plots a binary heatmap of which movements occured (n, s, e, w), a gray scale heatmap showing the trajectory, and a line plot of how intensities and directions changed throughout the sequence. 

In [None]:
# a helper function to visualize things
# we visualize the movements and 
#   the predicted intensity and direction
#   values according to the coordinates of 
#   visited cells
function visualize(trace; title="")
    choices = get_choices(trace)
    fields  = get_retval(trace)

    # get the movements and coordinates
    ms = [field.movement for field in fields]
    xs = [field.x for field in fields]
    ys = [field.y for field in fields]
    
    # predicted intensities
    intensities = intensity.(fields)
    
    # predicted directions
    directions = direction.(fields)

    # create a binary matrix of movements (4 x length(ms))
    binary_movements = falses(4, length(ms))
    for (index, movement) in enumerate(ms)
        movement_order = findfirst(isequal(movement), DIRECTIONS)
        binary_movements[movement_order, index] = true
    end

    # create a gray scale matrix of coordinates, [-2, 12] x [-2, 12]
    # the brighter the color, the more recent the step is
    binary_coordinates = zeros(15, 15)
    n = length(xs)
    for t = 1:n
        binary_coordinates[xs[t] + 3, ys[t] + 3] = t / n
    end

    # plot the movements
    p1 = plot(
        binary_movements,
        seriestype=:heatmap,
        legend=false, 
        thickness_scaling=3.5,
        title=title,
        titlefont=5,
    )
    p2 = plot(
        binary_coordinates,
        seriestype=:heatmap,
        legend=false, 
        thickness_scaling=3.5,
        aspect_ratio = :equal,
        size=(800, 800),
        xtickfontsize=3,
        ytickfontsize=3,
    )
    # plot intensities and directions
    p3 = plot(
        collect(1:length(ms)),
        [intensities, directions], 
        thickness_scaling=3.5,
        labels=["intensity" "direction"]
    )
    plot(p1, p2, p3, layout=(3,1), legend=:inside, size=(1200, 1600))
end

visualize(trace)

Execute the following codeblock to load and visualize the observed intensity and directions

In [None]:
# load observations (sensory features)
obs_fields = DataFrame(CSV.File("./data/observed_fields.csv"))

# visualize (just run the following code)
plot(
    collect(1:size(obs_fields, 1)),
    [obs_fields[!, :intensity], obs_fields[!, :direction]],
    labels=["intensity" "direction"]
)

### Q 1C [4 pts]

In Gen, write a particle filtering algorithm to infer a posterior distribution over movements given the intensity and direction measurements in `observations`.

In [None]:
function particle_filter(num_particles::Int, obs_fields, num_samples::Int)
    #initital observation
    init_obs = Gen.choicemap(
        (:trajectory => 0 => :obs_intensity, obs_fields[1, :intensity]),
        (:trajectory => 0 => :obs_direction, obs_fields[1, :direction]),
    )
    
    # initialize the particle filter    
    # your code here
    throw(Exception("Not Implemented."))
        
    for (idx, obs_field) in enumerate(eachrow(obs_fields))
        # Resample
        # your code here
        throw(Exception("Not Implemented."))

        # load observations of this time step
        # your code here
        throw(Exception("Not Implemented."))

        # Re-weight by the likelihood 
        # your code here
        throw(Exception("Not Implemented."))
    end
    
    # return a sample of unweighted traces from the weighted collection
    # your code here
    throw(Exception("Not Implemented."))
end

Now call this particle filter inference procedure with 1000 particles and return 100 samples

In [None]:
# your code here
throw(Exception("Not Implemented."))

### Q 1D [1 pts]

The following codeblock visualizes these 100 posterior samples you just computer, one after the other; each frame shows the inferred sequence of movements and trajectory according to the posterior sample (top) and the predicted intensities and directions. View this animation and explain what it reveals about the posterior distribution. Write 1-2 sentences in the Markdown block after the animation.

In [None]:
viz = Plots.@animate for (trace_id, trace) in enumerate(pf_traces)
    visualize(trace;title="Sample $trace_id / 100")
end
gif(viz, fps=1)

YOUR ANSWER HERE