In [82]:
using Flux
using Flux: onehot, chunk, batchseq, throttle, logitcrossentropy
using StatsBase: wsample
using Base.Iterators: partition
using Parameters: @with_kw

# Hyperparameter arguments 
@with_kw mutable struct Args
    lr::Float64 = 1e-2	# Learning rate
    seqlen::Int = 50	# Length of batchseqences
    nbatch::Int = 50	# number of batches text is divided into
    throttle::Int = 3	# Throttle timeout
end

function getdata(args; text=nothing)
    # an array of all unique characters
    alphabet = [unique(text)..., '_']
    
    text = map(ch -> onehot(ch, alphabet), text)
    stop = onehot('_', alphabet)

    N = length(alphabet)
    
    # Partitioning the data as sequence of batches, which are then collected as array of batches
    Xs = collect(partition(batchseq(chunk(text, args.nbatch), stop), args.seqlen))
    Ys = collect(partition(batchseq(chunk(text[2:end], args.nbatch), stop), args.seqlen))

    return Xs, Ys, N, alphabet
end

# Function to construct model
function build_model(N)
    return Chain(
            LSTM(N, 128),
            LSTM(128, 128),
            Dense(128, N))
end 

function train(text, nsteps; kws...)
    # Initialize the parameters
    args = Args(; kws...)
    
    # Get Data
    Xs, Ys, N, alphabet = getdata(args, text=text)

    # Constructing Model
    m = build_model(N)

    function loss(xs, ys)
      l = sum(logitcrossentropy.(m.(xs), ys))
      return l
    end
    
    ## Training
    opt = ADAM(args.lr)
    tx, ty = (Xs[5], Ys[5])
    evalcb = () -> @show loss(tx, ty)
    for i = 1:nsteps
        Flux.train!(loss, params(m), zip(Xs, Ys), opt, cb = throttle(evalcb, args.throttle))
    end
    return m, alphabet
end

# Sampling
function sample(m, alphabet, len; seed="")
    m = cpu(m)
    Flux.reset!(m)
    buf = IOBuffer()
    if seed == ""
        seed = string(rand(alphabet))
    end
    write(buf, seed)
    c = wsample(alphabet, softmax(m.(map(c -> onehot(c, alphabet), collect(seed)))[end]))
    for i = 1:len
        write(buf, c)
        c = wsample(alphabet, softmax(m(onehot(c, alphabet))))
    end
    return String(take!(buf))
end



sample (generic function with 1 method)

In [106]:
cd(@__DIR__)
isfile("male.txt") ||
    download("http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/male.txt","male.txt")
male=collect(string(readlines("male.txt",keep=true)[7:end]...))

isfile("female.txt") ||
    download("http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/female.txt","female.txt")
female=collect(string(readlines("female.txt",keep=true)[7:end]...))

names=vcat(male,female)
length(names)

55869

In [109]:
m, alphabet = train(names, 100)
sample(m, alphabet, 10000) |> println

loss(tx, ty) = 173.24821f0
loss(tx, ty) = 160.78125f0
loss(tx, ty) = 159.0167f0
loss(tx, ty) = 156.72714f0
loss(tx, ty) = 153.3975f0
loss(tx, ty) = 150.32512f0
loss(tx, ty) = 144.5955f0
loss(tx, ty) = 139.16145f0
loss(tx, ty) = 134.97668f0
loss(tx, ty) = 132.0291f0
loss(tx, ty) = 135.619f0
loss(tx, ty) = 136.07945f0
loss(tx, ty) = 130.47249f0
loss(tx, ty) = 130.97147f0
loss(tx, ty) = 130.43753f0
loss(tx, ty) = 129.66972f0
loss(tx, ty) = 132.54782f0
loss(tx, ty) = 130.30905f0
loss(tx, ty) = 130.32872f0
loss(tx, ty) = 134.45042f0
loss(tx, ty) = 136.48637f0
loss(tx, ty) = 130.51817f0
loss(tx, ty) = 128.08492f0
loss(tx, ty) = 127.74405f0
loss(tx, ty) = 128.09514f0
loss(tx, ty) = 130.36101f0
loss(tx, ty) = 128.12236f0
loss(tx, ty) = 128.42017f0
loss(tx, ty) = 128.3154f0
loss(tx, ty) = 127.52186f0
loss(tx, ty) = 125.74154f0
loss(tx, ty) = 127.91339f0
loss(tx, ty) = 126.26141f0
loss(tx, ty) = 128.44533f0
loss(tx, ty) = 127.21608f0
loss(tx, ty) = 127.09291f0
loss(tx, ty) = 125.93344f0
loss(tx,

Reletld


In [71]:
m, alphabet = train(M)
sample(m, alphabet, 1000) |> println

loss(tx, ty) = 174.7434f0
loss(tx, ty) = 163.19757f0
loss(tx, ty) = 162.01944f0
loss(tx, ty) = 161.3833f0
loss(tx, ty) = 159.84792f0
loss(tx, ty) = 162.13979f0
loss(tx, ty) = 162.96867f0
loss(tx, ty) = 162.30132f0
loss(tx, ty) = 161.68213f0
loss(tx, ty) = 161.45326f0
loss(tx, ty) = 159.66147f0
loss(tx, ty) = 157.4465f0
loss(tx, ty) = 156.85132f0
loss(tx, ty) = 151.71425f0
loss(tx, ty) = 147.25282f0
loss(tx, ty) = 143.75894f0
loss(tx, ty) = 141.19958f0
loss(tx, ty) = 138.89828f0
loss(tx, ty) = 137.14867f0
loss(tx, ty) = 135.71844f0
loss(tx, ty) = 135.0299f0
loss(tx, ty) = 133.69594f0
loss(tx, ty) = 131.97601f0
loss(tx, ty) = 131.41594f0
loss(tx, ty) = 130.62589f0
loss(tx, ty) = 130.16699f0
loss(tx, ty) = 130.06006f0
loss(tx, ty) = 129.6974f0
loss(tx, ty) = 129.00066f0
loss(tx, ty) = 128.92249f0
loss(tx, ty) = 128.8132f0
loss(tx, ty) = 128.09833f0
loss(tx, ty) = 128.2645f0
loss(tx, ty) = 127.211655f0
loss(tx, ty) = 127.98432f0
loss(tx, ty) = 128.33676f0
loss(tx, ty) = 128.86119f0
loss(tx