In [1]:
using DifferentialEquations, LinearAlgebra, SparseArrays,Sundials,CUDA,PreallocationTools,SciMLSensitivity,LinearSolve
using SparseDiffTools,Symbolics
using BenchmarkTools
using DelimitedFiles
using Plots,Statistics,StatsPlots

In [2]:
function GCH_2D_mask_full!(dc,c,p,t,dx,dy,Nx,Ny,Ψ )
    D, κ, Ω= p
    ψ = @view Ψ[:,:]
    @inline function ∇2c(ix,iy)
        left = ix > 1 ? c[ix-1,iy] : c[ix+1,iy]
        right = ix < Nx ? c[ix+1,iy] : c[ix-1,iy]
        bottom = iy > 1 ? c[ix,iy-1] : c[ix,iy+1]
        top = (iy < Ny ? c[ix,iy+1] : c[ix,iy-1])
        return ((right + left - 2.0*c[ix,iy])/dx^2 + (top + bottom - 2.0*c[ix,iy])/dy^2)
    end
    @inline function ∇ψ∇c(ix,iy)
        ψleft = ix > 1 ? ψ[ix-1,iy] : ψ[ix+1,iy]
        ψright = ix < Nx ? ψ[ix+1,iy] : ψ[ix-1,iy]
        ψbottom = iy > 1 ? ψ[ix,iy-1] : ψ[ix,iy+1]
        ψtop = iy < Ny ? ψ[ix,iy+1] : ψ[ix,iy-1]

        cleft = ix > 1 ? c[ix-1,iy] : c[ix+1,iy]
        cright = ix < Nx ? c[ix+1,iy] : c[ix-1,iy]
        cbottom = iy > 1 ? c[ix,iy-1] : c[ix,iy+1]
        ctop = iy < Ny ? c[ix,iy+1] : c[ix,iy-1]

        return ((ψleft-ψright)/(2*dx))*((cleft-cright)/(2*dx)) + ((ψtop-ψbottom)/(2*dx))*((ctop-cbottom)/(2*dy))
    end
    @inline function μₕ(ix,iy)
        return log(max(1e-10,c[ix,iy]/(1-c[ix,iy]))) + Ω*(1.0-2.0*c[ix,iy])
    end
    @inline function μ(ix,iy)
        return μₕ(ix,iy) - κ*(∇ψ∇c(ix,iy)/ψ[ix,iy] + ∇2c(ix,iy))
    end
    @inline function ∇ψ∇μ(ix,iy)
        ψleft = ix > 1 ? ψ[ix-1,iy] : ψ[ix+1,iy]
        ψright = ix < Nx ? ψ[ix+1,iy] : ψ[ix-1,iy]
        ψbottom = iy > 1 ? ψ[ix,iy-1] : ψ[ix,iy+1]
        ψtop = iy < Ny ? ψ[ix,iy+1] : ψ[ix,iy-1]

        μleft = ix > 1 ? μ(ix-1,iy) : μ(ix+1,iy)
        μright = ix < Nx ? μ(ix+1,iy) : μ(ix-1,iy)
        μbottom = iy > 1 ? μ(ix,iy-1) : μ(ix,iy+1)
        μtop = iy < Ny ? μ(ix,iy+1) : μ(ix,iy-1)

        return ((ψleft-ψright)/(2*dx))*((μleft-μright)/(2*dx)) + ((ψtop-ψbottom)/(2*dx))*((μtop-μbottom)/(2*dy))
    end
    @inline function ∇c∇μ(ix,iy)
        cleft = ix > 1 ? c[ix-1,iy] : c[ix+1,iy]
        cright = ix < Nx ? c[ix+1,iy] : c[ix-1,iy]
        cbottom = iy > 1 ? c[ix,iy-1] : c[ix,iy+1]
        ctop = iy < Ny ? c[ix,iy+1] : c[ix,iy-1]

        μleft = ix > 1 ? μ(ix-1,iy) : μ(ix+1,iy)
        μright = ix < Nx ? μ(ix+1,iy) : μ(ix-1,iy)
        μbottom = iy > 1 ? μ(ix,iy-1) : μ(ix,iy+1)
        μtop = iy < Ny ? μ(ix,iy+1) : μ(ix,iy-1)

        return ((cleft-cright)/(2*dx))*((μleft-μright)/(2*dx)) + ((ctop-cbottom)/(2*dx))*((μtop-μbottom)/(2*dy))
    end
    @inline function ∇2μ(ix,iy)
        left = ix > 1 ? μ(ix-1,iy) : μ(ix+1,iy)
        right = ix < Nx ? μ(ix+1,iy) : μ(ix-1,iy)
        bottom = iy > 1 ? μ(ix,iy-1) : μ(ix,iy+1)
        top = iy < Ny ? μ(ix,iy+1) : μ(ix,iy-1)
        return ((right + left - 2.0*μ(ix,iy))/dx^2 + (top + bottom - 2.0*μ(ix,iy))/dy^2)
    end
    @inline function getD(ix::Int,iy::Int)
        return D*(1.0-c[ix,iy])*c[ix,iy];
    end
    @inline function ∂D∂c(ix,iy)
        return D*(1.0-2*c[ix,iy]);
    end
    @inline function normψ(ix,iy)
        if ((ix > 1) && (ix < Nx)) && ((iy > 1) && (iy < Ny))
            return sqrt(((c[ix+1,iy]-c[ix-1,iy])/(2*dx))^2 + ((c[ix,iy+1]-c[ix,iy-1])/(2*dx))^2)
        else
            return 0.0
        end
    end
    @inbounds @views for I in CartesianIndices((Nx, Ny))
        ix, iy = Tuple(I);
        dc[ix,iy]=(getD(ix,iy)/ψ[ix,iy])*∇ψ∇μ(ix,iy)+ ∂D∂c(ix,iy)*∇c∇μ(ix,iy) + getD(ix,iy)*∇2μ(ix,iy)
    end
    return nothing
