# Triad Model Parameter Control and Sensitivity Analysis

This notebook demonstrates parameter control and sensitivity analysis for a triad model system with multiplicative noise. It covers simulation, score function estimation, response matrix computation, optimization, and visualization.

In [None]:
using Pkg
Pkg.activate("../..")
Pkg.instantiate()

using LinearAlgebra
using Random
using StatsBase
using GLMakie
using CairoMakie
using QuadGK
using Base.Threads
using ProgressBars
using KernelDensity
using HDF5
using Flux
using BSON
using MarkovChainHammer
using ClustGen

In [None]:
# --- Triad Model System Definition ---
const L11 = -2.0; const L12 = 0.2; const L13 = 0.1; const g2 = 0.6; const g3 = 0.4;
const s2_param = 1.2; const s3 = 0.8; const II = 1.0; const ϵ = 0.1
const a = L11 + ϵ * ( (II^2 * s2_param^2) / (2 * g2^2) - (L12^2) / g2 - (L13^2) / g3 )
const b = -2 * (L12 * II) / g2 * ϵ
const c = (II^2) / g2 * ϵ
const B = -(II * s2_param) / g2 * sqrt(ϵ)
const A = -(L12 * B) / II
const s_noise = (L13 * s3) / g3 * sqrt(ϵ)
const F_tilde = (A * B) / 2
const params_triad = [a, b, c, F_tilde, A, B, s_noise]

In [None]:
# --- Drift and Noise Functions ---
F(x, t; p=params_triad) = [-p[4] + p[1] * x[1] + p[2] * x[1]^2 - p[3] * x[1]^3]
sigma1(x, t; p=params_triad) = (p[5] - p[6] * x[1]) / √2
sigma2(x, t; p=params_triad) = p[7] / √2

In [None]:
# --- Score Functions & Derivatives ---
function create_true_score_and_derivative_triad(p)
    a, b, c, F_tilde, A, B, s_val = p
    num(u) = 2 * (F_tilde + a*u + b*u^2 - c*u^3)
    den(u) = s_val^2 + (A - B*u)^2
    score_func(x) = num(x) / den(x)
    function score_derivative_func(u)
        num_deriv = 2 * (a + 2*b*u - 3*c*u^2)
        den_deriv = 2 * (A - B*u) * (-B)
        return (num_deriv * den(u) - num(u) * den_deriv) / den(u)^2
    end
    return score_func, score_derivative_func
end

function create_linear_s_ds(x_t)
    μ = mean(x_t)
    variance = var(x_t)
    score_func(x) = -(x - μ) / variance
    score_derivative(x) = -1 / variance
    return score_func, score_derivative
end

In [None]:
# --- Divergence Score Construction ---
function construct_divergence_score(s::Function; n_points::Int=500, range::Tuple{Real,Real}=(-5.0, 5.0))
    x_grid = Base.range(range[1], range[2], length=n_points)
    divergence_values = similar(x_grid)
    h = 1e-6 # Finite difference step
    score_scalar(x) = begin
        s_val = s(x)
        return s_val isa AbstractArray ? s_val[1] : s_val
    end
    @threads for i in eachindex(x_grid)
        divergence_values[i] = (score_scalar(x_grid[i] + h) - score_scalar(x_grid[i] - h)) / (2 * h)
    end
    function fast_divergence_func(x_val::Real)
        if x_val < range[1] return divergence_values[1] end
        if x_val > range[2] return divergence_values[end] end
        idx_float = (x_val - range[1]) / step(x_grid) + 1
        idx_low = floor(Int, idx_float)
        idx_high = min(idx_low + 1, n_points)
        idx_low = max(1, idx_low)
        t = idx_float - idx_low
        return divergence_values[idx_low] * (1 - t) + divergence_values[idx_high] * t
    end
    return fast_divergence_func
end

In [None]:
# --- Potentials and Jacobians ---
function potential_triad(x, p)
    a_p, b_p, c_p, F_tilde_p, A_p, B_p, s_val_p = p
    sigma_sq(y) = (A_p - B_p*y)^2 + s_val_p^2
    drift(y) = -F_tilde_p + a_p*y + b_p*y^2 - c_p*y^3
    integral_part, _ = quadgk(y -> 2 * drift(y) / sigma_sq(y), 0, x, rtol=1e-6)
    return -integral_part + log(sigma_sq(x))
