In [None]:
import numpy as np
from matplotlib import pyplot as plt

In [None]:
plt.rcParams.update(
    {
        "font.size": 24,
        "axes.linewidth": 3,
        "xtick.major.width": 3,
        "ytick.major.width": 3,
        "xtick.minor.width": 3,
        "ytick.minor.width": 3,
    }
)

In [None]:
steps = 3000
exp_data = np.loadtxt("perm_v2_absorbance_525nm.csv", delimiter=",", skiprows=1)[:260]
# normalise the experimental data
exp_data[:, 1] = exp_data[:, 1] / np.max(exp_data[:, 1])

exp_indices = np.digitize(exp_data[:, 0], list(range(steps)), right=True) - 1

In [None]:
def run_skrabal(
    k1=1e-1,
    k2=1e-1,
    k3=1e-1,
    k4=1e-1,
    k5=1e-1,
    k6=1e-1,
    k7=1e-1,
    k8=1e-1,
):
    # initial concentrations in mmol/L
    c_permanganate = 1
    c_oxalic = 2.5
    c_Mn6 = 0
    c_Mn4 = 0
    c_Mn3 = 0
    c_Mn2 = 1e-3
    c_CO2 = 0
    c_CO2_neg = 0
    # arrays for storing concentrations
    a_permanganate = np.zeros(steps)
    a_oxalic = np.zeros(steps)
    a_Mn6 = np.zeros(steps)
    a_Mn4 = np.zeros(steps)
    a_Mn3 = np.zeros(steps)
    a_Mn2 = np.zeros(steps)
    a_CO2 = np.zeros(steps)
    a_CO2_neg = np.zeros(steps)

    for t in range(steps):
        # calculate increments
        d_permanganate = -k1 * c_permanganate * c_Mn2 - k8 * c_permanganate * c_Mn3
        d_oxalic = (
            -k4 * c_Mn6 * c_oxalic - k5 * c_Mn4 * c_oxalic - k6 * c_Mn3 * c_oxalic
        )
        d_Mn6 = (
            +k1 * c_permanganate * c_Mn2
            - k2 * c_Mn6 * c_Mn2
            - k4 * c_Mn6 * c_oxalic
            + k8 * c_permanganate * c_Mn3
        )
        d_Mn4 = (
            +2 * k2 * c_Mn6 * c_Mn2
            - k3 * c_Mn4 * c_Mn2
            + k4 * c_Mn6 * c_oxalic
            - 2 * k5 * c_Mn4 * c_oxalic
            + k8 * c_permanganate * c_Mn3
        )
        d_Mn3 = (
            +k1 * c_permanganate * c_Mn2
            + 2 * k3 * c_Mn4 * c_Mn2
            + 2 * k5 * c_Mn4 * c_oxalic
            - k6 * c_Mn3 * c_oxalic
            - k7 * c_Mn3 * c_CO2_neg
            - k8 * c_permanganate * c_Mn3
        )
        d_Mn2 = (
            -k1 * c_permanganate * c_Mn2
            - k2 * c_Mn6 * c_Mn2
            - k3 * c_Mn4 * c_Mn2
            + k6 * c_Mn3 * c_oxalic
            + k7 * c_Mn3 * c_CO2_neg
        )
        d_CO2 = (
            +2 * k4 * c_Mn6 * c_oxalic
            + 2 * k5 * c_Mn4 * c_oxalic
            + k6 * c_Mn3 * c_oxalic
            + k7 * c_Mn3 * c_CO2_neg
        )
        d_CO2_neg = +k6 * c_Mn3 * c_oxalic - k7 * c_Mn3 * c_CO2_neg

        # store current concentrations
        a_permanganate[t] = c_permanganate
        a_oxalic[t] = c_oxalic
        a_Mn6[t] = c_Mn6
        a_Mn4[t] = c_Mn4
        a_Mn3[t] = c_Mn3
        a_Mn2[t] = c_Mn2
        a_CO2[t] = c_CO2
        a_CO2_neg[t] = c_CO2_neg

        # update concentrations
        c_permanganate += d_permanganate
        c_oxalic += d_oxalic
        c_Mn6 += d_Mn6
        c_Mn4 += d_Mn4
        c_Mn3 += d_Mn3
        c_Mn2 += d_Mn2
        c_CO2 += d_CO2
        c_CO2_neg += d_CO2_neg

    a_perm_filt = a_permanganate[exp_indices]
    return np.sqrt(np.mean((exp_data[:, 1] - a_perm_filt) ** 2)), a_perm_filt

In [None]:
from skopt.plots import plot_evaluations, plot_objective, plot_convergence
from skopt import Optimizer
from tqdm import tqdm

In [None]:
opt = Optimizer(
    dimensions=[
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
        (0.0, 1e-1),
    ],
    # base_estimator="GP",
    # acq_func="EI",
    # acq_optimizer="sampling",
    # initial_point_generator="lhs",
)

In [None]:
for it in tqdm(range(50)):
    next_params = opt.ask()
    # use negative value because the function is being minimised
    rmse, sim_data = run_skrabal(*next_params)
    res = opt.tell(next_params, rmse)

In [None]:
for i, k in enumerate(opt.get_result()["x"]):
    print(f"k{i+1}:", k)
print("RMSE:", opt.get_result()["fun"])

In [None]:
from matplotlib import pyplot as plt

a_perm_filt = run_skrabal(*opt.get_result()["x"])[1]

plt.figure(figsize=(12, 6))
plt.plot(exp_data[:, 0], exp_data[:, 1], label="Experimental data")
plt.plot(exp_data[:, 0], a_perm_filt, label="Simulation for Permanganate")

plt.legend()
plt.xlabel("Time step")
plt.ylabel("Absorbance (normalised)")
plt.title("Comparison of experimental and simulated data")
plt.savefig(f"perm_v2_exp.svg", bbox_inches="tight", dpi=300, transparent=True)
plt.show()

In [None]:
# _ = plot_objective(res)

In [None]:
# _ = plot_convergence(res)