In [31]:
using JLD2
using OrdinaryDiffEq, Lux, StableRNGs, ComponentArrays
using DelimitedFiles

@load "models/PINN_sol.jld2" PINN_sol

└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:574
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:458
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:574
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:570
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:458
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:574
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:458
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:560
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:566
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:566
└ @ JLD2 /Users/michal/.julia/packages/JLD2/pdSa4/src/data/reconstructing_datatypes.jl:566

1-element Vector{Symbol}:
 :PINN_sol

In [29]:
function Valve(R, deltaP, open)
    dq = 0.0
    if (-open) < 0.0
        dq = deltaP / R
    else
        dq = 0.0
    end
    return dq

end

function ShiElastance(t, Eₘᵢₙ, Eₘₐₓ, τ, τₑₛ, τₑₚ, Eshift)
    τₑₛ = τₑₛ * τ

    τₑₚ = τₑₚ * τ
    #τ = 4/3(τₑₛ+τₑₚ)
    tᵢ = rem(t + (1 - Eshift) * τ, τ)

    Eₚ =
        (tᵢ <= τₑₛ) * (1 - cos(tᵢ / τₑₛ * pi)) / 2 +
        (tᵢ > τₑₛ) * (tᵢ <= τₑₚ) * (1 + cos((tᵢ - τₑₛ) / (τₑₚ - τₑₛ) * pi)) / 2 +
        (tᵢ <= τₑₚ) * 0

    E = Eₘᵢₙ + (Eₘₐₓ - Eₘᵢₙ) * Eₚ

    return E
end

function DShiElastance(t, Eₘᵢₙ, Eₘₐₓ, τ, τₑₛ, τₑₚ, Eshift)

    τₑₛ = τₑₛ * τ
    τₑₚ = τₑₚ * τ
    #τ = 4/3(τₑₛ+τₑₚ)
    tᵢ = rem(t + (1 - Eshift) * τ, τ)

    DEₚ =
        (tᵢ <= τₑₛ) * pi / τₑₛ * sin(tᵢ / τₑₛ * pi) / 2 +
        (tᵢ > τₑₛ) * (tᵢ <= τₑₚ) * pi / (τₑₚ - τₑₛ) * sin((τₑₛ - tᵢ) / (τₑₚ - τₑₛ) * pi) / 2
    (tᵢ <= τₑₚ) * 0
    DE = (Eₘₐₓ - Eₘᵢₙ) * DEₚ

    return DE
end


#Shi timing parameters
Eshift = 0.0
Eₘᵢₙ = 0.03

τₑₛ = 0.3
τₑₚ = 0.45
Eₘₐₓ = 1.5
Rmv = 0.006
τ = 1.0


function NIK_PINN!(du, u, p, t)
    pLV, psa, psv, Vlv, Qav, Qmv, Qs = u
    τₑₛ, τₑₚ, Rmv, Zao, Rs, Csa, Csv, Eₘₐₓ, Eₘᵢₙ = params

    # Neural Network component (NN for correction)
    NN_output = NN(u, p, st)[1]

    # The differential equations with NN correction
    du[1] =
        (Qmv - Qav) * ShiElastance(t, Eₘᵢₙ, Eₘₐₓ, τ, τₑₛ, τₑₚ, Eshift) +
        pLV / ShiElastance(t, Eₘᵢₙ, Eₘₐₓ, τ, τₑₛ, τₑₚ, Eshift) *
        DShiElastance(t, Eₘᵢₙ, Eₘₐₓ, τ, τₑₛ, τₑₚ, Eshift) +
        NN_output[1]
    du[2] = (Qav - Qs) / Csa + NN_output[2] #Systemic arteries     
    du[3] = (Qs - Qmv) / Csv + NN_output[3] # Venous
    du[4] = Qmv - Qav + NN_output[4] # LV volume
    du[5] = Valve(Zao, (du[1] - du[2]), u[1] - u[2]) + NN_output[5]  # AV 
    du[6] = Valve(Rmv, (du[3] - du[1]), u[3] - u[1]) + NN_output[6]  # MV
    du[7] = (du[2] - du[3]) / Rs + NN_output[7] # Systemic flow

end

u0 = [6.0, 6.0, 6.0, 200.0, 0.0, 0.0, 0.0]
params = [0.3, 0.45, 0.006, 0.033, 1.11, 1.13, 11.0, 1.5, 0.03]

NN = Lux.Chain(
    Lux.Dense(7, 10, elu),
    Lux.Dense(10, 10, elu),
    Lux.Dense(10, 10, elu),
    Lux.Dense(10, 7),
)

rng = StableRNG(5958)
p, st = Lux.setup(rng, NN)
p = 0.5 * ComponentVector{Float64}(p)

[0mComponentVector{Float64}(layer_1 = (weight = [-0.025089571252465248 -0.13317450881004333 … -0.12778112292289734 0.22667565941810608; 0.16364116966724396 -0.20269614458084106 … -0.05486127734184265 0.1415461152791977; … ; -0.1762702912092209 0.006427908316254616 … -0.12591703236103058 0.14767184853553772; -0.13822905719280243 -0.24787046015262604 … -0.18120597302913666 0.20931868255138397], bias = [-0.15276210010051727, 0.14352291822433472, 0.07330312579870224, -0.007145705167204142, 0.08284787833690643, 0.12516358494758606, -0.18695995211601257, 0.06358467042446136, -0.1222396194934845, 0.05672786757349968]), layer_2 = (weight = [-0.07584447413682938 -0.2616642117500305 … -0.16171568632125854 0.005802055820822716; 0.09017910063266754 -0.24110998213291168 … -0.13088202476501465 -0.011641394346952438; … ; -0.035823144018650055 0.21726873517036438 … 0.08900126814842224 0.1112726554274559; -0.05435895547270775 0.27346813678741455 … -0.06243923678994179 -0.11225976794958115], bias = [-0

In [32]:
tspan2 = (0.0, 20.0)
num_of_samples = 3000
tsteps = range(0.0, 20.0, length = num_of_samples)

trained_NN = ODEProblem(NIK_PINN!, u0, tspan2, PINN_sol.u)
s = solve(trained_NN, Vern7(), dtmax = 1e-2, saveat = tsteps, reltol = 1e-7, abstol = 1e-4)

data_to_save = hcat(s[1, :], s[2, :], s[3, :], s[4, :], s[5, :], s[6, :], s[7, :])

println("Pinn extrapolation saved to presaved_pinn_extrapolation.txt")
writedlm("data/presaved_pinn_extrapolation.txt", data_to_save)

Pinn extrapolation saved to presaved_pinn_extrapolation.txt
