In [1]:
# # Control dependency

#=
Here, we give a example of controlled HMM (also called input-output HMM), in the special case of Markov switching regression.
=#

using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using LinearAlgebra
using Random
using StableRNGs
using StatsAPI

#-

rng = StableRNG(63);

# ## Model

#=
A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM.
We can represent it with the following subtype of `AbstractHMM` (see [Custom HMM structures](@ref)), which has one vector of coefficients $\beta_i$ per state.
=#

struct ControlledGaussianHMM{T} <: AbstractHMM
    init::Vector{T}
    trans::Matrix{T}
    dist_coeffs::Vector{Vector{T}}
end

#=
In state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$.
Controls must be provided to both `transition_matrix` and `obs_distributions` even if they are only used by one.
=#

function HMMs.initialization(hmm::ControlledGaussianHMM)
    return hmm.init
end

function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
    return hmm.trans
end

function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector)
    return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)]
end

#=
In this case, the transition matrix does not depend on the control.
=#

# ## Simulation

d = 3
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);

#=
Simulation requires a vector of controls, each being a vector itself with the right dimension.

Let us build several sequences of variable lengths.
=#

control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];

obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));

# ## Inference

#=
Not much changes from the case with simple time dependency.
=#

best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)

# ## Learning

#=
Once more, we override the `fit!` function.
The state-related parameters are estimated in the standard way.
Meanwhile, the observation coefficients are given by the formula for [weighted least squares](https://en.wikipedia.org/wiki/Weighted_least_squares).
=#

function StatsAPI.fit!(
    hmm::ControlledGaussianHMM{T},
    fb_storage::HMMs.ForwardBackwardStorage,
    obs_seq::AbstractVector,
    control_seq::AbstractVector;
    seq_ends,
) where {T}
    (; γ, ξ) = fb_storage
    N = length(hmm)

    hmm.init .= 0
    hmm.trans .= 0
    for k in eachindex(seq_ends)
        t1, t2 = HMMs.seq_limits(seq_ends, k)
        hmm.init .+= γ[:, t1]
        hmm.trans .+= sum(ξ[t1:t2])
    end
    hmm.init ./= sum(hmm.init)
    for row in eachrow(hmm.trans)
        row ./= sum(row)
    end

    U = reduce(hcat, control_seq)'
    y = obs_seq
    for i in 1:N
        W = sqrt.(Diagonal(γ[i, :]))
        hmm.dist_coeffs[i] = (W * U) \ (W * y)
    end
end

#=
Now we put it to the test.
=#

init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-2 * ones(d), 2 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);

#-

hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
first(loglikelihood_evolution), last(loglikelihood_evolution)


(-482033.28522111394, -260982.53860078458)

In [7]:
control_seq

150110-element Vector{Vector{Float64}}:
 [-1.2724663276017179, -0.5425092089741408, -0.006775114305754266]
 [0.39899635102542447, 0.6879079099787647, -0.011462083229865102]
 [-0.3778445956351776, 1.02611610123778, -0.7521106533303417]
 [-0.05023471135145104, 1.1592855860576567, 1.600041146986573]
 [0.3024054684489508, 0.07762793399392462, -2.1032490660891834]
 [-0.4914103455484126, 1.5029814490027942, -0.48991872574893264]
 [0.1676976135048406, -0.6519291268812127, -0.03569567969966787]
 [-1.6538832667367795, -0.42886974189154153, -0.44106095881055]
 [-0.39234414283073954, 0.2287469842189177, 0.0930174991863142]
 [-0.34402635947107824, 0.5003837097671952, 0.7384034098094486]
 ⋮
 [-0.5190687943174294, -0.23376686824753887, 0.5361063482684831]
 [0.5518973819422871, 1.1883943478539016, 0.5546575013552255]
 [-1.0643243410902814, 0.22637917814975755, -0.05665553546986064]
 [-0.3313557353889336, -0.08741062427993526, -0.7741876282484813]
 [0.04135352660160016, -1.6520944862449851, 0.85704563

In [5]:
dist_coeffs_guess

2-element Vector{Vector{Float64}}:
 [-2.0, -2.0, -2.0]
 [2.0, 2.0, 2.0]