In [6]:
using Knet, ArgParse, Dates, Printf, Random, MLDatasets

┌ Info: Precompiling ArgParse [c7e460c6-2fb9-53a9-8c5b-16f535851c63]
└ @ Base loading.jl:1192
┌ Info: Precompiling MLDatasets [eb30cadb-4394-5ae3-aed4-317e484a6458]
└ @ Base loading.jl:1192


In [7]:
function initweights(atype, hidden, words, labels, embed, winit=0.01)
    w = Array{Any}(undef,9)
    w[1] = winit*randn(3*hidden, embed)
    w[2] = zeros(3*hidden, 1)
    w[3] = winit*randn(3*hidden, 2*hidden)
    w[4] = zeros(3*hidden, 1)
    w[5] = winit*randn(hidden, hidden)
    w[6] = winit*randn(hidden, hidden)
    w[7] = ones(hidden,1)
    w[8] = winit*randn(labels, hidden)
    w[9] = winit*randn(embed, words)
    return map(i->convert(atype, i), w)
end

initweights (generic function with 2 methods)

In [8]:
function lstm(w,ind)
    x = w[end][:,ind]
    x = reshape(x, length(x), 1)
    hsize = size(x,1)
    gates = w[1] * x .+ w[2]
    i = sigm.(gates[1:hsize,:])
    o = sigm.(gates[1+hsize:2hsize,:])
    u = sigm.(gates[1+2hsize:3hsize,:])
    c = i .* u
    h = o .* tanh.(c)
    return (h,c)
end

lstm (generic function with 1 method)

In [9]:
function slstm(w,h1,h2,c1,c2)
    hsize = size(h1,1)
    h = vcat(h1,h2)
    gates = w[3] * h .+ w[4]
    i  = sigm.(gates[1:hsize,:])
    o  = sigm.(gates[1+hsize:2hsize,:])
    u  = sigm.(gates[1+2hsize:3hsize,:])
    f1 = sigm.(w[5] * h1 .+ w[7])
    f2 = sigm.(w[6] * h2 .+ w[7])
    c  = i .* u .+ f1 .* c1 .+ f2 .* c2
    h  = o .* tanh.(c)
    return (h,c)
end

slstm (generic function with 1 method)

In [10]:
let
    global traverse
    function traverse(w, tree)
        h,c,hs,ys = helper(w,tree, Any[], Any[])
        return hs,ys
    end

    function helper(w,t,hs,ys)
        h = c = nothing
        if length(t.children) == 1 && isleaf(t.children[1])
            l = t.children[1]
            h,c = lstm(w,l.data)
        elseif length(t.children) == 2
            t1,t2 = t.children[1], t.children[2]
            h1,c1,hs,ys = helper(w,t1,hs,ys)
            h2,c2,hs,ys = helper(w,t2,hs,ys)
            h,c = slstm(w,h1,h2,c1,c2)
        else
            error("invalid tree")
        end
        return (h,c,[hs...,h],[ys...,t.data])
    end
end

(::getfield(Main, Symbol("#helper#5"))) (generic function with 1 method)

In [11]:
# treenn loss function
function loss(w, tree, values=[])
    hs, ygold = traverse(w, tree)
    ypred = w[end-1] * hcat(hs...)
    len = length(ygold)
    (lossval, cnt) = nll(ypred,ygold; average=false)
    @assert len == cnt
    push!(values, lossval); push!(values, len)
    return lossval/len
end

loss (generic function with 2 methods)

In [12]:
# tag given input sentence
function predict(w,tree)
    total = 0
    hs, ys = traverse(w, tree)
    ypred = w[end-1] * hs[end]
    ypred = convert(Array{Float32}, ypred)[:]
    return (argmax(ypred),length(ys))
end

predict (generic function with 1 method)

In [14]:

lossgradient = grad(loss)

function train!(w,tree,opt)
    values = []
    gloss = lossgradient(w, tree, values)
    update!(w,gloss,opt)
    return (values...,)
end

train! (generic function with 1 method)

In [15]:
#load train dataset
x, y = PTBLM.traindata()

This program has requested access to the data dependency PTBLM.
which is not currently installed. It can be installed automatically, and you will not see this message again.

Dataset: Penn Treebank sentences for language modeling
Website: https://github.com/tomsercu/lstm

-----------------------------------------------------
Please be aware that this dataset is from a secondary
source. The provided interface by this package is not
as developed as those for other datasets. We would
welcome any contribution to provide this dataset in a
more mature manner.
------------------------------------------------------

The PTBLM dataset consists of Penn Treebank sentences
for language modeling, available from tomsercu/lstm.
The unknown words are replaced with <unk> so that the
total vocaburary size becomes 10000.

The files are available for download at the github
repository linked above. Note that using the data
responsibly and respecting copyright remains your
responsibility.



Do you want to 

(Array{String,1}[["aer", "banknote", "berlitz", "calloway", "centrust", "cluett", "fromstein", "gitano", "guterman", "hydro-quebec"  …  "nahb", "punts", "rake", "regatta", "rubens", "sim", "snack-food", "ssangyong", "swapo", "wachter"], ["pierre", "<unk>", "N", "years", "old", "will", "join", "the", "board", "as", "a", "nonexecutive", "director", "nov.", "N"], ["mr.", "<unk>", "is", "chairman", "of", "<unk>", "n.v.", "the", "dutch", "publishing", "group"], ["rudolph", "<unk>", "N", "years", "old", "and", "former", "chairman", "of", "consolidated"  …  "was", "named", "a", "nonexecutive", "director", "of", "this", "british", "industrial", "conglomerate"], ["a", "form", "of", "asbestos", "once", "used", "to", "make", "kent", "cigarette"  …  "exposed", "to", "it", "more", "than", "N", "years", "ago", "researchers", "reported"], ["the", "asbestos", "fiber", "<unk>", "is", "unusually", "<unk>", "once", "it", "enters"  …  "it", "causing", "symptoms", "that", "show", "up", "decades", "later", 

In [28]:
println(x[1])
print(y[1])

["aer", "banknote", "berlitz", "calloway", "centrust", "cluett", "fromstein", "gitano", "guterman", "hydro-quebec", "ipo", "kia", "memotec", "mlx", "nahb", "punts", "rake", "regatta", "rubens", "sim", "snack-food", "ssangyong", "swapo", "wachter"]
["banknote", "berlitz", "calloway", "centrust", "cluett", "fromstein", "gitano", "guterman", "hydro-quebec", "ipo", "kia", "memotec", "mlx", "nahb", "punts", "rake", "regatta", "rubens", "sim", "snack-food", "ssangyong", "swapo", "wachter", "<eos>"]