# Neural Ordinary Differential Equation

This notebook is meant to play and experiment with code in preperation for a talk I plan to give.

Original code is from this blogpost:
https://sebastiancallh.github.io/post/neural-ode-weather-forecast/


In [1]:
using DiffEqFlux

┌ Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
└ @ Base loading.jl:1278


In [2]:
function neural_ode(t, data_dim; saveat = t)
    f = FastChain(FastDense(data_dim, 64, swish),
          FastDense(64, 32, swish),
          FastDense(32, data_dim))

    node = NeuralODE(f, (minimum(t), maximum(t)), Tsit5(),
             saveat = saveat, abstol = 1e-9,
             reltol = 1e-9)
end

neural_ode (generic function with 1 method)

In [6]:
using DataFrames, CSV

┌ Info: Precompiling CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b]
└ @ Base loading.jl:1278


In [8]:
delhi_train = CSV.read("data/DailyDelhiClimateTrain.csv")
delhi_test = CSV.read("data/DailyDelhiClimateTest.csv")
delhi = vcat(delhi_train, delhi_test)

Unnamed: 0_level_0,date,meantemp,humidity,wind_speed,meanpressure
Unnamed: 0_level_1,Date…,Float64,Float64,Float64,Float64
1,2013-01-01,10.0,84.5,0.0,1015.67
2,2013-01-02,7.4,92.0,2.98,1017.8
3,2013-01-03,7.16667,87.0,4.63333,1018.67
4,2013-01-04,8.66667,71.3333,1.23333,1017.17
5,2013-01-05,6.0,86.8333,3.7,1016.5
6,2013-01-06,7.0,82.8,1.48,1018.0
7,2013-01-07,7.0,78.6,6.3,1020.0
8,2013-01-08,8.85714,63.7143,7.14286,1018.71
9,2013-01-09,14.0,51.25,12.5,1017.0
10,2013-01-10,11.0,62.0,7.4,1015.67


In [None]:
using Statistics
using Base.Iterators: take, cycle
using Dates

In [22]:
delhi[:,:year] = Float64.(year.(delhi[:,:date]))
delhi[:,:month] = Float64.(month.(delhi[:,:date]))
df_mean = by(delhi, [:year, :month],
         :meantemp => mean,
         :humidity => mean,
         :wind_speed => mean,
         :meanpressure => mean)

rename!(df_mean, [:year, :month, :meantemp,
          :humidity, :wind_speed, :meanpressure])

df_mean[!,:date] .= df_mean[:,:year] .+ df_mean[:,:month] ./ 12;

In [33]:
t = df_mean[:, :date] |>
    t -> t .- minimum(t) |>
    t -> reshape(t, 1, :)

y = df_mean[:, ["meantemp", "humidity", "wind_speed", "meanpressure"]] |>
    y -> Matrix(y)' |>
    y -> (y .- mean(y, dims = 2)) ./ std(y, dims = 2)

T = 20
train_dates = df_mean[1:T, :date]
test_dates = df_mean[T+1:end, :date]
train_t, test_t = t[1:T], t[T:end]
train_y, test_y = y[:,1:T], y[:,T:end];

In [36]:
using OrdinaryDiffEq, Flux, Random

In [40]:
function train_one_round(node, θ, y, opt, maxiters,
                 y0 = y[:, 1]; kwargs...)
    predict(θ) = Array(node(y0, θ))
    loss(θ) = begin
    ŷ = predict(θ)
    Flux.mse(ŷ, y)
    end

    θ = θ == nothing ? node.p : θ
    res = DiffEqFlux.sciml_train(
    loss, θ, opt,
    maxiters = maxiters;
    kwargs...
    )
    return res.minimizer
end


function train(θ = nothing, maxiters = 150, lr = 1e-2)
    log_results(θs, losses) =
    (θ, loss) -> begin
        push!(θs, copy(θ))
        push!(losses, loss)
        false
    end

    θs, losses = [], []
    num_obs = 4:4:length(train_t)
    for k in num_obs
    node = neural_ode(train_t[1:k], size(y, 1))
    θ = train_one_round(
        node, θ, train_y[:, 1:k],
        ADAMW(lr), maxiters;
        cb = log_results(θs, losses)
    )
    end
    θs, losses
end

train (generic function with 4 methods)

In [41]:
Random.seed!(1)
θs, losses = train();

[32mloss: 0.000253: 100%|█████████████████████████████████████████| Time: 0:00:13[39m
[32mloss: 0.0173: 100%|█████████████████████████████████████████| Time: 0:00:13[39m
[32mloss: 0.0565: 100%|█████████████████████████████████████████| Time: 0:00:19[39m
[32mloss: 0.0518: 100%|█████████████████████████████████████████| Time: 0:00:29[39m
[32mloss: 0.0557: 100%|█████████████████████████████████████████| Time: 0:00:37[39m


In [47]:
θs

755-element Array{Any,1}:
 Float32[0.028512908, 0.1896838, 0.040241625, 0.034062505, -0.13109177, 0.16505428, 0.18624657, -0.29472885, 0.118629284, 0.040455647  …  0.36531135, -0.24339958, 0.057699196, -0.19168895, 0.40598547, -0.048518263, 0.0, 0.0, 0.0, 0.0]
 Float32[0.0385129, 0.1796838, 0.050241616, 0.044062488, -0.14109176, 0.17505427, 0.19624655, -0.30472884, 0.10862929, 0.050452232  …  0.35531136, -0.23339976, 0.06769919, -0.20168895, 0.41598547, -0.05851818, 0.009999999, -0.009999999, 0.01, -0.009999991]
 Float32[0.04799506, 0.16981769, 0.06005909, 0.04785661, -0.15107925, 0.18498752, 0.20619051, -0.31472215, 0.09881389, 0.05796078  …  0.35430884, -0.23207727, 0.07730205, -0.21129222, 0.42558947, -0.06821115, 0.019982912, -0.019983353, 0.01998151, -0.019942634]
 Float32[0.056750584, 0.16027273, 0.06958249, 0.047821477, -0.16107944, 0.1949102, 0.21613178, -0.3247328, 0.08937666, 0.06636284  …  0.35972518, -0.23699127, 0.0867938, -0.22078784, 0.43508342, -0.07794266, 0.029935744,