In [4]:
using DifferentialEquations, Plots, DelimitedFiles, Flux, DiffEqSensitivity

In [30]:
using DifferentialEquations, Plots, DelimitedFiles, Flux

# Define ODE
function EpMod(x, p, t)
    s, e, i, r = x
    γ, R₀, σ = p
    return [-γ*R₀*s*i;       # ds/dt = -γR₀si
             γ*R₀*s*i -  σ*e;# de/dt =  γR₀si -σe
             σ*e - γ*i;      # di/dt =         σe -γi
                   γ*i;      # dr/dt =             γi
            ]
end

#=
N = 10^7 # Population
i_0 = 2000/N                # Initially infected
e_0 = 4.0 * i_0             # Initially exposed
s_0 = 1.0 - i_0 - e_0       # Initially susceptible
r_0 = 0.0                   # Initially recovered
x_0 = [s_0, e_0, i_0, r_0]  # initial state of the system
γ = 1/18
R₀ = 3.0
σ = 1/5.2 
p_0 = [γ,R₀,σ] # parameters
=#
N = 10^7
i_0 = 1E-7                  # 33 = 1E-7 * 330 million population = initially infected
e_0 = 4.0 * i_0             # 132 = 1E-7 *330 million = initially exposed
s_0 = 1.0 - i_0 - e_0
r_0 = 0.0
x_0 = [s_0, e_0, i_0, r_0]  # initial condition
γ = 1/5
R₀ = 3.0
σ = 1/5.2 # parameters
p_0 = [γ,R₀,σ]


dat = readdlm("covid.txt",',');
time_i = dat[:,1]; # times
Infected = 50*dat[:,2]/N; # Normalize data
Infected[2:end] = Infected[2:end] 
#scatter(time_i,Infected,label="Infected")


# Solve the problem with an ODE solver
tspan = (0.0, maximum(time_i))  # ≈ 170 days
prob = ODEProblem(EpMod, x_0, tspan, p_0)
sol = solve(prob, Tsit5())
sol2 = solve(prob,Tsit5(),saveat=0.1)
# plot the solution
plot(sol, labels = ["s" "e" "i" "r"], title = "SEIR Dynamics", lw = 2, xlabel = "t")
tshort_inds = (Int.(time_i)*10).+1
tshort = Int.(time_i).+1
A = sol2[3,tshort_inds] # length 101 vector
A = A + randn(length(tshort))/1000
A = Infected;
scatter!(tshort,A)



# Initial guess
p = [0.17,5.7,0.170]
params = Flux.params(p)

# Prediction function
function predict_rd() # Our 1-layer "neural network"
  solve(prob,Tsit5(),p=p,saveat=0.1,maxiters=1e7)[3,:] # override with new parameters
end

# loss function: MSE
loss_rd() = sum(abs2,predict_rd()[tshort_inds] .-  A) 


iter = 0
name_vorlage = "video/Flux/flux_realdata_im"

# Perform training
data = Iterators.repeated((), 200)
opt = ADAM()
cb = function () #callback function to observe training
  display(loss_rd())
  # using `remake` to re-create our `prob` with current parameters `p`
  #display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,10)))
  cur_pred = predict_rd()
  global iter
  iter += 1
  im_name = name_vorlage*string(iter)*".png"
  pl = scatter(tshort,A,
        label="Data",
        yticks = ([0:0.05:0.25;], ["0", "10000", "20000", "30000", "40000", "50000", "60000"]),
        ylabel="Infected",
        xlabel="t",
        legend=:outertopright, 
        size=(1500,1000),
        dpi=500,
        xtickfont=font(20),
        ytickfont=font(20),
        yguidefontsize=18,
        xguidefontsize=18,
        legendfontsize=20,
        titlefontsize=20,
        markersize=10,
        left_margin=10Plots.mm,
        bottom_margin=10Plots.mm)
  plot!(pl,solve(remake(prob,p=p),Tsit5(),saveat=0.1,maxiters=1e7),vars=(3),ylim=(0,0.25), 
        labels = "SEIR", 
        color="darkgreen",
        linewidth=5)
  #display(plot(pl))
  savefig(im_name)
end

#30 (generic function with 1 method)

In [31]:
# Display the ODE with the initial parameter values.
cb()
#=
numEpochs = 300
k = 10
train_loader = Flux.Data.DataLoader((data, t), batchsize = k)
=#
Flux.train!(loss_rd, params, data, opt, cb = cb)

0.29616427937098294

0.286358505493441

