# AWID RNN
#### Tomasz Mycielski

## External packages

In [1]:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`


---
## Boilerplate
(from lecture #4)

In [2]:
import Base: *
import LinearAlgebra: mul!, diagm

In [3]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name :: String
    Variable(output; name="?") = new(output, nothing, name)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct BroadcastedOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    BroadcastedOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

import Base: show, summary
show(io::IO, x::ScalarOperator{F}) where {F} = print(io, "op ", x.name, "(", F, ")");
show(io::IO, x::BroadcastedOperator{F}) where {F} = print(io, "op.", x.name, "(", F, ")");
show(io::IO, x::Constant) = print(io, "const ", x.output)
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end

function visit(node::GraphNode, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::Operator, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::GraphNode)
    visited = Set()
    order = Vector()
    visit(head, visited, order)
    return order
end

reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) = let
    node.output = forward(node, [input.output for input in node.inputs]...)
end

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end

update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = if isnothing(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    #@assert length(result.output) == 1 "Gradient is defined only for scalar functions"
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end

backward! (generic function with 4 methods)

## Dataset

In [4]:
import MLDatasets
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

import Flux
function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

In [5]:
train_loader = loader(train_data)
test_loader = loader(test_data)

10000-element DataLoader(::Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, shuffle=true)
  with first element:
  (784×1 Matrix{Float32}, 10×1 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

---
## Implemented operations

### Multiplication

In [6]:
*(A::GraphNode, x::GraphNode) = BroadcastedOperator(mul!, A, x)
forward(::BroadcastedOperator{typeof(mul!)}, A, x) = return A * x
backward(::BroadcastedOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = BroadcastedOperator(*, x, y)
forward(::BroadcastedOperator{typeof(*)}, x, y) = return x .* y
backward(node::BroadcastedOperator{typeof(*)}, x, y, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(y .* 𝟏)
    Jy = diagm(vec(x .* 𝟏))
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 2 methods)

### Addition

In [7]:
Base.Broadcast.broadcasted(+, x::GraphNode, y::GraphNode) = BroadcastedOperator(+, x, y)
forward(::BroadcastedOperator{typeof(+)}, x, y) = return x .+ y
backward(::BroadcastedOperator{typeof(+)}, x, y, g) = tuple(g, g)

backward (generic function with 3 methods)

### Summation

In [8]:
import Base: sum
sum(x::GraphNode) = BroadcastedOperator(sum, x)
forward(::BroadcastedOperator{typeof(sum)}, x) = return sum(x)
backward(::BroadcastedOperator{typeof(sum)}, x, g) = let
    𝟏 = ones(length(x))
    J = 𝟏'
    tuple(J' * g)
end

backward (generic function with 4 methods)

### Tanh

In [9]:
import Base: tanh

In [10]:
tanh(x::GraphNode) = BroadcastedOperator(tanh, x)
forward(::BroadcastedOperator{typeof(tanh)}, x) = return tanh.(x)
backward(node::BroadcastedOperator{typeof(tanh)}, x, g) = return tuple((1 .- tanh.(x).^2) .* g)

backward (generic function with 5 methods)

### Softmax

In [11]:
Softmax(x::GraphNode) = BroadcastedOperator(Softmax, x)
forward(::BroadcastedOperator{typeof(Softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(Softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 6 methods)

### Log

In [12]:
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    tuple(g ./ x)
end

backward (generic function with 7 methods)

---
## Network

### Hyperparameters

In [13]:
INPUT_SIZE = 196
HIDDEN_SIZE = 64
OUTPUT_SIZE = 10

10

### Weights init

In [14]:
function glorot(size_a, size_b)
    total_size = size_a + size_b
    denum = sqrt(total_size)
    return randn(size_a, size_b) .* ( sqrt(6)/denum )
end

glorot (generic function with 1 method)

In [15]:
Wi = Variable(glorot(HIDDEN_SIZE, INPUT_SIZE))
Wh = Variable(glorot(HIDDEN_SIZE, HIDDEN_SIZE))
Wo = Variable(glorot(OUTPUT_SIZE, HIDDEN_SIZE))

Bh = Variable(zeros(HIDDEN_SIZE))
Bo = Variable(zeros(OUTPUT_SIZE))

var ?
 ┣━ ^ 10-element Vector{Float64}
 ┗━ ∇ Nothing

### Loss function

In [16]:
function cross_entropy_loss(prediction, label)
    return Constant(-1) .* sum(Variable(label) .* log.(prediction))
end

cross_entropy_loss (generic function with 1 method)

### Network structure

In [17]:
function net(sample, input_weights, hidden_weights, output_weights, label)
    i_1 = Variable(sample[1:196], name="first_step_input")
    i_2 = Variable(sample[197:392], name="second_step_input")
    i_3 = Variable(sample[393:588], name="third_step_input")
    i_4 = Variable(sample[589:784], name="fourth_step_input")

    s_1 = tanh(Wi * i_1 .+ Bh)
    s_1.name = "s_1"
    s_2 = tanh(Wi * i_2 .+ Wh * s_1 .+ Bh)
    s_2.name = "s_2"
    s_3 = tanh(Wi * i_3 .+ Wh * s_2 .+ Bh)
    s_3.name = "s_3"
    s_4 = tanh(Wi * i_4 .+ Wh * s_3 .+ Bh)
    s_4.name = "s_4"
    prediction = Softmax(Wo * s_4 .+ Bo)
    prediction.name = "prediction"

    E = cross_entropy_loss(prediction, label)
    E.name = "loss"

    return topological_sort(E), prediction
end

net (generic function with 1 method)

---
## Model

### Hyperparameters

In [18]:
STEP_SIZE = 15e-3
EPOCHS = 5
BATCH_SIZE = 100

100

### Training loop

In [19]:
for epoch_index in range(start=1, stop=EPOCHS)
    Wi_grad_agg = zeros(HIDDEN_SIZE, INPUT_SIZE)
    Wh_grad_agg = zeros(HIDDEN_SIZE, HIDDEN_SIZE)
    Wo_grad_agg = zeros(OUTPUT_SIZE, HIDDEN_SIZE)

    Bh_grad_agg = zeros(HIDDEN_SIZE)
    Bo_grad_agg = zeros(OUTPUT_SIZE)
    
    @time for (index, (s, l)) in enumerate(train_loader)
        graph, predicted = net(s, Wi, Wh, Wo, l)
        forward!(graph)
        backward!(graph)
    
        Wi_grad_agg .+= Wi.gradient
        Wh_grad_agg .+= Wh.gradient
        Wo_grad_agg .+= Wo.gradient

        Bh_grad_agg .+= Bh.gradient
        Bo_grad_agg .+= Bo.gradient
        
        if index % BATCH_SIZE == 0
            Wi.output .-= ((Wi_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wh.output .-= ((Wh_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wo.output .-= ((Wo_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Wi_grad_agg = zeros(HIDDEN_SIZE, INPUT_SIZE)
            Wh_grad_agg = zeros(HIDDEN_SIZE, HIDDEN_SIZE)
            Wo_grad_agg = zeros(OUTPUT_SIZE, HIDDEN_SIZE)

            Bh.output .-= ((Bh_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Bo.output .-= ((Bo_grad_agg/BATCH_SIZE) * STEP_SIZE)
            Bh_grad_agg = zeros(HIDDEN_SIZE)
            Bo_grad_agg = zeros(OUTPUT_SIZE)
        end
    end
end

 20.519057 seconds (72.21 M allocations: 39.282 GiB, 9.28% gc time, 14.24% compilation time)
 17.402250 seconds (60.71 M allocations: 38.523 GiB, 9.42% gc time)
 17.884341 seconds (60.71 M allocations: 38.522 GiB, 9.63% gc time)
 17.410530 seconds (60.71 M allocations: 38.523 GiB, 9.05% gc time)
 17.343649 seconds (60.71 M allocations: 38.523 GiB, 9.02% gc time)


- 17.5 seconds per epoch
- 61.68 milion allocations (38.5 GiB) per epoch
- 8% gc time

Total training time is a minute and a half.

### Testing loop

In [20]:
correct = 0
total = 0

@time for (index, (s, l)) in enumerate(test_loader)
    graph, predicted = net(s, Wi, Wh, Wo, l)
    forward!(graph)
    backward!(graph)
    if argmax(predicted.output) == argmax(l)[1]
        correct += 1
    end
    total += 1
end

println("\nTest accuracy:")
println(round(correct/total*100, digits=4), "%")

  3.047552 seconds (10.21 M allocations: 6.389 GiB, 12.09% gc time, 1.25% compilation time)

Test accuracy:
91.9%


**Decent**, but I've seen it get as high as $96.5\%$.