# Probabilistic Programming - 3
## Variational inference

In this notebook, we are looking at inference in a dynamical system.

### Preliminaries

- Goal 
  - Learn to apply ForneyLab to a dynamical system.
- Materials        
  - Mandatory
    - These lecture notes.
  - Optional
    - Cheatsheets: [how does Julia differ from Matlab / Python](https://docs.julialang.org/en/v1/manual/noteworthy-differences/index.html).
    - Getting started with [ForneyLab](https://biaslab.github.io/forneylab/docs/getting-started/).

In [3]:
using Pkg;Pkg.activate("workspace");Pkg.instantiate()

[32m[1mActivating[22m[39m environment at `~/Documents/biaslab/repos/BMLIP/lessons/notebooks/probprog/workspace/Project.toml`


In [None]:
using Random
using Distributions
using Plots
pyplot()
include("../scripts/pp-3.jl") 

Random.seed!(1234);

### Generate data

In [None]:
# There are 3 possible states and each variable is in one of those (one-hot encoding)
K = 3

# Length of time-series
T = 50

# Transition matrix of latent variables
transition = [0.3 0.6 0.1; 
              0.5 0.2 0.3; 
              0.2 0.8 0.1]

# Emission matrix for observed variables
emission = [0.7 0.3 0.0; 
            0.2 0.6 0.2; 
            0.0 0.3 0.7]

# Preallocate data arrays
X = zeros(T+1, K)
Y = zeros(T, K)

# Initial state
X[1,:] = [0.0, 1.0, 0.0] 

# Generate data for entire time-series
for t = 2:T
    
    # Transition from previous state
    A = transition * X[t-1,:]
    
    # Sample from Categorical distribution
    X[t,:] = one_hot(rand(Categorical(A ./ sum(A)), 1)[1], K)
    
    # Emission of current state
    B = emission * X[t,:]
    
    # Sample from Categorical distribution
    Y[t-1,:] = one_hot(rand(Categorical(B ./ sum(B)), 1)[1], K)
    
end

# For visualization, we collapse the data from a one-hot to a numerical encoding
states = argmax.(eachrow(X))
observations = argmax.(eachrow(Y))

# Visualization.
plot(1:T, states[2:end], color="red", label="states", ylim=(0, 4), grid=false)
scatter!(1:T, observations, color="blue", label="observations")
xlabel!("time (t)")

### Model specification

In [None]:
using ForneyLab

In [None]:
g = FactorGraph()

@RV A ~ Dirichlet(ones(3,3)) # Vague prior on transition model
@RV B ~ Dirichlet([10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]) # Stronger prior on observation model
@RV s_0 ~ Categorical(1/3*ones(3))

s = Vector{Variable}(undef, n_samples) # one-hot coding
x = Vector{Variable}(undef, n_samples) # one-hot coding
s_t_min = s_0
for t = 1:n_samples
    @RV s[t] ~ Transition(s_t_min, A)
    @RV x[t] ~ Transition(s[t], B)
    
    s_t_min = s[t]
    
    placeholder(x[t], :x, index=t, dims=(3,))
end;

### Algorithm Generation

In [None]:
# Define the recognition factorization
q = RecognitionFactorization(A, B, [s_0; s], ids=[:A, :B, :S])

# Generate VMP algorithm
algo = variationalAlgorithm(q)

# Construct variational free energy evaluation code
algo_F = freeEnergyAlgorithm(q);

### Execution

In [None]:
# Load algorithms
eval(Meta.parse(algo))
eval(Meta.parse(algo_F))

# Initial recognition distributions
marginals = Dict{Symbol, ProbabilityDistribution}(
    :A => vague(Dirichlet, (3,3)),
    :B => vague(Dirichlet, (3,3)))

# Initialize data
data = Dict(:x => x_data)
n_its = 20

# Run algorithm
F = Vector{Float64}(undef, n_its)
for i = 1:n_its
    stepS!(data, marginals)
    stepB!(data, marginals)
    stepA!(data, marginals)

    F[i] = freeEnergy(data, marginals)
end
;

In [None]:
### Plot results

In [None]:
using PyPlot

# Plot free energy
plot(1:n_its, F, color="black")

grid("on")
xlabel("Iteration")
ylabel("Free Energy")
xlim(0,n_its);

In [None]:
figure(figsize=(10,5))

# Collect state estimates
x_obs = [findfirst(x_i.==1.0) for x_i in x_data]
s_true = [findfirst(s_i.==1.0) for s_i in s_data]

# Plot simulated state trajectory and observations
subplot(121)
plot(1:n_samples, x_obs, "k*", label="Observations x", markersize=7)
plot(1:n_samples, s_true, "k--", label="True state s")
yticks([1.0, 2.0, 3.0], ["Red", "Green", "Blue"])
grid("on")
xlabel("Time")
legend(loc="upper left")
xlim(0,n_samples)
ylim(0.9,3.1)
title("Data set and true state trajectory")

# Plot inferred state sequence
subplot(122)
m_s = [mean(marginals[:s_*t]) for t=1:n_samples]
m_s_1 = [m_s_t[1] for m_s_t in m_s]
m_s_2 = [m_s_t[2] for m_s_t in m_s]
m_s_3 = [m_s_t[3] for m_s_t in m_s]

fill_between(1:n_samples, zeros(n_samples), m_s_1, color="red")
fill_between(1:n_samples, m_s_1, m_s_1 + m_s_2, color="green")
fill_between(1:n_samples, m_s_1 + m_s_2, ones(n_samples), color="blue")
xlabel("Time")
ylabel("State belief")
grid("on")
title("Inferred state trajectory");

In [None]:
# True state transition probabilities
PyPlot.plt.matshow(A_data, cmap="bone", vmin=0.0, vmax=1.0)
ttl = title("True state transition probabilities")
ttl.set_position([.5, 1.15])
yticks([0, 1, 2], ["Red", "Green", "Blue"])
xticks([0, 1, 2], ["Red", "Green", "Blue"], rotation="vertical")
colorbar()

# Inferred state transition probabilities
PyPlot.plt.matshow(mean(marginals[:A]), cmap="bone", vmin=0.0, vmax=1.0)
ttl = title("Inferred state transition probabilities")
ttl.set_position([.5, 1.15])
yticks([0, 1, 2], ["Red", "Green", "Blue"])
xticks([0, 1, 2], ["Red", "Green", "Blue"], rotation="vertical")
colorbar();