# Packages

In [1]:
include("graph.jl")

backward (generic function with 9 methods)

In [2]:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`


---
# Dataset

In [3]:
using MLDatasets
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

using Flux
function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

train_loader = loader(train_data)
test_loader = loader(test_data)

10000-element DataLoader(::Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, shuffle=true)
  with first element:
  (784×1 Matrix{Float32}, 10×1 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

---
# Operations

## Tanh

In [4]:
import Base: tanh

# tanh(x::GraphNode) = ScalarOperator(tanh, x)
# forward(::ScalarOperator{typeof(tanh)}, x) = return tanh(x)
# backward(::ScalarOperator{typeof(tanh)}, x, g) = let
#     return tuple((1 - tanh(x)^2) * g)
# end
tanh(x::GraphNode) = BroadcastedOperator(tanh, x)
forward(::BroadcastedOperator{typeof(tanh)}, x) = return tanh.(x)
backward(node::BroadcastedOperator{typeof(tanh)}, x, g) = return tuple((1 .- tanh.(x).^2) .* g)

backward (generic function with 10 methods)

## Softmax

In [5]:
Softmax(x::GraphNode) = BroadcastedOperator(Softmax, x)
forward(::BroadcastedOperator{typeof(Softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(Softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 11 methods)

## Log

In [6]:
import Base: log
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    tuple(g ./ x)
end

backward (generic function with 12 methods)

---
# Network

In [7]:
INPUT_SIZE = 196
HIDDEN_SIZE = 64
OUTPUT_SIZE = 10

STEP_COUNT = 4

4

In [8]:
using LinearAlgebra

bound = 1/sqrt(HIDDEN_SIZE)  # read somewhere that this is a good way to init weights for tanh

Wi = Variable(bound .* rand(HIDDEN_SIZE, INPUT_SIZE), name="wi")
Wh = Variable(bound .* rand(HIDDEN_SIZE, HIDDEN_SIZE), name="wh")
Wo = Variable(randn(OUTPUT_SIZE, HIDDEN_SIZE), name="wo")

var wo
 ┣━ ^ 10×64 Matrix{Float64}
 ┗━ ∇ Nothing

In [9]:
function cross_entropy_loss(prediction, label)
    return Constant(-1) .* sum(Variable(label) .* log.(prediction))
end

cross_entropy_loss (generic function with 1 method)

In [19]:
function net(sample, input_weights, hidden_weights, output_weights, label)
    i_1 = Variable(transpose(sample[1:196]'), name="first_step_input")
    i_2 = Variable(transpose(sample[197:392]'), name="second_step_input")
    i_3 = Variable(transpose(sample[393:588]'), name="third_step_input")
    i_4 = Variable(transpose(sample[589:784]'), name="fourth_step_input")

    s_1 = tanh(Wi * i_1)
    s_1.name = "s_1"
    s_2 = tanh(Wi * i_2 .+ Wh * s_1)
    s_2.name = "s_2"
    s_3 = tanh(Wi * i_3 .+ Wh * s_2)
    s_3.name = "s_3"
    s_4 = tanh(Wi * i_4 .+ Wh * s_3)
    s_4.name = "s_4"
    prediction = Softmax(Wo * s_4)
    prediction.name = "prediction"

    E = cross_entropy_loss(prediction, label)
    E.name = "loss"

    return topological_sort(E), prediction
end

net (generic function with 1 method)

---
# Training

In [30]:
STEP_SIZE = 15e-3
EPOCHS = 5
BATCHSIZE = 100

100

In [31]:
for (s, l) in test_loader
    graph, predicted = net(s, Wi, Wh, Wo, l)
    forward!(graph)
    backward!(graph)
end