**This is the replication of vae from nn4nlp - 15 vae**


When we attempted to train the model from the original dynet implementation, training resulted in an error due to infinite gradients on the 30th epoch.
Secondly, the training was done word by word instead of batches being processed simultaneously. 
Thirdly, vae was trained with a parallel corpus, input data was japanese and the output was english.
Lastly, there was no kl divergence annealing or any other method to prevent posterior collapse.

In our implementation, we trained for 300 epochs on English to English. To prevent posterior collapse we additionally implemented KL divergence annealing
by a linearly incresing coefficient for KL loss. And to speed up the training we trained with minibatches.

As a result, in the 2nd to last cell, the pretrained model can be loaded. In the last cell -which works on CPU- a random z from standard gaussian is sampled.
Which is used to sample an English sentece form the probability distribution our model learned.

In [2]:
using Pkg
for p in ("Knet", "Random", "Statistics", "AutoGrad", "IterTools")
    haskey(Pkg.installed(),p) || Pkg.add(p)
end

Comment out last line if using CPU

In [4]:
using Knet, Random, Statistics, Base.Iterators, Test, IterTools
using AutoGrad: @gcheck
Knet.atype() = KnetArray{Float32}

The following cells are from previous projects

In [9]:
struct Vocab
    w2i::Dict{String,Int}
    i2w::Vector{String}
    unk::Int
    eos::Int
    tokenizer
end

function Vocab(file::String; tokenizer=split, vocabsize=Inf, mincount=1, unk="<unk>", eos="<s>")
    M = 100000
    wdict = Dict()
    wcount = Dict()
    w2i(x) = get!(wdict, x, 1+length(wdict))
    w2c(key) = haskey(wcount, key) ? wcount[key] = wcount[key] + 1 : get!(wcount, key, 1)
    wcount[unk] = M; wcount[eos] = M
    i2w = []; 

    
    for line in eachline(file)
        words = tokenizer(line)
        w2c.(words)
    end
    
    sortedcount = sort(collect(wcount), by=x->x[2])
    words = sortedcount[findfirst(x-> x[2]>=mincount, sortedcount):length(sortedcount)]
    
    #vocabsize excludes unk & eos
    if(length(words) > vocabsize)
        words = words[length(words) - vocabsize + 1 : length(words)]
    end

    map(x-> w2i(x[1]) , words)
    map(x-> push!(i2w, x[1]), words)
    
    Vocab(wdict, i2w, wdict[unk], wdict[eos], tokenizer)
end

Vocab

In [10]:
struct TextReader
    file::String
    vocab::Vocab
end

function Base.iterate(r::TextReader, s=nothing)
    w2i(x) = get(r.vocab.w2i, x, r.vocab.unk)
    if (s === nothing) 
        s = open(r.file, "r")
    end

    if eof(s) 
        close(s)
        return nothing
    
    else
        tmp = readline(s)
        line = r.vocab.tokenizer(tmp)
        words = w2i.(line) 
        return words, s
    end    
end

Base.IteratorSize(::Type{TextReader}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TextReader}) = Base.HasEltype()
Base.eltype(::Type{TextReader}) = Vector{Int}

In [11]:
struct Embed; w; end

function Embed(vocabsize::Int, embedsize::Int)
    Embed(param(embedsize, vocabsize))
end

function (l::Embed)(x)
    embedsz, vocabsz = size(l.w)
    tmparr = [embedsz]
    for dim in size(x)
        push!(tmparr, dim)
    end
    reshape(l.w[:,collect(flatten(x))], tuple(tmparr...))
end

function mask!(a,pad)
    x,y = size(a)
    
    for i = 1:x
        tmp_mem = []
        isfirst = true
        for j = 1:y
            if a[i, j] == pad
                
                if isfirst
                    isfirst = false
                else
                    push!(tmp_mem, j)
                end
            else
                isfirst = true
                tmp_mem = []
            end
        end
        tmp_mem = convert(Array{Int,1}, tmp_mem)
        a[i, tmp_mem] .= 0
    end
    return a
