In [1]:
using BenchmarkTools
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV
using Distributions
using HiddenMarkovModels
using StateSpaceDynamics


In [2]:
const np = pyimport("numpy")
const dynamax = pyimport("dynamax.hidden_markov_model")
const jr = pyimport("jax.random")
const jnp = pyimport("jax.numpy")

Python: <module 'jax.numpy' from 'c:\\Users\\zachl\\OneDrive\\Documents\\GitHub\\StateSpaceDynamics.jl\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\jax\\numpy\\__init__.py'>

In [101]:
struct BenchConfig
    latent_dims::Vector{Int}
    input_dims::Vector{Int}
    obs_dims::Vector{Int} 
    seq_lengths::Vector{Int}
    n_iters::Int
    n_repeats::Int
end

default_config = BenchConfig(
    [2, 4],       # latent dimensions
    [2, 4],        # input dimension
    [1],      # observation dimensions 
    [100, 500],    # sequence lengths
    100,                 # EM iterations
    5                    # benchmark repeats
)

function initialize_transition_matrix(K::Int)
    # Initialize a transition matrix with zeros
    A = zeros(Float64, K, K)
    
    for i in 1:K
        # Sample from a Dirichlet distribution
        A[i, :] = rand(Dirichlet(ones(K)))
    end

    A .+= 0.5.*I(K)
    A .= A ./ sum(A, dims=2)
    return A
end

function initialize_state_distribution(K::Int)
    # initialize a state distribution
    return rand(Dirichlet(ones(K)))
end


