# Import Julia Packages for Benchmarking
---

In [7]:
using BenchmarkTools
using StateSpaceDynamics
using Random
using Statistics
using PythonCall
using DataFrames
using Printf
using LinearAlgebra
using CSV

# Import Python Packages for Benchmarking
---

In [2]:
# Import Python packages
const dynamax = pyimport("dynamax.linear_gaussian_ssm")
const jr = pyimport("jax.random")
const pykalman = pyimport("pykalman")
const np = pyimport("numpy")

Python: <module 'numpy' from 'c:\\Users\\ryansenne\\Documents\\GitHub\\ssm_julia\\benchmarking\\.CondaPkg\\env\\Lib\\site-packages\\numpy\\__init__.py'>

# Benchmarking Functions
---

In [5]:

"""
Run a single benchmark configuration safely
"""
function run_single_benchmark(latent_dim::Int, obs_dim::Int, seq_len::Int, config::BenchConfig)
    try
        println("\nTrying configuration: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len")
        
        # Generate test data
        lds_julia, x, y, params = generate_test_data(latent_dim, obs_dim, seq_len)
        A, C, Q, R, x0, _ = params

        # Convert all arrays to numpy
        A_np = np.array(A)
        C_np = np.array(C)
        Q_np = np.array(Q)
        R_np = np.array(R)
        x0_np = np.array(x0)
        y_np = np.array(permutedims(dropdims(y, dims=3), (2, 1)))
        
        # Benchmark Julia implementation
        julia_bench = @benchmark fit!($lds_julia, $y, max_iter=$config.n_iters) samples=config.n_repeats
        
        # Setup Python models
        key = jr.PRNGKey(0)
        lds_dynamax = dynamax.LinearGaussianSSM(latent_dim, obs_dim)
        test_params, param_props = lds_dynamax.initialize(
            key, dynamics_weights=A_np, dynamics_covariance=Q_np,
            emission_weights=C_np, emission_covariance=R_np,
            initial_mean=x0_np
        )
        
        # Benchmark Dynamax
        dynamax_bench = @benchmark $lds_dynamax.fit_em(
            $test_params, $param_props, $y_np, 
            num_iters=$config.n_iters
        ) samples=config.n_repeats
        
        # Setup PyKalman
        kf = pykalman.KalmanFilter(
            n_dim_state=latent_dim, n_dim_obs=obs_dim,
            transition_matrices=A_np, observation_matrices=C_np,
            transition_covariance=Q_np, observation_covariance=R_np,
            initial_state_mean=x0_np
        )
        
        # Benchmark PyKalman
        pykalman_bench = @benchmark $kf.em($y_np, n_iter=$config.n_iters) samples=config.n_repeats
        
        return (
            latent_dim=latent_dim,
            obs_dim=obs_dim,
            seq_len=seq_len,
            status="success",
            error="",
            julia_time=minimum(julia_bench).time,
            julia_memory=julia_bench.memory,
            julia_allocs=julia_bench.allocs,
            dynamax_time=minimum(dynamax_bench).time,
            pykalman_time=minimum(pykalman_bench).time
        )
        
    catch e
        @warn "Benchmark failed for config: latent_dim=$latent_dim, obs_dim=$obs_dim, seq_len=$seq_len"
        @warn "Error: " * sprint(showerror, e)
        return (
            latent_dim=latent_dim,
            obs_dim=obs_dim,
            seq_len=seq_len,
            status="failed",
            error=sprint(showerror, e),
            julia_time=missing,
            julia_memory=missing,
            julia_allocs=missing,
            dynamax_time=missing,
            pykalman_time=missing
        )
    end
end

"""
Print results for a single row
"""
function print_single_result(row)
    println("\nConfiguration:")
    println("  Latent dim: $(row.latent_dim)")
    println("  Obs dim: $(row.obs_dim)")
    println("  Seq length: $(row.seq_len)")
    
    if row.status == "failed"
        println("Status: Failed")
        println("Error: $(row.error)")
        println("-"^50)
        return
    end
    
    # Convert nanoseconds to seconds
    julia_time = row.julia_time / 1e9
    dynamax_time = row.dynamax_time / 1e9
    pykalman_time = row.pykalman_time / 1e9
    
    println("\nTimings (seconds):")
    @printf("  Julia:    %.3f\n", julia_time)
    @printf("  Dynamax:  %.3f\n", dynamax_time)
    @printf("  PyKalman: %.3f\n", pykalman_time)
    
    println("\nMemory (Julia implementation):")
    println("  Allocations: $(row.julia_allocs)")
    println("  Memory: $(row.julia_memory) bytes")
    
    println("\nSpeedup ratios:")
    @printf("  vs Dynamax:  %.2fx\n", dynamax_time/julia_time)
    @printf("  vs PyKalman: %.2fx\n", pykalman_time/julia_time)
    println("-"^50)
end

"""
Run benchmarks with progressive results
"""
function benchmark_fitting(config::BenchConfig=default_config)
    # Create empty DataFrame with proper types
    df = DataFrame(
        latent_dim = Int[],
        obs_dim = Int[],
        seq_len = Int[],
        status = String[],
        error = String[],
        julia_time = Union{Float64, Missing}[],
        julia_memory = Union{Int, Missing}[],
        julia_allocs = Union{Int, Missing}[],
        dynamax_time = Union{Float64, Missing}[],
        pykalman_time = Union{Float64, Missing}[]
    )
    
    for latent_dim in config.latent_dims
        for obs_dim in config.obs_dims
            obs_dim < latent_dim && continue
            
            for seq_len in config.seq_lengths
                result = run_single_benchmark(latent_dim, obs_dim, seq_len, config)
                push!(df, result)
                print_single_result(last(df))
                
                # Optionally save after each result
                CSV.write("benchmark_results.csv", df)
            end
        end
    end
    
    return df
end

"""
Print summary of all results
"""
function print_benchmark_results(df::DataFrame)
    n_total = nrow(df)
    n_success = count(==(("success")), df.status)
    n_failed = n_total - n_success
    
    println("\nBenchmark Summary:")
    println("Total configurations: $n_total")
    println("Successful: $n_success")
    println("Failed: $n_failed")
    
    if n_success > 0
        successful = df[df.status .== "success", :]
        
        println("\nAverage Speedups:")
        avg_dynamax = mean(successful.dynamax_time ./ successful.julia_time)
        avg_pykalman = mean(successful.pykalman_time ./ successful.julia_time)
        @printf("  vs Dynamax:  %.2fx\n", avg_dynamax)
        @printf("  vs PyKalman: %.2fx\n", avg_pykalman)
    end
    
    println("\nDetailed Results:")
    for row in eachrow(df)
        print_single_result(row)
    end
end

print_benchmark_results

In [None]:
results = benchmark_fitting()