# Symbolic Auto-Differentiation

Quick crash course on symbolic auto-differentiation in julia using [Symbolics.jl](https://symbolics.juliasymbolics.org/dev/).

In [148]:
# code setup
@show versioninfo()
@show pwd()
import Pkg; Pkg.activate("..")
using Symbolics
using Symbolics: derivative
using BenchmarkTools

Julia Version 1.7.2
Commit bf53498635 (2022-02-06 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD Ryzen 9 3950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
versioninfo() = nothing
pwd() = "/home/mantas/.julia/dev/PRONTO.jl/dev"


[32m[1m  Activating[22m[39m project at `~/.julia/dev/PRONTO.jl`


Consider the dynamics:
$$ \dot{x} = f(x,u,t) = (H_0 + u H_1) x $$
Which we implement like so:

In [149]:
H0 = [0 0 1 0;
      0 0 0 -1;
     -1 0 0 0;
      0 1 0 0]

H1 = [0 -1 0 0;
      1 0 0 0;
      0 0 0 -1;
      0 0 1 0]

f(x,u,t) = (H0 + u*H1)*x

f (generic function with 1 method)

Because the definition for `f(x,u,t)` is generic it will work on all argument types which support addition and multiplication - the operations used inside of the function. For example, numbers or symbolic variables:

In [150]:
f(1,2,1)

4×4 Matrix{Int64}:
  0  -2  1   0
  2   0  0  -1
 -1   0  0  -2
  0   1  2   0

In [151]:
@variables x u t # make symbolic x,u,t
f(x,u,t)

4×4 Matrix{Num}:
   0  -u*x    x     0
 u*x     0    0    -x
  -x     0    0  -u*x
   0     x  u*x     0

In [152]:
f(3,u,"not used")

4×4 Matrix{Num}:
  0  -3u   3    0
 3u    0   0   -3
 -3    0   0  -3u
  0    3  3u    0

Under the hood, julia compiles a copy of machine code optimized to compute each permutation of argument types. Everything works quickly and conveniently. The same is true if we define the derivatives manually.

In [153]:
fx(x,u,t) = H0 + u*H1
fu(x,u,t) = H1*x

fx(x,u,t) # called using symbolic x,u,t

4×4 Matrix{Num}:
  0  -u  1   0
  u   0  0  -1
 -1   0  0  -u
  0   1  u   0

But what if we want do that automatically for any `f(x,u,t)`? Like Dr. Hauser, Symbolics.jl knows how to take derivatives. So we create a symbolic representation of `f` by calling it with symbolic arguments `(x,u,t)`, and take it's derivative with respect to symbolic `x`. However, this result is not a callable function like it needs to be.

In [154]:
fx_sym = Symbolics.derivative(f(x,u,t), x) # x,u,t symbolic

# fx_sym(x,u,t) # throws error

4×4 Matrix{Num}:
  0  -u  1   0
  u   0  0  -1
 -1   0  0  -u
  0   1  u   0

From the [metaprogramming section of the julia manual:](https://docs.julialang.org/en/v1/manual/metaprogramming/)
"Julia represents its own code as a data structure of the language itself. Since code is represented by objects that can be created and manipulated from within the language, it is possible for a program to transform and generate its own code."

Which is incredibly cool.

[Symbolics.jl](https://symbolics.juliasymbolics.org/dev/tutorials/symbolic_functions/) uses this machinery to generate julia code from symbolic expressions.

In [155]:
fx_exp = build_function(fx_sym, x, u, t)
Base.remove_linenums!.(fx_exp)
fx_exp

(:(function (x, u, t)
      begin
          (SymbolicUtils.Code.create_array)(Array, nothing, Val{2}(), Val{(4, 4)}(), 0, u, -1, 0, (*)(-1, u), 0, 0, 1, 1, 0, 0, u, 0, -1, (*)(-1, u), 0)
      end
  end), :(function (ˍ₋out, x, u, t)
      begin
          [90m#= /home/mantas/.julia/packages/SymbolicUtils/v2ZkM/src/code.jl:398 =#[39m @inbounds begin
                  ˍ₋out[1] = 0
                  ˍ₋out[2] = u
                  ˍ₋out[3] = -1
                  ˍ₋out[4] = 0
                  ˍ₋out[5] = (*)(-1, u)
                  ˍ₋out[6] = 0
                  ˍ₋out[7] = 0
                  ˍ₋out[8] = 1
                  ˍ₋out[9] = 1
                  ˍ₋out[10] = 0
                  ˍ₋out[11] = 0
                  ˍ₋out[12] = u
                  ˍ₋out[13] = 0
                  ˍ₋out[14] = -1
                  ˍ₋out[15] = (*)(-1, u)
                  ˍ₋out[16] = 0
                  nothing
              end
      end
  end))

`fx_exp` is still not a callable function. Rather, it is an *expression* representing the julia code which defines a function.

Technically, `fx_exp` is a tuple containing two expression (`Expr`) objects which define anonymous functions. The first represents the normal version of the function, which generally allocates memory for it's output. The second represents an in-place version, which writes it's output into a pre-allocated container passed as the first argument. By informal convention, the names of in-place functions end with an exclamation point.

These expressions can be evaluated, and the resulting anonymous functions bound to variable names. Maybe relevant: functions in julia are [first-class objects](https://en.wikipedia.org/wiki/First-class_citizen).

In [156]:
fx_gen = eval(fx_exp[1]) # defines fx_gen(x,u,t)
fx_gen! = eval(fx_exp[2]) # defines fx_gen!(out,x,u,t)

#144 (generic function with 1 method)

Now, we can call `fx_gen` as a generic function - as if we had defined it manually.

In [157]:
fx_gen(x,u,t)

4×4 Matrix{Num}:
  0  -u  1   0
  u   0  0  -1
 -1   0  0  -u
  0   1  u   0

In [158]:
fx_gen(1,2,3) == fx(1,2,3)

true

Ok, but that takes 4 steps, which is at least 3 too many...

So, why not automate the automatic differentiation process?

In [159]:
function Dx(f)
    @variables x u t
    fx_sym = Symbolics.derivative(f(x,u,t), x)
    fx_exp = build_function(fx_sym, x, u, t)
    return fx_exp isa Expr ? eval(fx_exp) : eval(fx_exp[1])
end

function Du(f)
    @variables x u t
    fx_sym = Symbolics.derivative(f(x,u,t), u)
    fx_exp = build_function(fx_sym, x, u, t)
    return fx_exp isa Expr ? eval(fx_exp) : eval(fx_exp[1])
    # return eval(fx_exp[1])
end

Du (generic function with 1 method)

In [160]:
fx_auto = Dx(f) # creates an anonymous function (x,u,t)->(...) and binds it to fx_auto
fx_auto(x,u,t) # calls the function bound to fx_auto using the arguments x,u,t, which are currently symbolic variables

# or, if we want to be fancy
Dx(f)(x,u,t)

4×4 Matrix{Num}:
  0  -u  1   0
  u   0  0  -1
 -1   0  0  -u
  0   1  u   0

And it works! Note that as written, these methods are limited to 3-argument functions of the form `f(x,u,t)` which **must** return an array. However, they are still quite useful:

In [177]:
fxx = Dx(Dx(f))
fxu = Dx(Du(f))
fuu = Du(Du(f))

#198 (generic function with 1 method)

In [178]:
fxu(x,u,t)

4×4 Matrix{Int64}:
 0  -1  0   0
 1   0  0   0
 0   0  0  -1
 0   0  1   0

In [179]:
fxu(x,u,t) == H1

true

If we want to support a different function signature, we could hack it together using multiple dispatch:

In [173]:
m(x) = [x^5] # our "real" function -> must return an array
m(x,u,t) = m(x) # form-matching definition
Dx(Dx(Dx(m)))(x,u,t)

1-element Vector{Num}:
 60(x^2)

What about speed? Because the auto-diff machinery doesn't have to worry about human-readability, or code length, it can potentially generate more efficient function definitions than you can. Especially where array operation reductions are involved. At least with this example, the auto-diff version of the function runs **2~3x faster** than the manual definition. 

In [169]:
@benchmark fx(1,2,3)

BenchmarkTools.Trial: 10000 samples with 852 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m142.850 ns[22m[39m … [35m  4.327 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 93.37%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m157.372 ns               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m195.054 ns[22m[39m ± [32m283.495 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m10.32% ±  6.81%

  [39m [39m [39m [39m▇[39m▄[39m [39m▃[34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m▂

In [170]:
@benchmark fx_auto(1,2,3)

BenchmarkTools.Trial: 10000 samples with 986 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m51.405 ns[22m[39m … [35m  4.071 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 97.24%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m56.851 ns               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m76.585 ns[22m[39m ± [32m187.313 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m13.22% ±  5.31%

  [39m▁[39m▅[39m▇[39m█[39m█[34m█[39m[39m▇[39m▄[39m▁[39m [39m▁[39m▁[39m▁[39m▁[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m▁[39m▂[39m▃[39m▄[39m▅[39m▄[39m▄[39m▄[39m▄[39m▃[39m▃[39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[39

The in-place version is **more than 10x faster!**

In [171]:
out = zeros(4,4)
@benchmark fx_gen!(out,1,2,3)

BenchmarkTools.Trial: 10000 samples with 999 evaluations.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m12.325 ns[22m[39m … [35m56.321 ns[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m12.947 ns              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m13.266 ns[22m[39m ± [32m 1.548 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m▄[39m▆[39m█[34m█[39m[39m▄[32m▆[39m[39m█[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[39m█[34m█[39m