TODO:
 * precompute $\log|\Sigma_i|$
 * precomute swapped indices for parameter updates

In [None]:
using Distributions
using StatPlots
using Plots
pyplot(size=(600, 400))

import CSV
import JLD2

In [None]:
include("src/NGSIM.jl")

In [None]:
DATA_PATH = "../data/trajdata_i101_trajectories-0750am-0805am"
JLD2.@load joinpath(DATA_PATH, "td.jld") td

(S, id_lookup) = td_sparse(td)

X_full = CSV.read(joinpath(DATA_PATH, "X.csv"); nullable=false)
X = Array(X_full[:, [:velocity, :heading, :acceleration]])'
V = [X[1, :] .* sin.(X[2, :]) X[1, :] .* cos.(X[2, :]) X[3, :]]'

(pairs, _) = readcsv(joinpath(DATA_PATH, "pairs.csv"), Int; header=true)
pairs = pairs'
;

In [None]:
plot(
    plot(X[1, :], seriestype=:density, title="velocity"),
    plot(X[2, :], seriestype=:density, title="heading"),
    plot(X[3, :], seriestype=:density, title="acceleration"),
    legend=:false, layout=(3, 1), size=(600, 400))

In [None]:
plot(
    plot(V[1, :], seriestype=:density, title="x velocity"),
    plot(V[2, :], seriestype=:density, title="y velocity"), 
    plot(V[3, :], seriestype=:density, title="acceleration"),
    legend=:false, layout=(3, 1), size=(600, 400))

## EM

In [None]:
K = 3
curr = chmm_from_data(V, K)
orig_est = copy(curr)
suff = ChmmSuffStats(curr)
;

In [None]:
(curr, log_like_hist) = chmm_em!(S, V, pairs, K, curr, suff; N_iters=50, print_every=10)
;

In [None]:
scatter(log_like_hist, legend=false)
ylabel!("log likelihood")
xlabel!("iteration")

In [None]:
@assert all( diff(log_like_hist) .> 0 )

# Analysis (????)

In [None]:
ms = hcat(curr.μs...)'
ms_orig = hcat(orig_est.μs...)'

scatter(V[1, :], V[2, :], label="", marker=(:circle, stroke(0)))
scatter!(ms_orig[:,1], ms_orig[:, 2], marker=:X, ms=15, label="original")
scatter!(ms[:,1], ms[:, 2], marker=:X, ms=10, label="final")

In [None]:
default_cgrad(:cmocean, default=:deep)

In [None]:
l = map(string, 1:3)
heatmap(l, l, reshape(curr.π0, 3, 3), aspect_ratio=1, title="π₀")
yaxis!(:flip)

In [None]:
ps = Matrix(K, K)
for i in 1:K
    for j in 1:K
        k = sub2ind((K, K), i, j)
        p = reshape(curr.P[:, k], K, K)
        ps[i,j] = heatmap(l, l, p, title="($i, $j)" aspect_ratio=1)
    end
end
plot(ps..., aspect_ratio=1, palette=:deep, colorbar=true, size=(700, 700))
yaxis!(:flip, ticks=0)