In [1]:
using BenchmarkTools
using StateSpaceDynamics
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV
using Plots

[32m[1m    CondaPkg [22m[39m[0mFound dependencies: c:\Users\ryansenne\Documents\GitHub\ssm_julia\benchmarking\CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mFound dependencies: C:\Users\ryansenne\.julia\packages\PythonCall\Nr75f\CondaPkg.toml
[32m[1m    CondaPkg [22m[39m[0mDependencies already up to date


# Import Python Packages for Benchmarking
---

In [2]:
# Import Python packages
const dynamax = pyimport("dynamax.linear_gaussian_ssm")
const jr = pyimport("jax.random")
const pykalman = pyimport("pykalman")
const np = pyimport("numpy")

Python: <module 'numpy' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\numpy\\__init__.py'>

In [4]:
const dyn = pyimport("dynamax")

Python: <module 'dynamax' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\dynamax\\__init__.py'>

In [5]:
dyn.__version__

Python: '0.1.5'

# Utility Functions
---

In [3]:
function sample_rotation_matrix(n::Int)
    A = randn(n, n)
    Q, r = qr(A)
    return Matrix(Q)
end

sample_rotation_matrix (generic function with 1 method)

# Benchmarking Functions
---

In [4]:
struct BenchConfig
    latent_dims::Vector{Int}
    obs_dims::Vector{Int} 
    seq_lengths::Vector{Int}
    n_iters::Int
    n_repeats::Int
end

default_config = BenchConfig(
    [2, 4, 8],       # latent dimensions
    [2, 4, 8],      # observation dimensions
    [100, 500, 1000],    # sequence lengths
    100,                 # EM iterations
    5                    # benchmark repeats
)

# Helper function to generate random parameters
function generate_random_params(latent_dim::Int, obs_dim::Int)
    A = sample_rotation_matrix(latent_dim)
    C = randn(obs_dim, latent_dim)
    Q = Matrix(Diagonal(ones(latent_dim)))
    R = Matrix(Diagonal(ones(obs_dim)))
    x0 = zeros(latent_dim)
    P0 = Matrix(Diagonal(ones(latent_dim)))
    return A, C, Q, R, x0, P0
end

"""
Generate random parameters and data for benchmarking
"""
function generate_test_data(latent_dim::Int, obs_dim::Int, seq_len::Int)
    # Generate true parameters for data generation
    A_true, C_true, Q_true, R_true, x0_true, P0_true = generate_random_params(latent_dim, obs_dim)
   
    # Create model with true parameters
    lds_true = GaussianLDS(
        A=A_true, C=C_true, Q=Q_true, R=R_true, x0=x0_true, P0=P0_true,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    # Generate data using true parameters
    x, y = sample(lds_true, seq_len, 1)
    
    # Generate different initial parameters for fitting
    A_init, C_init, Q_init, R_init, x0_init, P0_init = generate_random_params(latent_dim, obs_dim)
    
    
    # Create model with initial parameters for fitting
    lds_fit = GaussianLDS(
        A=A_init, C=C_init, Q=Q_init, R=R_init, x0=x0_init, P0=P0_init,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    return lds_fit, x, y, (A_init, C_init, Q_init, R_init, x0_init, P0_init)
end


function run_single_benchmark(model_type::Symbol, lds_julia, y, y_np, params=nothing; config=default_config)
    try
        if model_type == :julia
            bench = @benchmark begin
                local lds_fresh = deepcopy($lds_julia)
                fit!(lds_fresh, $y, max_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)
            
        elseif model_type == :dynamax
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local lds_dynamax = dynamax.LinearGaussianSSM($latent_dim, $obs_dim)
                local key = jr.PRNGKey(0)
                local params, props = lds_dynamax.initialize(
                    key, dynamics_weights=$A_np, dynamics_covariance=$Q_np,
                    emission_weights=$C_np, emission_covariance=$R_np,
                    initial_mean=$x0_np
                )
                lds_dynamax.fit_em(params, props, $y_np, num_iters=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
            
        elseif model_type == :pykalman
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local kf = pykalman.KalmanFilter(
                    n_dim_state=$latent_dim, n_dim_obs=$obs_dim,
                    transition_matrices=$A_np, observation_matrices=$C_np,
                    transition_covariance=$Q_np, observation_covariance=$R_np,
                    initial_state_mean=$x0_np
                )
                kf.em($y_np, n_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
        end
    catch e
        @warn "Benchmark failed for $model_type" exception=e stacktrace=catch_backtrace()
        return (time=NaN, memory=0, allocs=0, success=false)
    end
end

"""
Benchmark model fitting with separate runs for each package
"""
function benchmark_fitting(config::BenchConfig=default_config)
    results = []
    
    for latent_dim in config.latent_dims
        for obs_dim in config.obs_dims
            obs_dim < latent_dim && continue
            
            for seq_len in config.seq_lengths
                println("\nTesting configuration: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")
                
                # Generate test data once
                lds_julia, x, y, params = generate_test_data(latent_dim, obs_dim, seq_len)
                A, C, Q, R, x0, _ = params

                # Convert all arrays to numpy
                A_np = np.array(A)
                C_np = np.array(C)
                Q_np = np.array(Q)
                R_np = np.array(R)
                x0_np = np.array(x0)
                y_np = np.array(permutedims(dropdims(y, dims=3), (2, 1)))
                
                numpy_params = (A_np, C_np, Q_np, R_np, x0_np)

                # Run benchmarks separately
                julia_result = run_single_benchmark(:julia, lds_julia, y, y_np; config=config)
                dynamax_result = run_single_benchmark(:dynamax, lds_julia, y, y_np, numpy_params; config=config)
                pykalman_result = run_single_benchmark(:pykalman, lds_julia, y, y_np, numpy_params; config=config)
                
                push!(results, Dict(
                    "config" => (latent_dim=latent_dim, obs_dim=obs_dim, seq_len=seq_len),
                    "julia" => julia_result,
                    "dynamax" => dynamax_result,
                    "pykalman" => pykalman_result
                ))
                
                # Print immediate results for this configuration
                print_single_result(results[end])
            end
        end
    end
    
    return results
end

"""
Print results for a single configuration
"""
function print_single_result(result)
    config = result["config"]
    println("\nConfiguration:")
    println("  Latent dim: $(config.latent_dim)")
    println("  Obs dim: $(config.obs_dim)")
    println("  Seq length: $(config.seq_len)")
    
    for pkg in [:julia, :dynamax, :pykalman]
        r = result[string(pkg)]
        if r.success
            @printf("\n%s timing: %.3f seconds", uppercase(string(pkg)), r.time/1e9)
            if pkg == :julia
                println("\n  Memory: $(r.memory) bytes")
                println("  Allocations: $(r.allocs)")
            end
        else
            println("\n$(uppercase(string(pkg))) benchmark failed")
        end
    end
    println("\n" * "-"^50)
end

print_single_result

In [5]:
results = benchmark_fitting()


Testing configuration: latent_dim=2, obs_dim=2, seq_len=100


[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:09 (96.78 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.96 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.19 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.66 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.68 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.65 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:42] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 60.00% [60/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 85.00% [85/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX





JULIA timing: 0.265 seconds
  Memory: 235310552 bytes
  Allocations: 2386036

DYNAMAX timing: 0.493 seconds
PYKALMAN timing: 2.111 seconds
--------------------------------------------------

Testing configuration: latent_dim=2, obs_dim=2, seq_len=500


[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.50 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (11.81 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (11.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.53 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.82 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.16 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.47 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:39] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:00] |XXXXXXXXXXXXXXXXX-----------------------| 43.00% [43/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 62.00% [62/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 85.00% [85/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (25.61 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (23.92 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.65 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (23.56 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (23.24 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (25.57 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.06 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:37] |----------------------------------------| 2.00% [2/100 00:00<00:18] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:04] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 67.00% [67/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 86.00% [86/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.57 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.77 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.66 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:39] |----------------------------------------| 2.00% [2/100 00:00<00:20] |X---------------------------------------| 3.00% [3/100 00:00<00:13] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXXX-----------------------| 43.00% [43/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 62.00% [62/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----| 88.00% [88/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.92 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.91 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.01 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.48 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.33 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:41] |----------------------------------------| 2.00% [2/100 00:00<00:20] |X---------------------------------------| 3.00% [3/100 00:00<00:13] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXX-----------------| 59.00% [59/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---------| 79.00% [79/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.75 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.12 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (25.16 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:56] |----------------------------------------| 2.00% [2/100 00:00<00:27] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 39.00% [39/100 00:00<00:01] |XXXXXXXXXXXXXXXXXX------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.71 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.35 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.51 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.85 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.01 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.65 ms/it)[39m[K
[32mFitting LDS via EM... 100%|█████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:50] |----------------------------------------| 2.00% [2/100 00:00<00:25] |X---------------------------------------| 3.00% [3/100 00:00<00:16] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:05] |XXXX------------------------------------| 11.00% [11/100 00:00<00:04] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:02] |XXXXXXXXXX------------------------------| 27.00% [27/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 37.00% [37/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX--------------------| 51.00% [51/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.87 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.55 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.59 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.84 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.01 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.63 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.35 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:58] |----------------------------------------| 2.00% [2/100 00:00<00:28] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:14] |XX--------------------------------------| 5.00% [5/100 00:00<00:11] |XX--------------------------------------| 6.00% [6/100 00:00<00:09] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 38.00% [38/100 00:00<00:01] |XXXXXXXXXXXXXXXXXXXX----

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.66 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.69 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.51 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.36 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.97 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:58] |----------------------------------------| 2.00% [2/100 00:00<00:28] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:14] |XX--------------------------------------| 5.00% [5/100 00:00<00:11] |XX--------------------------------------| 6.00% [6/100 00:00<00:09] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 38.00% [38/100 00:00<00:01] |XXXXXXXXXXXXXXXXXXX-----

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (10.05 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.32 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.85 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.47 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.62 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.88 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.06 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.11 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:36] |----------------------------------------| 2.00% [2/100 00:00<00:18] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX--------------------| 50.00% [50/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----------| 73.00% [73/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (18.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.36 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.50 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (18.99 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.35 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:37] |----------------------------------------| 2.00% [2/100 00:00<00:18] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 65.00% [65/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--------| 82.00% [82/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (34.60 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (34.82 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (35.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (36.29 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (35.16 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (33.47 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:51] |----------------------------------------| 2.00% [2/100 00:00<00:25] |X---------------------------------------| 3.00% [3/100 00:00<00:16] |X---------------------------------------| 4.00% [4/100 00:00<00:12] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:05] |XXXX------------------------------------| 11.00% [11/100 00:00<00:04] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:02] |XXXXXXXXXX------------------------------| 27.00% [27/100 00:00<00:01] |XXXXXXXXXXXXXX--------------------------| 37.00% [37/100 00:00<00:01] |XXXXXXXXXXXXXXXXXXX---------------------| 48.00% [48/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXX-

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.95 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.79 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.08 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.73 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.88 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.46 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:34] |----------------------------------------| 2.00% [2/100 00:00<00:17] |X---------------------------------------| 3.00% [3/100 00:00<00:11] |X---------------------------------------| 4.00% [4/100 00:00<00:08] |XX--------------------------------------| 5.00% [5/100 00:00<00:06] |XX--------------------------------------| 7.00% [7/100 00:00<00:04] |XXXX------------------------------------| 11.00% [11/100 00:00<00:02] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:01] |XXXXXXXXXX------------------------------| 26.00% [26/100 00:00<00:00] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 61.00% [61/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----| 89.00% [89/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.26 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.18 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.95 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.82 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (21.11 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (19.09 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.68 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:54] |----------------------------------------| 2.00% [2/100 00:00<00:26] |X---------------------------------------| 3.00% [3/100 00:00<00:17] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 31.00% [31/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX---

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (37.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (35.84 ms/it)[39m[K
[32mFitting LDS via EM...  80%|█████████████████████████████████████████         |  ETA: 0:00:01 (36.88 ms/it)[39m[KExcessive output truncated after 524415 bytes.

18-element Vector{Any}:
 Dict{String, NamedTuple}("julia" => (time = 2.648843e8, memory = 235310552, allocs = 2386036, success = true), "dynamax" => (time = 4.926646e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 100), "pykalman" => (time = 2.111497e9, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 1.26140525e9, memory = 1100285608, allocs = 11630487, success = true), "dynamax" => (time = 7.987482e8, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 500), "pykalman" => (time = 1.06553109e10, memory = 0, allocs = 0, success = true))
 Dict{String, NamedTuple}("julia" => (time = 2.4065127e9, memory = 2228086496, allocs = 24064819, success = true), "dynamax" => (time = 1.1627641e9, memory = 0, allocs = 0, success = true), "config" => (latent_dim = 2, obs_dim = 2, seq_len = 1000), "pykalman" => (time = 2.15521513e10, memory = 0, allocs = 0, success = true))
 Dic

# Post-Processing Functions
---

In [None]:
function transform_to_df(data_vector::Vector)
    # Initialize vectors for all our columns
    packages = String[]
    times = Float64[]
    memories = Int[]
    allocs = Int[]
    successes = Bool[]
    latent_dims = Int[]
    obs_dims = Int[]
    seq_lens = Int[]
    
    # Process each dictionary in the vector
    for dict in data_vector
        # Get configuration values for this batch
        config = dict["config"]
        latent_dim = config.latent_dim
        obs_dim = config.obs_dim
        seq_len = config.seq_len
        
        # Process each package's results
        for (pkg_name, results) in dict
            if pkg_name != "config"
                push!(packages, pkg_name)
                push!(times, results.time)
                push!(memories, results.memory)
                push!(allocs, results.allocs)
                push!(successes, results.success)
                push!(latent_dims, latent_dim)
                push!(obs_dims, obs_dim)
                push!(seq_lens, seq_len)
            end
        end
    end
    
    # Create the DataFrame
    DataFrame(
        package = packages,
        time = times,
        memory = memories,
        allocs = allocs,
        success = successes,
        latent_dim = latent_dims,
        obs_dim = obs_dims,
        seq_length = seq_lens
    )
end

function plot_benchmarks(df::DataFrame)
    # Create a unique identifier for each obs_dim/latent_dim combination
    df.dim_combo = string.(df.obs_dim, "x", df.latent_dim)
    
    # Define marker shapes that will cycle if we have more combinations than shapes
    base_markers = [:circle, :square, :diamond, :utriangle, :dtriangle, :star5]
    dim_combos = unique(df.dim_combo)
    
    # Create marker dictionary by cycling through available markers
    marker_dict = Dict(
        combo => base_markers[mod1(i, length(base_markers))]
        for (i, combo) in enumerate(dim_combos)
    )
    
    # Create the plot
    p = plot(
        xlabel="log(Sequence Length)",
        ylabel="log(Runtime (s))",
        title="Expectation-Maximization Benchmark",
        legend=:outertopright,
        xscale=:log10,
        yscale=:log10
    )
    
    # Plot each package with a different color
    packages = unique(df.package)
    for (i, pkg) in enumerate(packages)
        pkg_data = df[df.package .== pkg, :]
        
        # Plot each dimension combination for this package
        for dim_combo in dim_combos
            combo_data = pkg_data[pkg_data.dim_combo .== dim_combo, :]
            if !isempty(combo_data)
                plot!(
                    p,
                    combo_data.seq_length,
                    combo_data.time,
                    label="$(pkg) ($(dim_combo))",
                    color=i,
                    marker=marker_dict[dim_combo],
                    markersize=6,
                    markerstrokewidth=1
                )
            end
        end
    end
    
    # Add gridlines and adjust layout
    plot!(
        p,
        grid=true,
        minorgrid=true,
        size=(900, 600),
        margin=10Plots.mm
    )
    
    return p
end

plot_benchmarks (generic function with 1 method)

# Process Results
---

In [11]:
df = transform_to_df(results)
df.time = df.time / 1e9;

CSV.write("lds_benchmark_results_no_elbo.csv", df)

"lds_benchmark_results_no_elbo.csv"

In [12]:
df

Row,package,time,memory,allocs,success,latent_dim,obs_dim,seq_length
Unnamed: 0_level_1,String,Float64,Int64,Int64,Bool,Int64,Int64,Int64
1,julia,0.264884,235310552,2386036,true,2,2,100
2,dynamax,0.492665,0,0,true,2,2,100
3,pykalman,2.1115,0,0,true,2,2,100
4,julia,1.26141,1100285608,11630487,true,2,2,500
5,dynamax,0.798748,0,0,true,2,2,500
6,pykalman,10.6553,0,0,true,2,2,500
7,julia,2.40651,2228086496,24064819,true,2,2,1000
8,dynamax,1.16276,0,0,true,2,2,1000
9,pykalman,21.5522,0,0,true,2,2,1000
10,julia,0.402609,245513616,2387217,true,2,4,100


In [13]:
benchmark_plot = plot_benchmarks(df)
# save plot_benchmarks, set dpi to 300
savefig(benchmark_plot, "benchmark_plot.pdf")

"c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\benchmark_plot.pdf"