# Packages

In [None]:
include("graph.jl")

In [22]:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`


backward (generic function with 12 methods)

---
# Dataset

In [4]:
using MLDatasets
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

using Flux
function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

train_loader = loader(train_data)
test_loader = loader(test_data)

10000-element DataLoader(::Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, shuffle=true)
  with first element:
  (784×1 Matrix{Float32}, 10×1 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

---
# Operations

## Tanh

In [5]:
import Base: tanh

tanh(x::GraphNode) = ScalarOperator(tanh, x)
forward(::ScalarOperator{typeof(tanh)}, x) = return tanh.(x)
backward(::ScalarOperator{typeof(tanh)}, x, g) = let
    println("TANH")
    return tuple((1 - tanh(x)^2) * g)
end

backward (generic function with 10 methods)

## Softmax

In [56]:
Softmax(x::GraphNode) = BroadcastedOperator(Softmax, x)
forward(::BroadcastedOperator{typeof(Softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(Softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 12 methods)

## Log

In [58]:
import Base: log
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    tuple(g ./ x)
end

backward (generic function with 12 methods)

---
# Network

In [11]:
INPUT_SIZE = 196
HIDDEN_SIZE = 64
OUTPUT_SIZE = 10

STEP_COUNT = 4

4

In [12]:
using LinearAlgebra

bound = 1/sqrt(HIDDEN_SIZE)  # read somewhere that this is a good way to init weights for tanh

Wi = Variable(bound .* rand(INPUT_SIZE, HIDDEN_SIZE), name="wi")
Wh = Variable(bound .* rand(HIDDEN_SIZE, HIDDEN_SIZE), name="wh")
Wo = Variable(randn(HIDDEN_SIZE, OUTPUT_SIZE), name="wo")

var wo
 ┣━ ^ 64×10 Matrix{Float64}
 ┗━ ∇ Nothing

In [153]:
function cross_entropy_loss(prediction, label)
    return Constant(-1) .* sum(Variable(label') .* log.(prediction))
end

cross_entropy_loss (generic function with 1 method)

In [155]:
include("graph.jl")

function net(sample, input_weights, hidden_weights, output_weights, label)
    i_1 = Variable(transpose(sample[1:196]), name="first_step_input")
    i_2 = Variable(transpose(sample[197:392]), name="second_step_input")
    i_3 = Variable(transpose(sample[393:588]), name="third_step_input")
    i_4 = Variable(transpose(sample[589:784]), name="fourth_step_input")

    s_1 = tanh(i_1 * Wi)
    s_1.name = "s_1"
    s_2 = tanh(i_2 * Wi .+ s_1 * Wh)
    s_2.name = "s_2"
    s_3 = tanh(i_3 * Wi .+ s_2 * Wh)
    s_3.name = "s_3"
    s_4 = tanh(i_4 * Wi .+ s_3 * Wh)
    s_4.name = "s_4"
    prediction = Softmax(s_4 * Wo)
    prediction.name = "prediction"

    E = cross_entropy_loss(prediction, label)
    E.name = "loss"

    return topological_sort(E)
end

for (s, l) in test_loader
    graph = net(s, Wi, Wh, Wo, l)
    forward!(graph)
    backward!(graph)
end

LoadError: MethodError: no method matching diagm(::Matrix{Float64})

[0mClosest candidates are:
[0m  diagm(::Any, [91m::Any[39m, [91m::Pair{<:Integer, <:ChainRulesCore.AbstractThunk}[39m, [91m::Pair{<:Integer, <:ChainRulesCore.AbstractThunk}...[39m)
[0m[90m   @[39m [32mChainRulesCore[39m [90m~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/[39m[90m[4mthunks.jl:77[24m[39m
[0m  diagm([91m::StaticArraysCore.StaticArray{Tuple{N}, T, 1} where {N, T}[39m)
[0m[90m   @[39m [36mStaticArrays[39m [90m~/.julia/packages/StaticArrays/YN0oL/src/[39m[90m[4mlinalg.jl:175[24m[39m
[0m  diagm([91m::Pair{<:Val, <:StaticArraysCore.StaticArray{Tuple{N}, T, 1} where {N, T}}[39m, Pair{<:Val, <:StaticArraysCore.StaticArray{Tuple{N}, T, 1} where {N, T}}...)
[0m[90m   @[39m [36mStaticArrays[39m [90m~/.julia/packages/StaticArrays/YN0oL/src/[39m[90m[4mlinalg.jl:153[24m[39m
[0m  ...
