In [1]:
using Knet
Atype = gpu() >= 0 ? KnetArray : Array

Knet.KnetArray

In [2]:
include(Knet.dir("data","gutenberg.jl"))
trn,tst,chars = shakespeare()
# summary(x) returns a string with a brief description. 
# By default returns string(typeof(x)), e.g. Int64.
# mapping applies summary to other arguments to give a
# brief look to variables, doesn't change anything.
map(summary,(trn,tst,chars))

("4934845-element Array{UInt8,1}", "526731-element Array{UInt8,1}", "84-element Array{Char,1}")

In [3]:
# There are 84 unique characters in the data and they are mapped to UInt8 values in 1:84.
# The chars array can be used to recover the original text:
chars[trn][end-97: end - 90]

8-element Array{Char,1}:
 'a'
 'n'
 'd'
 ' '
 't'
 'i'
 'm'
 'e'

In [4]:
BATCHSIZE = 256
SEQLENGTH = 100

function minibatch_rnn(data)
    num_batch = div(length(data), BATCHSIZE)
    # reshape full data to (B,num_batch) with contiguous rows
    # 
    #x = reshape(data[1:num_batch * BATCHSIZE], num_batch, BATCHSIZE)'
    x = reshape(data[1:num_batch * BATCHSIZE], BATCHSIZE, num_batch)
    # split into (B,T) blocks
    # Meaning of : x[:,1:num_batch-1], x[:,2:num_batch]
    # Remember in char-rnn we were feeding next char as, sort of,
    # label to the data. We sample at t, then compare with t+1. Then
    # obtain loss.
    minibatch(x[:,1:num_batch-1], x[:,2:num_batch], SEQLENGTH) 
end

dtrain, dtest = minibatch_rnn(trn), minibatch_rnn(tst)
map(length, (dtrain, dtest))

(192, 20)

In [5]:
RNNTYPE = :lstm  # can be :lstm, :gru, :tanh, :relu
NUMLAYERS = 1    # number of RNN layers
INPUTSIZE = 168  # size of the input character embedding
HIDDENSIZE = 334 # size of the hidden layers
VOCABSIZE = 84   # number of unique characters in data

# d... means varargs, 0 or more arguments
# xavier directly passes varargs to rand in implementation
function initmodel()
    w(d...)=Atype(xavier(Float32,d...))
    b(d...)=Atype(zeros(Float32,d...))
    # r is rnn struct.
    # w: single weight array that includes all matrices and biases for the RNN
    r,wr = rnninit(INPUTSIZE,HIDDENSIZE,rnnType=RNNTYPE,numLayers=NUMLAYERS)
    # input embedding matrix
    wx = w(INPUTSIZE,VOCABSIZE)
    # wy: hidden state to output
    wy = w(VOCABSIZE,HIDDENSIZE)
    by = b(VOCABSIZE,1)
    return r,wr,wx,wy,by
end

initmodel (generic function with 1 method)

<img src="lstm.png">

In [6]:
function predict(ws, xs, hx, cx)
    # ws expected to have this structure
    r,wr,wx,wy,by = ws
    # below embeds chars to vectors, wx is char embedding weights
    x = wx[:,xs] # xs=(Batch,Time) x=(X,B,T)
    # y=(H,B,T) hy=cy=(H,B,L)
    y,hy,cy = rnnforw(r, wr, x, hx, cx, hy=true, cy=true)
    ys = by.+wy*reshape(y,size(y,1),size(y,2)*size(y,3)) # ys=(H,B*T)
    return ys, hy, cy
end

predict (generic function with 1 method)

In [7]:
function loss(w, x, y, h)
    # h[1]-> hx, h[2]-> cx
    py, hy, cy = predict(w, x, h...)
    # In order AutoGrad to work we need getval somehow..
    h[1], h[2] = getval(hy), getval(cy)
    return nll(py, y)
end

# reports loss and calculates grads
lossgradient = gradloss(loss)

(::gradfun) (generic function with 1 method)