end

mask! (generic function with 1 method)

In [12]:
struct MTData
    src::TextReader        # reader for source language data
    tgt::TextReader        # reader for target language data
    batchsize::Int         # desired batch size
    maxlength::Int         # skip if source sentence above maxlength
    batchmajor::Bool       # batch dims (B,T) if batchmajor=false (default) or (T,B) if true.
    bucketwidth::Int       # batch sentences with length within bucketwidth of each other
    buckets::Vector        # sentences collected in separate arrays called buckets for each length range
    batchmaker::Function   # function that turns a bucket into a batch.
end

#batchsize 128
function MTData(src::TextReader, tgt::TextReader; batchmaker = arraybatch, batchsize = 64, maxlength = typemax(Int),
                batchmajor = false, bucketwidth = 10, numbuckets = min(128, maxlength ÷ bucketwidth))
    buckets = [ [] for i in 1:numbuckets ] # buckets[i] is an array of sentence pairs with similar length
    MTData(src, tgt, batchsize, maxlength, batchmajor, bucketwidth, buckets, batchmaker)
end

Base.IteratorSize(::Type{MTData}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{MTData}) = Base.HasEltype()
Base.eltype(::Type{MTData}) = NTuple{2}

In [13]:
function Base.iterate(d::MTData, state=nothing)
    if (state === nothing) 
        
        for i = 1:length(d.buckets)
            d.buckets[i] = []
        end
        src = d.src
        tgt = d.tgt
        src = Iterators.Stateful(src)
        tgt = Iterators.Stateful(tgt)
    else
        src = state[1]
        tgt = state[2]
    end
    
    
    if(isempty(src)&&isempty(tgt))
        for i = 1:length(d.buckets)
            if(length(d.buckets[i]) > 0)
                tmp_batch = d.batchmaker(d, d.buckets[i])
                 if(d.batchmajor == true)
                    tmp_batch = (transpose(tmp_batch[1]), transpose(tmp_batch[2]))
                end
                d.buckets[i] = []
                return (tmp_batch, (src, tgt))
            end
        end
    end    
        
    while(!isempty(src) && !isempty(tgt))
        sentences = (popfirst!(src), popfirst!(tgt))
        src_sentence = sentences[1]
        tgt_sentence = sentences[2]
        src_length = length(src_sentence)
        
        if(src_length > d.maxlength)
            continue
        elseif(length(d.buckets)*d.bucketwidth < src_length)
            index_in_buckets = length(d.buckets)
        else
            index_in_buckets = ceil(src_length/d.bucketwidth)
        end
        
        index_in_buckets = convert(Int64, index_in_buckets)
        push!(d.buckets[index_in_buckets], (src_sentence, tgt_sentence))
        
        if(isempty(src) && isempty(tgt))
                tmp_batch = d.batchmaker(d, d.buckets[index_in_buckets])
                if(d.batchmajor == true)
                    tmp_batch = (transpose(tmp_batch[1]), transpose(tmp_batch[2]))
                end
                d.buckets[index_in_buckets] = []
                return (tmp_batch, (src, tgt))
        end  
        
        if(length(d.buckets[index_in_buckets]) == d.batchsize)
            tmp_batch = d.batchmaker(d, d.buckets[index_in_buckets])
            if(d.batchmajor == true)
                tmp_batch = (transpose(tmp_batch[1]), transpose(tmp_batch[2]))
            end
            d.buckets[index_in_buckets] = []
            return (tmp_batch, (src, tgt))
        end 
    end   
end

