In [1]:
func square(_ x: Float) -> Float {
    return x * x
}

func square_derivative(_ x: Float) -> Float {
    return 2 * x
}

In [2]:
import Glibc

func exampleFunction(_ x: Float) -> Float {
    return sin(x * x)
}

func exampleFunctionDerivative(_ x: Float) -> Float {
    return 2 * x * cos(x * x)
}

In [3]:
func exampleFunctionDerivativeEfficient(_ x: Float) -> (value: Float, backward: () -> Float) {
    let xSquared = x * x
    let value = sin(xSquared)
    let backward = {2 * x * cos(xSquared)}  // A closure that captures xSquared
    return (value: value, backward: backward)
}

In [4]:
func exampleFunctionValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {
    let xSquared = x * x
    let value = sin(xSquared)
    let deriv = { (v: Float) -> Float in
        let gradXSquared = v * cos(xSquared)
        let gradX = gradXSquared * 2 * x
        return gradX
    }
    return (value: value, deriv: deriv)
}

In [5]:
func sinValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {
    return (value: sin(x), deriv: {v in cos(x) * v})
}

func squareValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {
    return (value: x * x, deriv: {v in 2 * x * v})
}

func exampleFunctionWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {
    let (xSquared, deriv1) = squareValueWithDeriv(x)
    let (value, deriv2) = sinValueWithDeriv(xSquared)
    return (value: value, deriv: { v in
        let gradXSquared = deriv2(v)
        let gradX = deriv1(gradXSquared)
        return gradX
    })
}

In [6]:
func myComplexFunction(_ x: Float) -> Float {
    let tmp1 = square(x)
    let tmp2 = sin(tmp1)
    let tmp3 = tmp2 + tmp1
    return tmp3
}

func plusWithDeriv(_ x: Float, _ y: Float) -> (value: Float, deriv: (Float) -> (Float, Float)) {
    return (value: x + y, deriv: {v in (v, v)})  // Value semantics are great! :-)
}

In [7]:
func myComplexFunctionValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {
    let (tmp1, pb1) = squareValueWithDeriv(x)
    let (tmp2, pb2) = sinValueWithDeriv(tmp1)
    let (tmp3, pb3) = plusWithDeriv(tmp2, tmp1)
    return (value: tmp3,
            deriv: { v in
        // Initialize the gradients for all values at zero.
        var gradX = Float(0.0)
        var grad1 = Float(0.0)
        var grad2 = Float(0.0)
        var grad3 = Float(0.0)
        // Add the temporaries to the gradients as we run the backwards pass.
        grad3 += v
        let (tmp2, tmp1b) = pb3(grad3)
        grad2 += tmp2
        grad1 += tmp1b
        let tmp1a = pb2(grad2)
        grad1 += tmp1a
        let tmpX = pb1(grad1)
        gradX += tmpX
        // Return the computed gradients.
        return gradX
    })
}

In [8]:
// Helper method
func square(_ x: Float) -> Float {
    return x * x
}

In [9]:
@differentiable
func myFunction(_ x: Float) -> Float {
    return x * x
}

In [10]:
let (value, deriv) = valueWithPullback(at: 3, in: myFunction)
print(value)
print(type(of: deriv))

9.0
(Float) -> Float


In [11]:
deriv(1)

6.0
