# Forward Mode

Modified from the reference: https://vladium.com/tutorials/study_julia_with_me/multiple_dispatch/

Overriding was inspired by https://github.com/JuliaLang/julia/blob/147bdf428cd14c979202678127d1618e425912d6/base/complex.jl

In [5]:
struct Context{T<:Number} <: Number
    v::T
    ∂::T
end

Context(x::T) where {T<:Number} = Context(x, 0.0)
Context(x::T, y::S) where {T<:Number, S<:Number} = Context(promote(x, y)...)

Context{T}(x::Real) where {T<:Real} = Context{T}(x, 0)
Context{T}(z::Context) where {T<:Real} = Context{T}(v(z), ∂(z))

Context(z::Context) = z

Base.promote_rule(::Type{Context{T}}, ::Type{S}) where {T<:Number, S<:Number} =
    Context{promote_type(T, S)}
Base.promote_rule(::Type{Context{T}}, ::Type{Context{S}}) where {T<:Real, S<:Real} =
    Context{promote_type(T, S)}

Define operators along with their derivatives

In [7]:
# binary operators
Base.:(+)(lhs::Context, rhs::Context) = Context(lhs.v + rhs.v, lhs.∂ + rhs.∂)
Base.:(-)(lhs::Context, rhs::Context) = Context(lhs.v - rhs.v, lhs.∂ - rhs.∂)

Base.:(*)(lhs::Context, rhs::Context) = Context(
    lhs.v * rhs.v, lhs.v * rhs.∂ + lhs.∂ * rhs.v
)
Base.:(/)(lhs::Context, rhs::Context) = Context(
    lhs.v / rhs.v, (lhs.∂ * rhs.v - lhs.v * rhs.∂) / rhs.v^2
)

# unary operators
Base.:(+)(x::Context) = x
Base.:(-)(x::Context) = Context(-x.v, -x.∂)

# Basic trigonometric functions
Base.:sin(x::Context) = Context(sin(x.v), cos(x.v) * x.∂)
Base.:cos(x::Context) = Context(cos(x.v), -sin(x.v) * x.∂)
Base.:tan(x::Context) = Context(tan(x.v), sec(x.v)^2 * x.∂)

# Exponential functions
Base.:exp(x::Context) = Context(exp(x.v), exp(x.v) * x.∂)
Base.:(^)(base::Context, expo::Context) = Context(
    base.v^expo.v,
    base.v^(expo.v - 1) * expo.v * base.∂ + base.v^expo.v * log(base.v) * expo.∂
)

# Logarithm
Base.log(x::Context) = Context(log(x.v), 1 / x.v * x.∂)
Base.log2(x::Context) = Context(log2(x.v), 1 / ln(2) / x.v * x.∂)
Base.log10(x::Context) = Context(log10(x.v), 1 / ln(10) / x.v * x.∂)
Base.log(b::Context, x::Context) = Context(
    log(b.v, x.v), -log(x.v) / b.v / log(b.v)^2 * b.∂ + 1 / x.v / log(b.v) * x.∂
)

# And more, if necessary.

Also test with the **Four-point method** to see if there's gonna be any error.

In [8]:
function fp_diff(f::Function, x::Number, h::Number = 8.8e-4)::Number
    return (f(x - 2h) - 8f(x - h) + 8f(x + h) - f(x + 2h)) / 12h
end

fp_diff (generic function with 2 methods)

Test with

$$
\begin{align*}
f(x, y) &= \sin(xy) \\
f^\prime(x, y) &= y\cos(xy) dx + x\cos(xy) dy \\
f^\prime(x, c) &= y\cos(xy) dx + \cancel{x\cos(xy) dy} \\
&= y\cos(xy) dx
\end{align*}
$$

In [13]:
function f1(x::Number, y::Number)::Number
    return sin(x * y)
end

x = Context(2, 1)
y = 3  # given the diff is 0 (y is constant)

computed = f1(x, y)
numerical = fp_diff(u -> f1(u, y), x.v)
symbolic = (y * cos(x * y)).v

println("f($(x.v), $y) = $(computed.v)")
println("f'($(x.v), $y)  = $(computed.∂)")
println("symbolic  = $symbolic")
println("numerical = $numerical")

f(2, 3) = -0.27941549819892586
f'(2, 3)  = 2.880510859951098
symbolic  = 2.880510859951098
numerical = 2.880510859945746


Test with

$$
\begin{align*}
f(x, y) &= xy + \sin(x) \\
f^\prime(x, y) &= (y + \cos(x))dx + x\ dy
\end{align*}
$$

In [15]:
function f2(x::Number, y::Number)::Number
    return x * y + sin(x)
end

x = Context(3, 1)
y = Context(1, 1)

computed = f2(x, y)
symbolic = (y + cos(x) + x).v

println("f($(x.v), $(y.v)) = $(computed.v)")
println("f'($(x.v), $(y.v)) = $(computed.∂)")
println("symbolic   = $symbolic")
println("numerical  = ")

f(3, 1) = 3.1411200080598674
f'(3, 1) = 3.010007503399555
actual   = 3.010007503399555


Test with

$$
f(x) = \frac{\log(1 + e^{5x^3})}{\sin(3x)+\cos(x/6)}
$$

We know that the derivative of $f(x)$ is (thanks to WolframAlpha)

$$
f^\prime(x) = \frac{15 e^{5 x^3} x^2}{(e^{5 x^3} + 1) (\sin(3 x) + \cos(x/6))} - \frac{\log(e^{5 x^3} + 1) (3 \cos(3 x) - \frac{1}{6} \sin(x/6))}{(\sin(3 x) + \cos(x/6))^2}
$$

In [17]:
f3(x) = log(1 + exp(5x^3)) / (sin(3x) + cos(x/6))
df3(x) = (15 * exp(5x^3) * x^2) / ((exp(5x^3) + 1) * (sin(3x) + cos(x/6))) - (log(exp(5x^3) + 1) * (3cos(3x) - sin(x/6) / 6)) / (sin(3x) + cos(x/6))^2

df3 (generic function with 1 method)

In [21]:
x = Context(1, 1)

computed = f3(x)
actual = df3(x.v)

println("f($(x.v)) = $(computed.v)")
println("f'($(x.v))  = $(computed.∂)")
println("actual = $actual")

f(1) = 4.4414784164916785
f'(1)  = 25.028317413642956
actual = 25.02831741364296


# Reverse Mode

*To be continued...*