end
p_unnormalized_triad(x, p) = exp(-potential_triad(x, p))
function compute_observables_triad(p, observables)
    int_bounds = (-6, 6)
    norm_const, _ = quadgk(x -> p_unnormalized_triad(x, p), int_bounds..., rtol=1e-8)
    if norm_const == 0.0 error("Normalization constant is zero. Check potential function or integration bounds.") end
    map(observables) do obs
        obs_integral, _ = quadgk(x -> obs(x) * p_unnormalized_triad(x, p), int_bounds..., rtol=1e-8)
        obs_integral / norm_const
    end
end
function compute_jacobian_triad(p, param_indices, observables; ε=1e-5)
    J = zeros(length(observables), length(param_indices))
    for (i, p_idx) in enumerate(param_indices)
        p_plus, p_minus = copy(p), copy(p)
        p_plus[p_idx] += ε
        p_minus[p_idx] -= ε
        obs_plus = compute_observables_triad(p_plus, observables)
        obs_minus = compute_observables_triad(p_minus, observables)
        J[:, i] = (obs_plus - obs_minus) / (2 * ε)
    end
    return J
end

In [None]:
# --- Conjugate Observables & Response Matrix ---
function create_conjugate_observables_triad(score, score_derivative, p)
    A_p, B_p, s_val_p = p[5], p[6], p[7]
    s_func(x) = begin
        s_val = score(x)
        return s_val isa AbstractArray ? s_val[1] : s_val
    end
    s_prime(x) = begin
        s_val = score_derivative(x)
        return s_val isa AbstractArray ? s_val[1] : s_val
    end
    function B_B(x)
        s_x = s_func(x)
        s_prime_x = s_prime(x)
        f_eff_B_val = 0.5 * ((-A_p*x + B_p*x^2) * s_x - A_p + 2*B_p*x)
        f_eff_B_prime_val = 0.5 * ((-A_p + 2*B_p*x)*s_x + (-A_p*x + B_p*x^2)*s_prime_x + 2*B_p)
        result = -(f_eff_B_prime_val + f_eff_B_val * s_x)
        return result isa AbstractArray ? result[1] : result
    end
    function B_s(x)
        s_x = s_func(x)
        s_prime_x = s_prime(x)
        result = -0.5 * s_val_p * (s_prime_x + s_x^2)
        return result isa AbstractArray ? result[1] : result
    end
    return [B_B, B_s]
end

function create_response_matrix(time_series, dt, observables, conjugate_observables; max_lag_time=30.0)
    lag_indices = 0:floor(Int, max_lag_time / dt)
    R = zeros(length(observables), length(conjugate_observables))
    obs_ts = [obs.(time_series) for obs in observables]
    conj_obs_ts = [c_obs.(time_series) for c_obs in conjugate_observables]
    @threads for k in eachindex(observables)
        for i in eachindex(conjugate_observables)
            corr = crosscov(obs_ts[k], conj_obs_ts[i], lag_indices; demean=true)
            R[k, i] = dt * (sum(corr) - 0.5 * (corr[1] + corr[end]))
        end
    end
    return R
end

In [None]:
# --- Sensitivity Analysis and Optimization Routines ---
function run_sensitivity_analysis(response_matrix, params_original, param_indices, initial_observables, observables; n_control=100, base_Δμ=nothing, λ=0.1)
    N_observables = size(response_matrix, 1)
    N_params = length(param_indices)
    base_Δμ = isnothing(base_Δμ) ? fill(0.01, N_observables) : base_Δμ
    predicted_changes = zeros(N_observables, n_control)
    actual_changes = zeros(N_observables, n_control)
    A = λ * I(N_params)
    for k in 1:n_control
        Δμ_k = k * base_Δμ
        δc_opt = (response_matrix' * response_matrix + A) \ (response_matrix' * Δμ_k)
        predicted_changes[:, k] = response_matrix * δc_opt
        params_new = copy(params_original)
        params_new[param_indices] .+= δc_opt
        observables_new = compute_observables_triad(params_new, observables)
        actual_changes[:, k] = observables_new - initial_observables
    end
    return predicted_changes, actual_changes
