In [69]:
using CuArrays, Flux, Statistics, Random

FILE = "D:/downloads/fra-eng/eng-fra.txt"

mutable struct Lang
    name
    word2index
    word2count
    index2word
    n_words
end

Lang(name) = Lang(
    name,
    Dict{String, Int}(),
    Dict{String, Int}(),
    Dict{Int, String}(1=>"SOS", 2=>"EOS", 3=>"UNK", 4=>"PAD"),
    4)

function (l::Lang)(sentence::String)
    for word in split(sentence, " ")
            if word ∉ keys(l.word2index)
                l.word2index[word] = l.n_words + 1
                l.word2count[word] = 1
                l.index2word[l.n_words + 1] = word
                l.n_words += 1
            else
                l.word2count[word] += 1
            end
    end
end

function normalizeString(s)
    s = strip(lowercase(s))
    s = replace(s, r"([.!?,])"=>s" \1")
    s = replace(s, "'"=>" ' ")
    return s
end

function readLangs(lang1, lang2; rev=false)
    println("Reading lines...")
    lines = readlines(FILE)
    pairs = [normalizeString.(pair) for pair in split.(lines, "\t")]
    if rev
        pairs = reverse.(pairs)
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
    end
    return(input_lang, output_lang, pairs)
end
        
MAX_LENGTH = 10

eng_prefixes = [
    "i am ", "i ' m ",
    "he is ", "he ' s ",
    "she is ", "she ' s ",
    "you are ", "you ' re ",
    "we are ", "we ' re ",
    "they are ", "they ' re "]
        
function filterPair(pair)
    return(false ∉ (length.(split.(pair, " ")) .<= MAX_LENGTH) && true ∈ (startswith.(pair[1], eng_prefixes)))
end

function prepareData(lang1, lang2; rev=false)
    input_lang, output_lang, pairs = readLangs(lang1, lang2; rev=rev)
    println("Read $(length(pairs)) sentence pairs.")
    pairs = [pair for pair in pairs if filterPair(pair)]
    println("Trimmed to $(length(pairs)) sentence pairs.\n")
    xs = []
    ys = []
    for pair in pairs
        push!(xs, pair[1])
        push!(ys, pair[2])
    end
    println("Counting words...")
    for pair in pairs
        input_lang(pair[2])
        output_lang(pair[1])
    end
    println("Counted words:")
    println("• ", input_lang.name, ": ", input_lang.n_words)
    println("• ", output_lang.name, ": ", output_lang.n_words)
    return(input_lang, output_lang, xs, ys)
end

fr, eng, xs, ys = prepareData("fr", "eng")
indices = shuffle([1:length(xs)...])
xs = xs[indices]
ys = ys[indices];
        
BATCH_SIZE = 64

indexesFromSentence(lang, sentence) = append!(get.(Ref(lang.word2index), split(lowercase(sentence), " "), 3), 2)

function batch(data, batch_size, voc_size; gpu=true)
    chunks = Iterators.partition(data, batch_size)
    batches = []
    for chunk in chunks
        max_length = maximum(length.(chunk))
        chunk = map(sentence->append!(sentence, fill(4, max_length-length(sentence))), chunk)
        chunk = hcat(reshape.(chunk, :, 1)...)
        batch = []
        for i in 1:size(chunk, 1)
            if gpu
                push!(batch, cu(Flux.onehotbatch(chunk[i, :], [1:voc_size...])))
            else
                push!(batch, Flux.onehotbatch(chunk[i, :], [1:voc_size...]))
            end
        end
        push!(batches, batch)
    end
    return(batches)
end

x, y = batch.([indexesFromSentence.([eng], xs), indexesFromSentence.([fr], ys)], [BATCH_SIZE], [eng.n_words, fr.n_words]; gpu=true);

Reading lines...
Read 154883 sentence pairs.
Trimmed to 11645 sentence pairs.

Counting words...
Counted words:
• fr: 4831
• eng: 3047


In [70]:
struct Encoder
    embedding
    dropout
    rnn
    out
end
Encoder(voc_size::Integer; h_size::Integer=HIDDEN, dropout::Number=DROPOUT) = Encoder(
    param(Flux.glorot_uniform(h_size, voc_size)),
    Dropout(dropout),
    GRU(h_size, h_size),
    Dense(h_size, h_size))
function (e::Encoder)(x)
    x = map(x->e.dropout(e.embedding*x), x)
    enc_outputs = e.rnn.(x)
    h = e.out(enc_outputs[end])
    return(enc_outputs, h)
end
Flux.@treelike Encoder

$ u_i^t = v^T tanh(W_1'h_i+W_2'd_t) $

$ a_i^t = softmax(u_i^t) $

$ \sum\limits_{i=1}^{T_a} a_i^t h_i$

In [71]:
struct Attention
    W1
    W2
    v
end
Attention(h_size) = Attention(
    Dense(h_size, h_size),
    Dense(h_size, h_size),
    param(Flux.glorot_uniform(1, h_size)))