end

GCH_2D_mask_full! (generic function with 1 method)

In [3]:
function GCH_2D_mul_slow(du,u,p,t,ψ,∇ψ_x,∇ψ_y,∇x,∇y,∇2x,∇2y)
    c = @view u[:,:]
    dc = @view du[:,:]
    p=D, κ, Ω
    
    # Compute ∇c
    ∇c_x = ∇x * c # Compute (∇c)ₓ = ∇x*c
    ∇c_y = c * ∇y # Compute (∇c)_y = c*∇y

    # Compute ∇2c
    ∇2c = ∇2x * c + c * ∇2y # Compute (∇2c)

    μ = log.(max.(1e-10, c ./ (1.0 .- c))) .+ Ω.*(1.0 .- 2.0.*c) .- κ.*((∇c_x .* ∇ψ_x .+ ∇c_y .* ∇ψ_y)./ψ .+ ∇2c)
    # Compute ∇2μ
    ∇2μ = ∇2x * μ + μ * ∇2y# Compute (∇2μ)

    # Compute ∇μ
    ∇μ_x = ∇x * μ # Compute (∇μ)ₓ = ∇x*μ
    ∇μ_y = μ * ∇y # Compute (∇μ)_y = μ*∇y

    dc = D .* (c .* (1.0 .- c) .* ((∇ψ_x .* ∇μ_x .+ ∇ψ_y .* ∇μ_y) ./ ψ .+ ∇2μ) .+ (1.0 .- 2.0 .* c) .* (∇c_x .* ∇μ_x .+ ∇c_y .* ∇μ_y))
    return nothing
end

GCH_2D_mul_slow (generic function with 1 method)

In [4]:
function GCH_2D_mul_full(du, u, p, t,ψ,∇x,∇y,∇2x,∇2y,∇ψ_x,∇ψ_y,∇c_x,∇c_y,∇2c,μ,∇2μ,∇μ_x,∇μ_y)
    D, κ, Ω=p
    c = @view u[:,:]
    dc = @view du[:,:]
    
    #Set up caches from DiffCache
    ∇c_x_t = get_tmp(∇c_x,u)
    ∇c_y_t = get_tmp(∇c_y,u)
    ∇2c_t = get_tmp(∇2c,u)
    μ_t = get_tmp(μ,u)
    ∇2μ_t = get_tmp(∇2μ,u)
    ∇μ_x_t = get_tmp(∇μ_x,u)
    ∇μ_y_t = get_tmp(∇μ_y,u)
    
    #Compute ∇c
    mul!(∇c_x_t,∇x,c) # Compute (∇c)ₓ = ∇x*c
    mul!(∇c_y_t,c,∇y) # Compute (∇c)_y = c*∇y
    
    #Compute ∇2c
    mul!(∇2c_t,∇2x,c) # Compute (∇2c)ₓ = c*∇2x
    mul!(∇2c_t,c,∇2y,1.0,1.0) #∇2c = 1*(∇2c)ₓ + 1*(∇2y)*c

    @. μ_t = log(max(1e-10,c./(1.0 - c)))+ Ω*(1.0 - 2.0*c) .- κ*((∇c_x_t*∇ψ_x  + ∇c_y_t*∇ψ_y)./ψ + ∇2c_t);

    #Compute ∇2μ
    mul!(∇2μ_t,∇2x,μ_t) # Compute (∇2μ)ₓ = μ*∇2x
    mul!(∇2μ_t,μ_t,∇2y,1.0,1.0) #∇2μ = 1*(∇2μ)ₓ + 1*(∇2y)*μ
    #Compute ∇μ
    mul!(∇μ_x_t,∇x,μ_t) # Compute (∇μ)ₓ = ∇x*μ
    mul!(∇μ_y_t,μ_t,∇y) # Compute (∇μ)_y = μ*∇y
    @. dc = D*(c*(1.0-c)*((∇ψ_x*∇μ_x_t + ∇ψ_y*∇μ_y_t)./ψ + ∇2μ_t) + (1.0-2.0*c)*(∇c_x_t*∇μ_x_t + ∇c_y_t*∇μ_y_t))
    return nothing
