In [8]:
using Flux: Optimise.update!
using Flux: Chain, Dense, params, crossentropy, onehotbatch,
            ADAM, train!, softmax
using Test

In [9]:
update!

update! (generic function with 3 methods)

In [10]:
# Data preparation
function fizzbuzz(x::Int)
    is_divisible_by_three = x % 3 == 0
    is_divisible_by_five = x % 5 == 0

    if is_divisible_by_three & is_divisible_by_five
        return "fizzbuzz"
    elseif is_divisible_by_three
        return "fizz"
    elseif is_divisible_by_five
        return "buzz"
    else
        return "else"
    end
end

fizzbuzz (generic function with 1 method)

In [11]:
const LABELS = ["fizz", "buzz", "fizzbuzz", "else"];

@test fizzbuzz.([3, 5, 15, 98]) == LABELS

[32m[1mTest Passed[22m[39m

In [12]:
raw_x = 1:100;
raw_y = fizzbuzz.(raw_x);

# Feature engineering
features(x) = float.([x % 3, x % 5, x % 15])
features(x::AbstractArray) = hcat(features.(x)...)

X = features(raw_x);
y = onehotbatch(raw_y, LABELS);

In [13]:
# Model
m = Chain(Dense(3, 10), Dense(10, 4), softmax)
loss(x, y) = crossentropy(m(x), y)
opt = ADAM()

ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

In [17]:
params(m)

Params([Float32[0.023057418 -0.0779104 0.6689654; -0.23048396 -0.53844374 -0.26118523; … ; -0.19421501 -0.38229203 0.089005105; -0.5831439 0.2426148 -0.2227154], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.6195384 0.2286895 … 0.15670998 0.34346035; -0.5088699 -0.5411671 … -0.42071873 0.030781789; 0.05211034 0.38014564 … -0.40414646 -0.23240362; -0.28806746 -0.60102713 … 0.005193459 0.501395], Float32[0.0, 0.0, 0.0, 0.0]])

In [20]:
for param in params(m)
    println(param)
end

Float32[0.023057418 -0.0779104 0.6689654; -0.23048396 -0.53844374 -0.26118523; 0.2748903 -0.5912878 -0.6728332; -0.40165436 0.42718706 0.5492706; -0.22545971 0.23024796 -0.55546314; -0.28587177 -0.31489873 -0.14026001; 0.33795744 0.5896619 0.4808923; 0.47051287 -0.42391908 -0.15348402; -0.19421501 -0.38229203 0.089005105; -0.5831439 0.2426148 -0.2227154]
Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Float32[0.6195384 0.2286895 0.13969304 0.37691537 0.38964492 0.004673083 0.28477994 -0.28805906 0.15670998 0.34346035; -0.5088699 -0.5411671 -0.2862296 -0.46749154 0.051294032 0.48093313 -0.20001373 -0.5012951 -0.42071873 0.030781789; 0.05211034 0.38014564 -0.1538554 -0.02091431 -0.16248344 -0.053116284 0.19549423 0.43855277 -0.40414646 -0.23240362; -0.28806746 -0.60102713 0.022345578 0.37853628 -0.19736159 -0.39281538 0.3431777 0.59090185 0.005193459 0.501395]
Float32[0.0, 0.0, 0.0, 0.0]


In [21]:
# Helpers
deepbuzz(x) = (a = argmax(m(features(x))); a == 4 ? x : LABELS[a])

function monitor(e)
    print("epoch $(lpad(e, 4)): loss = $(round(loss(X,y); digits=4)) | ")
    @show deepbuzz.([3, 5, 15, 98])
end

monitor (generic function with 1 method)

In [22]:
# Training
for e in 0:1000
    train!(loss, params(m), [(X, y)], opt)
    if e % 50 == 0; monitor(e) end
end

epoch    0: loss = 1.4572 | deepbuzz.([3, 5, 15, 98]) = Any[3, 5, "buzz", 98]
epoch   50: loss = 0.85 | deepbuzz.([3, 5, 15, 98]) = Any[3, 5, "buzz", 98]
epoch  100: loss = 0.6153 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "buzz", 98]
epoch  150: loss = 0.4991 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "buzz", 98]
epoch  200: loss = 0.4197 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "buzz", 98]
epoch  250: loss = 0.3593 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  300: loss = 0.311 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  350: loss = 0.2711 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  400: loss = 0.2375 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  450: loss = 0.2088 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  500: loss = 0.184 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  550: loss = 0.1625 | deepbuzz.([3, 5, 15, 98]

---

In [34]:
using Flux

In [36]:
typeof(Flux.Zygote.Params(m))

Zygote.Params

In [43]:
ps = params(m[1])

Params([Float32[-0.35291705 0.31448567 0.56221163; -0.67652786 -0.22921392 -0.43120492; … ; -0.9367012 0.051641323 -0.029333679; -1.1013161 0.815074 -0.39341104], Float32[0.2141568, 0.6503017, 0.30419415, -0.276849, 0.2461206, 0.36051977, -0.41794553, 0.16720857, 0.37139475, -0.777117]])

In [40]:
gs = gradient(ps) do
    loss(X, y)
end

Grads(...)

In [44]:
for p in ps
    gs[p]
end

In [45]:
ps

Params([Float32[-0.35291705 0.31448567 0.56221163; -0.67652786 -0.22921392 -0.43120492; … ; -0.9367012 0.051641323 -0.029333679; -1.1013161 0.815074 -0.39341104], Float32[0.2141568, 0.6503017, 0.30419415, -0.276849, 0.2461206, 0.36051977, -0.41794553, 0.16720857, 0.37139475, -0.777117]])

In [49]:
for layer_idx in 1 : length(m)
    update!(opt, params(m[layer_idx]), gs)
end

In [47]:
ps

Params([Float32[-0.35305926 0.3146432 0.5622069; -0.6767462 -0.22926393 -0.43124616; … ; -0.9372081 0.05188132 -0.029396448; -1.1015414 0.81542957 -0.3934069], Float32[0.2142149, 0.6507139, 0.30468586, -0.27705213, 0.24619964, 0.36074647, -0.418279, 0.16744277, 0.37222475, -0.7775639]])

In [50]:
ps

Params([Float32[-0.35318756 0.31478557 0.56220263; -0.6769431 -0.22930926 -0.43128335; … ; -0.9376668 0.052098405 -0.029453149; -1.1017461 0.8157507 -0.39340314], Float32[0.2142679, 0.65108687, 0.30513912, -0.27723724, 0.2462748, 0.36095285, -0.41858378, 0.1676554, 0.3729904, -0.7779734]])