# Recurrent Neural Networks

To learn how to work with **Recurrent Neural Networks** we will build the ([<b>Character-Level Language Model</b>](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)). Our goal is to train the conditional probability model, which will predict the next character in the sequence given the previous elements: 

$$P(c_k|\{c_1,c_2,\dots,c_{k-1}\})$$


We will work on the dataset containing all William Shakespeare playwrights.

[![](https://upload.wikimedia.org/wikipedia/commons/a/a2/Shakespeare.jpg)](https://en.wikipedia.org/wiki/William_Shakespeare)

>Tomorrow, and tomorrow, and tomorrow,
Creeps in this petty pace from day to day,
To the last syllable of recorded time;
And all our yesterdays have lighted fools
The way to dusty death. Out, out, brief candle!
Life's but a walking shadow, a poor player
That struts and frets his hour upon the stage,
And then is heard no more. It is a tale
Told by an idiot, full of sound and fury,
Signifying nothing.

    Macbeth, Act V, scene v.

We will build a model, that generates a new playwright (or at least its snippet) based on previously generated letters. But before we start implementing code, let us talk about the theory:

### Recurrent neural networks

Compared to the neural network architectures we previously discussed, RNN networks are directed cyclic graphs. It means that the data could flow in the network not only in one direction (forward) but also can be propagated through neurons in the same layer:

[![](http://karpathy.github.io/assets/rnn/diags.jpeg)](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)

As a result, RNNs are especially useful for building **n-gram** language models: 

[![](http://karpathy.github.io/assets/rnn/charseq.jpeg)](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)

### Long short-term memory

The biggest issue with a basic form of Recurrent Neural Network architecture is the problem of vanishing information, which is especially visible in the long sequences, where meaningful pieces of information are often separated by long chains of less impactful data. Basic RNNs can learn the relationships easily only when intertwined elements are close in the chain:

[![](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-shorttermdepdencies.png)](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)

But when the gap between them is large, then the relation can get lost in the noise:

[![](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-longtermdependencies.png)](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)

To avoid that we must redefine the model. A good solution is to use **Long-Short Term Memory** (LSTM) layers, which can control the flow of information and filter the data:

[![](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png)](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)

### Alternatives

Instead of LSTMs and their [modifications](https://en.wikipedia.org/wiki/Long_short-term_memory) one might use other kinds of RNN layers, e.g.  <b>Gated Recurrent Unit<b> (GRU) networks:
    
[![](https://upload.wikimedia.org/wikipedia/commons/5/5f/Gated_Recurrent_Unit.svg)](https://en.wikipedia.org/wiki/Gated_recurrent_unit)

or more case-specific models, e.g. layers designed specifically for modelling [time-series.](https://github.com/sdobber/FluxArchitectures)

### Implementation

In [1]:
using Random
using Flux
using Flux: onehot, onehotbatch, argmax, chunk, batchseq, crossentropy
using StatsBase: wsample
using Base.Iterators: partition
using BSON
using JLD2
using CUDA

In [2]:
use_cuda = true

true

In [3]:
 if use_cuda && CUDA.functional()
    device = gpu
    @info "Training on GPU"
else
    device = cpu
    @info "Training on CPU"
end

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTraining on GPU


Firstly, we will collect and prepare the dataset:

In [4]:
isfile("shakespeare.txt") ||
        download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt","shakespeare.txt")

true

In [5]:
text = collect(read("shakespeare.txt",String));
alphabet = [unique(text)..., '_'];
stop = '_'

'_': ASCII/Unicode U+005F (category Pc: Punctuation, connector)

In [6]:
N = length(alphabet);
seqlen = 100;
batch_size = 64;
epochs = 20

20

In [7]:
#split text into parts of size batch_size:
X = [collect(t) for t in chunk(text, batch_size)]
Y = [collect(t) for t in chunk(text[2:end], batch_size)]
#match corresponding elements of each chunk from previous step:
X = partition(batchseq(X, stop), seqlen)
Y = partition(batchseq(Y, stop), seqlen)
#collect batches of data:
X = [Flux.onehotbatch.(b, (alphabet,)) for b in X]
Y = [Flux.onehotbatch.(b, (alphabet,)) for b in Y];

Train/test split:

In [8]:
perm = shuffle(1:length(X))
split = floor(Int, 0.95 * length(X))

trainX, trainY = X[perm[1:split]], Y[perm[1:split]] |> device
testX,  testY =  X[perm[(split+1):end]], Y[perm[(split+1):end]] |> device;

Model definition:

In [9]:
m = Chain(
    Embedding(N, 256),
    LSTM(256, 1024),
    Dense(1024, N),
    softmax) |> device

function loss(model, xs, ys, ϵ = 1.0f-8)
    Flux.reset!(m)
    l = sum(crossentropy.(broadcast(x -> model(x) .+ ϵ, xs), ys))
    return l
end


loss (generic function with 2 methods)

In [10]:
@time sum(loss.(Ref(m), testX, testY)) / (batch_size * seqlen * length(testX))

 31.927109 seconds (46.40 M allocations: 1.903 GiB, 1.87% gc time, 75.08% compilation time: 1% of which was recompilation)


0.06648952f0

In [11]:
opt = Adam(0.001)
opt_state = Flux.setup(opt, m);

Sampling function:

In [12]:
function sample(m, alphabet, len)
    model = cpu(m)
    Flux.reset!(model)
    buf = IOBuffer()
    c = rand(alphabet)
    for i = 1:len
        write(buf, c)
        c = wsample(alphabet, model(onehot(c, alphabet)))
      end
    return String(take!(buf))
end

sample (generic function with 1 method)

In [13]:
sample(m, alphabet, 50)

"yfr\nj'rfjjTg-tFOh_L3KJg[Xk3M_jIijag_tUM]jtvbBc?]bG"

Training time!

In [14]:
@info("Beginning training loop...")
best_ls = Inf
last_improvement = 0
for epoch = 1:epochs
    @info "Epoch: $epoch"
    global best_ls, last_improvement
    @info sample(m, alphabet, 100)
    Flux.train!(loss, m, zip(trainX, trainY), opt_state)
    ls = sum(loss.(Ref(m), testX, testY)) / (batch_size * seqlen * length(testX))
    @show ls
    if ls <= best_ls      
        @info "New best result: $ls"
        char_model = cpu(Flux.state(m)) 
        BSON.@save "char_model.bson" char_model
        jldsave("char_model.jld2"; char_model)
        best_ls = ls
        last_improvement = epoch
    end
    if epoch - last_improvement >= 5
        @warn(" -> We're calling this converged.")
        break
    end
end

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mBeginning training loop...
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 1
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mKlh]Tk],.ObkgAV cgs,
[36m[1m│ [22m[39mRV
[36m[1m│ [22m[39myh&eu!d&FVJWazjp_Xt
[36m[1m│ [22m[39mXyxDVdC-.
[36m[1m└ [22m[39m!WZ]f,IMc zWE-&uudwT$L[[LRHBe$nW_k-$OeJo$MJRbv


ls = 0.036813036f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.036813036
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 2
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mXnnnnnnnnnnn nnnngngngd ncken ne nn  gns wne in n knos; no fu de me nue.
[36m[1m│ [22m[39m
[36m[1m└ [22m[39mge enin be itid no dfd me 


ls = 0.034421377f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.034421377
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 3
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m_!QCxJJpJQ[M&AxR;VKG3C]AqcGE3Mf_V;eeeeeeeieeeeeeeeeeeeeeeeeeeeeeeieeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee


ls = 0.03215535f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.03215535
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 4
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m]SJgB&Cq?FD&RL]LxRmTVqQccLg]t.t you sle dist 's nihon Puffut thy kes us WARANIUS:
[36m[1m└ [22m[39mWick oy deathy mot


ls = 0.03226347f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 5
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mc.
[36m[1m│ [22m[39mr th th me, the sfart.
[36m[1m│ [22m[39m
[36m[1m│ [22m[39mROTOBENZO:
[36m[1m│ [22m[39mWhich: ye, shall wn liffepring consse gersigngg.
[36m[1m│ [22m[39mUS:
[36m[1m└ [22m[39mI withee,


ls = 0.030669462f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.030669462
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 6
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mccr, oas! poonle we impry-wason alosampon Bil swer;
[36m[1m└ [22m[39mHearn, it do lord mier, my rorane; or onf's rest


ls = 0.03070914f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 7
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mQCxFDwNwQLqGcoll, sp.
[36m[1m│ [22m[39m
[36m[1m│ [22m[39mFORIO:
[36m[1m│ [22m[39mI ver ! rindss Ant they lich cullly u certlan,
[36m[1m│ [22m[39mWhrLARnge,
[36m[1m└ [22m[39mCoart me ton


ls = 0.030402586f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.030402586
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 8
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mH,'t es bese, he jounansose ok
[36m[1m│ [22m[39mA no; God! which h me, lear be vart peice of hey handeediAft, I
[36m[1m└ [22m[39mse he


ls = 0.029987825f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.029987825
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 9
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m'rrst the defUS:
[36m[1m└ [22m[39mO, and HeC nevery I had Speome the thy may to he lour at of hight om! mades it in a


ls = 0.03059442f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 10
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m.
[36m[1m│ [22m[39m
[36m[1m└ [22m[39mHow good is thand, be to you her, my chose thou sice, and w nigh'd as operched I at shou ofna par


ls = 0.030644182f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 11
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mhaiscas can is I with say, disI knone a a till IOLANUSIUS:
[36m[1m│ [22m[39mGrodshe hadfights.
[36m[1m│ [22m[39m
[36m[1m│ [22m[39mSROSIUMNIN NARESBO:
[36m[1m└ [22m[39mO


ls = 0.0299684f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mNew best result: 0.0299684
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 12
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mR,
[36m[1m│ [22m[39mn's les, sing thus,
[36m[1m│ [22m[39mAt ey, an will e'ee thou in wild; the: mothe prove ow
[36m[1m│ [22m[39me
[36m[1m└ [22m[39mMot't the do the dius


ls = 0.030566663f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 13
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mLE,, b gue blow be the fe, led me ? 'S FRILINs
[36m[1m└ [22m[39mI shall five to kne trutbednglak you, sight the feith


ls = 0.031807307f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 14
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m]K
[36m[1m│ [22m[39mLoetehere hearm,
[36m[1m│ [22m[39mThere andeed:
[36m[1m│ [22m[39mI in thee ifINUNIUS:
[36m[1m└ [22m[39mGood doughereyal theieapiet, and inideeeeineo


ls = 0.030466527f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 15
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39m
[36m[1m│ [22m[39mBy sure sir, them
[36m[1m│ [22m[39mth;
[36m[1m│ [22m[39m
[36m[1m│ [22m[39mBIRONY:
[36m[1m└ [22m[39mO, lder wao had supon of man shall then pdiust menceize not dry sfis


ls = 0.03131429f0


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mEpoch: 16
[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mKcouroesperinessiterive.
[36m[1m│ [22m[39mYet,
[36m[1m│ [22m[39mHad noteadld bid, to the good beone.
[36m[1m│ [22m[39m
[36m[1m│ [22m[39mPCINIUS:
[36m[1m│ [22m[39mWhiceral dayes;
[36m[1m│ [22m[39mA.
[36m[1m│ [22m[39m
[36m[1m└ [22m[39mPoe


ls = 0.030072836f0


[33m[1m└ [22m[39m[90m@ Main In[14]:20[39m


In [48]:
m = Chain(
    Embedding(N, 256),
    LSTM(256, 1024),
    Dense(1024, N),
    softmax) 

#BSON.@load "char_model.bson" char_model

#Flux.loadmodel!(m, char_model)

ps = JLD2.load("char_model.jld2", "char_model")

Flux.loadmodel!(m, ps)

LoadError: Encountered tied destination parameters with untied and mismatched sources.

In [43]:
ps.layers[5]

LoadError: BoundsError: attempt to access Tuple{@NamedTuple{weight::Matrix{Float32}}, @NamedTuple{cell::@NamedTuple{Wi::Matrix{Float32}, Wh::Matrix{Float32}, b::Vector{Float32}, state0::Tuple{Matrix{Float32}, Matrix{Float32}}}, state::Tuple{Matrix{Float32}, Matrix{Float32}}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, Tuple{}} at index [5]

In [18]:
print(sample(m, alphabet, 100))

KmYpDPRrVG]iXAF,IXv$Rg
yIDk3]bZe,Yn!3ieksMfb?XKA$fQK''W,fAufDu?
crqK l
v&xbWqajT]E.u;?wKCFA-X]AySqbg