In [None]:
# Activate the benchmarking environment
using Pkg
Pkg.activate(".")

# Import the necessary packages
include("SSD_Benchmark.jl")
using .SSD_Benchmark
using StableRNGs
using StateSpaceDynamics
import HiddenMarkovModels as HMMs

In [None]:
using StableRNGs
using Printf

# Benchmark configuration
latent_dims = [2, 4]
obs_dims = [1]
seq_lengths = [100, 200]
num_trials = 5  # can increase if you want

# Implementations to benchmark
implementations = [
    SSD_GLMHMM_Implem(),
    HMM_GLMHMM_Implem(),
    DYNAMAX_GLMHMM_Implem()
]

all_results = []

for latent_dim in latent_dims
    for obs_dim in obs_dims
        for seq_len in seq_lengths
            println("\n→ Benchmarking GLMHMM with latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")

            # Build instance and RNG
            rng = StableRNG(1234)
            instance = HMMInstance(
                num_states=latent_dim,
                num_trials=num_trials,
                seq_length=seq_len,
                input_dim=latent_dim,   # input_dim = latent_dim (adjust as you like)
                output_dim=obs_dim
            )

            # Create the data for this benchmark instance
            gen_instance = HMMInstance(num_states=2, num_trials=5, seq_length=100, input_dim=2, output_dim=1)
            gen_params = init_params(rng, gen_instance)
            gen_model = build_model(SSD_GLMHMM_Implem(), gen_instance, gen_params)
            labels, X, Y, obs_seq, control_seq, seq_ends = build_data(rng, gen_model, gen_instance)

            # Prepare results row
            results_row = Dict{String, Any}()
            results_row["config"] = (latent_dim=latent_dim, obs_dim=obs_dim, seq_len=seq_len)

            # Generate benchmarking init params
            instance_bench = HMMInstance(num_states=2, num_trials=5, seq_length=100, input_dim=2, output_dim=1)
            params_bench = init_params(rng, instance_bench)

            # Loop over implementations and run benchmarks
            for impl in implementations
                print("  Running $(string(impl))... ")
                try
                    if impl isa DYNAMAX_GLMHMM_Implem
                        model, dparams, dprops = build_model(impl, instance_bench, params_bench)
                        result = run_benchmark(impl, model, dparams, dprops, X, Y)
                    else
                        model = build_model(impl, instance_bench, params_bench)
                        result = run_benchmark(impl, model, X, Y)
                    end
                    results_row[string(impl)] = result
                    if result.success
                        @printf("✓ time = %.3f sec\n", result.time / 1e9)
                    else
                        println("✗ failed")
                    end
                catch e
                    results_row[string(impl)] = (time=NaN, memory=0, allocs=0, success=false)
                    println("✗ exception: ", e)
                end
            end
            push!(all_results, results_row)
            println("-"^50)
        end
    end
end

# write to CSV or show sumary
using DataFrames
using CSV
df = DataFrame()
for row in all_results
    config = row["config"]
    for (name, result) in row
        if name == "config"
            continue
        end
        push!(df, (
            implementation = name,
            latent_dim = config.latent_dim,
            obs_dim = config.obs_dim,
            seq_len = config.seq_len,
            time_sec = result[:time] / 1e9,
            memory = result[:memory],
            allocs = result[:allocs],
            success = result[:success]
        ))
    end
end


CSV.write("glmhmm_benchmark_results.csv", df)


→ Benchmarking GLMHMM with latent_dim=2, obs_dim=1, seq_len=100
  Running SSD_GLMHMM_Implem()... ✓ time = 0.074 sec
  Running HMM_GLMHMM_Implem()... ✓ time = 0.094 sec
  Running DYNAMAX_GLMHMM_Implem()... ✓ time = 0.337 sec
--------------------------------------------------

→ Benchmarking GLMHMM with latent_dim=2, obs_dim=1, seq_len=200
  Running SSD_GLMHMM_Implem()... ✓ time = 0.058 sec
  Running HMM_GLMHMM_Implem()... ✓ time = 0.079 sec
  Running DYNAMAX_GLMHMM_Implem()... ✓ time = 0.343 sec
--------------------------------------------------

→ Benchmarking GLMHMM with latent_dim=4, obs_dim=1, seq_len=100
  Running SSD_GLMHMM_Implem()... ✓ time = 0.060 sec
  Running HMM_GLMHMM_Implem()... ✓ time = 0.085 sec
  Running DYNAMAX_GLMHMM_Implem()... ✓ time = 0.346 sec
--------------------------------------------------

→ Benchmarking GLMHMM with latent_dim=4, obs_dim=1, seq_len=200
  Running SSD_GLMHMM_Implem()... ✓ time = 0.134 sec
  Running HMM_GLMHMM_Implem()... ✓ time = 0.238 sec
  R

In [2]:
rng = StableRNG(50)

# Test the instance struct
I1 = HMMInstance(num_states=2, num_trials=5, seq_length=100, input_dim=2, output_dim=1)

# Test the params struct
P1 = init_params(rng, I1)

# Test building the models
M1 = build_model(SSD_GLMHMM_Implem(), I1, P1)

# Test building the data
labels, X, Y, obs_seq, control_seq, seq_ends = build_data(rng, M1, I1)

# Run a benchmark
B1 = run_benchmark(SSD_GLMHMM_Implem(), M1, X, Y)

[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 6.45 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 1.86 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.14 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.22 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.34 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.35 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 3.32 ms/it)[39m[K
[32mRunning EM algorithm... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 2.00 ms/it)[39m[K
[32mRunning EM algorith

(time = 2.2853215e8, memory = 73635184, allocs = 176735, success = true)

In [3]:
# Test building a HiddenMarkovModels.jl implementation
M2 = build_model(HMM_GLMHMM_Implem(), I1, P1)

B2 = run_benchmark(HMM_GLMHMM_Implem(), M2, X, Y)

(time = 1.4574335e8, memory = 135069216, allocs = 182891, success = true)

In [4]:
# Test building the dynamax model
M3, dparams, dprops = build_model(DYNAMAX_GLMHMM_Implem(), I1, P1)

B3 = run_benchmark(DYNAMAX_GLMHMM_Implem(), M3, dparams, dprops, X, Y)

(time = 6.310759e8, memory = 808, allocs = 21, success = true)