# 1 - Horizontal Diffusion 

As a fist example, I kind of randomly picked the horizontal diffusion struct / function. 

But first, we have to load the enviroment, also load a state of SpeedyWeather that we can use 

In [1]:
import Pkg 
Pkg.activate("..")

using Enzyme, Test, KernelAbstractions, CUDAKernels, KernelGradients, SpeedyExperiments, LinearAlgebra, Adapt, Parameters, SpeedyWeather

[32m[1m  Activating[22m[39m project at `~/Nextcloud/SpeedyExperiments/scripts`


In [2]:
progn_vars, diagn_vars, model_setup = initialize_speedy();

So, we want to look into rewriting/adjusting `horizontal_diffusion!` (2D Version) here as an example. First, we will look up how this function is called given an initialized model. Looking up the source code we can see that the function signature is 

```julia
 horizontal_diffusion!(  tendency::AbstractMatrix{Complex{NF}}, # tendency of a 
                            A::AbstractMatrix{Complex{NF}},        # spectral horizontal field
                            damp_expl::AbstractMatrix{NF},         # explicit spectral damping
                            damp_impl::AbstractMatrix{NF}          # implicit spectral damping
                            ) where {NF<:AbstractFloat}
```

and it is called in the timestepping routine with 

```julia
@unpack vor = progn
@unpack vor_tend = diagn.tendencies
@unpack damping, damping_impl = M.horizontal_diffusion

# set all tendencies to zero
fill!(vor_tend,zero(Complex{NF}))

  
# PROPAGATE THE SPECTRAL STATE INTO THE DIAGNOSTIC VARIABLES
gridded!(diagn,progn,M,lf2)

# COMPUTE TENDENCIES OF PROGNOSTIC VARIABLES
get_tendencies!(diagn,progn,M,lf2)                   

vor_lf = view(vor,:,:,1,:)                                      # array view for leapfrog index
horizontal_diffusion!(vor_tend,vor_lf,damping,damping_impl)     # diffusion of vorticity
```

in which `progn` is an instance of the `PrognosticVariables` like the `progn_vars` we intialized, `diagn` is an instance of `DiagnosticVariables` like the `diagn_vars` we initialized and `M` is the `model_setup`  


Let's call it once, so we can later also cross-check our new version later. We have explicitly add `SpeedyWeather` a few times as those functions are not exported

In [3]:
M = model_setup 
diagn = diagn_vars 
progn = progn_vars 
lf2 = 2
NF = Float32

@unpack vor = progn
@unpack vor_tend = diagn.tendencies
@unpack damping, damping_impl = M.horizontal_diffusion

# set all tendencies to zero
fill!(vor_tend,zero(Complex{NF}))

# PROPAGATE THE SPECTRAL STATE INTO THE DIAGNOSTIC VARIABLES
SpeedyWeather.gridded!(diagn,progn,M,lf2)

# COMPUTE TENDENCIES OF PROGNOSTIC VARIABLES
SpeedyWeather.get_tendencies!(diagn,progn,M,lf2)                   

# we want the 2D version: 

vor_tend = vor_tend[:,:,1]
vor_lf = view(vor,:,:,1,1)                                      # array view for leapfrog index
SpeedyWeather.horizontal_diffusion!(vor_tend,vor_lf,damping,damping_impl)


Now, that we know how to call the function, let's rewrite it to work on GPU and with Enzyme! 

First, we inspect the old version: 


```julia 
function horizontal_diffusion!( tendency::AbstractMatrix{Complex{NF}}, # tendency of a 
                                A::AbstractMatrix{Complex{NF}},        # spectral horizontal field
                                damp_expl::AbstractMatrix{NF},         # explicit spectral damping
                                damp_impl::AbstractMatrix{NF}          # implicit spectral damping
                                ) where {NF<:AbstractFloat}

    lmax,mmax = size(A) .- 1            # degree l, order m but 0-based
    @boundscheck size(A) == size(tendency) || throw(BoundsError())
    @boundscheck size(A) == size(damp_expl) || throw(BoundsError())
    @boundscheck size(A) == size(damp_impl) || throw(BoundsError())
    
    @inbounds for m in 1:mmax+1         # loop through all spectral modes 
        for l in m:lmax+1
            tendency[l,m] = (tendency[l,m] - damp_expl[l,m]*A[l,m])*damp_impl[l,m]
        end
    end
end
```

`horizontal_diffusion!` consists of bounds checks and a double for loop. We'll have to write a kernel for the loop and then call this kernel from a wrapper function that also includes the bounds checks 

An important thing to note is that all matrices that save spherical harmonics like `tendency` here, are only filled in the lower triangle. The loop also only goes over the lower triangle of the matrix, so we have make our GPU operation also only work on the lower triangle, otherwise we waste computational power. The easiest way (that I can think of) is to translate a linear index to the index of the lower triangle. So that $1\rightarrow(1,1)$, $2\rightarrow(2,1)$, $3\rightarrow(2,2)$, $4\rightarrow(3,1)$ and so on and so furth. We will have to do that all the time, so we will just create a translation array that we will reuse for all other parts of the model as well 

In [4]:
"""
    lowertriangle_indices(Lmax::Integer)

Returns an array of `CartesianIndex` with the indices of the lower triangle of the square matrix with `Lmax` rows/columns
"""
function lowertriangle_indices(Lmax::Integer)
    N = sum(1:Lmax) # number of elements in the lower triangle
    indices = Array{CartesianIndex}(undef, N)

    count = 1
    for i=1:Lmax
        for j=1:i
            indices[count] = CartesianIndex(i,j)
            count += 1
        end 
    end 
    return indices 
end 

"""
    lowertriangle_indices(A::AbstractMatrix)

Returns an array of `CartesianIndex` with the indices of the lower triangle of the square matrix `A`.
"""
function lowertriangle_indices(A::AbstractMatrix) 
    @assert size(A,1) == size(A,2)
    lowertriangle_indices(size(A,1))
end

triangle_indices = lowertriangle_indices(vor_tend)


528-element Vector{CartesianIndex}:
 CartesianIndex(1, 1)
 CartesianIndex(2, 1)
 CartesianIndex(2, 2)
 CartesianIndex(3, 1)
 CartesianIndex(3, 2)
 CartesianIndex(3, 3)
 CartesianIndex(4, 1)
 CartesianIndex(4, 2)
 CartesianIndex(4, 3)
 CartesianIndex(4, 4)
 CartesianIndex(5, 1)
 CartesianIndex(5, 2)
 CartesianIndex(5, 3)
 ⋮
 CartesianIndex(32, 21)
 CartesianIndex(32, 22)
 CartesianIndex(32, 23)
 CartesianIndex(32, 24)
 CartesianIndex(32, 25)
 CartesianIndex(32, 26)
 CartesianIndex(32, 27)
 CartesianIndex(32, 28)
 CartesianIndex(32, 29)
 CartesianIndex(32, 30)
 CartesianIndex(32, 31)
 CartesianIndex(32, 32)

Great! So, now let's write the horizontal diffusion kernel 

In [5]:
@kernel function horizontal_diffusion_kernel!(tendency, @Const(A), @Const(damp_expl), @Const(damp_impl), @Const(triangle_index))
    i = @index(Global, Linear)
    i_cartesian = triangle_index[i]

    tendency[i_cartesian] = (tendency[i_cartesian] - damp_expl[i_cartesian]*A[i_cartesian])*damp_impl[i_cartesian]
end

horizontal_diffusion_kernel! (generic function with 5 methods)

We take the bounds checks from the old version an integrate now the kernel and launch it. We might agree on some other utility functions for the kernel launching later, but here I have a struct called `DeviceSetup` that holds the currently used device and workgroup size.

In [6]:
function horizontal_diffusion!(tendency::AbstractMatrix{Complex{NF}}, # tendency of a 
    A::AbstractMatrix{Complex{NF}},        # spectral horizontal field
    damp_expl::AbstractMatrix{NF},         # explicit spectral damping
    damp_impl::AbstractMatrix{NF},          # implicit spectral damping
    device_setup::DeviceSetup,              # device the function is executed on
    triangle_indices::AbstractArray{CartesianIndex}   # array with the indices                  
    ) where {NF<:AbstractFloat}

lmax,mmax = size(A) .- 1            # degree l, order m but 0-based
@boundscheck size(A) == size(tendency) || throw(BoundsError())
@boundscheck size(A) == size(damp_expl) || throw(BoundsError())
@boundscheck size(A) == size(damp_impl) || throw(BoundsError())
device = device_setup.device()
n = device_setup.n

wait(horizontal_diffusion_kernel!(device, n)(tendency, A, damp_expl, damp_impl, triangle_indices, ndrange=length(triangle_indices)))

end 


horizontal_diffusion! (generic function with 1 method)

Now we have to test if it works! We'll compare the old version to the KernelAbstractions version. 

In [7]:

const device_setup = DeviceSetup()

vor_tend_old = deepcopy(vor_tend)
vor_tend_new = deepcopy(vor_tend)

SpeedyWeather.horizontal_diffusion!(vor_tend_old, vor_lf, damping, damping_impl)
horizontal_diffusion!(vor_tend_new, vor_lf, damping, damping_impl, device_setup, triangle_indices)


In [8]:
@test vor_tend_old ≈ vor_tend_new

[32m[1mTest Passed[22m[39m
  Expression: vor_tend_old ≈ vor_tend_new
   Evaluated: ComplexF32[0.0f0 + 0.0f0im 0.0f0 + 0.0f0im … 0.0f0 + 0.0f0im 0.0f0 + 0.0f0im; 9.214844f-7 + 0.0f0im 1.1242739f-14 + 9.101589f-15im … 0.0f0 + 0.0f0im 0.0f0 + 0.0f0im; … ; 8.1235026f-5 + 0.0f0im -6.751766f-7 + 1.2070599f-5im … -9.633636f-7 + 9.024164f-6im 0.0f0 + 0.0f0im; 0.87312204f0 + 0.0f0im -2.1297915f-6 - 1.4817448f-5im … -4.4587105f-6 + 3.97241f-5im 1.017148f-5 - 1.3497451f-5im] ≈ ComplexF32[0.0f0 + 0.0f0im 0.0f0 + 0.0f0im … 0.0f0 + 0.0f0im 0.0f0 + 0.0f0im; 9.214844f-7 + 0.0f0im 1.1242739f-14 + 9.101589f-15im … 0.0f0 + 0.0f0im 0.0f0 + 0.0f0im; … ; 8.1235026f-5 + 0.0f0im -6.751766f-7 + 1.2070599f-5im … -9.633636f-7 + 9.024164f-6im 0.0f0 + 0.0f0im; 0.87312204f0 + 0.0f0im -2.1297915f-6 - 1.4817448f-5im … -4.4587105f-6 + 3.97241f-5im 1.017148f-5 - 1.3497451f-5im]

Great! 

Next step, is to bring Enzyme in and show that it is differentiable. For this we'll have to allocate the shadow memory that stores the gradient information. 

In [9]:
∂vor_tend = fill!(similar(vor_tend), 1)
vor_lf = Array(vor_lf) # needs to be an array not a subarray
∂vor_lf = zero(vor_lf)
∂damping = zero(damping)
∂damping_impl = zero(damping_impl);

In [10]:
typeof(vor_lf)

Matrix{ComplexF32} (alias for Array{Complex{Float32}, 2})

In [11]:
∇horizontal_diffusion! = autodiff(horizontal_diffusion_kernel!(device_setup.device(), device_setup.n))
wait(∇horizontal_diffusion!(Duplicated(vor_tend, ∂vor_tend), Duplicated(vor_lf,∂vor_lf), Duplicated(damping,∂damping), Duplicated(damping_impl,∂damping_impl), Const(device_setup), Const(triangle_indices); ndrange=length(triangle_indices)))

LoadError: TaskFailedException

[91m    nested task error: [39mReturn type inferred to be Union{}. Giving up.
    Stacktrace:
     [1] [0m[1merror[22m[0m[1m([22m[90ms[39m::[0mString[0m[1m)[22m
    [90m   @ [39m[90mBase[39m [90m./[39m[90m[4merror.jl:33[24m[39m
     [2] [0m[1mautodiff_deferred[22m[0m[1m([22m::[0mtypeof(cpu_horizontal_diffusion_kernel!), ::[0mType[90m{Const}[39m, ::[0mKernelAbstractions.CompilerMetadata[90m{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.NoDynamicCheck, CartesianIndex{1}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(4,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}}[39m, ::[0mVararg[90m{Any}[39m[0m[1m)[22m
    [90m   @ [39m[35mEnzyme[39m [90m~/.julia/packages/Enzyme/di3zM/src/[39m[90m[4mEnzyme.jl:440[24m[39m
     [3] [0m[1m(::KernelGradients.var"#df#1"{typeof(cpu_horizontal_diffusion_kernel!), typeof(cpu_horizontal_diffusion_kernel!)})[22m[0m[1m([22m::[0mKernelAbstractions.CompilerMetadata[90m{KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.NoDynamicCheck, CartesianIndex{1}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(4,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}}[39m, ::[0mDuplicated[90m{Matrix{ComplexF32}}[39m, ::[0mVararg[90m{Any}[39m[0m[1m)[22m
    [90m   @ [39m[36mKernelGradients[39m [90m~/.julia/packages/KernelGradients/LqkqJ/src/[39m[90m[4mKernelGradients.jl:9[24m[39m
     [4] [0m[1m__thread_run[22m[0m[1m([22m[90mtid[39m::[0mInt64, [90mlen[39m::[0mInt64, [90mrem[39m::[0mInt64, [90mobj[39m::[0mKernelAbstractions.Kernel[90m{CPU, KernelAbstractions.NDIteration.StaticSize{(4,)}, KernelAbstractions.NDIteration.DynamicSize, KernelGradients.var"#df#1"{typeof(cpu_horizontal_diffusion_kernel!), typeof(cpu_horizontal_diffusion_kernel!)}}[39m, [90mndrange[39m::[0mTuple[90m{Int64}[39m, [90miterspace[39m::[0mKernelAbstractions.NDIteration.NDRange[90m{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(4,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}[39m, [90margs[39m::[0mTuple[90m{Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{Float32}}, Duplicated{Matrix{Float32}}, Const{DeviceSetup{DataType, Int64}}, Const{Vector{CartesianIndex}}}[39m, [90mdynamic[39m::[0mKernelAbstractions.NDIteration.NoDynamicCheck[0m[1m)[22m
    [90m   @ [39m[32mKernelAbstractions[39m [90m~/.julia/packages/KernelAbstractions/1ZLga/src/[39m[90m[4mcpu.jl:157[24m[39m
     [5] [0m[1m__run[22m[0m[1m([22m[90mobj[39m::[0mKernelAbstractions.Kernel[90m{CPU, KernelAbstractions.NDIteration.StaticSize{(4,)}, KernelAbstractions.NDIteration.DynamicSize, KernelGradients.var"#df#1"{typeof(cpu_horizontal_diffusion_kernel!), typeof(cpu_horizontal_diffusion_kernel!)}}[39m, [90mndrange[39m::[0mTuple[90m{Int64}[39m, [90miterspace[39m::[0mKernelAbstractions.NDIteration.NDRange[90m{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(4,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}[39m, [90margs[39m::[0mTuple[90m{Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{Float32}}, Duplicated{Matrix{Float32}}, Const{DeviceSetup{DataType, Int64}}, Const{Vector{CartesianIndex}}}[39m, [90mdynamic[39m::[0mKernelAbstractions.NDIteration.NoDynamicCheck[0m[1m)[22m
    [90m   @ [39m[32mKernelAbstractions[39m [90m~/.julia/packages/KernelAbstractions/1ZLga/src/[39m[90m[4mcpu.jl:130[24m[39m
     [6] [0m[1m(::KernelAbstractions.var"#19#20"{Nothing, Nothing, typeof(KernelAbstractions.__run), Tuple{KernelAbstractions.Kernel{CPU, KernelAbstractions.NDIteration.StaticSize{(4,)}, KernelAbstractions.NDIteration.DynamicSize, KernelGradients.var"#df#1"{typeof(cpu_horizontal_diffusion_kernel!), typeof(cpu_horizontal_diffusion_kernel!)}}, Tuple{Int64}, KernelAbstractions.NDIteration.NDRange{1, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.StaticSize{(4,)}, CartesianIndices{1, Tuple{Base.OneTo{Int64}}}, Nothing}, Tuple{Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{ComplexF32}}, Duplicated{Matrix{Float32}}, Duplicated{Matrix{Float32}}, Const{DeviceSetup{DataType, Int64}}, Const{Vector{CartesianIndex}}}, KernelAbstractions.NDIteration.NoDynamicCheck}})[22m[0m[1m([22m[0m[1m)[22m
    [90m   @ [39m[32mKernelAbstractions[39m [90m~/.julia/packages/KernelAbstractions/1ZLga/src/[39m[90m[4mcpu.jl:22[24m[39m