end

function run_newton_optimization(initial_jacobian, params_original, param_indices, initial_obs, target_obs, observables; max_iters=10, tolerance=1e-4, λ=0.1, recalculate_jacobian=false, jacobian_method=:response_matrix, score_type=:true, sim_dt=0.01, sim_steps=10_000_000)
    method_name = recalculate_jacobian ? "Full Newton ($jacobian_method, $score_type)" : "Quasi-Newton"
    println("\n--- Starting Optimization: $method_name ---")
    curr_params = copy(params_original)
    curr_observables = copy(initial_obs)
    observables_history = [copy(curr_observables)]
    jacobian = copy(initial_jacobian)
    A = λ * I(length(param_indices))
    for iter in 1:max_iters
        println("--- Iteration $iter ---")
        if recalculate_jacobian && iter > 1
            println("Recalculating Jacobian/Response Matrix...")
            if jacobian_method == :analytical_jacobian
                jacobian = compute_jacobian_triad(curr_params, param_indices, observables)
            else # :response_matrix
                F_curr(x,t) = F(x,t; p=curr_params); s1_curr(x,t) = sigma1(x,t; p=curr_params); s2_curr(x,t) = sigma2(x,t; p=curr_params);
                new_ts = evolve([0.0], sim_dt, sim_steps, F_curr, s1_curr, s2_curr; timestepper=:rk4)[1,:]
                score_func, score_deriv_func = if score_type == :true
                    create_true_score_and_derivative_triad(curr_params)
                elseif score_type == :linear
                    create_linear_s_ds(new_ts)
                else # :kgmm
                    kgmm = calculate_score_kgmm(reshape(new_ts, 1, :); σ_value=0.05, verbose=false)
                    kgmm.score_function, construct_divergence_score(kgmm.score_function)
                end
                conjugates = create_conjugate_observables_triad(score_func, score_deriv_func, curr_params)
                jacobian = create_response_matrix(new_ts, sim_dt, observables, conjugates)
            end
            println("Updated Matrix:"); display(jacobian)
        end
        residual = target_obs - curr_observables
        if norm(residual) < tolerance
            println("Target reached within tolerance."); break
        end
        δc = (jacobian' * jacobian + A) \ (jacobian' * residual)
        curr_params[param_indices] .+= δc
        curr_observables = compute_observables_triad(curr_params, observables)
        push!(observables_history, copy(curr_observables))
        println("New observables (Mean, <x^2>): $(round.(curr_observables, digits=5))")
    end
    return observables_history
end

In [None]:
# --- Generate Baseline Time Series Data ---
dt_sim = 0.01
N_steps = 20_000_000
time_series_triad = evolve([0.0], dt_sim, N_steps, F, sigma1, sigma2; timestepper=:rk4)[1,:]
println("Data generation complete.")

In [None]:
# --- Compute Response Matrices and Jacobian ---
observables_to_control = [x -> x, x -> x^2]
observable_labels = ["Mean", "Variance"]
param_indices_to_control = [6, 7]
param_labels = ["δB", "δs_noise"]
initial_observables = compute_observables_triad(params_triad, observables_to_control)
initial_variance = initial_observables[2] - initial_observables[1]^2
println("Initial Observables (Mean, <x^2>): $initial_observables")
println("Initial Variance: $initial_variance")

# True score
true_score, true_score_deriv = create_true_score_and_derivative_triad(params_triad)
conjugate_obs_true = create_conjugate_observables_triad(true_score, true_score_deriv, params_triad)
R_true = create_response_matrix(time_series_triad, dt_sim, observables_to_control, conjugate_obs_true)

# Linear score
linear_score, linear_score_deriv = create_linear_s_ds(time_series_triad)
conjugate_obs_linear = create_conjugate_observables_triad(linear_score, linear_score_deriv, params_triad)
R_linear = create_response_matrix(time_series_triad, dt_sim, observables_to_control, conjugate_obs_linear)

