In [None]:
# Get to a deep learning model from some first principles.  Comment heavy

In [None]:
# This type definitions need to be in a module, so that they can be redefined without headache
module Layers
using IterTools
using Zygote
using Functors

# This (immutable) type closes over the non-linear function, the weights and the bias vector
struct Dense{F, M<:AbstractMatrix, B}
    σ::F
    weights::M
    bias::B
end
@functor Dense

# Here is the convenience outer constructor, which takes the dimensions of the Layer and the optional non-linearity
# Notice that it uses the default Dense constructor
# For now we assume that bias is not-optional
# Lets say that x, y is input, output dimensions.  Going to assume matrices, not tensors to keep it easy
function Dense(x, y, σ=identity)
    Dense(σ, rand(y,x), ones(y, 1))
end

# And the magic of dispatch turns the Layer into a function, so we can use the rest of the language to compose it
function (a::Dense)(x::AbstractVecOrMat)
  return a.σ.(a.weights * x .+ a.bias)
end

# Implement size protocol to get dimension checking
function size(l::Dense) 
    return Base.size(transpose(l.weights))
end

# ---------- Compose Layers into chain ------
# Convenience function to check whether two sizables conform
function conform((m1,m2))
    size(m1)[2] == size(m2)[1]
end

struct Chain{L}
    layers::L
end
@functor Chain

# Constructor checks the sizes of the layers, throws up as appropriate
function Chain(ls...)
    all(conform, partition(ls, 2, 1)) ? Chain(ls) : error("Layers in chain do not conform, dummy")
end

# Treat the chain as a function so we keep FP abstractions
# ♡ So we reverse and splat the layers into function composition. ♡
function (c::Chain)(x::AbstractVecOrMat)
    return ∘(reverse(c.layers)...)(x)
end


#-------- Training ----------
# Here we'll use auto differentiation, because its a primitive now :)
# Trick 1: Zygote will automatically seek out and differentiate with respect to keys in map, so because our layers are maps this happens automatically
# Trick 2: Functors.
function l2(ŷ, y)
    sum((ŷ - y).^2)
end


# Pass in the observed x and target ys, a loss function and the chain and it will return the optimized chain
function train(x, y, loss, chain, η=0.01f0)
    # Take a basic, 1000 iteration, approach
    count=0
    while(count < 1000)
        ∇ = gradient(ch -> loss(y, ch(x)), chain)[1]    # Get the gradinte of the loss function with respect to every parameter in the chain

        chain = fmap(chain, ∇) do x, dx                 # Update each parameter with a step along the gradient
            isnothing(x)  && return dx                  # We are explicitly using functors to travers the chain and its gradient, together
            isnothing(dx) && return x
            x .- η .* dx
        end
        count = count + 1
    end
    
    return chain
end
end

In [None]:
# Run the thing

ẋ = randn(2,1)        # Random input... we want the mapping from this to ȳ
ẏ = [3, -1, 1003]     # Give ourselves some target to hit

c = Layers.Chain(
    Layers.Dense(2, 3, (x)->1 ./ (1 + exp.(-x))),
    Layers.Dense(3,3))

ĉ = Layers.train(ẋ, ẏ, Layers.l2, c)    # Train an optimal chain

ĉ(ẋ)                                    # Look at it