# AutoDiff by Symboic Representation in Julia

In [1]:
using Symbolics

In [2]:
i(x) = x
f(x) = 3x^2
g(x) = 2x^2
h(x) = x^2
w_vec = [i, h, g, f]

@variables x 

1-element Vector{Num}:
 x

In [3]:
function forward_fn(w_vec, x, i::Int)
    y = w_vec[i](x)
    i == size(w_vec)[1] ? y : [y; forward_fn(w_vec,y,i+1)] 
end

forward_fn (generic function with 1 method)

In [4]:
x_vec = forward_fn(w_vec, x, 1)
display(x_vec)

4-element Vector{Num}:
             x
           x^2
        2(x^4)
 (12//1)*(x^8)

In [5]:
function gradient(w_i, x_i_1)
    @variables x
    dy = expand_derivatives(Differential(x)(w_i(x)))
    (substitute(dy, (Dict(x=>x_i_1,))),)
end

function reverse_autodiff(w_vec, x_vec, i::Int)
    i == 1 ? 1 :
        gradient(w_vec[i], x_vec[i-1])[1] * 
            reverse_autodiff(w_vec, x_vec, i-1)
end

reverse_autodiff (generic function with 1 method)

In [6]:
y_ad = x_vec[end]
display(y_ad)
dy_ad = reverse_autodiff(w_vec, x_vec, size(w_vec)[1])
display(dy_ad)

(12//1)*(x^8)

96(x^7)

## Check by theory

In [7]:
y_th = f(g(h(x)))
display(y_th)
dy_th = expand_derivatives(Differential(x)(y_th))
display(dy_th)

(12//1)*(x^8)

(96//1)*(x^7)

## Check by Zygote

In [18]:
using Symbolics
using Zygote

f(x) = 3x^2
g(x) = 2x^2
h(x) = x^2
y(x) = f(g(h(x)))
display(y(x))

dy(x) = Zygote.gradient(y,x)[1]
display(dy(x))

(12//1)*(x^8)

(96//1)*(x^7)

In [27]:
function y(x)
    N = 5
    y = 1
    for i=1:N
        y *= x
    end
    y
end

display(y(x))
dy(x) = Zygote.gradient(y,x)[1]
dy(x)

x^5

5(x^4)

In [28]:
function y(x, N)
    # N = 5
    y = 1
    for i=1:N
        y *= x
    end
    y
end

display(y(x, 5))
dy(x,N) = Zygote.gradient(y,x,N)[1]
dy(x,5)

x^5

5(x^4)

## All Codes

In [8]:
function gradient(w_i, x_i_1) # 1) Newly added
    @variables x
    dy = expand_derivatives(Differential(x)(w_i(x)))
    (substitute(dy, (Dict(x=>x_i_1,))),)
end

function main(w_vec)
    @variables x # 2) Replaced from x = 2.0 
    x_vec = forward_fn(w_vec, x, 1)
    y_ad = x_vec[end]
    dy_ad = reverse_autodiff(w_vec, x_vec, size(w_vec)[1])
    return x_vec, y_ad, dy_ad
end

i(x) = x
f(x) = 3x^2
g(x) = 2x^2
h(x) = x^2
w_vec = [i, h, g, f]
x_vec, y_ad, dy_ad = main(w_vec)
display(x_vec)
display(y_ad)
display(dy_ad)

# 3) Verification code
@variables x
y_th = f(g(h(x)))
display(y_th)
dy_th = expand_derivatives(Differential(x)(y_th))
display(dy_th)

4-element Vector{Num}:
             x
           x^2
        2(x^4)
 (12//1)*(x^8)

(12//1)*(x^8)

96(x^7)

(12//1)*(x^8)

(96//1)*(x^7)