# Local Variational Inference for a Softmax Model

The local variational method can be extended to multidimensional models by use of the softmax function (see e.g. Ahmed, 2013). In this demo we consider the following model:

\begin{align*}
    s &\sim \mathcal{N}(m_s, V_s)\\
    p_i &\sim \mathcal{N}(A_i \cdot s, I_2)\\
    y_i &\sim \mathcal{C}at(\sigma(p_i))\,,
\end{align*}

where $A_i$ is a matrix that encodes the lineup and winner/loser for encounter $i$, and with $\sigma$ a softmax. We are interested in estimating a belief over player strength. In this demo, we use the softmax to implement a "greater than" constraint as used in the example from Infer.NET, 2020, see references. The example consists of a number of match results in head-to-head encounters between 5 players being used to estimate their (relative) skills, including an estimate of the uncertainty on each skill. This notebook was developed by Keith Myerscough of Sioux LIME, lending heavily from other notebooks in this demo folder. 

In [1]:
using Random
using ForneyLab
using PyPlot
using LinearAlgebra
;

# Generate Data Set

In [2]:
# Generate data set
Random.seed!(21)
σ(x) = exp.(x)/sum(exp.(x)) # Softmax function

n_players = 5
winners = [1, 1, 1, 2, 4, 5]
losers = [2, 4, 5, 3, 2, 3]
n_matches = length(winners)

# Define matrices for match outcomes
A = [zeros(2, n_players) for _ in 1:n_matches]
for i_m = 1:n_matches
    println("$(winners[i_m]) beats $(losers[i_m])")
    A[i_m][1, winners[i_m]] = 1 # The first row of A_i always encodes the winner
    A[i_m][2, losers[i_m]] = 1
end
;

1 beats 2
1 beats 4
1 beats 5
2 beats 3
4 beats 2
5 beats 3


# Model specification

The model specification includes local variational parameters `xi` and `a`, which are used to define an upperbound on the softmax (Bouchard, 2007).

In [3]:
g = FactorGraph()

perf_prec = 1.0
m_s_prior = 6.0*ones(n_players)
v_s_prior = 9.0*eye(n_players)

@RV s ~ GaussianMeanVariance(m_s_prior, v_s_prior)

p = Vector{Variable}(undef, n_matches)
xi = Vector{Variable}(undef, n_matches)
a = Vector{Variable}(undef, n_matches)
y = Vector{Variable}(undef, n_matches)
for i_m = 1:n_matches
    @RV p[i_m] ~ GaussianMeanPrecision(A[i_m]*s, perf_prec*eye(2))
    @RV xi[i_m]
    @RV a[i_m]
    @RV y[i_m] ~ Softmax(p[i_m], xi[i_m], a[i_m])

    placeholder(y[i_m], :y, index=i_m, dims=(2,))
end
;

# Algorithm generation

Since we are interested in optimizing the local variational parameters `xi`, `a` together with the hidden state sequence `x`, we construct an algorithm that also updates `xi` and `a`.

In [4]:
q = PosteriorFactorization(s, p, xi, a, ids=[:S, :P, :Xi, :A])
algo = messagePassingAlgorithm(s, q)
source_code = algorithmSourceCode(algo);

In [5]:
# println(source_code) # Uncomment to inspect algorithm code

# Execution

For execution we initialize the local variational parameters and iterate the automatically derived algorithm.

In [6]:
eval(Meta.parse(source_code));

In [7]:
# Pre-initialize marginals
marginals = Dict()
marginals[:s] = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=m_s_prior, v=v_s_prior)
for t=1:n_matches
    marginals[:p_*t] = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=m_s_prior[1]*ones(2), v=v_s_prior[1, 1]*eye(2))
    marginals[:xi_*t] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=ones(2), w=eye(2))
    marginals[:a_*t] = ProbabilityDistribution(Univariate, GaussianMeanPrecision)
end

# Prepare data dictionary
# We encoded A_i so that the player encoded by the first row of A_i is always the winner
data = Dict(:y  => [[1, 0] for _ in 1:n_matches])

# Execute algorithm
n_its = 100
for i = 1:n_its
    stepA!(data, marginals)
    stepXi!(data, marginals) # Update local variational parameters
    stepS!(data, marginals) # Update hidden state
    stepP!(data, marginals) # Update hidden state
end
;

# Results

Results show that the algorithm accurately estimates the hidden state.

In [8]:
# Extract posterior state statistics
m_s = [mean(marginals[:s])[j_p] for j_p = 1:n_players]
v_s = [cov(marginals[:s])[j_p, j_p] for j_p = 1:n_players]

for j_p = 1:n_players
    println("player $(j_p) with rating $(round(m_s[j_p],digits=3)) ± $(round(sqrt.(v_s[j_p]),digits=3))")
end

player 1 with rating 9.678 ± 0.567
player 2 with rating 4.879 ± 0.567
player 3 with rating 2.637 ± 0.688
player 4 with rating 6.692 ± 0.688
player 5 with rating 6.072 ± 0.688


## Comparing to Reference

The output is compared to that of Infer.NET, 2020:

In [9]:
# Infer.NET reference values
ref_mean = [9.517, 4.955, 2.639, 6.834, 6.054]
ref_dev = [3.926, 3.503, 4.288, 3.892, 4.731]
;

Note that, while the esimated skills are very close to the reference of Infer.NET, the variances are not. This might be due to the current results being based on Variational Message Passing, while Infer.NET uses Expectation Propagation. VMP is known to exhibit mode-seeking behaviour, usually leading to an under-estimation of the variance. In contrast, EP exhibits mode-covering behaviour, generally over-estimating the variance. Expectation propagation for the `SoftMax` node is not yet implemented in ForneyLab.

### References

Bouchard, 2007 "Efficient Bounds for the Softmax Function"

Ahmed, 2013, "Bayesian Multicategorical Soft Data Fusion for Human-Robot Collaboration"

Infer.NET, 2020, https://docs.microsoft.com/en-us/dotnet/machine-learning/how-to-guides/matchup-app-infer-net