In [None]:
+using BenchmarkTools
using StateSpaceDynamics
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV
using Plots

# Import Python Packages for Benchmarking
---

In [2]:
# Import Python packages
const dynamax = pyimport("dynamax.linear_gaussian_ssm")
const jr = pyimport("jax.random")
const pykalman = pyimport("pykalman")
const np = pyimport("numpy")

Python: <module 'numpy' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\numpy\\__init__.py'>

# Utility Functions
---

In [3]:
function sample_rotation_matrix(n::Int)
    A = randn(n, n)
    Q, r = qr(A)
    return Matrix(Q)
end

sample_rotation_matrix (generic function with 1 method)

# Benchmarking Functions
---

In [4]:
struct BenchConfig
    latent_dims::Vector{Int}
    obs_dims::Vector{Int} 
    seq_lengths::Vector{Int}
    n_iters::Int
    n_repeats::Int
end

default_config = BenchConfig(
    [2, 4, 8],       # latent dimensions
    [2, 4, 8],      # observation dimensions
    [100, 500, 1000],    # sequence lengths
    100,                 # EM iterations
    5                    # benchmark repeats
)

# Helper function to generate random parameters
function generate_random_params(latent_dim::Int, obs_dim::Int)
    A = sample_rotation_matrix(latent_dim)
    C = randn(obs_dim, latent_dim)
    Q = Matrix(Diagonal(ones(latent_dim)))
    R = Matrix(Diagonal(ones(obs_dim)))
    x0 = zeros(latent_dim)
    P0 = Matrix(Diagonal(ones(latent_dim)))
    return A, C, Q, R, x0, P0
end

