# 'Single Case' optimization

Here we demonstrate the optimization of KPP in a perfect model setting for a single 'simple flux' case.

In [1]:
using Pkg; Pkg.activate("..")

using PyPlot, Printf, Statistics, OceanTurb, Dao, JLD2,
        ColumnModelOptimizationProject, ColumnModelOptimizationProject.KPPOptimization

In [7]:
     model_N = 30                # Model resolution. Perfect model resolution is N=600
    model_dt = 10*minute         # Model timestep. Perfect model timestep is 1 minute.
initial_data = 1                 # Choose initial condition for forward runs
 target_data = (4, 7, 10)        # Target samples of saved data for model-data comparison
        case = "unstable_weak"   # Case to run comparison

# Initialize the 'data' and the 'model'
datadir = joinpath("..", "data", "perfect_model_experiment")
filepath = joinpath(datadir, case * ".jld2")
data = ColumnData(filepath; initial=initial_data, targets=target_data)
model = KPPOptimization.ColumnModel(data, model_dt, N=model_N)

# Pick a set of parameters to optimize
defaults = DefaultFreeParameters(BasicParameters)

println("We will optimize the following parameters:")
@show propertynames(defaults);

We will optimize the following parameters:
propertynames(defaults) = (:CRi, :CKE, :CNL, :Cτ, :Cstab, :Cunst, :Cb_U, :Cb_T, :Cd_U, :Cd_T)


# Loss function definition

Our optimization strategy uses a 'composite' loss function constructed from error estimates for each field.
For a field $\Phi$ and case $i$, the loss function for a given parameter vector $\boldsymbol{C}$ is defined

$$ \mathcal{L}^i_\Phi(\boldsymbol{C}) = \int_{t_0}^{t_1}  \mathrm{d} t \sqrt{\int_{-L_z}^0 \mathrm{d} z \left ( \Phi_\mathrm{model} - \Phi_\mathrm{data} \right )^2 }$$

The total loss function for a case $i$ is then constructed from

$$ \mathcal{L}^i(\boldsymbol{C}) = \sum_{\Phi} \mathcal{W}_\Phi \mathcal{L}^i_\Phi(\boldsymbol{C}) \, , $$

In this notebook we only consider a single case. 
However, with multiple cases (representing different physical scenarios, flux conditions, initial conditions, *et cetera*), the value of the composite loss function is, finally

$$ \mathcal{L}(\boldsymbol{C}) = \sum_i \mathcal{L}^i(\boldsymbol{C}) $$

where $\mathcal{W}_\Phi$ essentially non-dimensionalizes the error associated with the comparison of different fields such as velocity $U$ or temperature $T$.

Below, we deduce the appropriate $\mathcal{W}_\Phi$ for our problem based on the velocity- and temperature-specific error assocaited with our initial (perfect) parameter choices.

Note that the object `MarkovLink` stores `(parameter, error)` pairs for a given parameter vector and `NegativeLogLikelihood` function. 

In [11]:
# Obtain an estimate of the relative error in the temperature and velocity fields
test_nll_temperature = NegativeLogLikelihood(model, data, temperature_loss)
   test_nll_velocity = NegativeLogLikelihood(model, data, velocity_loss)

# Calculate error
test_link_temperature = MarkovLink(test_nll_temperature, defaults)
   test_link_velocity = MarkovLink(test_nll_velocity, defaults)

@show error_ratio = test_link_velocity.error / test_link_temperature.error

# Build the weighted NLL, normalizing temperature error relative to velocity error.
@show weights = (1, 1, 10*round(Int, error_ratio/10), 0)

nll = NegativeLogLikelihood(model, data, weighted_fields_loss, weights=weights);

error_ratio = test_link_velocity.error / test_link_temperature.error = 6.290572515193099
weights = (1, 1, 10 * round(Int, error_ratio / 10), 0) = (1, 1, 10, 0)


# Specification of Markov Chain Monte Carlo parameters

Next we demonstrate how to set up a Markov Chain Monte Carlo (MCMC) sampler.
Currently, we are only able to use the 'Hastings-Metropolis' algorithm -- a simple 
algorithm based on normally-distributed random perturbations.

