# Import Julia Packages for Benchmarking
---

In [11]:
using BenchmarkTools
using StateSpaceDynamics
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV
using Plots

# Import Python Packages for Benchmarking
---

In [2]:
# Import Python packages
const dynamax = pyimport("dynamax.linear_gaussian_ssm")
const jr = pyimport("jax.random")
const pykalman = pyimport("pykalman")
const np = pyimport("numpy")

Python: <module 'numpy' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\numpy\\__init__.py'>

# Utility Functions
---

In [3]:
function sample_rotation_matrix(n::Int)
    A = randn(n, n)
    Q, r = qr(A)
    return Matrix(Q)
end

sample_rotation_matrix (generic function with 1 method)

# Benchmarking Functions
---

In [4]:
struct BenchConfig
    latent_dims::Vector{Int}
    obs_dims::Vector{Int} 
    seq_lengths::Vector{Int}
    n_iters::Int
    n_repeats::Int
end

default_config = BenchConfig(
    [2, 4, 8],       # latent dimensions
    [2, 4, 8],      # observation dimensions
    [100, 500, 1000],    # sequence lengths
    100,                 # EM iterations
    5                    # benchmark repeats
)

# Helper function to generate random parameters
function generate_random_params(latent_dim::Int, obs_dim::Int)
    A = sample_rotation_matrix(latent_dim)
    C = randn(obs_dim, latent_dim)
    Q = Matrix(Diagonal(ones(latent_dim)))
    R = Matrix(Diagonal(ones(obs_dim)))
    x0 = zeros(latent_dim)
    P0 = Matrix(Diagonal(ones(latent_dim)))
    return A, C, Q, R, x0, P0
end

