# Import Julia Packages for Benchmarking
---

In [50]:
using BenchmarkTools
using StateSpaceDynamics
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV
using Plots

# Import Python Packages for Benchmarking
---

In [2]:
# Import Python packages
const dynamax = pyimport("dynamax.linear_gaussian_ssm")
const jr = pyimport("jax.random")
const pykalman = pyimport("pykalman")
const np = pyimport("numpy")

Python: <module 'numpy' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\numpy\\__init__.py'>

# Utility Functions
---

In [3]:
function sample_rotation_matrix(n::Int)
    A = randn(n, n)
    Q, r = qr(A)
    return Matrix(Q)
end

sample_rotation_matrix (generic function with 1 method)

# Benchmarking Functions
---

In [None]:

struct BenchConfig
    latent_dims::Vector{Int}
    obs_dims::Vector{Int} 
    seq_lengths::Vector{Int}
    n_iters::Int
    n_repeats::Int
end

default_config = BenchConfig(
    [2, 4, 8],       # latent dimensions
    [2, 4, 8],      # observation dimensions
    [100, 500, 1000],    # sequence lengths
    100,                 # EM iterations
    5                    # benchmark repeats
)

# Helper function to generate random parameters
function generate_random_params(latent_dim::Int, obs_dim::Int)
    A = sample_rotation_matrix(latent_dim)
    C = randn(obs_dim, latent_dim)
    Q = Matrix(Diagonal(ones(latent_dim)))
    R = Matrix(Diagonal(ones(obs_dim)))
    x0 = zeros(latent_dim)
    P0 = Matrix(Diagonal(ones(latent_dim)))
    return A, C, Q, R, x0, P0
end

