In [1]:
using Pkg
Pkg.activate("/media/mat/HDD/EnKF/")

"/media/mat/HDD/EnKF/Project.toml"

In [2]:
using Revise
using EnKF
using Distributions
using DocStringExtensions
using LinearAlgebra
using ProgressMeter
using DifferentialEquations

┌ Info: Recompiling stale cache file /home/mat/.julia/compiled/v1.1/Revise/M1Qoh.ji for Revise [295af30f-e4ad-537b-8983-00126c2a3abe]
└ @ Base loading.jl:1184
┌ Info: Recompiling stale cache file /home/mat/.julia/compiled/v1.1/EnKF/oXK06.ji for EnKF [685896a8-a41b-11e9-3419-3315e75b5d74]
└ @ Base loading.jl:1184
└ @ Base.Docs docs/Docs.jl:223
└ @ Base.Docs docs/Docs.jl:223
┌ Info: Recompiling stale cache file /home/mat/.julia/compiled/v1.1/DifferentialEquations/UQdwS.ji for DifferentialEquations [0c46a032-eb83-5123-abaf-570d42b7fbaa]
└ @ Base loading.jl:1184


In [3]:
using Plots
default(tickfont = font("CMU Serif", 9), 
        titlefont = font("CMU Serif", 14), 
        guidefont = font("CMU Serif", 12),
        legendfont = font("CMU Serif", 10),
        grid = false)
clibrary(:colorbrewer)
pyplot()

┌ Info: Recompiling stale cache file /home/mat/.julia/compiled/v1.1/Plots/ld3vC.ji for Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1184
┌ Info: Recompiling stale cache file /home/mat/.julia/compiled/v1.1/PyPlot/oatAj.ji for PyPlot [d330b81b-6aea-500a-939a-2ce795aea3ee]
└ @ Base loading.jl:1184


Plots.PyPlotBackend()

We are interested in simulating the Lorenz attractor
            $$\frac{d x_i}{dt} = (x_{i+1} - x_{i-2})x_{i-1} - x_i + F$$ with 
            $F = 8, x_0 = x_{40}, x_{-1} = x_{39}, x_{41} = x_1$ 

Define parameters of the Lorenz attractor

In [4]:
function lorenz95(du,u,p,t)
    
    # first the 3 edge cases: i=1,2,N
    du[1] = (u[2] - u[39])*u[40] - u[1] 
    du[2] = (u[3] - u[40])*u[1] -  u[2]
    du[40] = (u[1] - u[38])*u[39] -  u[40]
    
    for i = 3:39
    du[i] = (u[i+1] - u[i-2])*u[i-1] -u[i] 
    end

    du .+=8.0
end


u0 = rand(40)
tspan = (0.0,500.0)

Δt = 1e-2
T = tspan[1]:Δt:tspan[end]

prob = ODEProblem(lorenz95,u0,tspan)
sol = solve(prob, RK4(), adaptive = false, dt = Δt)

integrator = init(prob, RK4(), adaptive =false, dt = Δt, save_everystep=false)

t: 0.0
u: 40-element Array{Float64,1}:
 0.098818440600275   
 0.6141722464455859  
 0.4999790804728117  
 0.7711691428072494  
 0.16858871073581017 
 0.08898788661152524 
 0.5264096163970087  
 0.9772470764924037  
 0.17885531920050668 
 0.0979838833805815  
 0.2980952441275022  
 0.8905814697024217  
 0.15459637367347767 
 ⋮                   
 0.35625563082063705 
 0.36961260053053935 
 0.2132271105648491  
 0.5010463476622813  
 0.037103303045139535
 0.43772539650123754 
 0.4761738639676185  
 0.652499558777418   
 0.7617681218611547  
 0.33590631653589664 
 0.7763229001998402  
 0.950600566053333   

In [5]:
states = [deepcopy(u0)]

1-element Array{Array{Float64,1},1}:
 [0.0988184, 0.614172, 0.499979, 0.771169, 0.168589, 0.0889879, 0.52641, 0.977247, 0.178855, 0.0979839  …  0.213227, 0.501046, 0.0371033, 0.437725, 0.476174, 0.6525, 0.761768, 0.335906, 0.776323, 0.950601]