end

GCH_2D_mul_full (generic function with 1 method)

In [5]:
ψ = readdlm("psi.csv")
ψ_binary = readdlm("psi_b.csv")
ψ = ψ[end:-1:1, :]
ψ_binary = ψ_binary[end:-1:1, :]
Nx, Ny = size(ψ)
x = LinRange(0.0, 1, Nx)
y = LinRange(0.0, 1, Ny)
dx = x[2] - x[1]
dy = y[2] - y[1]

D = 0.1
κ = 0.001
Ω = 3.0
p = D, κ, Ω

c0 = readdlm("goodc0.csv")

tspan = (0.0, 5)
dc0 = similar(c0);

In [6]:
∇2x = Tridiagonal([1.0 for i in 1:Nx-1],[-2.0 for i in 1:Nx],[1.0 for i in 1:Nx-1])
∇2x[1,2] = 2.0
∇2x[end,end-1] = 2.0
∇2y= deepcopy(∇2x)
∇2y = ∇2y'

∇x= Tridiagonal([-1.0 for i in 1:Nx-1],[0.0 for i in 1:Nx],[1.0 for i in 1:Nx-1]);
∇x[1,2]=0.0
∇x[end,end-1]=0.0

∇y= Tridiagonal([-1.0 for i in 1:Ny-1],[0.0 for i in 1:Ny],[1.0 for i in 1:Ny-1]);
∇y[1,2]=0.0
∇y[end,end-1]=0.0
∇y =∇y'


∇2x ./= dx^2;
∇2y ./= dy^2;
∇x ./= 2*dx;
∇y ./= 2*dy;


∇ψ_x = ∇x*ψ 
∇ψ_y = ψ*∇y


∇c_x=zeros(Nx,Ny);
∇c_y=zeros(Nx,Ny);
∇2c=zeros(Nx,Ny);
μ = zeros(Nx,Ny);
∇2μ=zeros(Nx,Ny);
∇μ_x=zeros(Nx,Ny);
∇μ_y=zeros(Nx,Ny);

chunk_size = 25;

∇c_x_c= DiffCache(∇c_x,chunk_size);
∇c_y_c = DiffCache(∇c_y,chunk_size);
∇2c_c = DiffCache(∇2c,chunk_size);
μ_c = DiffCache(μ,chunk_size);
∇2μ_c = DiffCache(∇2μ,chunk_size);
∇μ_x_c = DiffCache(∇μ_x,chunk_size);
∇μ_y_c = DiffCache(∇μ_y,chunk_size);

In [7]:
GCH_2D_mul_cache!(du,u,p,t) = GCH_2D_mul_full(du, u, p, t,ψ,∇x,∇y,∇2x,∇2y,∇ψ_x,∇ψ_y,∇c_x_c,∇c_y_c,∇2c_c,μ_c,∇2μ_c,∇μ_x_c,∇μ_y_c)
GCH_2D_element!(du,u,p,t) = GCH_2D_mask_full!(du,u,p,t,dx,dy,Nx,Ny,ψ)
GCH_2D_mul_slow!(du,u,p,t) = GCH_2D_mul_slow(du,u,p,t,ψ,∇ψ_x,∇ψ_y,∇x,∇y,∇2x,∇2y)

GCH_2D_mul_slow! (generic function with 1 method)

