In [1]:
using Plots

In [2]:
using StateSpaceDynamics

In [3]:
using LinearAlgebra

In [4]:
K = 2

2

In [5]:
# create a state-space model for the tutorial
obs_dim = 10
latent_dim = 2

# set up the state parameters
A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)] 
Q = Matrix(0.1 * I(2))

x0 = [0.0; 0.0]
P0 = Matrix(0.1 * I(2))

# set up the observation parameters
C = randn(obs_dim, latent_dim)
R = Matrix(0.5 * I(10))


# create the state-space model
emissions = [GaussianLDS(;A=A, 
                        Q=Q, 
                        C=C, 
                        R=R, 
                        x0=x0, 
                        P0=P0, 
                        obs_dim=obs_dim, 
                        latent_dim=latent_dim, 
                        fit_bool=fill(true, 6)) for _ in 1:K]


2-element Vector{StateSpaceDynamics.LinearDynamicalSystem{StateSpaceDynamics.GaussianStateModel{Float64}, StateSpaceDynamics.GaussianObservationModel{Float64}}}:
 StateSpaceDynamics.LinearDynamicalSystem{StateSpaceDynamics.GaussianStateModel{Float64}, StateSpaceDynamics.GaussianObservationModel{Float64}}(StateSpaceDynamics.GaussianStateModel{Float64}([0.9204668006251124 -0.2350337612917968; 0.2350337612917968 0.9204668006251124], [0.1 0.0; 0.0 0.1], [0.0, 0.0], [0.1 0.0; 0.0 0.1]), StateSpaceDynamics.GaussianObservationModel{Float64}([0.338971409126926 -1.615669630479408; -0.7246751951812889 -1.8935141243258438; … ; -0.32247790187722525 -1.245141727893118; 0.4607462256016656 0.08795973609351078], [0.5 0.0 … 0.0 0.0; 0.0 0.5 … 0.0 0.0; … ; 0.0 0.0 … 0.5 0.0; 0.0 0.0 … 0.0 0.5]), 2, 10, Bool[1, 1, 1, 1, 1, 1])
 StateSpaceDynamics.LinearDynamicalSystem{StateSpaceDynamics.GaussianStateModel{Float64}, StateSpaceDynamics.GaussianObservationModel{Float64}}(StateSpaceDynamics.GaussianStateMode

In [6]:
using StateSpaceDynamics: initialize_transition_matrix, initialize_state_distribution

In [7]:
true_ssm = SwitchingLinearDynamicalSystem(initialize_transition_matrix(K), emissions, 
    initialize_state_distribution(K), K)

SwitchingLinearDynamicalSystem([0.1430513497044162 0.856948650295584; 0.7739542077580015 0.22604579224199836], StateSpaceDynamics.LinearDynamicalSystem[StateSpaceDynamics.LinearDynamicalSystem{StateSpaceDynamics.GaussianStateModel{Float64}, StateSpaceDynamics.GaussianObservationModel{Float64}}(StateSpaceDynamics.GaussianStateModel{Float64}([0.9204668006251124 -0.2350337612917968; 0.2350337612917968 0.9204668006251124], [0.1 0.0; 0.0 0.1], [0.0, 0.0], [0.1 0.0; 0.0 0.1]), StateSpaceDynamics.GaussianObservationModel{Float64}([0.338971409126926 -1.615669630479408; -0.7246751951812889 -1.8935141243258438; … ; -0.32247790187722525 -1.245141727893118; 0.4607462256016656 0.08795973609351078], [0.5 0.0 … 0.0 0.0; 0.0 0.5 … 0.0 0.0; … ; 0.0 0.0 … 0.5 0.0; 0.0 0.0 … 0.0 0.5]), 2, 10, Bool[1, 1, 1, 1, 1, 1]), StateSpaceDynamics.LinearDynamicalSystem{StateSpaceDynamics.GaussianStateModel{Float64}, StateSpaceDynamics.GaussianObservationModel{Float64}}(StateSpaceDynamics.GaussianStateModel{Float64}(

In [8]:
using LinearAlgebra, Random, Distributions
using StatsFuns: logsumexp

In [11]:
using StateSpaceDynamics: sample

In [12]:

# Example usage
obs_dim = 2
state_dim = 2
n_modes = 3
T = 100

# Generate synthetic data
_, observations, _ = sample(true_ssm, T)


([0.24476549894621924 0.019359248258796563 … -1.7401734194728746 -1.9790206179131933; 0.21517741599797152 -0.13426780111680525 … 1.377960690303263 1.4562908607225724], [0.22084541775158884 -0.27583698781739274 … -3.723881043841036 -3.5907907796342586; -1.6714596534998334 0.26252993868314045 … -0.6492998192510948 -0.3538406195987812; … ; -0.2073986930158949 -1.665533036991213 … -1.2499421088812703 -0.7216725518926183; 1.337352023420358 0.5447354720033666 … -0.921858072631659 0.6702654838717208], [2, 1, 2, 1, 2, 2, 1, 1, 2, 1  …  1, 1, 2, 1, 2, 1, 2, 1, 2, 1])

In [10]:
function e_step(slds::SwitchingLinearDynamicalSystem, observations::Matrix{Float64})
    # Compute log-likelihoods for each mode and time step
    T = size(observations, 2)
    n_modes = slds.K
    gamma = zeros(n_modes, T)

    # Forward algorithm for mode probabilities
    log_alpha = zeros(n_modes, T)
    log_likelihoods = compute_log_likelihoods(slds, observations, gamma)

    # Initialize forward probabilities with initial mode distribution
    log_alpha[:, 1] .= log.(slds.πₖ) + log_likelihoods[:, 1]

    # Recursion for forward probabilities
    for t in 2:T
        for k in 1:n_modes
            log_alpha[k, t] = logsumexp(log_alpha[:, t-1] + log.(slds.A[:, k])) + log_likelihoods[k, t]
        end
    end

    # Compute mode responsibilities (gamma)
    gamma .= exp.(log_alpha .- logsumexp(log_alpha, dims=1))  # Normalize in log space
    return gamma, log_alpha
end


e_step (generic function with 1 method)

In [None]:
function compute_log_likelihoods(slds::SwitchingLinearDynamicalSystem, observations::Matrix{Float64}, gamma::Matrix{Float64})
    T = size(observations, 2)  # Number of time steps
    n_modes = slds.K          # Number of modes
    state_dim = slds.B[1].latent_dim  # Dimensionality of latent state

    log_likelihoods = zeros(n_modes, T)  # Log-likelihoods for each mode and time step

    for k in 1:n_modes
        lds = slds.B[k]  # Access the Linear Dynamical System for mode k
        pred_mean = zeros(state_dim)  # Initialize predicted mean
        pred_cov = I(state_dim)  # Initialize predicted covariance as identity

        for t in 1:T
            # Observation model: obs_mean = C * x, obs_cov = C * P * C' + R
            obs_mean = lds.obs_model * pred_mean
            obs_cov = Symmetric(lds.obs_model * pred_cov * lds.obs_model' + I(lds.obs_dim))  # Assuming R as identity

            # Compute log-likelihood for the current observation
            log_likelihoods[k, t] = logpdf(MvNormal(obs_mean, obs_cov), observations[:, t])

            # Weighted Kalman update for next-state prediction
            pred_mean, pred_cov = weighted_kalman_update(
                observations[:, t], [pred_mean], [pred_cov], lds.obs_model, I(lds.obs_dim), gamma[:, t]
            )
        end
    end

    return log_likelihoods
end


In [None]:
function weighted_kalman_update(
    obs::Vector{Float64}, 
    pred_means::Vector{Vector{Float64}}, 
    pred_covs::Vector{Matrix{Float64}}, 
    C::Matrix{Float64}, 
    R::Matrix{Float64}, 
    gamma_t::Vector{Float64}
)
    # Initialize weighted mean and covariance
    updated_mean = zeros(size(pred_means[1]))
    updated_cov = zeros(size(pred_covs[1]))

    for k in 1:length(pred_means)
        # Kalman update for mode k
        obs_pred = C * pred_means[k]
        S = Symmetric(C * pred_covs[k] * C' + R)  # Observation covariance matrix

        # Ensure numerical stability with Cholesky decomposition
        L = cholesky(S).U
        K = pred_covs[k] * C' * inv(L)' * inv(L)  # Equivalent to inv(S) using Cholesky

        mean_k = pred_means[k] + K * (obs - obs_pred)
        cov_k = pred_covs[k] - K * C * pred_covs[k]

        # Weight by gamma_t[k]
        updated_mean += gamma_t[k] * mean_k
        updated_cov += gamma_t[k] * cov_k
    end

    return updated_mean, Symmetric(updated_cov)
end