In [14]:
function arraybatch(d::MTData, bucket)
    # Your code here
    x = []
    y = []
    
    padded_x = Array{Int64,1}[]
    padded_y = Array{Int64,1}[]
    
    max_length_x = 0
    max_length_y = 0
    
    for sent_pair in bucket
        push!(x, sent_pair[1])
        push!(sent_pair[2], d.tgt.vocab.eos)
        pushfirst!(sent_pair[2], d.tgt.vocab.eos)
        push!(y, sent_pair[2])
        
        if(length(sent_pair[1]) > max_length_x)
            max_length_x = length(sent_pair[1])
        end
        
        if(length(sent_pair[2]) > max_length_y)
            max_length_y = length(sent_pair[2])
        end
    end
    for sent_pair in zip(x,y)
        x_pad_length = max_length_x - length(sent_pair[1])
        y_pad_length = max_length_y - length(sent_pair[2])
        x_pad_seq = repeat([d.src.vocab.eos], x_pad_length)
        y_pad_seq = repeat([d.tgt.vocab.eos], y_pad_length)
        push!(padded_x, append!(x_pad_seq, sent_pair[1]))
        push!(padded_y, append!(sent_pair[2], y_pad_seq))
    end
    
    no_of_sentences = length(padded_x)

    
    padded_x = permutedims(hcat(padded_x...), (2,1))
    padded_y = permutedims(hcat(padded_y...), (2,1))
    
    return (padded_x,padded_y)
end

arraybatch (generic function with 1 method)

Data used is from nn4nlp, in our implementation we used only the English corpus

In [96]:
train_src_file = "nn4nlp-code-master/data/parallel/train.ja"
train_tgt_file = "nn4nlp-code-master/data/parallel/train.en"
dev_src_file = "nn4nlp-code-master/data/parallel/dev.ja"
dev_tgt_file = "nn4nlp-code-master/data/parallel/dev.en"
test_src_file = "nn4nlp-code-master/data/parallel/test.ja"
test_tgt_file = "nn4nlp-code-master/data/parallel/test.en"



ja_vocab = Vocab(train_src_file, mincount=2)
en_vocab = Vocab(train_tgt_file, mincount=2)
ja_train = TextReader(train_src_file, ja_vocab)
en_train = TextReader(train_tgt_file, en_vocab)
ja_dev = TextReader(dev_src_file, ja_vocab)
en_dev = TextReader(dev_tgt_file, en_vocab)
ja_test = TextReader(test_src_file, ja_vocab)
en_test = TextReader(test_tgt_file, en_vocab)

TextReader("nn4nlp-code-master/data/parallel/test.en", Vocab(Dict("enjoy" => 3043,"shouldn" => 1987,"chocolate" => 1630,"fight" => 2560,"helping" => 1988,"whose" => 2231,"hurried" => 1631,"favor" => 2759,"borders" => 1,"star" => 1632…), ["borders", "stress", "fireworks", "methods", "parted", "shakespeare", "customer", "musical", "regarded", "21"  …  "you", "he", "is", "a", "i", "to", "the", ".", "<unk>", "<s>"], 3710, 3711, split))

In [16]:
dtrn = MTData(en_train, en_train)
ddev = MTData(en_dev, en_dev)
dtst = MTData(en_test, en_test)
summary(dtrn)

"MTData"

In [17]:
function reparameterize(μ, σ)
   μ .+ randn!(similar(μ)) .* σ
end

reparameterize (generic function with 1 method)

In [18]:
struct Linear; w; b; end

function Linear(inputsize::Int, outputsize::Int)
    w = param(outputsize, inputsize)
    b = param0(outputsize)
    Linear(w,b)
end

function (l::Linear)(x)
    l.w * x .+ l.b
end

In [19]:
mutable struct MLP
    W
    V
    b
end

function(m::MLP)(x)
    m.V * tanh.(m.W*x .+ m.b)
end

function MLP(input::Int, hidden::Int, output::Int)
    W = param(hidden, input)
    b = param0(hidden)
    V = param(output, hidden)
    MLP(W,V,b)
end

MLP