"""
Generate random parameters and data for benchmarking
"""
function generate_test_data(latent_dim::Int, obs_dim::Int, seq_len::Int)
    # Generate true parameters for data generation
    A_true, C_true, Q_true, R_true, x0_true, P0_true = generate_random_params(latent_dim, obs_dim)
   
    # Create model with true parameters
    lds_true = GaussianLDS(
        A=A_true, C=C_true, Q=Q_true, R=R_true, x0=x0_true, P0=P0_true,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    # Generate data using true parameters
    x, y = sample(lds_true, seq_len, 1)
    
    # Generate different initial parameters for fitting
    A_init, C_init, Q_init, R_init, x0_init, P0_init = generate_random_params(latent_dim, obs_dim)
    
    
    # Create model with initial parameters for fitting
    lds_fit = GaussianLDS(
        A=A_init, C=C_init, Q=Q_init, R=R_init, x0=x0_init, P0=P0_init,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    return lds_fit, x, y, (A_init, C_init, Q_init, R_init, x0_init, P0_init)
end


function run_single_benchmark(model_type::Symbol, lds_julia, y, y_np, params=nothing; config=default_config)
    try
        if model_type == :julia
            bench = @benchmark begin
                local lds_fresh = deepcopy($lds_julia)
                fit!(lds_fresh, $y, max_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)
            
        elseif model_type == :dynamax
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local lds_dynamax = dynamax.LinearGaussianSSM($latent_dim, $obs_dim)
                local key = jr.PRNGKey(0)
                local params, props = lds_dynamax.initialize(
                    key, dynamics_weights=$A_np, dynamics_covariance=$Q_np,
                    emission_weights=$C_np, emission_covariance=$R_np,
                    initial_mean=$x0_np
                )
                lds_dynamax.fit_em(params, props, $y_np, num_iters=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
            
        elseif model_type == :pykalman
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local kf = pykalman.KalmanFilter(
                    n_dim_state=$latent_dim, n_dim_obs=$obs_dim,
                    transition_matrices=$A_np, observation_matrices=$C_np,
                    transition_covariance=$Q_np, observation_covariance=$R_np,
                    initial_state_mean=$x0_np
                )
                kf.em($y_np, n_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
        end
    catch e
        @warn "Benchmark failed for $model_type" exception=e stacktrace=catch_backtrace()
        return (time=NaN, memory=0, allocs=0, success=false)
    end
end

"""
Benchmark model fitting with separate runs for each package
"""
function benchmark_fitting(config::BenchConfig=default_config)
    results = []
    
    for latent_dim in config.latent_dims
        for obs_dim in config.obs_dims
            obs_dim < latent_dim && continue
            
            for seq_len in config.seq_lengths
                println("\nTesting configuration: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")
                
                # Generate test data once
                lds_julia, x, y, params = generate_test_data(latent_dim, obs_dim, seq_len)
                A, C, Q, R, x0, _ = params

                # Convert all arrays to numpy
                A_np = np.array(A)
                C_np = np.array(C)
                Q_np = np.array(Q)
                R_np = np.array(R)
                x0_np = np.array(x0)
                y_np = np.array(permutedims(dropdims(y, dims=3), (2, 1)))
                
                numpy_params = (A_np, C_np, Q_np, R_np, x0_np)

                # Run benchmarks separately
                julia_result = run_single_benchmark(:julia, lds_julia, y, y_np; config=config)
                dynamax_result = run_single_benchmark(:dynamax, lds_julia, y, y_np, numpy_params; config=config)
                pykalman_result = run_single_benchmark(:pykalman, lds_julia, y, y_np, numpy_params; config=config)
                
                push!(results, Dict(
                    "config" => (latent_dim=latent_dim, obs_dim=obs_dim, seq_len=seq_len),
                    "julia" => julia_result,
                    "dynamax" => dynamax_result,
                    "pykalman" => pykalman_result
                ))
                
                # Print immediate results for this configuration
                print_single_result(results[end])
            end
        end
    end
    
    return results
end

"""
Print results for a single configuration
"""
function print_single_result(result)
    config = result["config"]
    println("\nConfiguration:")
    println("  Latent dim: $(config.latent_dim)")
    println("  Obs dim: $(config.obs_dim)")
    println("  Seq length: $(config.seq_len)")
    
    for pkg in [:julia, :dynamax, :pykalman]
        r = result[string(pkg)]
        if r.success
            @printf("\n%s timing: %.3f seconds", uppercase(string(pkg)), r.time/1e9)
            if pkg == :julia
                println("\n  Memory: $(r.memory) bytes")
                println("  Allocations: $(r.allocs)")
            end
        else
            println("\n$(uppercase(string(pkg))) benchmark failed")
        end
    end
    println("\n" * "-"^50)
end

print_single_result

In [5]:
results = benchmark_fitting()

[32mFitting LDS via EM...   2%|██                                                |  ETA: 0:03:58 ( 2.43  s/it)[39m[K


Testing configuration: latent_dim=2, obs_dim=2, seq_len=100


[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:05 (52.43 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:38] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----------| 73.00% [73/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.07 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:47] |----------------------------------------| 2.00% [2/100 00:00<00:23] |X---------------------------------------| 3.00% [3/100 00:00<00:15] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXX-------------------------------------| 9.00% [9/100 00:00<00:04] |XXXX------------------------------------| 12.00% [12/100 00:00<00:03] |XXXXXX----------------------------------| 16.00% [16/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 31.00% [31/100 00:00<00:01] |XXXXXXXXXXXXXXXXX-----------------------| 43.00% [43/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXX-----------------| 58.00% [58/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.00 ms/it)[39m[K



 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:56] |----------------------------------------| 2.00% [2/100 00:00<00:28] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 39.00% [39/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX---

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.60 ms/it)[39m[K



 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:43] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX-------------------| 54.00% [54/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 67.00% [67/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.75 ms/it)[39m[K



 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:45] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX-------------------| 53.00% [53/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 66.00% [66/100 00:01<00:00] |XXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.63 ms/it)[39m[K



 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:31] |----------------------------------------| 2.00% [2/100 00:00<00:15] |X---------------------------------------| 3.00% [3/100 00:00<00:10] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:06] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:02] |XXXXXXXX--------------------------------| 21.00% [21/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 34.00% [34/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 55.00% [55/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 85.00% [85/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |----------------------------------------| 0.00% [0/100 00:00<?] |-------------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.83 ms/it)[39m[K





[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.51 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:43] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 57.00% [57/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----------| 74.00% [74/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.67 ms/it)[39m[K





[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.48 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.45 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.57 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.42 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.42 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:43] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 55.00% [55/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXX-------------| 68.00% [68/100 00:01<00:00] |XXXXXXXXXXXXXXXXXXXXXX

18-element Vector{Any}:
 Dict{String, NamedTuple}("julia" => (time = 4.1714e6, memory = 5416944, allocs = 53917, success = true), "dynamax" => (time = 4.065633e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 100), "pykalman" => (time = 1.8278558e9, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 1.98869e7, memory = 26279216, allocs = 264410, success = true), "dynamax" => (time = 6.69069e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 500), "pykalman" => (time = 8.8343693e9, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 3.89495e7, memory = 52452560, allocs = 545062, success = true), "dynamax" => (time = 1.056677e9, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 1000), "pykalman" => (time = 1.86363996e10, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}(

# Post-Processing Functions
---

In [6]:
function transform_to_df(data_vector::Vector)
    # Initialize vectors for all our columns
    packages = String[]
    times = Float64[]
    memories = Int[]
    allocs = Int[]
    successes = Bool[]
    latent_dims = Int[]
    obs_dims = Int[]
    seq_lens = Int[]
    
    # Process each dictionary in the vector
    for dict in data_vector
        # Get configuration values for this batch
        config = dict["config"]
        latent_dim = config.latent_dim
        obs_dim = config.obs_dim
        seq_len = config.seq_len
        
        # Process each package's results
        for (pkg_name, results) in dict
            if pkg_name != "config"
                push!(packages, pkg_name)
                push!(times, results.time)
                push!(memories, results.memory)
                push!(allocs, results.allocs)
                push!(successes, results.success)
                push!(latent_dims, latent_dim)
                push!(obs_dims, obs_dim)
                push!(seq_lens, seq_len)
            end
        end
    end
    
    # Create the DataFrame
    DataFrame(
        package = packages,
        time = times,
        memory = memories,
        allocs = allocs,
        success = successes,
        latent_dim = latent_dims,
        obs_dim = obs_dims,
        seq_length = seq_lens
    )
end

function plot_benchmarks(df::DataFrame)
    # Create a unique identifier for each obs_dim/latent_dim combination
    df.dim_combo = string.(df.obs_dim, "x", df.latent_dim)
    
    # Define line styles that will cycle if we have more combinations than styles
    base_styles = [:solid, :dash, :dot, :dashdot, :dashdotdot]
    dim_combos = unique(df.dim_combo)
    
    # Create style dictionary by cycling through available styles
    style_dict = Dict(
        combo => base_styles[mod1(i, length(base_styles))] 
        for (i, combo) in enumerate(dim_combos)
    )
    
    # Create the plot
    p = plot(
        xlabel="log(Sequence Length)",
        ylabel="log(Runtime (s))",
        title="Expectation-Maximization Benchmark",
        legend=:outertopright,
        xscale=:log10,
        yscale=:log10
    )
    
    # Plot each package with a different color
    packages = unique(df.package)
    for (i, pkg) in enumerate(packages)
        pkg_data = df[df.package .== pkg, :]
        
        # Plot each dimension combination for this package
        for dim_combo in dim_combos
            combo_data = pkg_data[pkg_data.dim_combo .== dim_combo, :]
            if !isempty(combo_data)
                plot!(
                    p,
                    combo_data.seq_length,
                    combo_data.time,  # Convert to seconds
                    label="$(pkg) ($(dim_combo))",
                    color=i,
                    linestyle=style_dict[dim_combo],
                    marker=:circle,
                    markersize=4
                )
            end
        end
    end
    
    # Add gridlines and adjust layout
    plot!(
        p,
        grid=true,
        minorgrid=true,
        size=(900, 600),
        margin=10Plots.mm
    )
    
    return p
end

plot_benchmarks (generic function with 1 method)

# Process Results
---

In [7]:
df = transform_to_df(results)
df.time = df.time / 1e9;

CSV.write("lds_benchmark_results_no_elbo.csv", df)

"lds_benchmark_results_no_elbo.csv"

In [8]:
df

Row,package,time,memory,allocs,success,latent_dim,obs_dim,seq_length
Unnamed: 0_level_1,String,Float64,Int64,Int64,Bool,Int64,Int64,Int64
1,julia,0.0041714,5416944,53917,true,2,2,100
2,dynamax,0.406563,0,0,true,2,2,100
3,pykalman,1.82786,0,0,true,2,2,100
4,julia,0.0198869,26279216,264410,true,2,2,500
5,dynamax,0.669069,0,0,true,2,2,500
6,pykalman,8.83437,0,0,true,2,2,500
7,julia,0.0389495,52452560,545062,true,2,2,1000
8,dynamax,1.05668,0,0,true,2,2,1000
9,pykalman,18.6364,0,0,true,2,2,1000
10,julia,0.0048905,5712192,53940,true,2,4,100


In [9]:
benchmark_plot = plot_benchmarks(df)
# save plot_benchmarks, set dpi to 300
savefig(benchmark_plot, "benchmark_plot.pdf")

"c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\benchmark_plot.pdf"

In [12]:
# create a state-space model for the tutorial
obs_dim = 10
latent_dim = 2

# set up the state parameters
A = 0.95 * [cos(0.25) -sin(0.25); sin(0.25) cos(0.25)] 
Q = Matrix(0.1 * I(2))

x0 = [0.0; 0.0]
P0 = Matrix(0.1 * I(2))

# set up the observation parameters
C = randn(obs_dim, latent_dim)
R = Matrix(0.5 * I(10))


# create the state-space model
true_ssm = GaussianLDS(;A=A, 
                        Q=Q, 
                        C=C, 
                        R=R, 
                        x0=x0, 
                        P0=P0, 
                        obs_dim=obs_dim, 
                        latent_dim=latent_dim, 
                        fit_bool=fill(true, 6))

# simulate data from the model
tSteps = 500
latents, observations = StateSpaceDynamics.sample(true_ssm, tSteps, 1) # one trial for tutorial purposes


# smooth data first
E_z, E_zz, E_zz_prev, x_smooth, p_smooth, ml_total = estep(true_ssm, observations)

# smooth data first
function Q_state_new(
    A::Matrix{<:Real},
    Q::AbstractMatrix{<:Real},
    P0::AbstractMatrix{<:Real},
    x0::Vector{<:Real},
    E_z::Matrix{<:Real},
    E_zz::Array{<:Real,3},
    E_zz_prev::Array{<:Real,3},
)
    T_step = size(E_z, 2)
    state_dim = size(A, 1)
    
    # Pre-compute constants and decompositions once
    Q_chol = cholesky(Symmetric(Q))
    P0_chol = cholesky(Symmetric(P0))
    log_det_Q = logdet(Q_chol)
    log_det_P0 = logdet(P0_chol)
    
    # Pre-allocate temp matrix
    temp = zeros(state_dim, state_dim)
    
    # First time step (handled separately)
    mul!(temp, E_z[:, 1], x0', -1.0, 0.0)  # -E_z[:,1] * x0'
    temp .+= view(E_zz, :, :, 1)           # Add E_zz[:,:,1]
    temp .-= x0 * E_z[:, 1]'               # Subtract x0 * E_z[:,1]'
    temp .+= x0 * x0'                      # Add x0 * x0'
    Q_val = -0.5 * (log_det_P0 + tr(P0_chol \ temp))
    
    # Compute correct sums for t ≥ 2
    sum_E_zz_current = zeros(state_dim, state_dim)     # Sum of E_zz[:,:,t]
    sum_E_zz_prev_cross = zeros(state_dim, state_dim)  # Sum of E_zz_prev[:,:,t]
    sum_E_zz_prev_time = zeros(state_dim, state_dim)   # Sum of E_zz[:,:,t-1]
    
    # Compute the sums with proper temporal alignment
    @inbounds for t in 2:T_step
        sum_E_zz_current .+= view(E_zz, :, :, t)        # t
        sum_E_zz_prev_cross .+= view(E_zz_prev, :, :, t) # t with t-1
        sum_E_zz_prev_time .+= view(E_zz, :, :, t-1)     # t-1
    end
    
    # Compute the summed transition term
    copyto!(temp, sum_E_zz_current)
    mul!(temp, A, sum_E_zz_prev_cross', -1.0, 1.0)  # Subtract A * sum_E_zz_prev_cross'
    temp .-= sum_E_zz_prev_cross * A'               # Subtract sum_E_zz_prev_cross * A'
    temp .+= A * sum_E_zz_prev_time * A'           # Add A * sum_E_zz_prev_time * A'
    
    # Add contribution from all other time steps
    Q_val += -0.5 * ((T_step - 1) * log_det_Q + tr(Q_chol \ temp))
    
    return Q_val
end


Q_state_new (generic function with 1 method)

In [13]:
@benchmark StateSpaceDynamics.Q_state(A, Q, P0, x0, E_z, E_zz, E_zz_prev)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m179.600 μs[22m[39m … [35m78.578 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 98.78%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m370.500 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m421.066 μs[22m[39m ± [32m 1.169 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m4.55% ±  1.71%

  [39m [39m [39m [39m [39m [39m [39m [39m▂[39m [39m▁[39m█[39m▇[39m▅[39m▆[39m▅[34m▅[39m[39m▅[39m▆[39m▆[32m▄[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▂[39m▂[39m▂[39

In [15]:
@benchmark Q_state_new(A, Q, P0, x0, E_z[:, :, 1], E_zz[:, :, :, 1], E_zz_prev[:, :, :, 1])

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m25.100 μs[22m[39m … [35m 2.119 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 95.77%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m27.200 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m30.954 μs[22m[39m ± [32m57.007 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.28% ±  2.87%

  [39m▅[39m█[39m█[34m▇[39m[39m▆[39m▅[39m▄[39m▄[39m▄[32m▄[39m[39m▄[39m▃[39m▃[39m▂[39m▂[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[34m█[39m[39m█[