# Structures

In [None]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name::String
    batch_gradient::Any
    Variable(output; name = "?") = new(output, nothing, name, nothing)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct BroadcastedOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    cache :: Any
    BroadcastedOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name, nothing)
end

# Graph builder

In [2]:
function visit(node::GraphNode, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::Operator, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::GraphNode)
    visited = Set()
    order = Vector()
    visit(head, visited, order)
    return order
end

topological_sort (generic function with 1 method)

# Forward pass

In [3]:
reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) =
    node.output = forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end

forward! (generic function with 1 method)

# Backward pass

In [4]:
update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = let
    node.gradient = gradient
    if typeof(node) == Variable
        if isnothing(node.batch_gradient)
            node.batch_gradient = gradient
        else
            node.batch_gradient .+= gradient
        end
    end
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    @assert length(result.output) == 1 "Gradient is defined only for scalar functions"
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end

backward! (generic function with 4 methods)

# Operators

In [5]:
import Base: sum

rnnCell(U :: GraphNode, W :: GraphNode, h :: GraphNode, b :: GraphNode, x :: GraphNode) = BroadcastedOperator(rnnCell, U, W, h, b, x)
forward(::BroadcastedOperator{typeof(rnnCell)}, U, W, h, b, x) = let
    Uh_mul = U * x
    Wx_mul = W * h

    vectors_sum = Uh_mul + Wx_mul + b
     
    return tanh.(vectors_sum)
end
backward(::BroadcastedOperator{typeof(rnnCell)}, U, W, h, b, x, g) = let 
    Uh_mul = U * x
    Wx_mul = W * h
    vectors_sum = Uh_mul + Wx_mul + b

    dh = g .* (1 .- tanh.(vectors_sum) .^ 2)

    dU = dh * x'
    dW = dh * h'
    db = sum(dh, dims=2)
    dx = U' * dh
    dh_prev = W' * dh

    return tuple(dU, dW, dh_prev, db, dx)
end


dense(x::GraphNode, w::GraphNode) = BroadcastedOperator(dense, x, w)
forward(::BroadcastedOperator{typeof(dense)}, x, w) = w * x
backward(::BroadcastedOperator{typeof(dense)}, x, w, g) = tuple(w' * g, g * x', g)

identity(x::GraphNode) = BroadcastedOperator(identity, x)
forward(::BroadcastedOperator{typeof(identity)}, x) = x
backward(::BroadcastedOperator{typeof(identity)}, x, g) = tuple(g)

cross_entropy_loss(y_hat::GraphNode, y::GraphNode) = BroadcastedOperator(cross_entropy_loss, y_hat, y)
forward(::BroadcastedOperator{typeof(cross_entropy_loss)}, y_hat, y) = let
    global predictions
    global correct_predictions

    predictions += 1
    if argmax(y_hat) == argmax(y)
        correct_predictions += 1
    end
    
    y_hat = y_hat .- maximum(y_hat)
    y_hat = exp.(y_hat) ./ sum(exp.(y_hat))
    loss = sum(log.(y_hat) .* y) * -1.0
    return loss
end
backward(::BroadcastedOperator{typeof(cross_entropy_loss)}, y_hat, y, g) = let
    y_hat = y_hat .- maximum(y_hat)
    y_hat = exp.(y_hat) ./ sum(exp.(y_hat))
    return tuple(g .* (y_hat .- y))
end

backward (generic function with 4 methods)

# Model

In [6]:
using Random
using Printf
predictions = 0
correct_predictions = 0

mutable struct myRNN
    WW
    WU
    WV
    bh
    by
    h
end

function update_weights!(graph::Vector, lr::Float64, batch_size::Int64)
    for node in graph
        if isa(node, Variable) && hasproperty(node, :batch_gradient)
			node.batch_gradient ./= batch_size
            node.output .-= lr * node.batch_gradient 
            fill(node.batch_gradient, 0)
        end
    end
end


