# Character based RNN language model trained on 'The Complete Works of William Shakespeare'
Based on http://karpathy.github.io/2015/05/21/rnn-effectiveness

In [82]:
using Knet
BATCHSIZE = 256
SEQLENGTH = 100
INPUTSIZE = 168
VOCABSIZE = 84
HIDDENSIZE = 334
NUMLAYERS = 1
DROPOUT = 0.0
LR=0.001
BETA_1=0.9
BETA_2=0.999
EPS=1e-08
EPOCHS = 50;

In [83]:
# Load 'The Complete Works of William Shakespeare'
include(Pkg.dir("Knet","data","gutenberg.jl"))
trn,tst,chars = shakespeare()
map(summary,(trn,tst,chars))

("4925284-element Array{UInt8,1}", "525665-element Array{UInt8,1}", "84-element Array{Char,1}")

In [84]:
# Print a sample
println(string(chars[trn[1020:1210]]...)) 

    Cheated of feature by dissembling nature,
    Deform'd, unfinish'd, sent before my time
    Into this breathing world scarce half made up,
    And that so lamely and unfashionable
   


In [85]:
# Minibatch data
function mb(a)
    N = div(length(a),BATCHSIZE)
    x = reshape(a[1:N*BATCHSIZE],N,BATCHSIZE)' # reshape full data to (B,N) with contiguous rows
    minibatch(x[:,1:N-1], x[:,2:N], SEQLENGTH) # split into (B,T) blocks 
end
dtrn,dtst = mb(trn),mb(tst)
map(length, (dtrn,dtst))

In [86]:
# Define model
function initmodel()
    w(d...)=KnetArray(xavier(Float32,d...))
    b(d...)=KnetArray(zeros(Float32,d...))
    r,wr = rnninit(INPUTSIZE,HIDDENSIZE,numLayers=NUMLAYERS,dropout=DROPOUT)
    wx = w(INPUTSIZE,VOCABSIZE)
    wy = w(VOCABSIZE,HIDDENSIZE)
    by = b(VOCABSIZE,1)
    r,(wr,wx,wy,by)
end;

In [87]:
# Define loss and its gradient
function predict(r,ws,xs,hx,cx)
    wr,wx,wy,by = ws
    x = wx[:,xs]                                    # xs=(B,T) x=(X,B,T)
    y,hy,cy = rnnforw(r,wr,x,hx,cx,hy=true,cy=true) # y=(H,B,T) hy=cy=(H,B,L)
    y2 = reshape(y,size(y,1),size(y,2)*size(y,3))   # y2=(H,B*T)
    return wy*y2.+by, hy, cy
end

function loss(r,w,x,y,h)
    py,hy,cy = predict(r,w,x,h...)
    h[1],h[2] = getval(hy),getval(cy)
    return nll(py,y)
end

lossgradient = gradloss(loss,2);

In [88]:
function train(rnn,weights,data,optim)
    hiddens = Any[nothing,nothing]
    losses = []
    for (x,y) in data
        grads,loss1 = lossgradient(rnn,weights,x,y,hiddens)
        update!(weights, grads, optim)
        push!(losses, loss1)
    end
    return mean(losses)
end

function test(rnn,weights,data)
    hiddens = Any[nothing,nothing]
    losses = []
    for (x,y) in data
        push!(losses, loss(rnn,weights,x,y,hiddens))
    end
    return mean(losses)
end; 

In [89]:
# Initialize model
rnn=weights=optim=nothing; knetgc()
rnn,weights = initmodel()
optim = optimizers(weights, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);

In [90]:
info("Training...")
@time for epoch in 1:EPOCHS
    @time trnloss = train(rnn,weights,dtrn,optim) # ~18 seconds
    @time tstloss = test(rnn,weights,dtst)        # ~0.5 seconds
    println((:epoch, epoch, :trnppl, exp(trnloss), :tstppl, exp(tstloss)))
end