In [20]:
# Model parameters, identical to the original dynet impplementation
EMBED_SIZE = 64
HIDDEN_SIZE = 128
Q_HIDDEN_SIZE = 64
BATCH_SIZE = 16
MAX_SENT_SIZE = 50
SRC_VOCAB_SIZE = length(ja_vocab.i2w)
TGT_VOCAB_SIZE = length(en_vocab.i2w)

3711

In [21]:
mutable struct s2s_vae
    srcembed::Embed     # source language embedding
    encoder::RNN        # encoder RNN (can be bidirectional)
    tgtembed::Embed     # target language embedding
    decoder::RNN        # decoder RNN
    projection::Linear  # converts decoder output to vocab scores
    mean_mlp::MLP       # MLP for estimating mean
    var_mlp::MLP        # MLP for estimating standard deviations
    srcvocab::Vocab     # source language vocabulary
    tgtvocab::Vocab     # target language vocabulary
end

In [22]:
function s2s_vae(hidden::Int,        # hidden size for both the encoder and decoder RNN
                q_hidden::Int,       # hidden size for MLP hidden layer which estimates mean and std dev for q
                srcembsz::Int,       # embedding size for source language
                tgtembsz::Int,       # embedding size for target language
                srcvocab::Vocab,     # vocabulary for source language
                tgtvocab::Vocab;     # vocabulary for target language
                layers=1,            # number of layers
                bidirectional=false, # whether encoder RNN is bidirectional
                ) 
    
    srcembed = Embed(length(srcvocab.i2w), srcembsz)
    tgtembed = Embed(length(tgtvocab.i2w), tgtembsz)
    
    encoder = RNN(srcembsz, hidden, rnnType = :lstm, h = 0)
    decoder = RNN(tgtembsz, hidden, rnnType = :lstm, h = 0)
    
    mean_mlp = MLP(hidden, q_hidden, hidden)
    var_mlp = MLP(hidden, q_hidden, hidden)
    
    projection = Linear(hidden, length(tgtvocab.i2w))
    
    s2s_vae(srcembed, encoder, tgtembed, decoder, projection, mean_mlp, var_mlp, srcvocab, tgtvocab)
    
end 

s2s_vae

In [23]:
epoch_no = 0    #for debugging purposes, tracking epoch
softmax_loss_total = 0  #for debugging purposes, tracking softmax loss of current epoch
kl_loss_total = 0 #for debugging purposes, tracking kl loss of current epoch

0

In [24]:
function (s::s2s_vae)(src, tgt; average=true) #calculate loss for each batch
    
    global epoch_no
    global softmax_loss_total
    global kl_loss_total
    
    #KL loss coefficient increases linearly in the first 80 epochs
    linear_divergence_schedule = 80
    λ = (linear_divergence_schedule-max(linear_divergence_schedule-epoch_no,0))/linear_divergence_schedule
        
    #initialize encoder 
    s.encoder.c = 0; s.encoder.h = 0
    
    #get the final hidden state from the LSTM given the source sentence 
    y_enc = s.encoder(s.srcembed(src))[:,:,end]

    #estimate means and log variances from the hidden state represenation of the input sentence
    mu = s.mean_mlp(y_enc)
    log_var = s.var_mlp(y_enc)
    
    x_mu, y_mu = size(mu)
    
    #calculate kl loss
    kl_loss = -0.5 * sum(1 .+ (log_var - mu.*mu - exp.(log_var)))
    
    
    #perform reparameterization trick
    z = reparameterize(mu, exp.(log_var))
    x,y = size(z)
    z = reshape(z, (x,y,1))
    
    #initialize decoder according to sampled z
    s.decoder.c = z; s.decoder.h = tanh.(z)
    y_dec = s.decoder(s.tgtembed(tgt[:,1:end-1]))
   
    #predict next words with decoder and reconstruction loss by comparing against gold answers
    hy, b ,ty = size(y_dec)
    y_dec = reshape(y_dec, (hy, b*ty))
    scores = s.projection(y_dec)
    y_gold = mask!(tgt[:,2:end], s.tgtvocab.eos)
    softmax_loss, instances = nll(scores, y_gold; average = false)
    
    #calculate loss by adding weighted kl loss and reconstruction loss
    softmax_loss_total += softmax_loss
    kl_loss_total += kl_loss
    
    λ*kl_loss+softmax_loss