In [6]:
for t in T[1:end-1]
    step!(integrator)
    push!(states, deepcopy(integrator.u))
end

In [None]:
plot(T, hcat(states...)[1,:], linewidth  = 3)
plot!(T, hcat(states...)[2,:], linewidth = 3)
plot!(T, hcat(states...)[3,:], linewidth = 3)

Define propagation function fprop

In [None]:
function (::PropagationFunction)(t::Float64, ENS::EnsembleState{N, TS}) where {N, TS}
    for (i,s) in enumerate(ENS.S)
        
        set_t!(integrator, deepcopy(t))
        set_u!(integrator, deepcopy(s))
        for j=1:50
        step!(integrator)
        end
        ENS.S[i] = deepcopy(integrator.u)

    end
    
    return ENS
end

In [None]:
fprop = PropagationFunction()

Define measurement function m

In [None]:
function (::MeasurementFunction)(t::Float64, s::TS) where TS
    return s
end

In [None]:
function (::MeasurementFunction)(t::Float64) 
    return I
end

In [None]:
m = MeasurementFunction()

Define  real measurement function z, always measure the true state but is corrupted by noise ϵ

In [None]:
function (::RealMeasurementFunction)(t::Float64, ENS::EnsembleState{N, TZ}) where {N, TZ}
    let s = sol(t)
    fill!(ENS, deepcopy(s))
    end
    return ENS
end

In [None]:
z = RealMeasurementFunction()

Define filtering function

In [None]:
g = FilteringFunction()

Define covariance inflation

In [None]:
# A = MultiAdditiveInflation(40, 1.05, MvNormal(zeros(40), 2.0*I))
# A = RTPSInflation(0.8)
# A = MultiplicativeInflation(40, 1.05)
A = RTPSAdditiveInflation(0.75, MvNormal(zeros(40), 0.1*I));
# A = IdentityInflation()

Define noise covariance

In [None]:
ϵ = AdditiveInflation(MvNormal(zeros(40), 1.0*I));

In [None]:
N = 10
NZ = 40
isinflated = true
isfiltered = false
isaugmented = false;

In [None]:
ens = initialize(N, MvNormal(zeros(40), 2.0*I))
estimation_state = [deepcopy(ens.S)]

tmp = deepcopy(u0)
true_state = [deepcopy(u0)]

In [None]:
enkf = ENKF(N, NZ, fprop, A, g, m, z, ϵ, isinflated, isfiltered, isaugmented)

### Ensemble Kalman filter estimation

In [None]:
Δt = 1e-2
Tsub = 0.0:50*Δt:500.0-50*Δt

@showprogress for (n,t) in enumerate(Tsub)

    global ens
#     enkf.f(t, ens)
    t, ens,_ = enkf(t, 50*Δt, ens)
    push!(estimation_state, deepcopy(ens.S))
    

end

$$ \|A\|_p = \left( \sum_{i=1}^n | a_i | ^p \right)^{1/p} $$

In [None]:
s =  hcat(sol(T).u...)
ŝ =  hcat(mean.(estimation_state)...)
ssub =  hcat(sol(Tsub).u...)
norm(ssub - ŝ[:,1:end-1])/norm(ssub)

In [None]:
plt = plot(layout = (3, 1), legend = true)
plot!(plt[1], T, s[1,1:end], linewidth = 2, label = "truth")
scatter!(plt[1], Tsub, ŝ[1,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)

plot!(plt[2], T, s[2,1:end], linewidth = 2, label = "truth")
scatter!(plt[2], Tsub, ŝ[2,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "y", linestyle =:dash)

plot!(plt[3], T, s[3,1:end], linewidth = 2, label = "truth")
scatter!(plt[3], Tsub, ŝ[3,1:end-1], linewidth = 2, markersize = 3, label = "EnKF mean", xlabel = "t", ylabel = "z", linestyle =:dash)

In [None]:
plot(T, s[1,:], linewidth = 3, label = "truth")
# plot!(Tsub, ŝ[1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)
scatter!(Tsub, ŝ[1,1:end-1], linewidth = 3, label = "EnKF mean", xlabel = "t", ylabel = "x", linestyle =:dash)