In [2]:
using Flux

┌ Info: CUDAdrv.jl failed to initialize, GPU functionality unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)
└ @ CUDAdrv /home/edoardo/.julia/packages/CUDAdrv/mCr0O/src/CUDAdrv.jl:69


In [3]:
versioninfo()

Julia Version 1.3.0
Commit 46ce4d7933 (2019-11-26 06:09 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, skylake)


In [4]:
mutable struct MinimalRNNCell{T,H}
    kernel::T
    recurrent_kernel::T
    hidden::H
end

function MinimalRNNCell(input_shape, units::Integer)
    kernel = rand(input_shape[end], units)*0.1.-0.05
    recurrent_kernel = rand(units, units)*0.1.-0.05
    MinimalRNNCell(kernel, recurrent_kernel, zeros(units))
end

# Also checked with all harcoded types, as shown in the commented out function definition.
# function (c::MinimalRNNCell{Array{Float64,2},Array{Float64,1}})(hᵢ₋₁::Array{Float64,1}, x::Array{Float64,1})::Tuple{Array{Float64,1},Array{Float64,1}}

function (c::MinimalRNNCell{T,H})(hᵢ₋₁::AbstractArray, x::AbstractArray)::Tuple{H,H} where {T<:AbstractArray,H<:AbstractArray}
    output = (c.kernel * x) + (c.recurrent_kernel * hᵢ₋₁)
    c.hidden = output
    return output, output
end
Flux.hidden(c::MinimalRNNCell) = c.hidden

Flux.@functor MinimalRNNCell

In [5]:
function GenerateSample(length, output)
    inputs = randn(length, output) * 0.1
    outputs = zeros((length,output))
    for i in 1:length
        if i == 1
            outputs[i,:] = inputs[i,:]
        else
            outputs[i,:] = inputs[i,:] + inputs[i-1,:] 
        end
    end
    inputs, outputs
end

GenerateSample (generic function with 1 method)

In [6]:
train_x = zeros(50,10,3)
train_y = zeros(50,10,3)
for i in 1:50
   train_x[i,:,:], train_y[i,:,:] = GenerateSample(10,3) 
end

In [7]:
opt = Flux.ADAM()
function loss(model, x,y)
    total_loss = 0.
    for i in 1:size(x,1)
        for j in 1:size(x,2)
            @inbounds total_loss += Flux.mse(model(x[i,j,:]),y[i,j,:])
        end    
        Flux.reset!(model)
    end
    total_loss / size(x,1)
end
m = Chain(Flux.Recur(MinimalRNNCell((10,3),3)))

Chain(Recur(MinimalRNNCell{Array{Float64,2},Array{Float64,1}}([-0.027523667462084168 -0.010014997945074898 -0.048027565464752286; 0.016387530488057453 -0.024100254679517086 -0.030129950701895215; 0.037395467202281416 -0.03530250560058528 -0.0002345065028006993], [0.030655089277021694 0.0459878452842276 -0.018470383329286945; -0.04403496865022865 -0.015367590977248068 0.017898986329221933; 0.020266992968376835 -0.014033099536112315 0.04070625466102025], [0.0, 0.0, 0.0])))

In [8]:
# Set up hidden state
m(train_x[1,1,:])
@code_warntype m.layers[1].cell(m.layers[1].state, train_x[1,1,:])