[1m[36mINFO: [39m[22m[36mTraining...
[39m

 18.386971 seconds (234.39 k allocations: 131.457 MiB, 0.04% gc time)
  0.625695 seconds (38.05 k allocations: 10.695 MiB)
(:epoch, 1, :trnppl, 17.069067f0, :tstppl, 8.832048f0)
 18.843587 seconds (214.38 k allocations: 130.405 MiB, 0.04% gc time)
  0.564837 seconds (8.04 k allocations: 9.036 MiB, 0.16% gc time)
(:epoch, 2, :trnppl, 7.3479204f0, :tstppl, 6.5189676f0)
 18.204367 seconds (215.41 k allocations: 130.420 MiB, 0.03% gc time)
  0.573242 seconds (4.40 k allocations: 8.980 MiB)
(:epoch, 3, :trnppl, 5.9228806f0, :tstppl, 5.6123896f0)
 18.287730 seconds (216.25 k allocations: 130.433 MiB, 0.03% gc time)
  0.572192 seconds (7.45 k allocations: 9.027 MiB, 0.15% gc time)
(:epoch, 4, :trnppl, 5.2358246f0, :tstppl, 5.093068f0)
 18.305785 seconds (213.11 k allocations: 130.385 MiB, 0.03% gc time)
  0.570167 seconds (8.25 k allocations: 9.039 MiB, 0.16% gc time)
(:epoch, 5, :trnppl, 4.7901196f0, :tstppl, 4.7554827f0)
 18.321947 seconds (216.21 k allocations: 130.433 MiB, 0.03% gc time)


 18.375170 seconds (216.25 k allocations: 130.433 MiB, 0.03% gc time)
  0.576564 seconds (8.04 k allocations: 9.036 MiB, 0.18% gc time)
(:epoch, 46, :trnppl, 2.9528632f0, :tstppl, 3.2962208f0)
 18.360796 seconds (216.43 k allocations: 130.436 MiB, 0.03% gc time)
  0.578960 seconds (4.40 k allocations: 8.980 MiB)
(:epoch, 47, :trnppl, 2.9473572f0, :tstppl, 3.295634f0)
 18.382327 seconds (216.17 k allocations: 130.432 MiB, 0.04% gc time)
  0.580014 seconds (7.54 k allocations: 9.028 MiB, 0.19% gc time)
(:epoch, 48, :trnppl, 2.9417865f0, :tstppl, 3.300782f0)
 18.364350 seconds (213.11 k allocations: 130.385 MiB, 0.03% gc time)
  0.574310 seconds (8.23 k allocations: 9.039 MiB, 0.19% gc time)
(:epoch, 49, :trnppl, 2.9387174f0, :tstppl, 3.297684f0)
 18.370900 seconds (216.23 k allocations: 130.433 MiB, 0.03% gc time)
  0.577678 seconds (4.40 k allocations: 8.980 MiB)
(:epoch, 50, :trnppl, 2.9316425f0, :tstppl, 3.2906218f0)
948.078930 seconds (11.28 M allocations: 6.816 GiB, 0.03% gc time)


In [91]:
rnn

Knet.RNN(168, 334, 1, 0.0, 0, 0, 2, 0, Float32, Knet.RD(Ptr{Void} @0x000000001aef2f30), Knet.DD(Ptr{Void} @0x000000001c90e7d0, Knet.KnetArray{UInt8,1}(Knet.KnetPtr(Ptr{Void} @0x000000810cdc0000, 638976, 1, nothing), (638976,))), Knet.KnetArray{Float32,3}(Knet.KnetPtr(Ptr{Void} @0x000000834d940000, 17203200, 1, nothing), (168, 256, 100)), Knet.KnetArray{Float32,3}(Knet.KnetPtr(Ptr{Void} @0x00000083ecdf3800, 342016, 1, nothing), (334, 256, 1)), Knet.KnetArray{Float32,3}(Knet.KnetPtr(Ptr{Void} @0x00000082f9793800, 342016, 1, nothing), (334, 256, 1)))

In [92]:
using JLD
@save "shakespeare01.jld" weights

In [93]:
weights

(Knet.KnetArray{Float32,3}(Knet.KnetPtr(Ptr{Void} @0x000000810a480000, 2693376, 1, nothing), (1, 1, 673344)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x00000081053e8000, 56448, 1, nothing), (168, 84)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x0000008117aef800, 112224, 1, nothing), (84, 334)), Knet.KnetArray{Float32,2}(Knet.KnetPtr(Ptr{Void} @0x00000081052e0600, 336, 1, nothing), (84, 1)))