0.2762839504551078

0.2659784519356142

0.25548952339852776

0.24484232125949365

0.234076134495983

0.22324008085206246

0.21238914617401

0.20157928507318987

0.19085123311783128

0.1802565398548582

0.16984235516446305

0.159644477159539

0.14969643133888044

0.1400213509304226

0.13063676458196674

0.12155976318852327

0.11281332370629285

0.10440379989803024

0.09631935559470557

0.08856865868706609

0.08117082530605073

0.07415212491194245

0.06753273677648627

0.06134261389235338

0.055608113077070115

0.0503647311476963

0.04564358778191986

0.04146595850992288

0.03784548644958142

0.03478584518880315

0.032279571634109466

0.030312272826659818

0.028847256123334775

0.02784084693912576

0.027238168340902046

0.026971059386446224

0.02696665890830455

0.027147996310833703

0.02744013836405852

0.02777587067205464

0.028097923056965655

0.02835985909744949

0.028527780470751996

0.02858048064849474

0.028508563138202245

0.02831290451461646

0.028002730615519725

0.027593523900377094

0.027104922900802023

0.02655874689321034

0.02597724560247134

0.025381637899384007

0.024790934050337145

0.024221151709734507

0.023684832796248917

0.023190738497167615

0.022743890505331273

0.022346193128080293

0.02199672023790029

0.021692242860324742

0.02142782086239302

0.021197449610916762

0.02099456071866133

0.020812406720944712

0.02064455393435921

0.02048520209210873

0.02032940424175123

0.020173195788649786

0.02001364391456617

0.019848828363064307

0.01967776786118533

0.019500305888251374

0.019316969789737386

0.0191288140082211

0.018937260374846738

0.018743942980322906

0.01855056653744186

0.01835878316667547

0.018170092231439568

0.01798576550128409

0.017806797437896645

0.017633882040899513

0.017467409550514772

0.017307477920223576

0.017153932298084307

0.01700643184031259

0.016864484026504038

0.016727498066583222

0.016594836972311355

0.016465864831089496

0.016339986671798392

0.01621667920110702

0.016095511714193284

0.015976157338334157

0.01585839505838081

0.015742104181395773

0.015627252375850435

0.01551387976530109

0.015402080276485201

0.015291982584703823

0.01518373151562776

0.015077471840133258

0.014973334891290999

0.014871428389010125

0.014771829884003726

0.014674583986066882

0.01457970213765331

0.014487165272118295

0.014396928284337845

0.014308925809366454

0.014223078408488209

0.014139298917573784

0.014057498277786844

0.013977590446215639

0.013899496343921118

0.01382314653485513

0.013748483191329375

0.013675460235696633

0.013604042118515765

0.013534202841167303

0.013465923968830775

0.013399192866639226

0.013333999920953262

0.01327033692370318

0.01320819506077556

0.013147563262252943

0.013088427606203433

0.013030770237087914

0.012974569614671267

0.01291980036296525

0.012866433910347305

0.012814439052859792

0.012763782981543969

0.012714431833944415

0.012666351466539025

0.012619508429736487

0.012573870224737278

0.012529405805731164

0.012486085781187095

0.012443882595572569

0.012402770276256716

0.01236272454676207

0.01232372218760895

0.012285741354255888

0.012248760659786919

0.012212759310461528

0.012177716853092997

0.012143612890700144

0.012110426961145022

0.01207813864778523

0.012046727262653628

0.012016172393106945

0.011986453436187552

0.011957549907175374

0.011929441676208043

0.011902108823106336

0.011875531895969214

0.011849691919810732

0.011824570414626458

0.011800149486861344

0.011776411689188884

0.011753340208011083

0.011730918674708383

0.011709131167579243

0.011687962199212278

0.011667396570768861

0.011647419489983582

0.011628016385551906

0.011609173028215559

0.01159087527017373

0.011573109314413972

0.011555861499363279

0.011539118445510848

0.011522866997515487

0.011507094176421801

0.011491787267233138

0.011476933944757524

0.011462521974709481

0.011448539474772212

0.011434974858271744

0.011421816774060008

0.011409054160374023

0.011396676252182276

0.01138467252246121

0.011373032660598833

0.011361746713206958

0.011350804792059195

0.011340197512790414

0.011329915455983925

0.011319949515649085

0.011310290826129195

0.0113009307013611

0.011291860696094187

0.011283072585394856

0.01127455828969743

0.011266310007877051

0.011258320102046521

0.011250581100209316

0.011243085862173631