end

In [25]:
function loss(model, data; average=true) #calculate loss for the entire corpus
    instances = 0
    loss = 0
    global epoch_no
    global softmax_loss_total
    global kl_loss_total
    #track epoch for debugging 
    epoch_no += 1
   
    for batch in data
        src, tgt = batch
        instances += length(tgt)
        loss += model(src,tgt)     
    end
    println("softmax loss for epoch ", epoch_no, " is: ", softmax_loss_total, " and kl loss is: ", kl_loss_total)
    
    softmax_loss_total = 0
    kl_loss_total = 0
    
    loss/instances
end

loss (generic function with 1 method)

In [26]:
function train!(model, trn, dev, tst...) #training loop from previous assignments

    bestmodel, bestloss = deepcopy(model), loss(model, dev)
    progress!(adam(model, trn), steps=100) do y

        losses = [ loss(model, d) for d in (dev,tst...) ]
        if losses[1] < bestloss
            bestmodel, bestloss = deepcopy(model), losses[1]
        end
        return (losses...,)
    end
    return bestmodel
end

train! (generic function with 1 method)

to initialize model uncomment first, to load a pretrained model uncomment second line
to train with parallel japanese to english corpus change the first en_vocab in the first line to ja_vocab

In [140]:
#model = s2s_vae(HIDDEN_SIZE, Q_HIDDEN_SIZE, EMBED_SIZE, EMBED_SIZE, en_vocab, en_vocab)
#model = Knet.load("vae.jld2", "model")

