This is a notebook that has the simplest implementation of KAN-ODEs. It shows how to generate the data, define a KAN then pass it to a NeuralODE, then shows how to train it. It also introduces the `ProgressBar.jl` package.

In [1]:
using Random, Lux, LinearAlgebra
using NNlib, ConcreteStructs, WeightInitializers, ChainRulesCore
using ComponentArrays
using BenchmarkTools
using OrdinaryDiffEq, Plots, DiffEqFlux, ForwardDiff
using Flux: Adam, mae, update!
using Flux
using Optimisers
using MAT
using Plots
using ProgressBars
using Zygote: gradient as Zgrad

# Load the KAN package from https://github.com/vpuri3/KolmogorovArnold.jl
include("src/KolmogorovArnold.jl")
using .KolmogorovArnold
#load the activation function getter (written for this project, see the corresponding script):
include("Activation_getter.jl")



activation_getter (generic function with 1 method)

# Defining the ODE and Generating Data

First we generate data from the ODE 

$$ \begin{cases}
x' &= \alpha x - \beta x y \\
y' &= \gamma x y - \delta y \end{cases}$$

Note we'll define the function using the inplace style as this saves allocating space in memory, speeding up the execution. 

In [2]:
function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

lotka! (generic function with 1 method)

Next we simulate the ODE and generate the data by defining a object of type [ODEProblem](https://docs.sciml.ai/DiffEqDocs/stable/types/ode_types/) and calling [solve](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/#CommonSolve.solve-Tuple%7BSciMLBase.AbstractDEProblem,%20Vararg%7BAny%7D%7D) which are part of the `DifferentialEquations.jl` package.

In [3]:
#data generation parameters
timestep=0.1
n_plot_save=1000
rng = Random.default_rng()
Random.seed!(rng, 0)
tspan = (0.0, 14)
tspan_train=(0.0, 3.5)
u0 = [1, 1]
p_ = Float32[1.5, 1, 1, 3]
prob = ODEProblem(lotka!, u0, tspan, p_)

#generate training data, split into train/test
solution = solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = timestep)
end_index=Int64(floor(length(solution.t)*tspan_train[2]/tspan[2]))
t = solution.t #full dataset
t_train=t[1:end_index] #training cut
#NOTE: What are these?
X = Array(solution)
Xn = deepcopy(X);

2×141 Matrix{Float64}:
 1.0  1.06108   1.14403   1.24917   1.37764   …  1.58284   1.7727    1.99376
 1.0  0.821084  0.679053  0.566893  0.478813     0.392679  0.343959  0.30753

# Defining the KAN

Here we define the architecture of the KAN. Using the `KolmogorovArnold.jl` package, we can call Lux layers that are defined there to save us the work of defining them by ourselves. We may choose 
* which basis functions we would like (defined in `src/utils.jl`)
* which normalizer functions we would like (defined in `src/utils.jl`)
* the grid size the basis functions will use
* the number of layers to our KAN
* the width of the layers

In [5]:
basis_func = rbf      # rbf, rswaf
normalizer = tanh_fast # sigmoid(_fast), tanh(_fast), softsign 
grid_size=5 #Grid size for the activation functions 
##Not sure what this is? It seems like this normalizes the inputs 
##to be between -1,1,/0,1 but i dont quite see for sure where.
num_layers=2 #number of layers in the KAN 
layer_width=10 #Width of each layer in the KAN (number of activation functions )

kan1 = Lux.Chain(
    KDense( 2, layer_width, grid_size; use_base_act = true, basis_func, normalizer),
    KDense(layer_width,  2, grid_size; use_base_act = true, basis_func, normalizer),
)
pM , stM  = Lux.setup(rng, kan1) #Assign parameters and the state to memory


pM_axis = getaxes(ComponentArray(pM))
pM_data = getdata(ComponentArray(pM))
p = (deepcopy(pM_data))./1e5 ;

# Preparing to train

## Construct a `NeuralODE` 

When we defined our KAN we constructed a model in Lux. With this type, we can easily pass it to the NeuralODE package for training. In order to train our model, we need to 
* Construct a `NeuralODE`
* Define a loss function
* Choose an optimizer
* Choose an algorithm for constructing the gradient.

