# 2.5.1. A Simple Function

In [68]:
using LinearAlgebra, Zygote, ChainRules

In [69]:
x = Vector(0:1.:3)

4-element Vector{Float64}:
 0.0
 1.0
 2.0
 3.0

In [70]:
y = x -> 2 * dot(x, x)
y(x)

28.0

In [71]:
y'(x)

4-element Vector{Float64}:
  0.0
  4.0
  8.0
 12.0

In [72]:
y'(x) .== 4 * x

4-element BitVector:
 1
 1
 1
 1

In [73]:
y = x -> sum(x)
gradient(y, x)

(Fill(1.0, 4),)

# 2.5.2. Backward for Non-Scalar Variables

In [52]:
gradient(x -> dot(x, x), x)

([0.0, 2.0, 4.0, 6.0],)

# 2.5.3. Detaching Computation

In [96]:
stop_gradient(f) = f()
Zygote.@nograd stop_gradient
y = dot(x, x)
z = stop_gradient() do
    y * x
end
gradient(z, x)

LoadError: MethodError: objects of type Vector{Float64} are not callable
Use square brackets [] for indexing an Array.

# 2.5.4. Gradients and Python Control Flow

In [97]:
function f(a)
    b = a * 2
    while norm(b) < 1000
        b = b * 2
    end
    if sum(b) > 0
        c = b
    else
        c = 100 * b
    end
    return c
end

f (generic function with 1 method)

In [102]:
a = randn()
d = f(a)
f'(a)

2048.0

In [103]:
f'(a) == d / a

true