function generate_random_hmm(latent_dim::Int, input_dim::Int, obs_dim::Int)
    """
    Create the StateSpaceDynamics.jl Model
    """
    # Create Gaussian Emission Models with random means and covariances
    emissions = Vector{BernoulliRegressionEmission}(undef, latent_dim)
    true_model = StateSpaceDynamics.SwitchingBernoulliRegression(K=latent_dim, input_dim=input_dim, output_dim=obs_dim, include_intercept=false)

    # Randomly initialize the parameters for each state
    for state in 1:latent_dim
        # Generate random coefficients (input_dim x output_dim)
        # we assume no intercept term here
        β = randn(input_dim, obs_dim)
        
        # Create the Bernoulli regression emission for the state
        true_model.B[state] = BernoulliRegressionEmission(input_dim=input_dim, output_dim=obs_dim, β=β, include_intercept=false)
    end

    true_model.A = initialize_transition_matrix(latent_dim)
    true_model.πₖ = initialize_state_distribution(latent_dim)

    """
    Create the Dynamax Model
    """
    # Convert Julia parameters to NumPy arrays
    initial_probs = jnp.array(true_model.πₖ)  # Convert initial state probabilities
    transition_matrix = jnp.array(true_model.A)  # Convert transition matrix

    # Convert emission weights (concatenate all β arrays into a 3D NumPy array)
    emission_weights = emission_weights = jnp.stack([jnp.array(true_model.B[state].β') for state in 1:latent_dim])

    emission_biases = jnp.zeros(latent_dim)  # No intercept terms

    # Initialize the Dynamax model
    dynamax_model = dynamax.LogisticRegressionHMM(
        num_states=latent_dim,
        input_dim=input_dim
    )
        key = jr.PRNGKey(1)

    #For some reason this initialization isn't right
    #Pass parameters manually to the Dynamax model
    # params, props = dynamax_model.initialize(
    #     key=key,
    #     method="prior",
    #     initial_probs=initial_probs,
    #     transition_matrix=transition_matrix,
    #     emission_weights=emission_weights,
    #     emission_biases=emission_biases
    # )




    params, props = dynamax_model.initialize(
        key=key,
        method="prior"
    )

    return true_model, dynamax_model, params, props
end



function generate_test_data(model, seq_len::Int)
    # Generate random input data
    Φ = randn(model.B[1].input_dim, seq_len)

    # Sample from the model
    labels, data = StateSpaceDynamics.sample(model, Φ, n=seq_len)

    return model, labels, Φ, data
end


function run_single_benchmark(model_type::Symbol, hmm_ssd, y, Φ, params=nothing, props=nothing; config=default_config)
    if model_type == :julia
        bench = @benchmark begin
            model = deepcopy($hmm_ssd)  # Create a fresh copy for each iteration
            StateSpaceDynamics.fit!(model, $y, $Φ, max_iters=100, tol=1e-10)
        end samples=config.n_repeats
        return (time=median(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)
    else
        bench = @benchmark begin
            model = deepcopy($hmm_ssd)
            dynamax_model.fit_em(params, props, data_np, inputs=inputs_np, num_iters=100)
        end samples=config.n_repeats
        return (time=median(bench).time, memory=bench.memory, allocs=bench.allocs, success=true)        
    end
end

function benchmark_fitting(config::BenchConfig = default_config)
    results = []

    for latent_dim in config.latent_dims
        for input_dim in config.input_dims
            for obs_dim in config.obs_dims
                for seq_len in config.seq_lengths
                    println("\nTesting configuration: latent_dim=$latent_dim, input_dim=$input_dim, obs_dim=$obs_dim, seq_len=$seq_len")

                    # Create true model
                    true_model, dynamax_model, params, props = generate_random_hmm(latent_dim, input_dim, obs_dim)
                    
                    # Generate test data
                    model, labels, Φ, data = generate_test_data(true_model, seq_len)
                    vectorized_data = [data[:, i] for i in 1:size(data, 2)]  # Vectorize for HMMjl

                    # Convert inputs to NumPy format (inputs are seq_len x input_dim in dynamax)
                    inputs_np = np.array(Φ')
                    data_np = np.array(data)[0]
                    labels_np = np.array(labels .- 1)  # Dynamax expects labels indexed from 0

                    # Generate random HMMs for fitting
                    test_model, dynamax_model, params, props = generate_random_hmm(latent_dim, input_dim, obs_dim)

                    # Run benchmarks separately with error handling
                    julia_result = try
                        run_single_benchmark(:julia, test_model, data, Φ)
                    catch err
                        println("Error in SSD.jl benchmarking: ", err)
                        (time="FAIL", memory="FAIL", allocs="FAIL", success=false)
                    end

                    dynamax_result = try
                        run_single_benchmark(:dynamax, dynamax_model, data, Φ, params, props)
                    catch err
                        println("Error in dynamax benchmarking: ", err)
                        (time="FAIL", memory="FAIL", allocs="FAIL", success=false)
                    end

                    # Save results
                    push!(results, Dict(
                        "config" => (latent_dim=latent_dim, input_dim=input_dim, obs_dim=obs_dim, seq_len=seq_len),
                        "SSD.jl" => julia_result,
                        "Dynamax" => dynamax_result
                    ))
                end
            end
        end
    end

    return results
end

benchmark_fitting (generic function with 2 methods)

In [61]:
results = benchmark_fitting()


Testing configuration: latent_dim=2, input_dim=2, obs_dim=1, seq_len=100


[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.69 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.95 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.61 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.55 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.91 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.01 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.53 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.88 ms/it)[39m[K
[32mRunning EM algorith

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:01<01:46] |----------------------------------------| 2.00% [2/100 00:01<00:52] |X---------------------------------------| 3.00% [3/100 00:01<00:34] |X---------------------------------------| 4.00% [4/100 00:01<00:25] |XX--------------------------------------| 5.00% [5/100 00:01<00:20] |XX--------------------------------------| 6.00% [6/100 00:01<00:16] |XX--------------------------------------| 7.00% [7/100 00:01<00:14] |XXX-------------------------------------| 8.00% [8/100 00:01<00:12] |XXX-------------------------------------| 9.00% [9/100 00:01<00:10] |XXXX------------------------------------| 10.00% [10/100 00:01<00:09] |XXXX------------------------------------| 11.00% [11/100 00:01<00:08] |XXXXX-----------------------------------| 13.00% [13/100 00:01<00:07] |XXXXXX----------------------------------| 15.00% [15/100 00:01<00:06] |XXXXXX----------------------

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (30.07 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (32.32 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:02 (28.20 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (33.22 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (36.60 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (34.84 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:46] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:15] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXX-------------------------------------| 9.00% [9/100 00:00<00:04] |XXXX------------------------------------| 12.00% [12/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXXX-------------------------------| 24.00% [24/100 00:00<00:01] |XXXXXXXXXXXXX---------------------------| 34.00% [34/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXX---------------------| 48.00% [48/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXX--------------| 67.00% [67/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 4.74 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.76 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.93 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (11.42 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.77 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.23 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.62 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (13.60 ms/it)[39m[K
[32mRunning EM algorith

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:43] |----------------------------------------| 2.00% [2/100 00:00<00:21] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:05] |XXXX------------------------------------| 10.00% [10/100 00:00<00:03] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 29.00% [29/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX----------------| 60.00% [60/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX------| 85.00% [85/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (32.65 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:05 (50.49 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (48.43 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (34.13 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:45] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:14] |X---------------------------------------| 4.00% [4/100 00:00<00:10] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 57.00% [57/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX--------| 80.00% [80/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.55 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.53 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.87 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.23 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.23 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.83 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.96 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.47 ms/it)[39m[K
[32mRunning EM algorith

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:55] |----------------------------------------| 2.00% [2/100 00:00<00:27] |X---------------------------------------| 3.00% [3/100 00:00<00:18] |X---------------------------------------| 4.00% [4/100 00:00<00:13] |XX--------------------------------------| 5.00% [5/100 00:00<00:10] |XX--------------------------------------| 6.00% [6/100 00:00<00:08] |XXX-------------------------------------| 8.00% [8/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:05] |XXXXX-----------------------------------| 13.00% [13/100 00:00<00:03] |XXXXXX----------------------------------| 17.00% [17/100 00:00<00:02] |XXXXXXXXX-------------------------------| 23.00% [23/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 31.00% [31/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX--

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (36.95 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (40.37 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (38.38 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (37.49 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (39.48 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (39.46 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:48] |----------------------------------------| 2.00% [2/100 00:00<00:24] |X---------------------------------------| 3.00% [3/100 00:00<00:15] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:09] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXX-------------------------------------| 9.00% [9/100 00:00<00:04] |XXXX------------------------------------| 12.00% [12/100 00:00<00:03] |XXXXXX----------------------------------| 16.00% [16/100 00:00<00:02] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:01] |XXXXXXXXXXXX----------------------------| 30.00% [30/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 42.00% [42/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXX-----------------| 58.00% [58/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXX

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.82 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 7.94 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.90 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:01 (15.23 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 9.62 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 9.45 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 8.97 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 9.74 ms/it)[39m[K
[32mRunning EM algorith

 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<00:46] |----------------------------------------| 2.00% [2/100 00:00<00:22] |X---------------------------------------| 3.00% [3/100 00:00<00:15] |X---------------------------------------| 4.00% [4/100 00:00<00:11] |XX--------------------------------------| 5.00% [5/100 00:00<00:08] |XX--------------------------------------| 7.00% [7/100 00:00<00:06] |XXXX------------------------------------| 10.00% [10/100 00:00<00:04] |XXXXX-----------------------------------| 14.00% [14/100 00:00<00:02] |XXXXXXXX--------------------------------| 20.00% [20/100 00:00<00:01] |XXXXXXXXXXX-----------------------------| 28.00% [28/100 00:00<00:01] |XXXXXXXXXXXXXXXX------------------------| 40.00% [40/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX------------------| 56.00% [56/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---------| 78.00% [78/100 00:00<00:00] |XXXXXXXXXXXXXXXXXXXXXX

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (49.37 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (45.07 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (40.11 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (40.27 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:04 (42.49 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:03 (39.37 ms/it)[39m[K


 |----------------------------------------| 0.00% [0/100 00:00<?] |----------------------------------------| 1.00% [1/100 00:00<01:17] |----------------------------------------| 2.00% [2/100 00:00<00:38] |X---------------------------------------| 3.00% [3/100 00:00<00:25] |X---------------------------------------| 4.00% [4/100 00:00<00:18] |XX--------------------------------------| 5.00% [5/100 00:00<00:14] |XX--------------------------------------| 6.00% [6/100 00:00<00:12] |XX--------------------------------------| 7.00% [7/100 00:00<00:10] |XXX-------------------------------------| 8.00% [8/100 00:00<00:08] |XXXX------------------------------------| 10.00% [10/100 00:00<00:07] |XXXX------------------------------------| 12.00% [12/100 00:00<00:05] |XXXXXX----------------------------------| 15.00% [15/100 00:00<00:04] |XXXXXXX---------------------------------| 18.00% [18/100 00:00<00:03] |XXXXXXXX--------------------------------| 22.00% [22/100 00:00<00:02] |XXXXXXXXXX----------------

8-element Vector{Any}:
 Dict{String, NamedTuple}("Dynamax" => (time = 4.298737e8, memory = 1088, allocs = 28, success = true), "config" => (latent_dim = 2, input_dim = 2, obs_dim = 1, seq_len = 100), "SSD.jl" => (time = 8.090367e8, memory = 269718160, allocs = 2658909, success = true))
 Dict{String, NamedTuple}("Dynamax" => (time = 5.228302e8, memory = 1088, allocs = 28, success = true), "config" => (latent_dim = 2, input_dim = 2, obs_dim = 1, seq_len = 500), "SSD.jl" => (time = 3.6119903e9, memory = 1350422144, allocs = 13020702, success = true))
 Dict{String, NamedTuple}("Dynamax" => (time = 4.629584e8, memory = 1088, allocs = 28, success = true), "config" => (latent_dim = 2, input_dim = 4, obs_dim = 1, seq_len = 100), "SSD.jl" => (time = 9.405538e8, memory = 310604992, allocs = 2904984, success = true))
 Dict{String, NamedTuple}("Dynamax" => (time = 5.400694e8, memory = 1088, allocs = 28, success = true), "config" => (latent_dim = 2, input_dim = 4, obs_dim = 1, seq_len = 500), "SSD.

In [64]:
using CSV
using DataFrames

function prepare_results_for_csv(results)
    rows = []
    for result in results
        config = result["config"]
        ssd = result["SSD.jl"]
        Dynamax = result["Dynamax"]

        # Add a row for SSD.jl
        push!(rows, (
            latent_dim=config.latent_dim,
            obs_dim=config.obs_dim,
            seq_len=config.seq_len,
            library="SSD.jl",
            time=ssd.time,
            memory=ssd.memory,
            allocs=ssd.allocs,
            success=ssd.success,
        ))

        # Add a row for HMM.jl
        push!(rows, (
            latent_dim=config.latent_dim,
            obs_dim=config.obs_dim,
            seq_len=config.seq_len,
            library="Dynamax",
            time=Dynamax.time,
            memory=Dynamax.memory,
            allocs=Dynamax.allocs,
            success=Dynamax.success,
        ))
    end
    return DataFrame(rows)
end

results_df = prepare_results_for_csv(results)
CSV.write("benchmark_results.csv", results_df)


"benchmark_results.csv"

In [103]:
function transform_to_df(data_vector::Vector)
    # Initialize vectors for all our columns
    packages = String[]
    times = Float64[]
    memories = Int[]
    allocs = Int[]
    successes = Bool[]
    latent_dims = Int[]
    obs_dims = Int[]
    seq_lens = Int[]
    
    # Process each dictionary in the vector
    for dict in data_vector
        # Get configuration values for this batch
        config = dict["config"]
        latent_dim = config.latent_dim
        obs_dim = config.obs_dim
        seq_len = config.seq_len
        
        # Process each package's results
        for (pkg_name, results) in dict
            if pkg_name != "config"
                push!(packages, pkg_name)
                push!(times, results.time)
                push!(memories, results.memory)
                push!(allocs, results.allocs)
                push!(successes, results.success)
                push!(latent_dims, latent_dim)
                push!(obs_dims, obs_dim)
                push!(seq_lens, seq_len)
            end
        end
    end
    
    # Create the DataFrame
    DataFrame(
        package = packages,
        time = times,
        memory = memories,
        allocs = allocs,
        success = successes,
        latent_dim = latent_dims,
        obs_dim = obs_dims,
        seq_length = seq_lens
    )
end

function plot_benchmarks(df::DataFrame)
    # Create a unique identifier for each obs_dim/latent_dim combination
    df.dim_combo = string.(df.obs_dim, "x", df.latent_dim)
    
    # Define line styles that will cycle if we have more combinations than styles
    base_styles = [:solid, :dash, :dot, :dashdot, :dashdotdot]
    dim_combos = unique(df.dim_combo)
    
    # Create style dictionary by cycling through available styles
    style_dict = Dict(
        combo => base_styles[mod1(i, length(base_styles))] 
        for (i, combo) in enumerate(dim_combos)
    )
    
    # Create the plot
    p = plot(
        xlabel="Sequence Length",
        ylabel="Time (seconds)",
        title="Package Performance Across Sequence Lengths",
        legend=:outertopright,
        xscale=:log10,
        yscale=:log10
    )
    
    # Plot each package with a different color
    packages = unique(df.package)
    for (i, pkg) in enumerate(packages)
        pkg_data = df[df.package .== pkg, :]
        
        # Plot each dimension combination for this package
        for dim_combo in dim_combos
            combo_data = pkg_data[pkg_data.dim_combo .== dim_combo, :]
            if !isempty(combo_data)
                plot!(
                    p,
                    combo_data.seq_length,
                    combo_data.time ./ 1e9,  # Convert to seconds
                    label="$(pkg) ($(dim_combo))",
                    color=i,
                    linestyle=style_dict[dim_combo],
                    marker=:circle,
                    markersize=4
                )
            end
        end
    end
    
    # Add gridlines and adjust layout
    plot!(
        p,
        grid=true,
        minorgrid=true,
        size=(900, 600),
        margin=10Plots.mm
    )
    
    return p
end

plot_benchmarks (generic function with 1 method)

In [104]:
df = transform_to_df(results)
df.time = df.time / 1e9;

CSV.write("benchmark_results_bernoulli.csv", df)

"benchmark_results_bernoulli.csv"