# Packages

In [1]:
using Pkg
Pkg.add("MLDatasets")
Pkg.add("Flux")

include("graph.jl")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.10/Manifest.toml`


backward (generic function with 9 methods)

# Dataset

In [2]:
using MLDatasets
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

dataset MNIST:
  metadata  =>    Dict{String, Any} with 3 entries
  split     =>    :test
  features  =>    28√ó28√ó10000 Array{Float32, 3}
  targets   =>    10000-element Vector{Int64}

In [3]:
using Flux
function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28√ó28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10√ó60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

In [4]:
train_loader = loader(train_data)
test_loader = loader(test_data)

10000-element DataLoader(::Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, shuffle=true)
  with first element:
  (784√ó1 Matrix{Float32}, 10√ó1 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

# Implemented operations

## Multiplication

In [87]:
import Base: *
import LinearAlgebra: mul!

# x * y (aka matrix multiplication)
*(A::GraphNode, x::GraphNode) = BroadcastedOperator(mul!, A, x)
forward(::BroadcastedOperator{typeof(mul!)}, A, x) = return A * x
backward(::BroadcastedOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

# x .* y (element-wise multiplication)
Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = BroadcastedOperator(*, x, y)
forward(::BroadcastedOperator{typeof(*)}, x, y) = return x .* y
backward(node::BroadcastedOperator{typeof(*)}, x, y, g) = let
    ùüè = ones(length(node.output))
    Jx = diagm(y .* ùüè)
    Jy = diagm(x .* ùüè)
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 14 methods)

## Addition

In [6]:
Base.Broadcast.broadcasted(+, x::GraphNode, y::GraphNode) = BroadcastedOperator(+, x, y)
forward(::BroadcastedOperator{typeof(+)}, x, y) = return x .+ y
backward(::BroadcastedOperator{typeof(+)}, x, y, g) = let
    println("ADD")
    return tuple(g, g)
end

backward (generic function with 9 methods)

## Summation

In [38]:
import Base: sum
sum(x::GraphNode) = BroadcastedOperator(sum, x)
forward(::BroadcastedOperator{typeof(sum)}, x) = let
    println("SUM_FWD")
    display(x)
    return sum(x)
end
backward(::BroadcastedOperator{typeof(sum)}, x, g) = let
    println("SUM")
    println("x: ")
    display(x)
    ùüè = ones(Float32, size(x))
    println("1:")
    display(ùüè)
    tuple(ùüè .* g)
end

backward (generic function with 14 methods)

## Tanh

In [8]:
import Base: tanh

tanh(x::GraphNode) = ScalarOperator(tanh, x)
forward(::ScalarOperator{typeof(tanh)}, x) = return tanh.(x)
backward(::ScalarOperator{typeof(tanh)}, x, g) = let
    println("TANH")
    return tuple((1 - tanh(x)^2) * g)
end

backward (generic function with 10 methods)

## Log

In [9]:
import Base: log
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    println("LOG")
    tuple(g ./ x)
end

backward (generic function with 11 methods)

## Subtraction

In [10]:
import Base: -
-(x::GraphNode) = ScalarOperator(-, x)
forward(::ScalarOperator{typeof(-)}, x) = return .-x
backward(::ScalarOperator{typeof(-)}, x, g) = tuple(-g)

Base.Broadcast.broadcasted(-, x::GraphNode, y::GraphNode) = BroadcastedOperator(-, x, y)
forward(::BroadcastedOperator{typeof(-)}, x, y) = return x .- y
backward(::BroadcastedOperator{typeof(-)}, x, y, g) = let
    println("SUB")
    tuple(g,-g)
end

backward (generic function with 12 methods)

## Exp

In [11]:
import Base: exp 
Base.Broadcast.broadcasted(exp, x::GraphNode) = BroadcastedOperator(exp, x)
forward(::BroadcastedOperator{typeof(exp)}, x) = return exp.(x)
backward(::BroadcastedOperator{typeof(exp)}, x, g) = let
    println("EXP")
    tuple(exp.(x) .* g)
end

backward (generic function with 13 methods)

## Softmax

In [12]:
Softmax(x::GraphNode) = BroadcastedOperator(Softmax, x)
forward(::BroadcastedOperator{typeof(Softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(Softmax)}, x, g) = let
    println("SOFTMAX")
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 14 methods)

## Division

In [13]:
Base.Broadcast.broadcasted(/, x::GraphNode, y::GraphNode) = BroadcastedOperator(/, x, y)
forward(::BroadcastedOperator{typeof(/)}, x, y) = return x ./ y
backward(node::BroadcastedOperator{typeof(/)}, x, y::Real, g) = let
    println("DIV")
    ùüè = ones(Float32, length(node.output))
    Jx = diagm(ùüè ./ y)
    Jy = (-x ./ y .^2)
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 14 methods)

# Net

In [14]:
INPUT_SIZE = 196
HIDDEN_SIZE = 64
OUTPUT_SIZE = 10

STEP_COUNT = 4

4

In [15]:
using LinearAlgebra

bound = 1/sqrt(HIDDEN_SIZE)  # read somewhere that this is a good way to init weights for tanh

Wi = Variable(bound .* rand(INPUT_SIZE, HIDDEN_SIZE), name="wi")
Wh = Variable(bound .* rand(HIDDEN_SIZE, HIDDEN_SIZE), name="wh")
Wo = Variable(randn(HIDDEN_SIZE, OUTPUT_SIZE), name="wo")

var wo
 ‚î£‚îÅ ^ 64√ó10 Matrix{Float64}
 ‚îó‚îÅ ‚àá Nothing

In [77]:
function cross_entropy_loss(prediction, label)
    return sum(Constant(-1) .* label .* log.(prediction))
end

cross_entropy_loss (generic function with 1 method)

In [78]:
function net(sample, input_weights, hidden_weights, output_weights, label)
    i_1 = Variable(transpose(sample[1:196]), name="first_step_input")
    i_2 = Variable(transpose(sample[197:392]), name="second_step_input")
    i_3 = Variable(transpose(sample[393:588]), name="third_step_input")
    i_4 = Variable(transpose(sample[589:784]), name="fourth_step_input")

    s_1 = tanh(i_1 * Wi)
    s_1.name = "s_1"
    s_2 = tanh(i_2 * Wi .+ s_1 * Wh)
    s_2.name = "s_2"
    s_3 = tanh(i_3 * Wi .+ s_2 * Wh)
    s_3.name = "s_3"
    s_4 = tanh(i_4 * Wi .+ s_3 * Wh)
    s_4.name = "s_4"
    prediction = Softmax(s_4 * Wo)
    prediction.name = "prediction"

    E = cross_entropy_loss(prediction, Variable(label'))
    E.name = "loss"

    return topological_sort(E)
end

for (s, l) in test_loader
    graph = net(s, Wi, Wh, Wo, l)
    forward!(graph)
    backward!(graph)
end

xy MUL x:


-1

xy MUL y:


1√ó10 adjoint(OneHotMatrix(::Vector{UInt32})) with eltype Bool:
 ‚ãÖ  ‚ãÖ  1  ‚ãÖ  ‚ãÖ  ‚ãÖ  ‚ãÖ  ‚ãÖ  ‚ãÖ  ‚ãÖ

Ax MUL A:


1√ó196 transpose(::Vector{Float32}) with eltype Float32:
 0.0  0.0  0.0  0.0  0.992157  0.988235  ‚Ä¶  0.0  0.0  0.0  0.0  0.0  0.0  0.0

Ax MUL x:


196√ó64 Matrix{Float64}:
 0.0304327   0.0690618  0.0506888   ‚Ä¶  0.0641123   0.00242877  0.0850739
 0.113224    0.109251   0.0159016      0.0843462   0.122879    0.0960529
 0.0337808   0.0842965  0.0394637      0.0357925   0.107245    0.0620308
 0.0252732   0.100767   0.0598819      0.00541035  0.0189774   0.118941
 0.00803884  0.0987203  0.0437645      0.0399419   0.104393    0.0497051
 0.0296385   0.108659   0.115907    ‚Ä¶  0.0819186   0.0511352   0.123874
 0.039099    0.0745427  0.0463541      0.0314367   0.0103932   0.0393173
 0.0483966   0.0595271  0.00622421     0.112435    0.119857    0.0902706
 0.0450593   0.110979   0.0813287      0.0914159   0.034374    0.106267
 0.0742788   0.0459861  0.0315831      0.0872023   0.105413    0.0654476
 0.00077064  0.0623376  0.108166    ‚Ä¶  0.0640638   0.0183233   0.0216957
 0.0555468   0.0711175  0.0848264      0.0622482   0.0466525   0.123737
 0.100259    0.0283314  0.0144103      0.0972441   0.0957441   0.105552
 ‚ãÆ                     

Ax MUL A:


1√ó196 transpose(::Vector{Float32}) with eltype Float32:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ‚Ä¶  0.278431  0.0  0.0  0.0  0.0  0.0

Ax MUL x:


196√ó64 Matrix{Float64}:
 0.0304327   0.0690618  0.0506888   ‚Ä¶  0.0641123   0.00242877  0.0850739
 0.113224    0.109251   0.0159016      0.0843462   0.122879    0.0960529
 0.0337808   0.0842965  0.0394637      0.0357925   0.107245    0.0620308
 0.0252732   0.100767   0.0598819      0.00541035  0.0189774   0.118941
 0.00803884  0.0987203  0.0437645      0.0399419   0.104393    0.0497051
 0.0296385   0.108659   0.115907    ‚Ä¶  0.0819186   0.0511352   0.123874
 0.039099    0.0745427  0.0463541      0.0314367   0.0103932   0.0393173
 0.0483966   0.0595271  0.00622421     0.112435    0.119857    0.0902706
 0.0450593   0.110979   0.0813287      0.0914159   0.034374    0.106267
 0.0742788   0.0459861  0.0315831      0.0872023   0.105413    0.0654476
 0.00077064  0.0623376  0.108166    ‚Ä¶  0.0640638   0.0183233   0.0216957
 0.0555468   0.0711175  0.0848264      0.0622482   0.0466525   0.123737
 0.100259    0.0283314  0.0144103      0.0972441   0.0957441   0.105552
 ‚ãÆ                     

Ax MUL A:


1√ó196 transpose(::Vector{Float32}) with eltype Float32:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ‚Ä¶  0.0  0.0  0.0  0.0  0.0  0.0  0.0

Ax MUL x:


196√ó64 Matrix{Float64}:
 0.0304327   0.0690618  0.0506888   ‚Ä¶  0.0641123   0.00242877  0.0850739
 0.113224    0.109251   0.0159016      0.0843462   0.122879    0.0960529
 0.0337808   0.0842965  0.0394637      0.0357925   0.107245    0.0620308
 0.0252732   0.100767   0.0598819      0.00541035  0.0189774   0.118941
 0.00803884  0.0987203  0.0437645      0.0399419   0.104393    0.0497051
 0.0296385   0.108659   0.115907    ‚Ä¶  0.0819186   0.0511352   0.123874
 0.039099    0.0745427  0.0463541      0.0314367   0.0103932   0.0393173
 0.0483966   0.0595271  0.00622421     0.112435    0.119857    0.0902706
 0.0450593   0.110979   0.0813287      0.0914159   0.034374    0.106267
 0.0742788   0.0459861  0.0315831      0.0872023   0.105413    0.0654476
 0.00077064  0.0623376  0.108166    ‚Ä¶  0.0640638   0.0183233   0.0216957
 0.0555468   0.0711175  0.0848264      0.0622482   0.0466525   0.123737
 0.100259    0.0283314  0.0144103      0.0972441   0.0957441   0.105552
 ‚ãÆ                     

Ax MUL A:


1√ó196 transpose(::Vector{Float32}) with eltype Float32:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ‚Ä¶  0.0  0.0  0.0  0.0  0.0  0.0

Ax MUL x:


196√ó64 Matrix{Float64}:
 0.0304327   0.0690618  0.0506888   ‚Ä¶  0.0641123   0.00242877  0.0850739
 0.113224    0.109251   0.0159016      0.0843462   0.122879    0.0960529
 0.0337808   0.0842965  0.0394637      0.0357925   0.107245    0.0620308
 0.0252732   0.100767   0.0598819      0.00541035  0.0189774   0.118941
 0.00803884  0.0987203  0.0437645      0.0399419   0.104393    0.0497051
 0.0296385   0.108659   0.115907    ‚Ä¶  0.0819186   0.0511352   0.123874
 0.039099    0.0745427  0.0463541      0.0314367   0.0103932   0.0393173
 0.0483966   0.0595271  0.00622421     0.112435    0.119857    0.0902706
 0.0450593   0.110979   0.0813287      0.0914159   0.034374    0.106267
 0.0742788   0.0459861  0.0315831      0.0872023   0.105413    0.0654476
 0.00077064  0.0623376  0.108166    ‚Ä¶  0.0640638   0.0183233   0.0216957
 0.0555468   0.0711175  0.0848264      0.0622482   0.0466525   0.123737
 0.100259    0.0283314  0.0144103      0.0972441   0.0957441   0.105552
 ‚ãÆ                     

Ax MUL A:


1√ó64 Matrix{Float64}:
 0.797119  0.824287  0.883752  0.805685  ‚Ä¶  0.828911  0.874194  0.826445

Ax MUL x:


64√ó64 Matrix{Float64}:
 0.0639274   0.0905363   0.0706276   ‚Ä¶  0.0747666   0.0183594   0.10817
 0.0889023   0.0231719   0.0746507      0.0340268   0.0568271   0.107503
 0.0846382   0.0197928   0.0686405      0.0940777   0.0320004   0.0910258
 0.0565107   0.114285    0.00957795     0.114657    0.0770781   0.00440119
 0.0713441   0.0298513   0.0335321      0.0501846   0.0517012   0.0790391
 0.0762887   0.114771    0.0925048   ‚Ä¶  0.0812257   0.0555651   0.0963269
 0.0842594   0.00593419  0.0903976      0.0204095   0.0419067   0.0601329
 0.0706046   0.0397648   0.0901694      0.037765    0.116477    0.0322432
 0.0766536   0.107443    0.0902471      0.0616452   0.0534763   0.00285245
 0.0871677   0.0383124   0.0233348      0.0841025   0.00986813  0.0625841
 0.00729231  0.0121733   0.0519198   ‚Ä¶  0.0622069   0.0158511   0.116336
 0.103651    0.0542458   0.079765       0.0326171   0.0694933   0.0126675
 0.104624    0.00857396  0.092909       0.0716568   0.122432    0.046267
 ‚ãÆ       

Ax MUL A:


1√ó64 Matrix{Float64}:
 0.999996  0.999989  0.999999  0.999998  ‚Ä¶  0.999995  0.999995  0.999994

Ax MUL x:


64√ó64 Matrix{Float64}:
 0.0639274   0.0905363   0.0706276   ‚Ä¶  0.0747666   0.0183594   0.10817
 0.0889023   0.0231719   0.0746507      0.0340268   0.0568271   0.107503
 0.0846382   0.0197928   0.0686405      0.0940777   0.0320004   0.0910258
 0.0565107   0.114285    0.00957795     0.114657    0.0770781   0.00440119
 0.0713441   0.0298513   0.0335321      0.0501846   0.0517012   0.0790391
 0.0762887   0.114771    0.0925048   ‚Ä¶  0.0812257   0.0555651   0.0963269
 0.0842594   0.00593419  0.0903976      0.0204095   0.0419067   0.0601329
 0.0706046   0.0397648   0.0901694      0.037765    0.116477    0.0322432
 0.0766536   0.107443    0.0902471      0.0616452   0.0534763   0.00285245
 0.0871677   0.0383124   0.0233348      0.0841025   0.00986813  0.0625841
 0.00729231  0.0121733   0.0519198   ‚Ä¶  0.0622069   0.0158511   0.116336
 0.103651    0.0542458   0.079765       0.0326171   0.0694933   0.0126675
 0.104624    0.00857396  0.092909       0.0716568   0.122432    0.046267
 ‚ãÆ       

Ax MUL A:


1√ó64 Matrix{Float64}:
 1.0  0.999999  1.0  1.0  1.0  1.0  1.0  ‚Ä¶  0.999999  1.0  1.0  1.0  1.0  1.0

Ax MUL x:


64√ó64 Matrix{Float64}:
 0.0639274   0.0905363   0.0706276   ‚Ä¶  0.0747666   0.0183594   0.10817
 0.0889023   0.0231719   0.0746507      0.0340268   0.0568271   0.107503
 0.0846382   0.0197928   0.0686405      0.0940777   0.0320004   0.0910258
 0.0565107   0.114285    0.00957795     0.114657    0.0770781   0.00440119
 0.0713441   0.0298513   0.0335321      0.0501846   0.0517012   0.0790391
 0.0762887   0.114771    0.0925048   ‚Ä¶  0.0812257   0.0555651   0.0963269
 0.0842594   0.00593419  0.0903976      0.0204095   0.0419067   0.0601329
 0.0706046   0.0397648   0.0901694      0.037765    0.116477    0.0322432
 0.0766536   0.107443    0.0902471      0.0616452   0.0534763   0.00285245
 0.0871677   0.0383124   0.0233348      0.0841025   0.00986813  0.0625841
 0.00729231  0.0121733   0.0519198   ‚Ä¶  0.0622069   0.0158511   0.116336
 0.103651    0.0542458   0.079765       0.0326171   0.0694933   0.0126675
 0.104624    0.00857396  0.092909       0.0716568   0.122432    0.046267
 ‚ãÆ       

Ax MUL A:


1√ó64 Matrix{Float64}:
 0.999968  0.999924  0.999984  0.999976  ‚Ä¶  0.999949  0.999925  0.999975

Ax MUL x:


64√ó10 Matrix{Float64}:
  1.15144     -0.40001     0.232525   ‚Ä¶  -0.0892798   1.57632   -0.495668
  0.049746    -0.887612   -1.32883        1.11027     1.24757    1.0812
 -0.612986     0.228951    0.12828        1.94442    -1.28154   -0.0323962
  0.228032     0.398862    0.7913         0.0390806   0.120643   1.70973
 -1.24985     -1.12473    -1.85214       -0.334245   -0.294998  -0.955993
  0.442234    -0.111571    0.223134   ‚Ä¶  -0.470594    0.572351  -0.399368
  0.00259773   0.202605   -1.09072        1.29724    -1.33476    0.848125
  0.910998     0.427703   -0.648541      -0.561135    0.281861  -2.71355
 -1.36184      0.339811    0.309479       2.17584    -1.50443   -0.622939
 -0.366944    -0.745275   -0.564239      -1.82825     0.181382   0.976081
 -0.98292     -0.650674    2.59853    ‚Ä¶  -1.65074     0.684271   1.76332
  0.82307     -0.318352    1.17192       -0.576592   -0.331817   0.732123
  2.28221      0.752609    0.760229      -0.197305    1.16622    1.62622
  ‚ãÆ        

xy MUL x:


1√ó10 Matrix{Int64}:
 0  0  -1  0  0  0  0  0  0  0

xy MUL y:


1√ó10 Matrix{Float64}:
 -14.6188  -15.4004  -3.57505  -10.3034  ‚Ä¶  -9.56071  -15.1188  -0.894974

LoadError: DimensionMismatch: matrix A has dimensions (1,10), matrix B has dimensions (1,10)

In [89]:
forward!(topological_sort(Variable([1 2]) * Variable([1 2])))

LoadError: DimensionMismatch: matrix A has dimensions (1,2), matrix B has dimensions (1,2)