In [1]:
import Base: show, show_unquoted, print

struct Variable
  name::Symbol
  number::Int
end
Symbol(x::Variable) = Symbol(x.name, x.number)

show(io::IO, x::Variable) = print(io, ":(", x.name, x.number, ")")
print(io::IO, x::Variable) = show_unquoted(io, x, 0, -1)
show_unquoted(io::IO, x::Variable, ::Int, ::Int) =
  print(io, x.name, x.number)

show_unquoted (generic function with 20 methods)

In [2]:
import Base: keys, lastindex, getindex, push!
import MacroTools: unblock

struct Wengert
  variable::Symbol
  instructions::Vector{Any}
end

Wengert(; variable = :y) = Wengert(variable, [])

keys(w::Wengert) = (Variable(w.variable, i) for i = 1:length(w.instructions))
lastindex(w::Wengert) = Variable(w.variable, length(w.instructions))
getindex(w::Wengert, v::Variable) = w.instructions[v.number]

function Base.show(io::IO, w::Wengert)
  println(io, "Wengert List")
  if length(w.instructions) < 1
    println(io, " (empty)")
  end
  for (i, x) in enumerate(w.instructions)
    print(io, Variable(w.variable, i), " = ")
    println(io, x)
  end
end

function push!(w::Wengert, x)
  return x
end

function push!(w::Wengert, x::Expr)
  x = Expr(x.head, map(x -> x isa Expr ? push!(w, x) : x, x.args)...)
  push!(w.instructions, x)
  return lastindex(w)
end

function Wengert(ex; variable = :y)
  w = Wengert(variable = variable)
  push!(w, ex)
  return w
end

function Expr(w::Wengert)
  cs = Dict()
  for x in w.instructions
    x isa Expr || continue
    for v in x.args
      v isa Variable || continue
      cs[v] = get(cs, v, 0) + 1
    end
  end
  bs = Dict()
  rename(ex::Expr) = Expr(ex.head, map(x -> get(bs, x, x), ex.args)...)
  rename(x) = x
  ex = :(;)
  for v in keys(w)
    if get(cs, v, 0) > 1
      push!(ex.args, :($(Symbol(v)) = $(rename(w[v]))))
      bs[v] = Symbol(v)
    else
      bs[v] = rename(w[v])
    end
  end
  push!(ex.args, rename(bs[lastindex(w)]))
  return unblock(ex)
end

Expr

In [106]:
import MacroTools
import MacroTools: @capture

addm(a, b) = a == 0 ? b : b == 0 ? a : :($a + $b)
mulm(a, b) = 0 in (a, b) ? 0 : a == 1 ? b : b == 1 ? a : :($a * $b)
mulm(a, b, c...) = mulm(mulm(a, b), c...)

function derive_forward(w::Wengert, x; out=Wengert())
  ds = Dict()
  ds[x] = 1
  d(y) = get(ds, y, 0)
  for v in keys(w)
    ex = w[v]
    Δ = @capture(ex, a_ + b_) ? addm(d(a), d(b)) :
        @capture(ex, a_ * b_) ? addm(mulm(a, d(b)), mulm(b, d(a))) :
        @capture(ex, a_^n_)   ? mulm(d(a),n,:($a^$(n-1))) :
        @capture(ex, a_ / b_) ? :($(mulm(b, d(a))) - $(mulm(a, d(b))) / $b^2) :
        @capture(ex, sin(a_)) ? mulm(:( cos($a)), d(a)) :
        @capture(ex, cos(a_)) ? mulm(:(-sin($a)), d(a)) :
        @capture(ex, exp(a_)) ? mulm(v, d(a)) :
        @capture(ex, log(a_)) ? mulm(:(1/$a), d(a)) :
        error("$ex is not differentiable")
    ds[v] = push!(out, Δ)
    println(ds)
  end
  return out
end

function derive_reverse(w::Wengert, x; out=Wengert())
  ds = Dict()
  d(x) = get(ds, x, 0)
  d(x, Δ) = ds[x] = haskey(ds, x) ? addm(ds[x],Δ) : Δ
  d(lastindex(w), 1)
  for v in reverse(collect(keys(w)))
    ex = w[v]
    Δ = d(v)
    if     @capture(ex, a_ + b_)
      d(a, Δ)
      d(b, Δ)
    elseif @capture(ex, a_ * b_)
      d(a, push!(out, mulm(Δ, b)))
      d(b, push!(out, mulm(Δ, a)))
    elseif @capture(ex, a_^n_)
      d(a, mulm(Δ, n, :($a^$(n-1))))
    elseif @capture(ex, a_ / b_)
      d(a, push!(out, mulm(Δ, b)))
      d(b, push!(out, :(-$(mulm(Δ, a))/$b^2)))
    elseif @capture(ex, sin(a_))
      d(a, push!(out, mulm(Δ, cos(a))))
    elseif @capture(ex, cos(a_))
      d(a, push!(out, mulm(Δ,-sin(a))))
    elseif @capture(ex, exp(a_))
      d(a, push!(out, mulm(Δ, v)))
    elseif @capture(ex, log(a_))
      d(a, push!(out, mulm(Δ, 1/a)))
    else
      error("$ex is not differentiable")
    end
  end
  push!(out, d(x))
  return out