In [8]:
RHS_mul_cache= @benchmark GCH_2D_mul_cache!($dc0,$c0,$p,0.0)
RHS_mul_alloc= @benchmark GCH_2D_element!($dc0,$c0,$p,0.0)
RHS_elementwise= @benchmark GCH_2D_mul_slow!($dc0,$c0,$p,0.0)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m29.200 μs[22m[39m … [35m 21.059 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 99.38%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m46.800 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m54.407 μs[22m[39m ± [32m306.610 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m10.91% ±  1.98%

  [39m [39m█[39m█[39m [39m▃[39m▅[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [34m▁[39m[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m█[39m█[39m▆

In [53]:
# Store the benchmarks in a dictionary
benchmarks_RHS = Dict(
    "RHS_mul_cache" => RHS_mul_cache,
    "RHS_mul_alloc" => RHS_mul_alloc,
    "RHS_elementwise" => RHS_elementwise
)

# Extract the mean times and standard deviations
mean_times = zeros(3); std_devs = zeros(3);
k=1;
for name in keys(benchmarks_RHS)
    norm_fac = 1e9;
    mean_times[k]=mean(benchmarks_RHS[name]).time/norm_fac
    std_devs[k]=std(benchmarks_RHS[name]).time/norm_fac
    k +=1;
end

# Create the plot
RHSbench = bar(collect(keys(benchmarks_RHS)), mean_times, ylabel="Time (s)", yaxis=:log,
legend=false,color=[:maroon,:grey,:navy],grid=false,yticks=[1e-5,1e-4,1e-3],ylim=[1e-5,1e-3])#,yticks=[0,50,100,150,200,250,300,350],size=(1200,800))
savefig("RHSbench.png")

"c:\\Users\\Sam\\Desktop\\Research\\Graphite Data\\RHSbench.png"

In [8]:
jac_sparsity_cache = Symbolics.jacobian_sparsity((du, u) -> GCH_2D_mul_cache!(du,u,p,0),dc0,c0);
colorvec_cache = matrix_colors(jac_sparsity_cache);
f_cache = ODEFunction(GCH_2D_mul_cache!;jac_prototype=jac_sparsity_cache,colorvec=colorvec_cache);
sparse_prob_cache = ODEProblem(f_cache,c0,tspan,p);

In [11]:
solve(sparse_prob_cache,CVODE_BDF(linear_solver=:GMRES))

retcode: Success
Interpolation: 3rd order Hermite
t: 321-element Vector{Float64}:
 0.0
 5.9546317043980725e-5
 0.00011909263408796145
 0.00023921103516645566
 0.00035932943624494986
 0.0005806844498297175
 0.0008020394634144851
 0.0010233944769992528
 0.0012447494905840203
 0.0014661045041687879
 ⋮
 4.635816628103401
 4.72426708781937
 4.812717547535339
 4.834830162464331
 4.856942777393323
 4.890140956105591
 4.923339134817859
 4.956537313530126
 5.0
u: 321-element Vector{Matrix{Float64}}:
 [0.5366764007515327 0.5266191080804261 … 0.5196974003625441 0.48205874045476527; 0.5393540439034791 0.5207173918583038 … 0.5137379599142022 0.5019623321746877; … ; 0.49730611197552144 0.5231724683004547 … 0.4609671122718525 0.5482654878987233; 0.48832131235933357 0.5182400625963216 … 0.5194941341560293 0.45285736959836964]
 [0.5369698427911664 0.526402600119962 … 0.5185881297729997 0.48373147177401965; 0.538420452305937 0.5220816861874418 … 0.513261690095814 0.5020536994781607; … ; 0.49798589385614

In [11]:
jac_sparsity_mul_slow = Symbolics.jacobian_sparsity((du, u) -> GCH_2D_mul_slow!(du,u,p,0),dc0,c0);
colorvec_mul_slow = matrix_colors(jac_sparsity_mul_slow);
f_mul_slow = ODEFunction(GCH_2D_mul_slow!;jac_prototype=jac_sparsity_mul_slow,colorvec=colorvec_mul_slow);
sparse_prob_mul_slow = ODEProblem(f_mul_slow,c0,tspan,p);

In [12]:
jac_sparsity_element = Symbolics.jacobian_sparsity((du, u) -> GCH_2D_element!(du,u,p,0),dc0,c0);
colorvec_element = matrix_colors(jac_sparsity_element);
f_element = ODEFunction(GCH_2D_element!;jac_prototype=jac_sparsity_element,colorvec=colorvec_element);
sparse_prob_element = ODEProblem(f_element,c0,tspan,p);

In [95]:
function solver_to_string(solver)
    str = string(solver)  # Convert the solver function to a string
    idx = findfirst("{", str)  # Find the index of the first "{"
    idx2 = findfirst("Krylov",str)
    idx3 = findfirst("GMRES",str)
    if idx !== nothing  # Check if "{" was found
        str = str[1:prevind(str, first(idx))]  # Cut off the string at the position before "{"
    end
    if idx2 !== nothing || idx3 !==nothing
        str *= "_KrylovGMRES"
    end
    return str
end


# Define the solvers
solvers = [
    ROCK2(),
    ROCK4(),
    RKC(),
    SERK2(),
    ESERK5(),
    TRBDF2(),
    KenCarp4(),
    Rosenbrock23(),
    CVODE_BDF(),
    CVODE_BDF(linear_solver=:GMRES),
    TRBDF2(linsolve = KrylovJL_GMRES()),
    KenCarp4(linsolve = KrylovJL_GMRES()),
    Rosenbrock23(linsolve = KrylovJL_GMRES())
]

# Define the problems
problems = [
    sparse_prob_cache, 
    sparse_prob_element
]

# Initialize the dictionary to store benchmarks
benchmarks = Dict()

#Benchmark each solver for each problem
for solver in solvers
    # Initialize an empty vector for the solver
    benchmarks[solver_to_string(solver)] = []
    for prob in problems
        # Benchmark the solve function
        println("Benchmarking problem $(prob) with solver $(solver_to_string(solver))")
        benchmark_result = @benchmark solve($prob, $solver, save_everystep=false) samples =100
        
        # Append the benchmark result to the vector
        push!(benchmarks[solver_to_string(solver)], benchmark_result)
    end
end

benchmarks  # This will display the benchmarks dictionary

Dict{Any, Any} with 13 entries:
  "ESERK5"                   => Any[Trial(426.843 ms), Trial(6.136 s)]
  "Rosenbrock23_KrylovGMRES" => Any[Trial(1.008 s), Trial(8.858 s)]
  "TRBDF2"                   => Any[Trial(447.465 ms), Trial(1.025 s)]
  "SERK2"                    => Any[Trial(268.553 ms), Trial(3.841 s)]
  "KenCarp4_KrylovGMRES"     => Any[Trial(846.181 ms), Trial(7.700 s)]
  "TRBDF2_KrylovGMRES"       => Any[Trial(2.291 s), Trial(17.474 s)]
  "Rosenbrock23"             => Any[Trial(500.574 ms), Trial(1.012 s)]
  "ROCK4"                    => Any[Trial(167.145 ms), Trial(2.356 s)]
  "KenCarp4"                 => Any[Trial(643.835 ms), Trial(1.720 s)]
  "ROCK2"                    => Any[Trial(118.028 ms), Trial(1.744 s)]
  "RKC"                      => Any[Trial(1.973 s), Trial(28.979 s)]
  "CVODE_BDF"                => Any[Trial(62.051 s), Trial(9.443 s)]
  "CVODE_BDF_KrylovGMRES"    => Any[Trial(72.847 ms), Trial(718.586 ms)]

In [102]:
using Plots.PlotMeasures
# Extract the solver names
solver_names = collect(keys(benchmarks))


# Extract the mean times and standard deviations for each solver and problem
mean_times = [[mean(bench).time / 1e9 for bench in benchmarks[solver_name]] for solver_name in solver_names]  # Convert to milliseconds
std_devs = [[std(bench).time / 1e9 for bench in benchmarks[solver_name]] for solver_name in solver_names]  # Convert to milliseconds

# Define the problem names
problem_names = ["Cached Matmul", "Elementwise"]

# Repeat the problem names for each solver
group_names = repeat(problem_names, length(solver_names))

# Flatten the mean_times and std_devs for plotting
mean_times_flat = vcat(mean_times...)
std_devs_flat = vcat(std_devs...)
solver_names = replace.(solver_names, "_KrylovGMRES" => "\nKrylov")
solver_names = replace.(solver_names, "Rosenbrock23" => "Rbrock23")
solver_names = replace.(solver_names, "CVODE_BDF" => "CVODE\nBDF")
solver_names = replace.(solver_names, "KenCarp4" => "KCarp4")
# Create the grouped bar plot
grpbar = groupedbar(repeat(solver_names, inner=length(problems)), mean_times_flat, yerr=std_devs_flat, group = group_names,
bar_width = 0.67, framestyle = :box,size=(900,600),yaxis=:log, legend = :topright,c = [:maroon :grey],grid=false,left_margin=4mm,yticks=[1e-2,1e-1,1e0,1e1,1e2],ylim=[1e-2,1e2])
ylabel!("Time (s)")
savefig("simbench.png")

"c:\\Users\\Sam\\Desktop\\Research\\Graphite Data\\simbench.png"