# Flux for ML/RL

This notebooks is going to lay out some basics about [Flux.jl](https://fluxml.ai). Unfortunately, flux currently takes quite awhile to precompile and load. This is being worked on, but I recommend evaluating the next cell before digging into the text which appears after. This will speed up after the first time as there will be cached pre-compiled version of the library (much like Plots from before).


In [1]:
using Flux, Random, Statistics, BenchmarkTools

## What is Flux?

Flux is a deep learning framework that uses source-to-source automatic differentiation through Zygote.jl. The resulting library is incredibly flexible and can deferintiate through many Julia functions right out of the box. The benefit of this is that _all_ of Flux's models are written in pure julia (even GPU operations!!), and the library can take full advantage of multiple dispatch. We will discuss the nice features which come from this down the road, but first lets start with a simple example (artificial regression with a linear model). We can then move to talk about how Flux can be used in RL research.


Because our problem is artificial, we will need to create a dataset.

In [2]:

Random.seed!(10293)

train_points = 2^14
val_points = 2^9
feature_size = 10
ϵ = 0.01f0

target_model = Chain(Dense(feature_size, 256, relu), Dense(256, 1)) # These layers default to using the global random seed!

X_train = randn(Float32, feature_size, train_points)
Y_train = target_model(X_train) + ϵ*randn(Float32, train_points)'

X_val = randn(Float32, feature_size, val_points)
Y_val = target_model(X_val) + ϵ*randn(Float32, val_points)'

1×512 Array{Float32,2}:
 0.0671812  0.210611  -0.212441  …  0.173896  -0.878224  -0.00244423

Now with the dataset created, we will setup a model and do a simple training loop with mini-batch gradient descent. We will decompose some of the flux primitives afterwards.

In [6]:
batchsize = 64
opt = Descent(0.01)

model = Chain(Dense(feature_size, 64, relu), Dense(64, 1))
loss(x, y) = Flux.mse(model(x), y)

println("Initial:")
@show loss(X_train, Y_train)
@show loss(X_val, Y_val)
println()

for n ∈ 1:100
    train_loader = Flux.Data.DataLoader(X_train, Y_train, batchsize=batchsize, shuffle=true)
    Flux.train!(
        loss, Flux.params(model), train_loader, opt)
    if (n) % 10 == 0
        println("Epoch: $(n)")
        @show loss(X_train, Y_train)
        @show loss(X_val, Y_val)
        println()
    end
end



Initial:
loss(X_train, Y_train) = 0.33594036f0
loss(X_val, Y_val) = 0.36127582f0

Epoch: 10
loss(X_train, Y_train) = 0.004389817f0
loss(X_val, Y_val) = 0.0050601633f0

Epoch: 20
loss(X_train, Y_train) = 0.0036121837f0
loss(X_val, Y_val) = 0.0042091855f0

Epoch: 30
loss(X_train, Y_train) = 0.003174251f0
loss(X_val, Y_val) = 0.0037046995f0

Epoch: 40
loss(X_train, Y_train) = 0.0028663334f0
loss(X_val, Y_val) = 0.003357669f0

Epoch: 50
loss(X_train, Y_train) = 0.0026439303f0
loss(X_val, Y_val) = 0.0031033668f0

Epoch: 60
loss(X_train, Y_train) = 0.0024611754f0
loss(X_val, Y_val) = 0.0028887577f0

Epoch: 70
loss(X_train, Y_train) = 0.002320534f0
loss(X_val, Y_val) = 0.0027234575f0

Epoch: 80
loss(X_train, Y_train) = 0.0022109547f0
loss(X_val, Y_val) = 0.0026023488f0

Epoch: 90
loss(X_train, Y_train) = 0.0021132152f0
loss(X_val, Y_val) = 0.002481657f0

Epoch: 100
loss(X_train, Y_train) = 0.0020341163f0
loss(X_val, Y_val) = 0.0023871385f0



## Custom Training Loop

The first piece we need to decompose is the training loop. In the above example we are using Flux's built in `train!` function. The beauty of Julia and Flux is that this is written all using Julia (meaning we can customize our training loop w/o any extra computational cost). While not as useful for the purposes of ML, for RL this is a critical component as the training loop contains interactions with the environment and other various processing book keeping ideas.



In [28]:
function cust_train!(loss::Function, m, ps, data, opt)
    for d in data
        gs = gradient(ps) do
            training_loss = loss(m, d...)
            # Insert what ever code you want here that needs Training loss, e.g. logging
            return training_loss
        end
        # insert what ever code you want here that needs gradient
        # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
        Flux.Optimise.update!(opt, ps, gs)
        # Here you might like to check validation set accuracy, and break out to do early stopping
    end
end

cust_train! (generic function with 1 method)

## Custom Layer

Just like the training loop, all of Flux's layers are written in Julia. Below is an example of the a Dense layer, but there are plenty of other examples and layers (all written in Julia) found [here](https://github.com/FluxML/Flux.jl/tree/master/src/layers).

In [35]:
struct CustDense{S, B, F}
    W::S
    b::B
    σ::F
end

CustDense(W, b) = CustDense(W, b, identity)

function CustDense(in::Integer, out::Integer, σ = identity;
               initW = Flux.glorot_uniform, initb = Flux.zeros)
    return CustDense(initW(out, in), initb(out), σ)
end

(l::CustDense)(X) = l.σ.(l.W*X .+ l.b)
Flux.@functor CustDense


## Custom Optimiser

Again, there are plenty 

In [38]:
struct CustDescent
  eta::Float64
end

CustDescent() = CustDescent(0.1)

function Flux.Optimise.apply!(o::CustDescent, x, Δ)
  Δ .*= o.eta
end

## Putting it all together

In [40]:
cust_model = Chain(CustDense(feature_size, 64, relu), CustDense(64, 1))
opt = CustDescent(0.01)

println("Initial:")
@show Flux.mse(cust_model(X_train), Y_train)
@show Flux.mse(cust_model(X_val), Y_val)
println()

for n ∈ 1:100
    train_loader = Flux.Data.DataLoader(X_train, Y_train, batchsize=batchsize, shuffle=true)
    cust_train!(cust_model, Flux.params(cust_model), train_loader, opt) do m, X, Y
        Flux.mse(m(X), Y)
    end
    if (n) % 10 == 0
        println("Epoch: $(n)")
        @show Flux.mse(cust_model(X_train), Y_train)
        @show Flux.mse(cust_model(X_val), Y_val)
        println()
    end
end


Initial:
Flux.mse(cust_model(X_train), Y_train) = 0.27601916f0
Flux.mse(cust_model(X_val), Y_val) = 0.2873298f0

Epoch: 10
Flux.mse(cust_model(X_train), Y_train) = 0.003711562f0
Flux.mse(cust_model(X_val), Y_val) = 0.003937993f0

Epoch: 20
Flux.mse(cust_model(X_train), Y_train) = 0.0030475203f0
Flux.mse(cust_model(X_val), Y_val) = 0.0032307252f0

Epoch: 30
Flux.mse(cust_model(X_train), Y_train) = 0.0027394698f0
Flux.mse(cust_model(X_val), Y_val) = 0.0029539869f0

Epoch: 40
Flux.mse(cust_model(X_train), Y_train) = 0.0025184238f0
Flux.mse(cust_model(X_val), Y_val) = 0.002762441f0

Epoch: 50
Flux.mse(cust_model(X_train), Y_train) = 0.0023464428f0
Flux.mse(cust_model(X_val), Y_val) = 0.0026077998f0

Epoch: 60
Flux.mse(cust_model(X_train), Y_train) = 0.0022084129f0
Flux.mse(cust_model(X_val), Y_val) = 0.0024781355f0

Epoch: 70
Flux.mse(cust_model(X_train), Y_train) = 0.002097625f0
Flux.mse(cust_model(X_val), Y_val) = 0.002386366f0

Epoch: 80
Flux.mse(cust_model(X_train), Y_train) = 0.002007