In addition to training the model, we also want to see how well the model is predicting the future state of the ODE. To do this, we will instantiate a second `NeuralODE` called `train_node_test` which will have a timespan over the test set but will be built with the same `kan1`. This will allow us to easily see how our model performs over the test set.

In [None]:
train_node      = NeuralODE(kan1, tspan_train, Tsit5(), saveat = t_train); 
train_node_test = NeuralODE(kan1, tspan, Tsit5(), saveat = t); #only difference is the time span

function predict(p)
    Array(train_node(u0, p, stM)[1])
end
function predict_test(p)
    Array(train_node_test(u0, p, stM)[1])
end

## Define a loss function

In order to train our model, we will need to specify a loss function. Probably the simplest loss function is the mean square error (MSE)
$$\mathcal{L}_1(\theta) = MSE(u^{\text{KAN}}(t, \theta), u^{\text{obs}}(t)) = \frac{1}{N}\sum_{i=1}^N \lVert u^{\text{KAN}}(t_i, \theta) - u^{\text{obs}}(t_i) \rVert^2 $$ 
which are implemented as functions of the parameter `p` like so


In [None]:
function loss_train(p)
    mean(abs2, Xn[:, 1:end_index].- predict(ComponentArray(p,pM_axis)))
end
function loss_test(p)
    mean(abs2, Xn .- predict_test(ComponentArray(p,pM_axis)))
end

However, in KAN-ODEs paper, we may also want to include a term to encourage sparcity in our model. One way we can do this is to add an $l1$ norm term to our loss 
$$\mathcal{L}_2(\theta) = \frac{1}{N}\sum_{i=1}^N \| u^{\text{KAN}}(t_i, \theta) - u^{\text{obs}}(t_i) \|^2  + \gamma_{sp} | \theta |_1 $$
which introduces a sparcity hyperparameter $\gamma_{sp}$ which we may control. We'll add the sparcity term as a function `reg_loss` and add it to the result of `loss_train` for our total loss function `loss`.

In [3]:
#regularization loss 
function reg_loss(p, reg=1.0)
    l1_temp=(abs.(p))
    activation_loss=sum(l1_temp)
    activation_loss*act_reg
end

#overall loss is the sum of trhe training loss and the sparcity regularization loss
function loss(p)
    loss_temp=mean(abs2, Xn[:, 1:end_index].- predict(ComponentArray(p,pM_axis)))
    if sparse_on==1
        loss_temp+=reg_loss(p, 5e-4) #if we have sparsity enabled, add the reg loss
    end
    return loss_temp
end

loss (generic function with 1 method)

## Choosing an Optimizer
The `Flux.jl` has an assortment of optimizers to choose from. Here we will choose `Adam` with a learning rate of `5e-4`

In [None]:
optimizer = Flux.Adam(5e-4)

## Choosing a Gradient Algorithm

We are also able to choose which algorithm we would like to use to calculate the gradient of our loss function. We'll use the `Zgrad` alogirthm and include this explicitly when we do the training loop. (TODO: Why?)

# Training Loop

At this stage we have everything we need to start the training loop. 

We'll want to see how far along we are in our training algorithm, so we'll want to make use of the `ProgressBars.jl` package. Passing a range into `ProgressBar` creates an object with Julia can iterate over, which also includes some extra functionality. By default, a progress bar will now appear in our terminal showing how far along our for loop is, as well as the rate at which it is iterating. We can also include additional information such as the current loss with `set_description`.

In the training loop, we only do 3 actions
1. Calculate the gradient using `Zgrad`
2. Update the parameters of the model
3. Calculate the current loss and test loss and print them to the `ProgressBar`

In [2]:
# TRAINING
du = [0.0; 0.0]

N_iter = 10
##Actual training loop:
iters=ProgressBar(1:N_iter)
 for i in iters
    
    # gradient computation
    grad = Zgrad(loss, p)[1]

    #model update
    update!(optimizer, p, grad)

    #loss metrics
    loss_curr=deepcopy(loss_train(p))
    loss_curr_test=deepcopy(loss_test(p))

    set_description(iters, string("Loss:", loss_curr, "Test Loss:" loss_curr_test))

    #=
    if i%n_plot_save==0
        plot_save(l, l_test, p_list, i)
    end
    =#
    
end


UndefVarError: UndefVarError: `Flux` not defined in `Main`
Suggestion: check for spelling errors or missing imports.