"""
Generate random parameters and data for benchmarking
"""
function generate_test_data(latent_dim::Int, obs_dim::Int, seq_len::Int)
    # Generate true parameters for data generation
    A_true, C_true, Q_true, R_true, x0_true, P0_true = generate_random_params(latent_dim, obs_dim)
   
    # Create model with true parameters
    lds_true = GaussianLDS(
        A=A_true, C=C_true, Q=Q_true, R=R_true, x0=x0_true, P0=P0_true,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    # Generate data using true parameters
    x, y = sample(lds_true, seq_len, 1)
    
    # Generate different initial parameters for fitting
    A_init, C_init, Q_init, R_init, x0_init, P0_init = generate_random_params(latent_dim, obs_dim)
    
    
    # Create model with initial parameters for fitting
    lds_fit = GaussianLDS(
        A=A_init, C=C_init, Q=Q_init, R=R_init, x0=x0_init, P0=P0_init,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    return lds_fit, x, y, (A_init, C_init, Q_init, R_init, x0_init, P0_init)
end

"""
Run a single benchmark trial with error handling
"""
function run_single_benchmark(model_type::Symbol, lds_julia, y, y_np, params=nothing; config=default_config)
    try
        if model_type == :julia
            bench = @benchmark fit!($lds_julia, $y, max_iter=$config.n_iters) samples=config.n_repeats
            return (time=minimum(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)
            
        elseif model_type == :dynamax
            A_np, C_np, Q_np, R_np, x0_np = params
            
            # Get dimensions directly from numpy arrays
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            lds_dynamax = dynamax.LinearGaussianSSM(latent_dim, obs_dim)
            key = jr.PRNGKey(0)
            test_params, param_props = lds_dynamax.initialize(
                key, dynamics_weights=A_np, dynamics_covariance=Q_np,
                emission_weights=C_np, emission_covariance=R_np,
                initial_mean=x0_np
            )
            bench = @benchmark $lds_dynamax.fit_em(
                $test_params, $param_props, $y_np, 
                num_iters=$config.n_iters
            ) samples=config.n_repeats
            return (time=minimum(bench).time, memory=0, allocs=0, success=true)
            
        elseif model_type == :pykalman
            A_np, C_np, Q_np, R_np, x0_np = params
            
            # Get dimensions directly from numpy arrays
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            kf = pykalman.KalmanFilter(
                n_dim_state=latent_dim, n_dim_obs=obs_dim,
                transition_matrices=A_np, observation_matrices=C_np,
                transition_covariance=Q_np, observation_covariance=R_np,
                initial_state_mean=x0_np
            )
            bench = @benchmark $kf.em($y_np, n_iter=$config.n_iters) samples=config.n_repeats
            return (time=minimum(bench).time, memory=0, allocs=0, success=true)
        end
    catch e
        @warn "Benchmark failed for $model_type" exception=e stacktrace=catch_backtrace()
        return (time=NaN, memory=0, allocs=0, success=false)
    end
end

"""
Benchmark model fitting with separate runs for each package
"""
function benchmark_fitting(config::BenchConfig=default_config)
    results = []
    
    for latent_dim in config.latent_dims
        for obs_dim in config.obs_dims
            obs_dim < latent_dim && continue
            
            for seq_len in config.seq_lengths
                println("\nTesting configuration: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")
                
                # Generate test data once
                lds_julia, x, y, params = generate_test_data(latent_dim, obs_dim, seq_len)
                A, C, Q, R, x0, _ = params

                # Convert all arrays to numpy
                A_np = np.array(A)
                C_np = np.array(C)
                Q_np = np.array(Q)
                R_np = np.array(R)
                x0_np = np.array(x0)
                y_np = np.array(permutedims(dropdims(y, dims=3), (2, 1)))
                
                numpy_params = (A_np, C_np, Q_np, R_np, x0_np)

                # Run benchmarks separately
                julia_result = run_single_benchmark(:julia, lds_julia, y, y_np; config=config)
                dynamax_result = run_single_benchmark(:dynamax, lds_julia, y, y_np, numpy_params; config=config)
                pykalman_result = run_single_benchmark(:pykalman, lds_julia, y, y_np, numpy_params; config=config)
                
                push!(results, Dict(
                    "config" => (latent_dim=latent_dim, obs_dim=obs_dim, seq_len=seq_len),
                    "julia" => julia_result,
                    "dynamax" => dynamax_result,
                    "pykalman" => pykalman_result
                ))
                
                # Print immediate results for this configuration
                print_single_result(results[end])
            end
        end
    end
    
    return results
end

"""
Print results for a single configuration
"""
function print_single_result(result)
    config = result["config"]
    println("\nConfiguration:")
    println("  Latent dim: $(config.latent_dim)")
    println("  Obs dim: $(config.obs_dim)")
    println("  Seq length: $(config.seq_len)")
    
    for pkg in [:julia, :dynamax, :pykalman]
        r = result[string(pkg)]
        if r.success
            @printf("\n%s timing: %.3f seconds", uppercase(string(pkg)), r.time/1e9)
            if pkg == :julia
                println("\n  Memory: $(r.memory) bytes")
                println("  Allocations: $(r.allocs)")
            end
        else
            println("\n$(uppercase(string(pkg))) benchmark failed")
        end
    end
    println("\n" * "-"^50)
end

print_single_result

In [5]:
results = benchmark_fitting()


Testing configuration: latent_dim=2, obs_dim=2, seq_len=100


[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:10 ( 0.10  s/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.27 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.32 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.07 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.14 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.58 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.58 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.78 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:35] |----------------------------------------| 2.00% [2/100 00:00<00:17] |X---------------------------------------| 3.00% [3/100 00:00<00:11] |X---------------------------------------| 4.00% [4/100 00:00<00:08] |XX--------------------------------------| 5.00% [5/100 00:00<00:06] |XX--------------------------------------| 7.00% [7/100 00:00<00:04] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 35.00% [35/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX-------------------| 54.00% [54/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--------| 82.00% [82/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.15 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.80 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (30.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.89 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.23 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.25 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.68 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.16 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:36] |----------------------------------------| 2.00% [2/100 00:00<00:18] |X---------------------------------------| 3.00% [3/100 00:00<00:11] |X---------------------------------------| 4.00% [4/100 00:00<00:08] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:04] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 35.00% [35/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX-------------------| 53.00% [53/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX----------| 76.00% [76/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (31.28 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.36 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.38 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (31.72 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.73 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (31.65 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:32] |----------------------------------------| 2.00% [2/100 00:00<00:16] |X---------------------------------------| 3.00% [3/100 00:00<00:10] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:06] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXX------------------------------------| 12.00% [12/100 00:00<00:02] |XXXXXXX---------------------------------| 19.00% [19/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 30.00% [30/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 48.00% [48/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXX------------| 70.00% [70/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---| 94.00% [94/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.31 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.84 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.80 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.28 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.30 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:39] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXX------------| 72.00% [72/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.46 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.23 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (33.39 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.26 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (18.56 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (18.72 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:28] |----------------------------------------| 2.00% [2/100 00:00<00:14] |X---------------------------------------| 3.00% [3/100 00:00<00:09] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:05] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:01] |XXXXXXXX--------------------------------| 21.00% [21/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 35.00% [35/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 57.00% [57/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-------| 83.00% [83/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |----------------------------------------| 0.00% [0/100 00:00<?] |--------------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (47.29 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.25 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (34.91 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (32.72 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (33.61 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (35.01 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:50] |----------------------------------------| 2.00% [2/100 00:00<00:24] |X---------------------------------------| 3.00% [3/100 00:00<00:16] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 6.00% [6/100 00:00<00:07] |XXX-------------------------------------| 8.00% [8/100 00:00<00:05] |XXXX------------------------------------| 11.00% [11/100 00:00<00:04] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:02] |XXXXXXXXXX------------------------------| 27.00% [27/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 37.00% [37/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX--------------------| 50.00% [50/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.59 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.72 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.95 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.84 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.89 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.21 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.56 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.08 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:38] |----------------------------------------| 2.00% [2/100 00:00<00:18] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX--------------------| 50.00% [50/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----------| 74.00% [74/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.80 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.94 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.07 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.02 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.79 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.45 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.03 ms/it)[39m[K[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:49] |----------------------------------------| 2.00% [2/100 00:00<00:24] |X---------------------------------------| 3.00% [3/100 00:00<00:16] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 6.00% [6/100 00:00<00:07] |XXX-------------------------------------| 8.00% [8/100 00:00<00:05] |XXXX------------------------------------| 11.00% [11/100 00:00<00:04] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:02] |XXXXXXXXXX------------------------------| 27.00% [27/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 37.00% [37/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX--------------------| 51.00% [51/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (32.85 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (33.69 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (32.47 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (36.39 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (35.95 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (36.17 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:52] |----------------------------------------| 2.00% [2/100 00:00<00:26] |X---------------------------------------| 3.00% [3/100 00:00<00:17] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 31.00% [31/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX--

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 9.80 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.24 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.84 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.20 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.70 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.63 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.79 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:31] |----------------------------------------| 2.00% [2/100 00:00<00:15] |X---------------------------------------| 3.00% [3/100 00:00<00:10] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:06] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:02] |XXXXXXXX--------------------------------| 21.00% [21/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 34.00% [34/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 55.00% [55/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 85.00% [85/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |----------------------------------------| 0.00% [0/100 00:00<?] |--------------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.72 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.35 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.42 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.43 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.30 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:30] |----------------------------------------| 2.00% [2/100 00:00<00:15] |X---------------------------------------| 3.00% [3/100 00:00<00:09] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:05] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:02] |XXXXXXXX--------------------------------| 21.00% [21/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 34.00% [34/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX-------------------| 54.00% [54/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXX------------| 72.00% [72/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX----| 90.00% [90/100 00:01<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:01<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:05 (50.66 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:05 (51.42 ms/it)[39m[K
[32mFitting LDS via EM...  98%|██████████████████████████████████████████████████|  ETA: 0:00:00 (53.98 ms/it)[39m[K

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:51] |----------------------------------------| 2.00% [2/100 00:00<00:25] |X---------------------------------------| 3.00% [3/100 00:00<00:16] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:05] |XXXX------------------------------------| 11.00% [11/100 00:00<00:04] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:02] |XXXXXXXXXX------------------------------| 27.00% [27/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 37.00% [37/100 00:00<00:01] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:05 (54.04 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.12 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.56 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.53 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.04 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.35 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:29] |----------------------------------------| 2.00% [2/100 00:00<00:14] |X---------------------------------------| 3.00% [3/100 00:00<00:09] |X---------------------------------------| 4.00% [4/100 00:00<00:07] |XX--------------------------------------| 5.00% [5/100 00:00<00:05] |XXX-------------------------------------| 8.00% [8/100 00:00<00:03] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:02] |XXXXXXXX--------------------------------| 21.00% [21/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 34.00% [34/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 56.00% [56/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 87.00% [87/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |----------------------------------------| 0.00% [0/100 00:00<?] |--------------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.32 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.58 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (30.36 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (30.16 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.34 ms/it)[39m[K
[32mFitting LDS via EM...  48%|█████████████████████████                         |  ETA: 0:00:01 (28.83 ms/it)[39m[KExcessive output truncated after 524387 bytes.

18-element Vector{Any}:
 Dict{String, NamedTuple}("julia" => (time = 4.080895e8, memory = 444355456, allocs = 3951662, success = true), "dynamax" => (time = 3.852162e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 100), "pykalman" => (time = 1.7279784e9, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 1.324036e9, memory = 1568853520, allocs = 15164015, success = true), "dynamax" => (time = 6.296255e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 500), "pykalman" => (time = 8.7068049e9, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 2.9737644e9, memory = 2929387064, allocs = 29501462, success = true), "dynamax" => (time = 9.6624e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 1000), "pykalman" => (time = 1.77055419e10, memory = 0, allocs = 0, success = true))
 Dict{Stri

# Post-Processing Functions
---

In [53]:
function transform_to_df(data_vector::Vector)
    # Initialize vectors for all our columns
    packages = String[]
    times = Float64[]
    memories = Int[]
    allocs = Int[]
    successes = Bool[]
    latent_dims = Int[]
    obs_dims = Int[]
    seq_lens = Int[]
    
    # Process each dictionary in the vector
    for dict in data_vector
        # Get configuration values for this batch
        config = dict["config"]
        latent_dim = config.latent_dim
        obs_dim = config.obs_dim
        seq_len = config.seq_len
        
        # Process each package's results
        for (pkg_name, results) in dict
            if pkg_name != "config"
                push!(packages, pkg_name)
                push!(times, results.time)
                push!(memories, results.memory)
                push!(allocs, results.allocs)
                push!(successes, results.success)
                push!(latent_dims, latent_dim)
                push!(obs_dims, obs_dim)
                push!(seq_lens, seq_len)
            end
        end
    end
    
    # Create the DataFrame
    DataFrame(
        package = packages,
        time = times,
        memory = memories,
        allocs = allocs,
        success = successes,
        latent_dim = latent_dims,
        obs_dim = obs_dims,
        seq_length = seq_lens
    )
end

function plot_benchmarks(df::DataFrame)
    # Create a unique identifier for each obs_dim/latent_dim combination
    df.dim_combo = string.(df.obs_dim, "x", df.latent_dim)
    
    # Define line styles that will cycle if we have more combinations than styles
    base_styles = [:solid, :dash, :dot, :dashdot, :dashdotdot]
    dim_combos = unique(df.dim_combo)
    
    # Create style dictionary by cycling through available styles
    style_dict = Dict(
        combo => base_styles[mod1(i, length(base_styles))] 
        for (i, combo) in enumerate(dim_combos)
    )
    
    # Create the plot
    p = plot(
        xlabel="Sequence Length",
        ylabel="Time (seconds)",
        title="Package Performance Across Sequence Lengths",
        legend=:outertopright,
        xscale=:log10,
        yscale=:log10
    )
    
    # Plot each package with a different color
    packages = unique(df.package)
    for (i, pkg) in enumerate(packages)
        pkg_data = df[df.package .== pkg, :]
        
        # Plot each dimension combination for this package
        for dim_combo in dim_combos
            combo_data = pkg_data[pkg_data.dim_combo .== dim_combo, :]
            if !isempty(combo_data)
                plot!(
                    p,
                    combo_data.seq_length,
                    combo_data.time ./ 1e9,  # Convert to seconds
                    label="$(pkg) ($(dim_combo))",
                    color=i,
                    linestyle=style_dict[dim_combo],
                    marker=:circle,
                    markersize=4
                )
            end
        end
    end
    
    # Add gridlines and adjust layout
    plot!(
        p,
        grid=true,
        minorgrid=true,
        size=(900, 600),
        margin=10Plots.mm
    )
    
    return p
end

plot_benchmarks (generic function with 1 method)

# Process Results
---

In [None]:
df = transform_to_df(results)
df.time = df.time / 1e9;

#CSV.write("lds_benchmark_results.csv", df)

In [52]:
df

Row,package,time,memory,allocs,success,latent_dim,obs_dim,seq_length,dim_combo
Unnamed: 0_level_1,String,Float64,Int64,Int64,Bool,Int64,Int64,Int64,String
1,julia,0.408089,444355456,3951662,true,2,2,100,2x2
2,dynamax,0.385216,0,0,true,2,2,100,2x2
3,pykalman,1.72798,0,0,true,2,2,100,2x2
4,julia,1.32404,1568853520,15164015,true,2,2,500,2x2
5,dynamax,0.629625,0,0,true,2,2,500,2x2
6,pykalman,8.7068,0,0,true,2,2,500,2x2
7,julia,2.97376,2929387064,29501462,true,2,2,1000,2x2
8,dynamax,0.96624,0,0,true,2,2,1000,2x2
9,pykalman,17.7055,0,0,true,2,2,1000,2x2
10,julia,0.499268,492852832,4126033,true,2,4,100,4x2


In [59]:
benchmark_plot = plot_benchmarks(df)
# save plot_benchmarks, set dpi to 300
savefig(benchmark_plot, "benchmark_plot.pdf")

"c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\benchmark_plot.pdf"