function build_graph(x, y, rnn::myRNN, j:: Number)
    l1 = rnnCell(rnn.WU, rnn.WW, rnn.h, rnn.bh, Constant(x[1:196, j]))
    l2 = rnnCell(rnn.WU, rnn.WW, l1, rnn.bh, Constant(x[197:392, j]))
    l3 = rnnCell(rnn.WU, rnn.WW, l2, rnn.bh, Constant(x[393:588, j]))
    l4 = rnnCell(rnn.WU, rnn.WW, l3, rnn.bh, Constant(x[589:end, j]))
    l5 = dense(l4, rnn.WV) |> identity
    e = cross_entropy_loss(l5, y)

    return topological_sort(e)
end


function train(rnn::myRNN, x::Any, y::Any, epochs, batch_size, learning_rate)

    for i=1:epochs

        epoch_loss = 0.0
        samples = size(x, 2)

        global correct_predictions = 0
        global predictions = 0

        @time for j=1:samples
            y_train = Constant(y[:, j])
            
            graph = build_graph(x, y_train, rnn, j)
            rnn.h = Variable(zeros(64))
            epoch_loss += forward!(graph)
            backward!(graph)

            if j % batch_size == 0
                update_weights!(graph, learning_rate, batch_size)
            end
        end

        epoch = "Epoch $i"
        loss = epoch_loss/samples
        acc_calc = round(100 * (correct_predictions/predictions), digits=2)
        train_acc = "$acc_calc %"

        @info epoch loss train_acc
    end
end


function test(rnn::myRNN, x::Any, y::Any)

    samples = size(x, 2)

    global correct_predictions = 0
    global predictions = 0

    @time for j=1:samples
        y_train = Constant(y[:, j])
        graph = build_graph(x, y_train, rnn, j)
        rnn.h = Variable(zeros(64))
        forward!(graph)
    end

    test = "Test"
    acc_calc = round(100 * (correct_predictions/predictions), digits=2)
    test_acc = "$acc_calc %"

    @info test test_acc
end

test (generic function with 1 method)

# Settings and launch

In [7]:
using MLDatasets: MNIST
using Flux
train_data = MNIST(split=:train)  
test_data  = MNIST(split=:test)

x_train = reshape(train_data.features, 28 * 28, :)
y_train  = Flux.onehotbatch(train_data.targets, 0:9)

x_test = reshape(test_data.features, 28 * 28, :)
y_test  = Flux.onehotbatch(test_data.targets, 0:9)

WW = Variable(Flux.glorot_uniform(64,64))
WU = Variable(Flux.glorot_uniform(64,14*14))
WV = Variable(Flux.glorot_uniform(10,64))
bh = Variable(zeros(64))
by = Variable(zeros(10))
h = Variable(zeros(64))

rnn = myRNN(WW, WU, WV, bh, by, h)

settings = (;
    eta = 15e-3,
    epochs = 5,
    batch_size = 100,
)

println("Training model...")
train(rnn, x_train, y_train, settings.epochs, settings.batch_size, settings.eta)

println("Testing model...")
test(rnn, x_test, y_test) 

Training model...
 14.473325 seconds (40.37 M allocations: 33.483 GiB, 17.30% gc time, 11.90% compilation time)


┌ Info: Epoch 1
│   loss = 0.9640906713744877
│   train_acc = 74.1 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:66


 12.008625 seconds (36.07 M allocations: 33.196 GiB, 16.81% gc time)


┌ Info: Epoch 2
│   loss = 0.4704312321099117
│   train_acc = 87.41 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:66


 12.001599 seconds (36.07 M allocations: 33.196 GiB, 17.12% gc time)


┌ Info: Epoch 3
│   loss = 0.37615601475070337
│   train_acc = 89.71 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:66


 12.011165 seconds (36.07 M allocations: 33.196 GiB, 17.00% gc time)


┌ Info: Epoch 4
│   loss = 0.32743266330230286
│   train_acc = 90.88 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:66


 11.922957 seconds (36.07 M allocations: 33.196 GiB, 16.87% gc time)
Testing model...


┌ Info: Epoch 5
│   loss = 0.29511370468676784
│   train_acc = 91.74 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:66


  0.292641 seconds (2.03 M allocations: 194.522 MiB, 4.88% gc time)


┌ Info: Test
│   test_acc = 92.2 %
└ @ Main /Users/narodzi/Desktop/awid-projekt/main.ipynb:89