"""
Generate random parameters and data for benchmarking
"""
function generate_test_data(latent_dim::Int, obs_dim::Int, seq_len::Int)
    # Generate true parameters for data generation
    A_true, C_true, Q_true, R_true, x0_true, P0_true = generate_random_params(latent_dim, obs_dim)
   
    # Create model with true parameters
    lds_true = GaussianLDS(
        A=A_true, C=C_true, Q=Q_true, R=R_true, x0=x0_true, P0=P0_true,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    # Generate data using true parameters
    x, y = sample(lds_true, seq_len, 1)
    
    # Generate different initial parameters for fitting
    A_init, C_init, Q_init, R_init, x0_init, P0_init = generate_random_params(latent_dim, obs_dim)
    
    
    # Create model with initial parameters for fitting
    lds_fit = GaussianLDS(
        A=A_init, C=C_init, Q=Q_init, R=R_init, x0=x0_init, P0=P0_init,
        obs_dim=obs_dim, latent_dim=latent_dim
    )
    
    return lds_fit, x, y, (A_init, C_init, Q_init, R_init, x0_init, P0_init)
end


function run_single_benchmark(model_type::Symbol, lds_julia, y, y_np, params=nothing; config=default_config)
    try
        if model_type == :julia
            bench = @benchmark begin
                local lds_fresh = deepcopy($lds_julia)
                fit!(lds_fresh, $y, max_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)
            
        elseif model_type == :dynamax
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local lds_dynamax = dynamax.LinearGaussianSSM($latent_dim, $obs_dim)
                local key = jr.PRNGKey(0)
                local params, props = lds_dynamax.initialize(
                    key, dynamics_weights=$A_np, dynamics_covariance=$Q_np,
                    emission_weights=$C_np, emission_covariance=$R_np,
                    initial_mean=$x0_np
                )
                lds_dynamax.fit_em(params, props, $y_np, num_iters=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
            
        elseif model_type == :pykalman
            A_np, C_np, Q_np, R_np, x0_np = params
            latent_dim = pyconvert(Int, A_np.shape[0])
            obs_dim = pyconvert(Int, C_np.shape[0])
            
            bench = @benchmark begin
                local kf = pykalman.KalmanFilter(
                    n_dim_state=$latent_dim, n_dim_obs=$obs_dim,
                    transition_matrices=$A_np, observation_matrices=$C_np,
                    transition_covariance=$Q_np, observation_covariance=$R_np,
                    initial_state_mean=$x0_np
                )
                kf.em($y_np, n_iter=$config.n_iters)
            end samples=config.n_repeats
            return (time=median(bench).time, memory=0, allocs=0, success=true)
        end
    catch e
        @warn "Benchmark failed for $model_type" exception=e stacktrace=catch_backtrace()
        return (time=NaN, memory=0, allocs=0, success=false)
    end
end

"""
Benchmark model fitting with separate runs for each package
"""
function benchmark_fitting(config::BenchConfig=default_config)
    results = []
    
    for latent_dim in config.latent_dims
        for obs_dim in config.obs_dims
            obs_dim < latent_dim && continue
            
            for seq_len in config.seq_lengths
                println("\nTesting configuration: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")
                
                # Generate test data once
                lds_julia, x, y, params = generate_test_data(latent_dim, obs_dim, seq_len)
                A, C, Q, R, x0, _ = params

                # Convert all arrays to numpy
                A_np = np.array(A)
                C_np = np.array(C)
                Q_np = np.array(Q)
                R_np = np.array(R)
                x0_np = np.array(x0)
                y_np = np.array(permutedims(dropdims(y, dims=3), (2, 1)))
                
                numpy_params = (A_np, C_np, Q_np, R_np, x0_np)

                # Run benchmarks separately
                julia_result = run_single_benchmark(:julia, lds_julia, y, y_np; config=config)
                dynamax_result = run_single_benchmark(:dynamax, lds_julia, y, y_np, numpy_params; config=config)
                pykalman_result = run_single_benchmark(:pykalman, lds_julia, y, y_np, numpy_params; config=config)
                
                push!(results, Dict(
                    "config" => (latent_dim=latent_dim, obs_dim=obs_dim, seq_len=seq_len),
                    "julia" => julia_result,
                    "dynamax" => dynamax_result,
                    "pykalman" => pykalman_result
                ))
                
                # Print immediate results for this configuration
                print_single_result(results[end])
            end
        end
    end
    
    return results
end

"""
Print results for a single configuration
"""
function print_single_result(result)
    config = result["config"]
    println("\nConfiguration:")
    println("  Latent dim: $(config.latent_dim)")
    println("  Obs dim: $(config.obs_dim)")
    println("  Seq length: $(config.seq_len)")
    
    for pkg in [:julia, :dynamax, :pykalman]
        r = result[string(pkg)]
        if r.success
            @printf("\n%s timing: %.3f seconds", uppercase(string(pkg)), r.time/1e9)
            if pkg == :julia
                println("\n  Memory: $(r.memory) bytes")
                println("  Allocations: $(r.allocs)")
            end
        else
            println("\n$(uppercase(string(pkg))) benchmark failed")
        end
    end
    println("\n" * "-"^50)
end

print_single_result

In [None]:
results = benchmark_fitting()


Testing configuration: latent_dim=2, obs_dim=2, seq_len=100


[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:10 ( 0.11  s/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.16 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.16 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.99 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:46] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:15] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 56.00% [56/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXX----------| 77.00% [77/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.23 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.42 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.00 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (12.81 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:42] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXX-----------------| 59.00% [59/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--------| 80.00% [80/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (25.50 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.50 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (25.17 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.14 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.34 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (24.16 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:44] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 55.00% [55/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXX------------| 70.00% [70/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.43 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.68 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.65 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.93 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.25 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.84 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:41] |----------------------------------------| 2.00% [2/100 00:00<00:20] |X---------------------------------------| 3.00% [3/100 00:00<00:13] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 60.00% [60/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-------| 84.00% [84/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.99 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.66 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.67 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.10 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.51 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (14.32 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:38] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 65.00% [65/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--------| 82.00% [82/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (30.29 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.25 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.69 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.48 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (27.64 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.17 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:55] |----------------------------------------| 2.00% [2/100 00:00<00:27] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 31.00% [31/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 41.00% [41/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXX---

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.19 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.77 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.72 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.31 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.99 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.99 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.31 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.37 ms/it)[39m[K
[32mFitting LDS via EM... 100%|█████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:39] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:12] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 33.00% [33/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 49.00% [49/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXX------------| 71.00% [71/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX| 100.00% [100/100 00:00<00:00] |--------------------

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (17.83 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.43 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.50 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.97 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.68 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (16.40 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.79 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.65 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:56] |----------------------------------------| 2.00% [2/100 00:00<00:27] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 39.00% [39/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX----

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (31.49 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.79 ms/it)[39m[K[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (26.86 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.03 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.01 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (29.00 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:56] |----------------------------------------| 2.00% [2/100 00:00<00:27] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXX-------------------------| 39.00% [39/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXX----

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (11.48 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.32 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.57 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.25 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.14 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.88 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.60 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.87 ms/it)[39m[K
[32mFitting LDS via EM... 100%|████████

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:39] |----------------------------------------| 2.00% [2/100 00:00<00:19] |X---------------------------------------| 3.00% [3/100 00:00<00:13] |X---------------------------------------| 4.00% [4/100 00:00<00:09] |XX--------------------------------------| 5.00% [5/100 00:00<00:07] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:00] |XXXXXXXXXXXXXXXXX-----------------------| 43.00% [43/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 62.00% [62/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX-----| 88.00% [88/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.71 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.75 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.15 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.44 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (21.31 ms/it)[39m[K
[32mFitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (20.53 ms/it)[39m[K
[32mFitting LDS via EM...  95%|████████████████████████████████████████████████  |  ETA: 0:00:00 (20.84 ms/it)[39m[K

# Post-Processing Functions
---

In [None]:
function transform_to_df(data_vector::Vector)
    # Initialize vectors for all our columns
    packages = String[]
    times = Float64[]
    memories = Int[]
    allocs = Int[]
    successes = Bool[]
    latent_dims = Int[]
    obs_dims = Int[]
    seq_lens = Int[]
    
    # Process each dictionary in the vector
    for dict in data_vector
        # Get configuration values for this batch
        config = dict["config"]
        latent_dim = config.latent_dim
        obs_dim = config.obs_dim
        seq_len = config.seq_len
        
        # Process each package's results
        for (pkg_name, results) in dict
            if pkg_name != "config"
                push!(packages, pkg_name)
                push!(times, results.time)
                push!(memories, results.memory)
                push!(allocs, results.allocs)
                push!(successes, results.success)
                push!(latent_dims, latent_dim)
                push!(obs_dims, obs_dim)
                push!(seq_lens, seq_len)
            end
        end
    end
    
    # Create the DataFrame
    DataFrame(
        package = packages,
        time = times,
        memory = memories,
        allocs = allocs,
        success = successes,
        latent_dim = latent_dims,
        obs_dim = obs_dims,
        seq_length = seq_lens
    )
end

function plot_benchmarks(df::DataFrame)
    # Create a unique identifier for each obs_dim/latent_dim combination
    df.dim_combo = string.(df.obs_dim, "x", df.latent_dim)
    
    # Define line styles that will cycle if we have more combinations than styles
    base_styles = [:solid, :dash, :dot, :dashdot, :dashdotdot]
    dim_combos = unique(df.dim_combo)
    
    # Create style dictionary by cycling through available styles
    style_dict = Dict(
        combo => base_styles[mod1(i, length(base_styles))] 
        for (i, combo) in enumerate(dim_combos)
    )
    
    # Create the plot
    p = plot(
        xlabel="log(Sequence Length)",
        ylabel="log(Runtime (s))",
        title="Expectation-Maximization Benchmark",
        legend=:outertopright,
        xscale=:log10,
        yscale=:log10
    )
    
    # Plot each package with a different color
    packages = unique(df.package)
    for (i, pkg) in enumerate(packages)
        pkg_data = df[df.package .== pkg, :]
        
        # Plot each dimension combination for this package
        for dim_combo in dim_combos
            combo_data = pkg_data[pkg_data.dim_combo .== dim_combo, :]
            if !isempty(combo_data)
                plot!(
                    p,
                    combo_data.seq_length,
                    combo_data.time,  # Convert to seconds
                    label="$(pkg) ($(dim_combo))",
                    color=i,
                    linestyle=style_dict[dim_combo],
                    marker=:circle,
                    markersize=4
                )
            end
        end
    end
    
    # Add gridlines and adjust layout
    plot!(
        p,
        grid=true,
        minorgrid=true,
        size=(900, 600),
        margin=10Plots.mm
    )
    
    return p
end

# Process Results
---

In [None]:
df = transform_to_df(results)
df.time = df.time / 1e9;

CSV.write("lds_benchmark_results_no_elbo.csv", df)

In [None]:
df

In [None]:
benchmark_plot = plot_benchmarks(df)
# save plot_benchmarks, set dpi to 300
savefig(benchmark_plot, "benchmark_plot.pdf")