In [1]:
# (c) Deniz Yuret, 2018.
# Note that this is an instructional example written in low-level Julia/Knet and it is slow to train.
# For a faster and high-level implementation please see `@doc RNN`.
# TODO: check the 50% speed regression in julia 1.0.
using Pkg; haskey(Pkg.installed(),"Knet") || Pkg.add("Knet")
using Knet; @show Knet.gpu()

┌ Info: Recompiling stale cache file /data/scratch/deniz/.julia/compiled/v1.0/Knet/f4vSz.ji for Knet [1902f260-5fb4-5aff-8c31-6271790ab950]
└ @ Base loading.jl:1184


Knet.gpu() = 2


2

## A one layer MLP vs a simple RNN

([Elman 1990](http://onlinelibrary.wiley.com/doi/10.1207/s15516709cog1402_1/pdf)) A simple RNN takes the previous hidden state as an extra input, and returns the next hidden state as an extra output.

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-rolled.png" width="150" />
([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))

In [2]:
# Comparison of a single hidden layer MLP and corresponding RNN

function mlp1(param, input)
    hidden = tanh(input * param[1] .+ param[2])
    output = hidden * param[3] .+ param[4]
    return output
end

function rnn1(param, input, hidden)
    input2 = hcat(input, hidden)
    hidden = tanh(input2 * param[1] .+ param[2])
    output = hidden * param[3] .+ param[4]
    return (hidden, output)
end;

## Backpropagation through time (BPTT)

([Werbos, 1988](http://www.sciencedirect.com/science/article/pii/089360808890007X))
An RNN unrolled in time is similar to a deep feed-forward network which (i) has as many layers as time steps, (ii) has weights shared between different layers, and (iii) may have multiple inputs and outputs received and produced at individual layers. Backpropagation can be used to train RNNs.

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png" width=800 />
([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))

In [3]:
# Loss calculation and training.

function rnnloss(param,inputs,hidden,outputs)
    # inputs and outputs are sequences of the same length
    sumloss = 0
    for t in 1:length(inputs)
        output,hidden = rnn1(param,inputs[t],hidden)
        sumloss += loss_function(output,outputs[t])
    end
    return sumloss
end

rnngrad = grad(rnnloss);

# ... train with our usual SGD procedure

## Long Short-Term Memory (LSTM)
([Hochreiter and Schmidhuber, 1997](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf))
LSTM is a more sophisticated RNN module that performs better with long-range dependencies. 

<img src="http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png" width=800 />
([image source](http://colah.github.io/posts/2015-08-Understanding-LSTMs))



$$\begin{align}
f_t &= \sigma(W_f\cdot[h_{t-1},x_t] + b_f) & \text{forget gate} \\
i_t &= \sigma(W_i\cdot[h_{t-1},x_t] + b_i) & \text{input gate} \\
\tilde{C}_t &= \tanh(W_C\cdot[h_{t-1},x_t] + b_C) & \text{cell candidate} \\
C_t &= f_t \ast C_{t-1} + i_t \ast \tilde{C}_t & \text{new cell} \\
o_t &= \sigma(W_o\cdot[h_{t-1},x_t] + b_o) & \text{output gate} \\
h_t &= o_t \ast \tanh(C_t) & \text{new output}\\
\end{align}$$

In [4]:
# A LSTM implementation in Knet

function lstm(param, state, input)
    weight,bias = param
    hidden,cell = state
    h       = size(hidden,2)
    gates   = hcat(input,hidden) * weight .+ bias
    forget  = sigm.(gates[:,1:h])
    ingate  = sigm.(gates[:,1+h:2h])
    outgate = sigm.(gates[:,1+2h:3h])
    change  = tanh.(gates[:,1+3h:4h])
    cell    = cell .* forget + ingate .* change
    hidden  = outgate .* tanh.(cell)
    return (hidden,cell)
end;

## Sequence to sequence model (S2S)
([Sutskever et al. 2014](https://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf))
S2S models learn to map input sequences to output sequences using an encoder and a decoder RNN.

<img src="http://nzw0301.github.io/images/seq2seq.svg"/>

In [5]:
# S2S loss function and its gradient

function s2s(model, inputs, outputs)
    state = initstate(inputs[1], model[:state0])
    for input in inputs
        input = onehotrows(input, model[:embed1])
        input = input * model[:embed1]
        state = lstm(model[:encode], state, input)
    end
    EOS = eosmatrix(outputs[1], model[:embed2])
    input = EOS * model[:embed2]
    sumlogp = 0
    for output in outputs
        state = lstm(model[:decode], state, input)
        ypred = predict(model[:output], state[1])
        ygold = onehotrows(output, model[:embed2])
        sumlogp += sum(ygold .* logp(ypred,dims=2))
        input = ygold * model[:embed2]
    end
    state = lstm(model[:decode], state, input)
    ypred = predict(model[:output], state[1])
    sumlogp += sum(EOS .* logp(ypred,dims=2))
    return -sumlogp
end

s2sgrad = gradloss(s2s);

<img src="https://docs.google.com/drawings/d/1BR871g8k4jpI-mKeXiJfpY5Jl5cKcognvH7hHSugQds/pub?w=958&h=236"/>

In [6]:
# S2S model definition

function initmodel(H, V; atype=(gpu()>=0 ? KnetArray{Float32} : Array{Float32}))
    init(d...)=atype(xavier(d...))
    model = Dict{Symbol,Any}()
    model[:state0] = [ init(1,H), init(1,H) ]
    model[:embed1] = init(V,H)
    model[:encode] = [ init(2H,4H), init(1,4H) ]
    model[:embed2] = init(V,H)
    model[:decode] = [ init(2H,4H), init(1,4H) ]
    model[:output] = [ init(H,V), init(1,V) ]
    return model
end;

In [7]:
# S2S helper functions

function predict(param, input)
    input * param[1] .+ param[2]
end

function initstate(idx, state0)
    h,c = state0
    h = h .+ fill!(similar(value(h), length(idx), length(h)), 0)
    c = c .+ fill!(similar(value(c), length(idx), length(c)), 0)
    return (h,c)
end

function onehotrows(idx, embeddings)
    nrows,ncols = length(idx), size(embeddings,1)
    z = zeros(Float32,nrows,ncols)
    @inbounds for i=1:nrows
        z[i,idx[i]] = 1
    end
    oftype(value(embeddings),z)
end

let EOS=nothing; global eosmatrix
function eosmatrix(idx, embeddings)
    nrows,ncols = length(idx), size(embeddings,1)
    if EOS==nothing || size(EOS) != (nrows,ncols)
        EOS = zeros(Float32,nrows,ncols)
        EOS[:,1] .= 1
        EOS = oftype(value(embeddings), EOS)
    end
    return EOS
end
end;

In [8]:
# Use reversing English words as an example task
# This loads them from /usr/share/dict/words and converts each character to an int.

function readdata(file="/usr/share/dict/words")
    global strings = map(chomp,readlines(file))
    global tok2int = Dict{Char,Int}()
    global int2tok = Vector{Char}()
    push!(int2tok,'\n'); tok2int['\n']=1 # We use '\n'=>1 as the EOS token                                                 
    sequences = Vector{Vector{Int}}()
    for w in strings
        s = Vector{Int}()
        for c in collect(w)
            if !haskey(tok2int,c)
                push!(int2tok,c)
                tok2int[c] = length(int2tok)
            end
            push!(s, tok2int[c])
        end
        push!(sequences, s)
    end
    return sequences
end

sequences = readdata();
for x in (sequences, strings, int2tok, tok2int); println(summary(x)); end
for x in strings[501:505]; println(x); end

99171-element Array{Array{Int64,1},1}
99171-element Array{SubString{String},1}
70-element Array{Char,1}
Dict{Char,Int64} with 70 entries
Alvaro's
Alvin
Alvin's
Alyce
Alyce's


In [9]:
# Minibatch sequences putting equal length sequences together:

function minibatch(sequences, batchsize)
    table = Dict{Int,Vector{Vector{Int}}}()
    data = Any[]
    for s in sequences
        n = length(s)
        nsequences = get!(table, n, Any[])
        push!(nsequences, s)
        if length(nsequences) == batchsize
            push!(data, [[ nsequences[i][j] for i in 1:batchsize] for j in 1:n ])
            empty!(nsequences)
        end
    end
    return data
end

batchsize, statesize, vocabsize = 128, 128, length(int2tok)
data = minibatch(sequences,batchsize)
summary(data)

"766-element Array{Any,1}"

In [10]:
# Training loop

function train(model, data, opts)
    sumloss = cntloss = 0
    for sequence in data
        grads,loss = s2sgrad(model, sequence, reverse(sequence))
        update!(model, grads, opts)
        sumloss += loss
        cntloss += (1+length(sequence)) * length(sequence[1])
    end
    return sumloss/cntloss
end

train (generic function with 1 method)

In [11]:
model = opts = nothing; Knet.gc() # clean memory from previous run
if !isfile("rnnreverse.jld2")
    # Initialize model and optimization parameters
    model = initmodel(statesize,vocabsize)
    opts = optimizers(model,Adam)
    @time for epoch=1:10
        @time loss = train(model,data,opts) # 15s (0.6.4) 22s (1.0.0) /epoch
        println((epoch,loss))
    end
    Knet.save("rnnreverse.jld2","model",model)
else
    model = Knet.load("rnnreverse.jld2","model")
end
summary(model)

 33.763921 seconds (29.87 M allocations: 1.764 GiB, 3.86% gc time)
(1, 2.1116824f0)
 24.243649 seconds (17.49 M allocations: 1.127 GiB, 3.00% gc time)
(2, 1.1684585f0)
 22.244873 seconds (17.52 M allocations: 1.128 GiB, 3.02% gc time)
(3, 0.5343968f0)
 22.714329 seconds (17.52 M allocations: 1.129 GiB, 2.93% gc time)
(4, 0.2762618f0)
 22.857687 seconds (17.52 M allocations: 1.128 GiB, 2.88% gc time)
(5, 0.16793005f0)
 23.432127 seconds (17.49 M allocations: 1.127 GiB, 2.78% gc time)
(6, 0.112778656f0)
 23.819823 seconds (17.52 M allocations: 1.128 GiB, 2.92% gc time)
(7, 0.07800572f0)
 25.183425 seconds (17.52 M allocations: 1.128 GiB, 2.85% gc time)
(8, 0.061287124f0)
 23.019519 seconds (17.52 M allocations: 1.128 GiB, 2.88% gc time)
(9, 0.04776922f0)
 22.754002 seconds (17.49 M allocations: 1.127 GiB, 2.92% gc time)
(10, 0.030740755f0)
244.698817 seconds (188.41 M allocations: 11.960 GiB, 3.05% gc time)


"Dict{Symbol,Any} with 6 entries"

In [12]:
# Test on some examples:

function translate(model, str)
    state = model[:state0]
    for c in collect(str)
        input = onehotrows(tok2int[c], model[:embed1])
        input = input * model[:embed1]
        state = lstm(model[:encode], state, input)
    end
    input = eosmatrix(1, model[:embed2]) * model[:embed2]
    output = Char[]
    for i=1:100 #while true                                                                                                
        state = lstm(model[:decode], state, input)
        pred = predict(model[:output], state[1])
        i = argmax(vec(Array(pred)))
        i == 1 && break
        push!(output, int2tok[i])
        input = onehotrows(i, model[:embed2]) * model[:embed2]
    end
    String(output)
end

translate(model,"capricorn")

"nrocirpac"