In [8]:
function train(model, data, optim)
    # rnn forwards assumes zero vector when hidden is nothing
    hiddens = Any[nothing, nothing]
    losses = []
    for (x, y) in data
        grads, current_loss = lossgradient(model, x, y, hiddens)
        update!(model, grads, optim)
        push!(losses, current_loss)
    end
    return mean(losses)
end

function test(model, data)
    hiddens = Any[nothing, nothing]
    losses = []
    for (x, y) in data
        current_loss = loss(model, x, y, hiddens)
        push!(losses, current_loss)
    end
    return mean(losses)
end

test (generic function with 1 method)

In [9]:
# Let's train
EPOCHS = 20
model = initmodel()
optim = optimizers(model, Adam)
@time for epoch in 1:EPOCHS
    @time trnloss = train(model,dtrain,optim) # ~18 seconds
    @time tstloss = test(model,dtest)        # ~0.5 seconds
    println((:epoch, epoch, :trnppl, exp(trnloss), :tstppl, exp(tstloss)))
end

 11.302084 seconds (1.89 M allocations: 221.551 MiB, 6.35% gc time)
  0.529385 seconds (273.74 k allocations: 23.188 MiB, 1.02% gc time)
(:epoch, 1, :trnppl, 25.741215f0, :tstppl, 23.52745f0)
  6.928312 seconds (212.92 k allocations: 130.735 MiB, 9.79% gc time)
  0.218703 seconds (3.64 k allocations: 8.966 MiB)
(:epoch, 2, :trnppl, 24.428959f0, :tstppl, 23.523182f0)
  7.034200 seconds (213.68 k allocations: 130.746 MiB, 10.74% gc time)
  0.305022 seconds (5.59 k allocations: 8.995 MiB, 27.51% gc time)
(:epoch, 3, :trnppl, 24.41456f0, :tstppl, 23.513992f0)
  6.909430 seconds (212.74 k allocations: 130.732 MiB, 9.69% gc time)
  0.306212 seconds (5.97 k allocations: 9.001 MiB, 27.41% gc time)
(:epoch, 4, :trnppl, 24.403934f0, :tstppl, 23.502546f0)
  7.027001 seconds (214.80 k allocations: 130.763 MiB, 10.74% gc time)
  0.221075 seconds (3.64 k allocations: 8.966 MiB)
(:epoch, 5, :trnppl, 24.395208f0, :tstppl, 23.492418f0)
  7.067156 seconds (214.69 k allocations: 130.762 MiB, 10.71% gc ti

In [10]:
function generate(model, n)
    function sample(y)
        p, r = Atype(exp.(y - logsumexp(y))), rand()
        for j=1:length(p)
            (r -= p[j]) < 0 && return j
        end
    end
    h,c = nothing,nothing
    x = findfirst(chars,'\n')
    for i=1:n
        y,h,c = predict(model,[x],h,c)
        x = sample(y)
        print(chars[x])
    end
    println()
end

generate (generic function with 1 method)

__I've trained for a very short time, it's no surpise that it did not converged.__

In [None]:
generate(model,1000)

n t sgnhslvhi
iis ir oc oNfwe yd ft  
b
Ie
 r
eie 
 e spoeetoI r ne  eenuo.c .t  d
u?koo,u la h eoes  rso htnsee 
gdll
oI e .a
  tLoert lnypn AsaeA  iooDsa.tn'ymyli ,ennoiidlda msly eNWel blt[r heIu w
skl Jm
nh etrnesThthhdfoeeain R 

e y
]   ,wimyoasr
 rt euaR
r vsse aw!iooiehus vgh  eg
nleersmhA de   e  ya ,Mss BTfi n a
 :t e   o oow
t
l-he y,  ebh ,hsn   
Tiodht afiy,, eN.tdstisubohog  r o    c  a YtR s  eHh wei
d eohylohdem  euektr 
m ni emyahshah
  Tbetco Kihdhw'n lq


In [1]:
!git status

LoadError: [91msyntax: extra token "status" after end of expression[39m