s2s_vae(Embed(P(Array{Float32,2}(64,3711))), LSTM(input=64,hidden=128), Embed(P(Array{Float32,2}(64,3711))), LSTM(input=64,hidden=128), Linear(P(Array{Float32,2}(3711,128)), P(Array{Float32,1}(3711))), MLP(P(Array{Float32,2}(64,128)), P(Array{Float32,2}(128,64)), P(Array{Float32,1}(64))), MLP(P(Array{Float32,2}(64,128)), P(Array{Float32,2}(128,64)), P(Array{Float32,1}(64))), Vocab(Dict("enjoy" => 3043,"shouldn" => 1987,"chocolate" => 1630,"fight" => 2560,"helping" => 1988,"whose" => 2231,"hurried" => 1631,"favor" => 2759,"borders" => 1,"star" => 1632…), ["borders", "stress", "fireworks", "methods", "parted", "shakespeare", "customer", "musical", "regarded", "21"  …  "you", "he", "is", "a", "i", "to", "the", ".", "<unk>", "<s>"], 3710, 3711, split), Vocab(Dict("enjoy" => 3043,"shouldn" => 1987,"chocolate" => 1630,"fight" => 2560,"helping" => 1988,"whose" => 2231,"hurried" => 1631,"favor" => 2759,"borders" => 1,"star" => 1632…), ["borders", "stress", "fireworks", "methods", "parted", "sh

In [138]:
epochs = 300  #no of epochs, other parts of the code are from previous assignments
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trn20 = ctrn[1:20]
dev38 = collect(ddev)

11-element Array{Tuple{T,T} where T,1}:
 ([3947 3947 … 3934 3945; 3947 3947 … 3924 3945; … ; 3947 3947 … 3943 3945; 3947 3947 … 3930 3945], [3711 3685 … 3711 3711; 3711 3704 … 3711 3711; … ; 3711 3710 … 3711 3711; 3711 3703 … 3711 3711])
 ([3947 3947 … 3940 3945; 3947 3947 … 3925 3945; … ; 3947 3947 … 3940 3945; 3947 3947 … 3930 3945], [3711 3694 … 3711 3711; 3711 3710 … 3711 3711; … ; 3711 3692 … 3711 3711; 3711 3462 … 3711 3711])
 ([3947 3913 … 3915 3945; 3947 3947 … 3925 3945; … ; 3851 3927 … 3930 3945; 3929 3944 … 3925 3945], [3711 3690 … 3711 3711; 3711 3703 … 3711 3711; … ; 3711 3690 … 3711 3711; 3711 3706 … 3711 3711])
 ([3947 3947 … 3923 3945; 3947 3947 … 3943 3945; … ; 3947 3947 … 3940 3945; 3947 3947 … 3934 3945], [3711 3702 … 3711 3711; 3711 3691 … 3711 3711; … ; 3711 3690 … 3711 3711; 3711 3290 … 3711 3711])
 ([3947 3947 … 3934 3945; 3947 3947 … 3940 3945; … ; 3947 3947 … 3943 3945; 3947 3947 … 3940 3945], [3711 3706 … 3711 3711; 3711 2923 … 3711 3711; … ; 3711 2843 … 3711 

In [139]:
#model = train!(model, trnx10, dev38, trn20)   #uncomment to train
#Knet.save("vae.jld2","model",model)           #uncomment to save trained model

In [156]:
model = Knet.load("vae.jld2", "model")  #load model for sampling from learned generative model

s2s_vae(Embed(P(Array{Float32,2}(64,3711))), LSTM(input=64,hidden=128), Embed(P(Array{Float32,2}(64,3711))), LSTM(input=64,hidden=128), Linear(P(Array{Float32,2}(3711,128)), P(Array{Float32,1}(3711))), MLP(P(Array{Float32,2}(64,128)), P(Array{Float32,2}(128,64)), P(Array{Float32,1}(64))), MLP(P(Array{Float32,2}(64,128)), P(Array{Float32,2}(128,64)), P(Array{Float32,1}(64))), Vocab(Dict("enjoy" => 3043,"shouldn" => 1987,"chocolate" => 1630,"fight" => 2560,"helping" => 1988,"whose" => 2231,"hurried" => 1631,"favor" => 2759,"borders" => 1,"star" => 1632…), ["borders", "stress", "fireworks", "methods", "parted", "shakespeare", "customer", "musical", "regarded", "21"  …  "you", "he", "is", "a", "i", "to", "the", ".", "<unk>", "<s>"], 3710, 3711, split), Vocab(Dict("enjoy" => 3043,"shouldn" => 1987,"chocolate" => 1630,"fight" => 2560,"helping" => 1988,"whose" => 2231,"hurried" => 1631,"favor" => 2759,"borders" => 1,"star" => 1632…), ["borders", "stress", "fireworks", "methods", "parted", "sh

As mentioned in the beggining, z is sampled from standard gaussian distribution. 
We use this z as the initial state for our decoder.
We sample a new sentence until eos token is reached by greedy decoding.

In [210]:
a = Random.seed!(2)
summary(a)

"MersenneTwister"

In [211]:
#sample from standard gaussian
z = reshape(reparameterize(zeros(Float32, 128),ones(Float32, 128)), (128,1,1))  

#inititalize decoder
model.decoder.h = tanh.(z)
model.decoder.c = z
input = [model.tgtvocab.eos]
isDone = false
translated_sentence = [model.tgtvocab.eos]
    
#sample next words with greedy decoding until eos is reached
while (!isDone)
        global input
        input = reshape(input, (1,1))
        y = model.decoder(model.tgtembed(input))
        
        scores = model.projection(mat(y))
        next_word = argmax(scores)[1]
        translated_sentence = push!(translated_sentence, next_word)
        input = [next_word]
        if(next_word == model.tgtvocab.eos || length(translated_sentence)>50)
            isDone = true
        end
    
end
#print out the final sentence
model.tgtvocab.i2w[translated_sentence]

9-element Array{String,1}:
 "<s>"   
 "i"     
 "have"  
 "to"    
 "go"    
 "to"    
 "school"
 "."     
 "<s>"   