function (a::Attention)(enc_outputs, d)
    U = [a.v*tanh.(x) for x in a.W1.(enc_outputs).+[a.W2(d)]]
    A = softmax(vcat(U...))
    out = sum([gpu(collect(A[i, :]')) .* h for (i, h) in enumerate(enc_outputs)])
end
Flux.@treelike Attention

W1 = Dense(16, 16)|>gpu
W2 = Dense(16, 16)|>gpu
v = param(Flux.glorot_uniform(1,16))|>gpu

A = softmax(vcat([v*tanh.(x) for x in W1.(enc_outputs).+[W2(d)]]...))
collect(A[1, :]')
gpu(collect(A[2, :]')).*enc_outputs[1]

m = Chain(x->gpu(Attention(16))(x...), x->sum(x))
opt = SGD(params(m))
#m((enc_outputs, d))
Flux.train!(m, [[(enc_outputs, d)]], opt)

enc_outputs, d = (gpu.([rand(16, 32), rand(16, 32)]), rand(16,32)|>gpu)

In [72]:
struct Decoder
    embedding
    attention
    rnn
    output
end
Decoder(h_size, voc_size) = Decoder(
    param(Flux.glorot_uniform(h_size, voc_size)),
    Attention(h_size),
    GRU(h_size*2, h_size),
    Dense(h_size, voc_size, relu))
function (d::Decoder)(x, enc_outputs; dropout=0)
    x = d.embedding * x
    x = Dropout(dropout)(x)
    decoder_state = d.rnn.state
    context = d.attention(enc_outputs, decoder_state)
    x = d.rnn([x; context])
    x = softmax(d.output(x))
    return(x)
end
Flux.@treelike Decoder

In [83]:
HIDDEN = 128
LEARNING_RATE = 0.03
DROPOUT = 0.2;

In [84]:
testEncoder = Encoder(eng.n_words)|>gpu
testDecoder = Decoder(HIDDEN, fr.n_words)|>gpu;

In [85]:
function model(encoder::Encoder, decoder::Decoder, x, y; teacher_forcing = 0.5, dropout=DROPOUT, voc_size=fr.n_words)
    total_loss = 0
    max_length = length(y)
    batch_size = size(x[1], 2)
    Flux.reset!.([encoder, decoder])
    enc_outputs, h = encoder(x)
    decoder_input = Flux.onehotbatch(ones(batch_size), [1:voc_size...])
    decoder.rnn.state = h
    for i in 1:max_length
        use_teacher_forcing = rand() < teacher_forcing
        decoder_output = decoder(decoder_input, enc_outputs; dropout=dropout)
        total_loss += loss(decoder_output, y[i])
        if use_teacher_forcing
            decoder_input = y[i]
        else
            decoder_input = Flux.onehotbatch(Flux.onecold(decoder_output.data), [1:voc_size...])
        end
    end
    return(total_loss)
end

model(x, y) = model(testEncoder, testDecoder, x, y; dropout = 0.05)

model (generic function with 3 methods)

In [86]:
function model(encoder::Encoder, decoder::Decoder, x; reset=true, voc_size=fr.n_words)
    result = []
    if reset Flux.reset!.([encoder, decoder]) end
    enc_outputs, h = encoder(x)
    decoder_input = Flux.onehot(1, [1:voc_size...])
    decoder.rnn.state = h
    for i in 1:12
        decoder_output = Flux.onecold(decoder(decoder_input, enc_outputs))
        if decoder_output[1] == 2 break end
        push!(result, decoder_output...)
    end
    return(result)
end

model (generic function with 3 methods)

In [87]:
lossmask = ones(fr.n_words)|>gpu
lossmask[4] = 0

loss(logits, target) = Flux.crossentropy(logits, target; weight=lossmask)

opt = SGD(params(testEncoder, testDecoder), LEARNING_RATE)

#43 (generic function with 1 method)

In [88]:
function partitionTrainTest(x, y, at)
    n = length(x)
    idx = shuffle(1:n)
    train_idx = view(idx, 1:floor(Int, at*n))
    test_idx = view(idx, (floor(Int, at*n)+1):n)
    train_x, test_x = x[train_idx,:], x[test_idx,:]
    train_y, test_y = y[train_idx,:], y[test_idx,:]
    return(train_x, train_y, test_x, test_y)
end

train_x, train_y, test_x, test_y = partitionTrainTest(x, y, 0.90);

In [91]:
EPOCHS = 20

for i in 1:EPOCHS
    Flux.train!(model, zip(train_x, train_y), opt)
    println("loss: ", mean(model.(test_x, test_y)).data)
end

loss: 27.571863
loss: 27.486343
loss: 26.830484
loss: 26.309473
loss: 26.658197
loss: 26.309603
loss: 25.964235
loss: 25.645933
loss: 25.657099
loss: 25.125471
loss: 25.051096
loss: 24.310825
loss: 24.868933
loss: 24.411522
loss: 24.269579
loss: 24.17102
loss: 24.438782
loss: 23.875217

In [79]:
EPOCHS = 20

for i in 1:EPOCHS
    Flux.train!(model, zip(train_x, train_y), opt)
    println("loss: ", mean(model.(test_x, test_y)).data)
end

loss: 63.385242
loss: 62.951447
loss: 62.48296
loss: 61.928158
loss: 61.14286
loss: 59.63839
loss: 54.74136
loss: 49.10149
loss: 47.900692
loss: 46.31653
loss: 44.378407
loss: 42.588562
loss: 40.99242
loss: 39.650463
loss: 38.72285
loss: 38.132557
loss: 37.71355
loss: 37.39871
loss: 37.14562
loss: 36.930775


In [81]:
function predict(encoder, decoder, sentence::String)
    sentence = normalizeString(sentence)
    input = append!(get.(Ref(eng.word2index), split(lowercase(sentence), " "), 3), 2)
    input = [Flux.onehot(word, [1:eng.n_words...]) for word in input]
    output = model(encoder, decoder, input)
    output = get.(Ref(fr.index2word), output, "UNK")
    println(output)
end

predict (generic function with 1 method)

In [90]:
predict(testEncoder, testDecoder, "she's doing her thing")
predict(testEncoder, testDecoder, "you're too skinny")
predict(testEncoder, testDecoder, "He is singing")

["il", "est", "'", "de", "de", "'", ".", ".", ".", ".", ".", "."]
["il", "'", "'", "de", "'", "'", ".", ".", ".", ".", ".", "."]
["il", "est", "'", "'", "'", ".", ".", ".", ".", ".", ".", "."]