end

derive_reverse (generic function with 1 method)

In [95]:
struct Staged
  w::Wengert
  var
end

import Base: +, -, *, ^, /
+(a::Staged, b::Staged) = Staged(w, push!(b.w, :($(a.var) + $(b.var))))
+(a::Staged, b)         = Staged(w, push!(a.w, :($(a.var) + $(b))))
+(a,         b::Staged) = Staged(w, push!(b.w, :($(a)     + $(b.var))))

-(a::Staged, b::Staged) = Staged(w, push!(b.w, :($(a.var) - $(b.var))))
-(a::Staged, b)         = Staged(w, push!(a.w, :($(a.var) - $(b))))
-(a,         b::Staged) = Staged(w, push!(b.w, :($(a)     - $(b.var))))

*(a::Staged, b::Staged) = Staged(w, push!(a.w, :($(a.var) * $(b.var))))
*(a::Staged, b)         = Staged(w, push!(a.w, :($(a.var) * $(b))))
*(a,         b::Staged) = Staged(w, push!(b.w, :($(a)     * $(b.var))))

^(a::Staged, b::Staged) = Staged(w, push!(a.w, :($(a.var) ^ $(b.var))))
^(a::Staged, b)         = Staged(w, push!(a.w, :($(a.var) ^ $(b))))
^(a,         b::Staged) = Staged(w, push!(b.w, :($(a)     ^ $(b.var))))

/(a::Staged, b::Staged) = Staged(w, push!(a.w, :($(a.var) / $(b.var))))
/(a::Staged, b)         = Staged(w, push!(a.w, :($(a.var) / $(b))))
/(a,         b::Staged) = Staged(w, push!(b.w, :($(a)     / $(b.var))))
nothing

In [60]:
import Base: abs, sin, cos, tan, exp, sqrt, max
abs(a::Staged)  = Staged(w, push!(a.w, :(abs($(a.var)))))
sin(a::Staged)  = Staged(w, push!(a.w, :(sin($(a.var)))))
cos(a::Staged)  = Staged(w, push!(a.w, :(cos($(a.var)))))
tan(a::Staged)  = Staged(w, push!(a.w, :(tan($(a.var)))))
exp(a::Staged)  = Staged(w, push!(a.w, :(exp($(a.var)))))
sqrt(a::Staged) = Staged(w, push!(a.w, :(sqrt($(a.var)))))
max(a::Staged, b::Staged) = Staged(w, push!(a.w, :(max($(a.var), $(b.var)))))
max(a,         b::Staged) = Staged(w, push!(b.w, :(max($(a)    , $(b.var)))))
max(a::Staged, b)         = Staged(w, push!(a.w, :(max($(a.var), $(b)))))
nothing

In [61]:
w = Wengert()
x = Staged(w, :x)

Staged(Wengert List
 (empty)
, :x)

In [62]:
function pow(x, n)
  r = 1
  for i = 1:n
    r *= x
  end
  return r
end

w = Wengert()
x = Staged(w, :x)
y = pow(x, 3)

Staged(Wengert List
y1 = 1x
y2 = y1 * x
y3 = y2 * x
, :(y3))

In [107]:
w = Wengert()
x = Staged(w, :x)
y = Staged(w, :y)
z = x*x + y*y*y
Dx = 2x
Dy = 3y^2
w

Wengert List
y1 = x * x
y2 = y * y
y3 = y2 * y
y4 = y1 + y3
y5 = 2x
y6 = y ^ 2
y7 = 3 * y6


In [109]:
#@show dx = derive_forward(w, :x);
@show dy = derive_forward(w, :y);

Dict{Any,Any}(:y => 1,:(y1) => 0)
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y1) => 0)
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y3) => :(y3),:(y1) => 0)
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y3) => :(y3),:(y1) => 0,:(y4) => :(y3))
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y5) => 0,:(y3) => :(y3),:(y1) => 0,:(y4) => :(y3))
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y5) => 0,:(y6) => :(y5),:(y3) => :(y3),:(y1) => 0,:(y4) => :(y3))
Dict{Any,Any}(:y => 1,:(y2) => :(y1),:(y5) => 0,:(y6) => :(y5),:(y3) => :(y3),:(y1) => 0,:(y7) => :(y6),:(y4) => :(y3))
dy = derive_forward(w, :y) = Wengert List
y1 = y + y
y2 = y * y1
y3 = y2 + y2
y4 = y ^ 1
y5 = 2 * y4
y6 = 3 * y5



In [82]:
@eval dzdx(x, y) = $(Expr(dx));
@eval dzdy(x, y) = $(Expr(dy));

In [101]:
dzdx(1.0, 2.0)

2.0

In [104]:
dzdy(1.0, 2.0), 3*2.0^2

(16.0, 12.0)

In [97]:
ww = Wengert(:(sin(x) * y))

Wengert List
y1 = sin(x)
y2 = y1 * y


In [100]:
derive_reverse(Wengert(:(x / (1 + x^2))), :x) |> Expr

quote
    y2 = y2 ^ 2
    y2 + ((-x / y2) * 2) * x ^ 1
end

In [None]:
derive_forward(ww, :x)