The algorithm consists of three steps:

1. Generate a set of proposal parameters by normally perturbing the current parameters.
2. Evaluate the Negative Log Likelihood for the proposed parameters.
3. Decide whether to 'accept' the proposed parameters as the new current parameters on the criterion:

```julia
accept(proposal, current, scale) = current.error - proposal.error > scale * log(rand(Uniform(0, 1)))
```

The code for this algorithm is:

```julia
accepted = 0
for i = 1:nlinks
    proposal = MarkovLink(nll, sampler.perturb(current.param))
    current = ifelse(accept(proposal, current, nll.scale), proposal, current)
    push!(links, proposal)
    if current === proposal
        accepted += 1
        push!(path, i)
    else
        @inbounds push!(path, path[end])
    end
end
```

That's it.

The two parameters in the MCMC algorithm are thus the 'scale' of the specified Negative Log Likelihood function, and the standard deviation of the random perturbations that generate proposal parameter sets.

Below we set the error scale to $\tfrac{1}{2}$ the error associated with the initial parameter choices. 
Note that a smaller error scale leads to a higher acceptance rate.
We also fix the standard deviation of the random perturbations at 5\% of the initial parameter values.

In [12]:
# Obtain the first link in the Markov chain
first_link = MarkovLink(nll, defaults)
@show nll.scale = first_link.error * 0.5 

# Use a normal perturbation with standard deviation set to 5% of default values
std = DefaultStdFreeParameters(0.05, typeof(defaults))
sampler = MetropolisSampler(NormalPerturbation(std))

ninit = 10^2
chain = MarkovChain(ninit, first_link, nll, sampler)
@show chain.acceptance 

nll.scale = first_link.error * 0.5 = 0.008473015501915086
chain.acceptance = 0.87


0.87

# Collecting samples

Finally, we collect samples and save the data. 
Fortunately, `JLD2` is able to save the entire Markov chain as a single object with no issues.

In [13]:
dsave = 10^3
chainname = "test_markov_chain"
chainpath = "$chainname.jld2"
@save chainpath chain

tstart = time()
for i = 1:4
    tint = @elapsed extend!(chain, dsave)

    @printf("tᵢ: %.2f seconds. Elapsed wall time: %.4f minutes.\n\n", tint, (time() - tstart)/60)
    @printf("First, optimal, and last links:\n")
    println((chain[1].error, chain[1].param))
    println((optimal(chain).error, optimal(chain).param))
    println((chain[end].error, chain[end].param))
    println(" ")

    println(status(chain))

    oldchainpath = chainname * "_old.jld2"
    mv(chainpath, oldchainpath, force=true)
    @save chainpath chain
    rm(oldchainpath)
end

tᵢ: 6.13 seconds. Elapsed wall time: 0.1052 minutes.
First, optimal, and last links:
(0.016946031003830173, [0.3, 4.32, 6.33, 0.4, 2.0, 6.4, 0.599, 1.36, 0.5, 2.5])
(0.009129877759768668, [0.149149, 14.7253, 4.49359, 0.372537, 2.34237, 15.412, 0.264218, 3.80903, 0.0982459, 5.55969])
(0.013594107012545245, [0.12174, 16.9578, 6.14141, 0.342488, 4.65758, 17.3954, 0.806114, 4.78111, 0.0619675, 7.28097])
 
               length | 1100
           acceptance | 0.824545455
 initial scaled error | 2.000000000
 optimal scaled error | 1.077524024

tᵢ: 6.00 seconds. Elapsed wall time: 0.2086 minutes.
First, optimal, and last links:
(0.016946031003830173, [0.3, 4.32, 6.33, 0.4, 2.0, 6.4, 0.599, 1.36, 0.5, 2.5])
(0.009129877759768668, [0.149149, 14.7253, 4.49359, 0.372537, 2.34237, 15.412, 0.264218, 3.80903, 0.0982459, 5.55969])
(0.014214405743991833, [0.131329, 13.7006, 2.69985, 0.392087, 5.63164, 16.7019, 1.59188, 3.86147, 0.0748555, 3.45431])
 
               length | 2100
           acceptance |