Variables
  c[36m::MinimalRNNCell{Array{Float64,2},Array{Float64,1}}[39m
  hᵢ₋₁[36m::Array{Float64,1}[39m
  x[36m::Array{Float64,1}[39m
  output[36m::Array{Float64,1}[39m

Body[36m::Tuple{Array{Float64,1},Array{Float64,1}}[39m
[90m1 ─[39m %1  = Core.apply_type(Main.Array, Main.Float64, 1)[36m::Core.Compiler.Const(Array{Float64,1}, false)[39m
[90m│  [39m %2  = Core.apply_type(Main.Array, Main.Float64, 1)[36m::Core.Compiler.Const(Array{Float64,1}, false)[39m
[90m│  [39m %3  = Core.apply_type(Main.Tuple, %1, %2)[36m::Core.Compiler.Const(Tuple{Array{Float64,1},Array{Float64,1}}, false)[39m
[90m│  [39m %4  = Base.getproperty(c, :kernel)[36m::Array{Float64,2}[39m
[90m│  [39m %5  = (%4 * x)[36m::Array{Float64,1}[39m
[90m│  [39m %6  = Base.getproperty(c, :recurrent_kernel)[36m::Array{Float64,2}[39m
[90m│  [39m %7  = (%6 * hᵢ₋₁)[36m::Array{Float64,1}[39m
[90m│  [39m       (output = %5 + %7)
[90m│  [39m       Base.setproperty!(c, :hidden, output)
[90m│ 

In [9]:
function fit!(model, train_x, train_y, epochs; opt=Flux.ADAM(), ps=params(model) )
    Flux.reset!(model)
    for i in 1:epochs
        println("Epoch $i")
         start = time_ns()
        gs = gradient(ps) do
            loss(model, train_x, train_y)
        end
        Flux.Optimise.update!(opt, ps, gs)
        stop = time_ns()
        println("$((stop-start) / 1000 ) μs")
    end
end

fit! (generic function with 1 method)

In [10]:
fit!(m, train_x, train_y, 10)

Epoch 1
6.178677965e6 μs
Epoch 2
73667.7 μs
Epoch 3
64103.657 μs
Epoch 4
59882.043 μs
Epoch 5
59172.423 μs
Epoch 6
62505.42 μs
Epoch 7
53048.052 μs
Epoch 8
63016.18 μs
Epoch 9
55187.327 μs
Epoch 10
57850.695 μs


### Gradients and Profiling

In [20]:
using Profile
ps = params(m)
gs = gradient(ps) do
    total_loss = 0.
    for i in 1:size(train_x,1)
        for j in 1:size(train_x,2)
            @inbounds total_loss += Flux.mse(m(train_x[i,j,:]),train_y[i,j,:])
        end    
        Flux.reset!(m)
    end
    total_loss
end
gs.grads

IdDict{Any,Any} with 1006 entries:
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => Ref

In [11]:
@profile gradient(ps) do
    total_loss = 0.
    for i in 1:size(train_x,1)
        for j in 1:size(train_x,2)
            @inbounds total_loss += Flux.mse(m(train_x[i,j,:]),train_y[i,j,:])
        end    
        Flux.reset!(m)
    end
    total_loss
end
Profile.print()

362 ./task.jl:333; (::IJulia.var"#15#18")()
 362 ...F1GUo/src/eventloop.jl:8; eventloop(::ZMQ.Socket)
  362 ./essentials.jl:708; invokelatest
   362 ./essentials.jl:709; #invokelatest#1
    362 ...rc/execute_request.jl:67; execute_request(::ZMQ.Socket, ::I...
     362 ...c/SoftGlobalScope.jl:218; softscope_include_string(::Modu...
      362 ./boot.jl:330; eval
       184 ...mpiler/interface.jl:46; gradient(::Function, ::Zygote.P...
        184 ...mpiler/interface.jl:96; pullback(::Function, ::Zygote....
         60 ...mpiler/typeinfer.jl:605; typeinf_ext(::Core.MethodInsta...
          33 ...piler/typeinfer.jl:572; typeinf_ext(::Core.MethodInst...
           33 .../inferencestate.jl:113; Core.Compiler.InferenceState...
            33 ...iler/utilities.jl:103; retrieve_code_info
             33 ...iler/utilities.jl:92; get_staged(::Core.MethodInst...
              33 ./boot.jl:524; (::Core.GeneratedFunctionSt...
               33 ./none:0; #s3085#1700(::Any, ::Any, :...
                

                              1 ./generator.jl:47; iterate
                               1 ...c/utils.jl:126; (::MacroTools.var...
                                1 ...c/utils.jl:126; prewalk(::IRTool...
                             1 ./array.jl:675; collect_to!(::Arra...
                   2 .../src/ir/wrap.jl:198; #IR#23(::Bool, ::Bool, :...
                    2 ./operators.jl:854; |>
                     1 ...sses/passes.jl:163; prune!(::IRTools.Inner.IR)
                      1 ...rc/ir/utils.jl:52; prewalk!
                       1 ...rc/ir/utils.jl:40; map!(::Function, ::IRT...
                        1 ...c/ir/utils.jl:31; map!(::Function, ::IRT...
                         1 ...c/ir/utils.jl:25; map!(::IRTools.Inner....
                          1 ...ractarray.jl:2066; map!(::IRTools.Inne...
                           1 .../ir/utils.jl:25; (::IRTools.Inner.var...
                            1 .../ir/utils.jl:52; (::IRTools.Inner.va...
                             1 ...rc/utils

                     3 ...er/optimize.jl:169; optimize(::Core.Compile...
                      1 ...sair/driver.jl:112; run_passes(::Core.Code...
                       1 ...air/driver.jl:107; just_construct_ssa(::C...
                        1 ...r/slot2ssa.jl:792; construct_ssa!(::Core...
                         1 .../ssair/ir.jl:1309; iterate
                          1 .../ssair/ir.jl:1312; iterate
                      1 ...sair/driver.jl:116; run_passes(::Core.Code...
                       1 ...ir/inlining.jl:71; ssa_inlining_pass!
                        1 ...r/inlining.jl:1040; assemble_inline_todo!...
                         1 ./array.jl:0; countunionsplit(::Arra...
                      1 ...sair/driver.jl:121; run_passes(::Core.Code...
                       1 ...air/passes.jl:540; getfield_elim_pass!(::...
                        1 ...ir/queries.jl:86; is_known_call(::Expr, ...
                         1 ...ir/queries.jl:77; compact_exprtype
                          1 .

            1 ...ayers/recurrent.jl:56; reset!
             1 ...iler/interface2.jl:0; _pullback(::Zygote.Context, :...
              1 ./abstractarray.jl:1920; foreach
               1 ...ler/interface2.jl:0; _pullback(::Zygote.Context, ...
                1 ...yers/recurrent.jl:55; reset!
                 1 ...ler/interface2.jl:0; _pullback(::Zygote.Context,...
                  1 ./Base.jl:21; setproperty!
                   1 .../src/adjoint.jl:47; _pullback
                    1 .../src/lib/lib.jl:202; adjoint
                     1 ...src/lib/lib.jl:194; grad_mut(::Zygote.Conte...
                      1 ./abstractdict.jl:600; getindex
                       1 ...stractdict.jl:596; get
          1  ...rc/compiler/emit.jl:18; _push!
           1 ./array.jl:827; _growend!
          1  ...nssF/src/adjoint.jl:48; _pullback
       176 ...mpiler/interface.jl:47; gradient(::Function, ::Zygote.P...
        63 ...mpiler/typeinfer.jl:605; typeinf_ext(::Core.MethodInsta...
         63 ...mp

                            3 ./none:0; Type
                             3 .../ir/wrap.jl:186; #IR#19(::IRTools.I...
                              1 ...rc/ir/ir.jl:445; push!
                               1 ...c/ir/ir.jl:419; push!(::IRTools.I...
                                1 ...c/ir/ir.jl:398; applyex(::Functi...
                                 1 .../ir/ir.jl:396; applyex(::IRTool...
                                  1 ./boot.jl:223; Expr(::Any, ::V...
                              2 .../ir/wrap.jl:165; (::IRTools.Inner.W...
                               2 ...c/utils.jl:126; prewalk(::IRTools...
                                2 ...c/utils.jl:105; walk(::Expr, ::F...
                                 2 ...tarray.jl:2073; map(::Function,...
                            1 ./operators.jl:854; |>
                             1 ...s/passes.jl:201; ssa!(::IRTools.Inn...
                              1 ...s/passes.jl:194; (::IRTools.Inner.v...
                               1 ...c/util

                       1 ...rpretation.jl:937; abstract_eval(::Any, :...
                      13 ...rpretation.jl:949; abstract_eval(::Any, :...
                       13 ...rpretation.jl:879; abstract_eval_call(::...
                        13 ...rpretation.jl:636; abstract_call(::Any, ...
                         13 ...pretation.jl:850; abstract_call(::Any,...
                          12 ...pretation.jl:50; abstract_call_gf_by_...
                           12 ...flection.jl:838; _methods_by_ftype
                          1  ...pretation.jl:93; abstract_call_gf_by_...
                           1 ...pretation.jl:396; abstract_call_metho...
                            1 ...ypeinfer.jl:465; typeinf_edge(::Meth...
                   14 ...er/typeinfer.jl:33; typeinf(::Core.Compiler....
                    14 ...er/optimize.jl:169; optimize(::Core.Compile...
                     13 ...sair/driver.jl:116; run_passes(::Core.Code...
                      13 ...ir/inlining.jl:71; ssa_inli