# Packages

In [1]:
using LinearAlgebra: diagm

In [2]:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`


In [3]:
include("graph.jl")

backward (generic function with 9 methods)

---
# Dataset

In [4]:
using MLDatasets
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

using Flux
function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

train_loader = loader(train_data)
test_loader = loader(test_data)

10000-element DataLoader(::Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, shuffle=true)
  with first element:
  (784×1 Matrix{Float32}, 10×1 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

---
# Operations

## Tanh

In [5]:
import Base: tanh

tanh(x::GraphNode) = BroadcastedOperator(tanh, x)
forward(::BroadcastedOperator{typeof(tanh)}, x) = return tanh.(x)
backward(node::BroadcastedOperator{typeof(tanh)}, x, g) = return tuple((1 .- tanh.(x).^2) .* g)

backward (generic function with 10 methods)

## Softmax

In [6]:
Softmax(x::GraphNode) = BroadcastedOperator(Softmax, x)
forward(::BroadcastedOperator{typeof(Softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(Softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 11 methods)

## Log

In [7]:
import Base: log
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    tuple(g ./ x)
end

backward (generic function with 12 methods)

---
# Network

## Hyperparameters

In [8]:
INPUT_SIZE = 196
HIDDEN_SIZE = 64
OUTPUT_SIZE = 10

10

## Weight init

I tried glorot, like so:

In [9]:
function glorot(size_a, size_b)
    total_size = size_a + size_b
    denum = sqrt(total_size)
    return randn(size_a, size_b) * sqrt(sqrt(6)/sqrt(denum)) .-= sqrt(6)/denum
end

function glorot(size_a)
    denum = sqrt(size_a)
    return randn(size_a) * sqrt(sqrt(6)/sqrt(denum)) .-= sqrt(6)/denum
end

glorot (generic function with 2 methods)

In [10]:
Wi = Variable(glorot(HIDDEN_SIZE, INPUT_SIZE), name="wi")
Wh = Variable(glorot(HIDDEN_SIZE, HIDDEN_SIZE), name="wh")
Wo = Variable(glorot(OUTPUT_SIZE, HIDDEN_SIZE), name="wo")

Bh = Variable(glorot(HIDDEN_SIZE), name="Bh")
Bo = Variable(glorot(OUTPUT_SIZE), name="Bh")

var Bh
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ Nothing

But *not* using glorot yields better results.

In [11]:
bound = 1/sqrt(HIDDEN_SIZE)  # read somewhere that this is a good way to init weights for tanh

Wi = Variable(bound .* randn(HIDDEN_SIZE, INPUT_SIZE), name="wi")
Wh = Variable(bound .* randn(HIDDEN_SIZE, HIDDEN_SIZE), name="wh")
Wo = Variable(bound .* randn(OUTPUT_SIZE, HIDDEN_SIZE), name="wo")

Bh = Variable(bound .* randn(HIDDEN_SIZE), name="Bh")
Bo = Variable(bound .* randn(OUTPUT_SIZE), name="Bo")

var Bo
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ Nothing

## Loss function

In [12]:
function cross_entropy_loss(prediction, label)
    return Constant(-1) .* sum(Variable(label) .* log.(prediction))
end

cross_entropy_loss (generic function with 1 method)

## Network

In [13]:
function net(sample, input_weights, hidden_weights, output_weights, label)
    i_1 = Variable(sample[1:196], name="first_step_input")
    i_2 = Variable(sample[197:392], name="second_step_input")
    i_3 = Variable(sample[393:588], name="third_step_input")
    i_4 = Variable(sample[589:784], name="fourth_step_input")

    s_1 = tanh(Wi * i_1 .+ Bh)
    s_1.name = "s_1"
    s_2 = tanh(Wi * i_2 .+ Wh * s_1 .+ Bh)
    s_2.name = "s_2"
    s_3 = tanh(Wi * i_3 .+ Wh * s_2 .+ Bh)
    s_3.name = "s_3"
    s_4 = tanh(Wi * i_4 .+ Wh * s_3 .+ Bh)
    s_4.name = "s_4"
    prediction = Softmax(Wo * s_4 .+ Bo)
    prediction.name = "prediction"

    E = cross_entropy_loss(prediction, label)
    E.name = "loss"

    return topological_sort(E), prediction
end

net (generic function with 1 method)

---
# Training

### Hyperparameters setup:

In [14]:
STEP_SIZE = 0.5  # this yields better accuracy than 15e-3
# STEP_SIZE = 15e-3
EPOCHS = 5
BATCH_SIZE = 100

100

### Main loop:

In [15]:
for epoch_index in range(start=1, stop=EPOCHS)
    Wi_grad_agg = zeros(HIDDEN_SIZE, INPUT_SIZE)
    Wh_grad_agg = zeros(HIDDEN_SIZE, HIDDEN_SIZE)
    Wo_grad_agg = zeros(OUTPUT_SIZE, HIDDEN_SIZE)

    Bh_grad_agg = zeros(HIDDEN_SIZE)
    Bo_grad_agg = zeros(OUTPUT_SIZE)
    
    # for (index, (s, l)) in enumerate(train_loader)
    @time for (index, (s, l)) in enumerate(train_loader)
        graph, predicted = net(s, Wi, Wh, Wo, l)
        forward!(graph)
        backward!(graph)
    
        Wi_grad_agg .+= Wi.gradient
        Wh_grad_agg .+= Wh.gradient
        Wo_grad_agg .+= Wo.gradient

        Bh_grad_agg .+= Bh.gradient
        Bo_grad_agg .+= Bo.gradient
        
        if index % BATCH_SIZE == 0
            Wi.output .-= ((Wi_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wh.output .-= ((Wh_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wo.output .-= ((Wo_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wi_grad_agg = zeros(HIDDEN_SIZE, INPUT_SIZE)
            Wh_grad_agg = zeros(HIDDEN_SIZE, HIDDEN_SIZE)
            Wo_grad_agg = zeros(OUTPUT_SIZE, HIDDEN_SIZE)

            Bh.output .-= ((Bh_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Bo.output .-= ((Bo_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Bh_grad_agg = zeros(HIDDEN_SIZE)
            Bo_grad_agg = zeros(OUTPUT_SIZE)
        end
    end
end

 20.396818 seconds (72.03 M allocations: 39.269 GiB, 8.84% gc time, 14.00% compilation time)
 18.065822 seconds (60.68 M allocations: 38.520 GiB, 9.63% gc time)
 17.540481 seconds (60.68 M allocations: 38.520 GiB, 9.40% gc time)
 17.218352 seconds (60.68 M allocations: 38.520 GiB, 8.72% gc time)
 17.524114 seconds (60.68 M allocations: 38.520 GiB, 8.94% gc time)


### Evaluation loop:

In [None]:
correct = 0
total = 0

@time for (index, (s, l)) in enumerate(test_loader)
    graph, predicted = net(s, Wi, Wh, Wo, l)
    forward!(graph)
    backward!(graph)
    if argmax(predicted.output) == argmax(l)[1]
        correct += 1
    end
    total += 1
end

println("\nTest accuracy:")
println(round(correct/total*100, digits=4), "%")

**Decent.**

---