# Clustered (KGMM) score
kgmm_results = calculate_score_kgmm(reshape(time_series_triad, 1, :); σ_value=0.05, clustering_prob=0.0005, verbose=false)
score_clustered = kgmm_results.score_function
score_clustered_derivative = construct_divergence_score(score_clustered)
conjugate_obs_clustered = create_conjugate_observables_triad(score_clustered, score_clustered_derivative, params_triad)
R_clustered = create_response_matrix(time_series_triad, dt_sim, observables_to_control, conjugate_obs_clustered)

# Analytical Jacobian
J_analytical = compute_jacobian_triad(params_triad, param_indices_to_control, observables_to_control; ε=1e-4)

println("Response Matrix R (from true score):"); display(R_true)
println("Response Matrix R (from linear score):"); display(R_linear)
println("Response Matrix R (from clustered/KGMM score):"); display(R_clustered)
println("Jacobian J (from analytic PDF):"); display(J_analytical)

In [None]:
# --- Run Newton's Method Optimization ---
target_mean = initial_observables[1] - 0.1
target_variance = initial_variance * 0.9
target_x_squared = target_variance + target_mean^2
target_observables = [target_mean, target_x_squared]
println("Target Observables (Mean, <x^2>): $target_observables")
common_args = (params_triad, param_indices_to_control, initial_observables, target_observables, observables_to_control)
newton_kwargs = (max_iters=8, λ=0.01)
obs_hist_J_quasi = run_newton_optimization(J_analytical, common_args...; newton_kwargs...)
obs_hist_R_true_quasi = run_newton_optimization(R_true, common_args...; newton_kwargs...)
obs_hist_R_linear_quasi = run_newton_optimization(R_linear, common_args...; newton_kwargs...)
obs_hist_R_clustered_quasi = run_newton_optimization(R_clustered, common_args...; newton_kwargs...)

In [None]:
# --- Plot Results with Makie ---
GLMakie.activate!()
get_mean(history) = [h[1] for h in history]
get_var(history) = [h[2] - h[1]^2 for h in history]
fig = Figure(size=(1600, 700), fontsize=20)
ga = fig[1, 1] = GridLayout()
ax1 = Axis(ga[1, 1], xlabel="Iteration", ylabel="Mean", title="Convergence of Mean")
lines!(ax1, 0:length(obs_hist_J_quasi)-1, get_mean(obs_hist_J_quasi), label="Jacobian", linewidth=3, color=:purple)
lines!(ax1, 0:length(obs_hist_R_true_quasi)-1, get_mean(obs_hist_R_true_quasi), label="True Score", linewidth=3, color=:blue)
lines!(ax1, 0:length(obs_hist_R_linear_quasi)-1, get_mean(obs_hist_R_linear_quasi), label="Linear Score", linewidth=3, color=:green)
lines!(ax1, 0:length(obs_hist_R_clustered_quasi)-1, get_mean(obs_hist_R_clustered_quasi), label="KGMM Score", linewidth=3, color=:orange)
hlines!(ax1, [target_mean], color=:red, linestyle=:dash, label="Target")
ax2 = Axis(ga[1, 2], xlabel="Iteration", ylabel="Variance", title="Convergence of Variance")
lines!(ax2, 0:length(obs_hist_J_quasi)-1, get_var(obs_hist_J_quasi), label="Jacobian", linewidth=3, color=:purple)
lines!(ax2, 0:length(obs_hist_R_true_quasi)-1, get_var(obs_hist_R_true_quasi), label="True Score", linewidth=3, color=:blue)
lines!(ax2, 0:length(obs_hist_R_linear_quasi)-1, get_var(obs_hist_R_linear_quasi), label="Linear Score", linewidth=3, color=:green)
lines!(ax2, 0:length(obs_hist_R_clustered_quasi)-1, get_var(obs_hist_R_clustered_quasi), label="KGMM Score", linewidth=3, color=:orange)
hlines!(ax2, [target_variance], color=:red, linestyle=:dash, label="Target")
Legend(ga[1, 3], ax1, "Method")
display(fig)
CairoMakie.activate!()
save("triad_model_full_analysis.pdf", fig)
println("✅ Figure saved to triad_model_full_analysis.pdf")
GLMakie.activate!()