In [1]:
# This cell will be updated with respect to needs of project
using Pkg; for p in ("Knet","AutoGrad","Plots","Images","ImageMagick","ArgParse","CUDA"); haskey(Pkg.installed(),p) || Pkg.add(p); end
using Knet
using Statistics
using Random
using Test
import Base: length, size, iterate, eltype, IteratorSize, IteratorEltype, haslength, @propagate_inbounds, repeat, rand, tail
import .Iterators: cycle, Cycle, take
#using Plots; default(fmt=:png,ls=:auto)

└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554
└ @ Pkg C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Pkg.jl:554


## Params

In [None]:
model_dim = 512
BATCH_SIZE = 64

## Multihead Linear Attention - MHLA - 

### Define Structures

In [None]:
struct Vocab
    w2i::Dict{String,Int}
    i2w::Vector{String}
    unk::Int
    eos::Int
    tokenizer
end

function Vocab(
    file::String;
    tokenizer = split,
    vocabsize = Inf,
    mincount = 1,
    unk = "<unk>",
    eos = "<s>",
)
    vocab_freq = Dict{String,Int64}(unk => 1, eos => 1)
    w2i = Dict{String,Int64}(unk => 2, eos => 1)
    i2w = Vector{String}()

    push!(i2w, eos)
    push!(i2w, unk)

    open(file) do f
        for line in eachline(f)
            sentence = strip(lowercase(line))
            sentence = tokenizer(line, [' '], keepempty = false)

            for word in sentence
                word == unk && continue
                word == eos && continue # They are default ones to be added later
                vocab_freq[word] = get!(vocab_freq, word, 0) + 1
            end
        end
        close(f)
    end


    vocab_freq = sort!(
        collect(vocab_freq),
        by = tuple -> last(tuple),
        rev = true,
    )

    if length(vocab_freq) > vocabsize - 2 # eos and unk ones
        vocab_freq = vocab_freq[1:vocabsize-2] # trim to fit the size
    end

    while true
        length(vocab_freq) == 0 && break
        word, freq = vocab_freq[end]
        freq >= mincount && break # since it is already ordered
        vocab_freq = vocab_freq[1:(end-1)]
    end

    for i = 1:length(vocab_freq)
        word, freq = vocab_freq[i]
        ind = (get!(w2i, word, 1 + length(w2i)))
        (length(i2w) < ind) && push!(i2w, word)
    end

    Vocab(w2i, i2w, 2, 1, tokenizer)
end

In [None]:
const datadir = "data"
train_vocab = Vocab("$datadir/train.txt")

In [None]:
struct TextReader
    file::String
    vocab::Vocab
end

word2ind(dict, x) = get(dict, x, 2) # unk is 2

#Implementing the iterate function
function Base.iterate(r::TextReader, s = nothing)
    if s == nothing
        state = open(r.file)
        Base.iterate(r, state)
    else
        if eof(s) == true
            close(s)
            return nothing
        else              
            sent = r.vocab.tokenizer(strip(lowercase(readline(s))), [' '], keepempty = false)
            sent_ind = Int[]
            for word in sent
                ind = word2ind(r.vocab.w2i, word)
                push!(sent_ind, ind)
            end
            return (sent_ind, s)
        end
    end
end

In [None]:
train_sentences, valid_sentences, test_sentences =
    (TextReader("$datadir/$file.txt", train_vocab) for file in ("train","valid","test"))

In [None]:
truct LMData
    src::TextReader
    batchsize::Int
    maxlength::Int
    bucketwidth::Int
    buckets
end

function LMData(src::TextReader; batchsize = 64, maxlength = typemax(Int), bucketwidth = 10)
    numbuckets = min(128, maxlength ÷ bucketwidth)
    buckets = [ [] for i in 1:numbuckets ]
    LMData(src, batchsize, maxlength, bucketwidth, buckets)
end

Base.IteratorSize(::Type{LMData}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{LMData}) = Base.HasEltype()
Base.eltype(::Type{LMData}) = Matrix{Int}

function Base.iterate(d::LMData, state=nothing)
    if state == nothing
        for b in d.buckets; empty!(b); end
    end
    bucket,ibucket = nothing,nothing
    while true
        iter = (state === nothing ? iterate(d.src) : iterate(d.src, state))
        if iter === nothing
            ibucket = findfirst(x -> !isempty(x), d.buckets)
            bucket = (ibucket === nothing ? nothing : d.buckets[ibucket])
            break
        else
            sent, state = iter
            if length(sent) > d.maxlength || length(sent) == 0; continue; end
            ibucket = min(1 + (length(sent)-1) ÷ d.bucketwidth, length(d.buckets))
            bucket = d.buckets[ibucket]
            push!(bucket, sent)
            if length(bucket) === d.batchsize; break; end
        end
    end
    if bucket === nothing; return nothing; end
    batchsize = length(bucket)
    maxlen = maximum(length.(bucket))
    batch = fill(d.src.vocab.eos, batchsize, maxlen + 1)
    for i in 1:batchsize
        batch[i, 1:length(bucket[i])] = bucket[i]
    end
    empty!(bucket)
    return batch, state
end

In [None]:
train_batches = collect(LMData(train_sentences))
valid_batches = collect(LMData(valid_sentences))
test_batches = collect(LMData(test_sentences))
tst_inst = first(test_batches)

In [None]:
#Embed
struct Embed
    w
end

function Embed(vocabsize::Int, embedsize::Int)
    Embed(param(embedsize, vocabsize))
end

# Attention here: Format has been changed for now
function (l::Embed)(x)
    permutedims(l.w[:, x],(2,1,3)) # Format has been changed to T,E,B
end

In [21]:
struct Memory
    w # for transformation
end

struct MHA
    head_num::Int # Number of heads, will be given as a hyperparameter
    dim::Int      # Feature space dimension, will be given as a hyperparameter
    Wq
    Wk
    Wv
    Wo
    k::Int        # Projection dimension, required hyperparameter for MHLA
    E             # Projection matrix E, required hyperparameter for MHLA
    F             # Projection matrix F, required hyperparameter for MHLA
end  

In [22]:
struct Embed
    w
end

function Embed(vocabsize::Int, embedsize::Int)
    Embed(param(embedsize, vocabsize))
end

function (l::Embed)(x)
    permutedims(l.w[:, x],(2,1,3)) # Format has been changed to T,E,B
end

In [33]:
function (mha::MHA)(cell, mem; linear::Bool=false)
    T,E,B = size(cell)    
    @assert size(mha.Wq)==size(mha.Wk)==size(mha.Wv)==(E,mha.dim,B,mha.head_num)
    @assert size(mha.Wo)==(mha.head_num*mha.dim,E,B)
    linear && @assert size(mha.E)==size(mha.F)==(mha.k,T,B,mha.head_num)
    
    K, V = mem # T,E,B
    Q = cell # no dimensionality mismatch for now
    

    # Seek for an efficient strategy
    head_container = zeros(T,mha.dim,B,mha.head_num)# T,d,B,h
    head(A,B,C) =  bmm(softmax(bmm(A,permutedims(B,(2,1,3)))*sqrt(mha.dim),dims=2),C)
    
    if linear
        for i in 1:mha.head_num
            head_container[:,:,:,i] = head(bmm(Q,mha.Wq[:,:,:,i]),bmm(bmm(mha.E[:,:,:,i],K),mha.Wk[:,:,:,i]),bmm(bmm(mha.F[:,:,:,i],V),mha.Wv[:,:,:,i]))
        end
    else
        for i in 1:mha.head_num
            head_container[:,:,:,i] = head(bmm(Q,mha.Wq[:,:,:,i]),bmm(K,mha.Wk[:,:,:,i]),bmm(V,mha.Wv[:,:,:,i]))
        end
    end
    
    bmm(reshape(head_container,(T,:,B)),mha.Wo)
end

In [34]:
E,T,d,h,B,k = 512,250,64,8,64,32
Wq = rand(E,d,B,h)
Wk = rand(E,d,B,h)
Wv = rand(E,d,B,h)
Wo = rand(h*d,E,B)
Em = rand(k,T,B,h)
Fm = rand(k,T,B,h)
attn = MHA(h,d,Wq,Wk,Wv,Wo,k,Em,Fm)

MHA(8, 64, [0.4697252274922721 0.7597540366461784 … 0.45564968245215987 0.7720846809546009; 0.9462852937281017 0.4692137080691037 … 0.5873350827659662 0.6040730779662242; … ; 0.6964578625975812 0.7493350937976635 … 0.05262919126853527 0.7827692652668437; 0.2952634779812848 0.9955923015982266 … 0.855155149477264 0.6701327698110005]

[0.46787324340659997 0.6738000118887386 … 0.17051856213808936 0.39729174288537883; 0.7585458006409989 0.3342266797006497 … 0.11521431814075767 0.4354628997014167; … ; 0.32018890260395083 0.2947628504601354 … 0.25719743210609236 0.08106545266559517; 0.030271905699194246 0.4155263670708491 … 0.93606964588136 0.5172529845179552]

[0.06759174357464981 0.34615776192781866 … 0.5262958328309053 0.710337507190568; 0.28656517204639087 0.14908083816041895 … 0.24537482387741227 0.8482059491270713; … ; 0.4165574575203006 0.2602165911718586 … 0.4266886090118298 0.5712669017174383; 0.924028511251985 0.6067932375754128 … 0.19979835183524663 0.6486756200267993]

...

[0.217

In [35]:
K = rand(T,E,B)
Q = rand(T,E,B)
V = rand(T,E,B)

250×512×64 Array{Float64,3}:
[:, :, 1] =
 0.771289   0.903463   0.931764    …  0.0109122  0.707378   0.432987
 0.205909   0.812696   0.782386       0.648454   0.891587   0.517512
 0.541777   0.374176   0.567321       0.903073   0.936917   0.64889
 0.682265   0.12744    0.123618       0.673131   0.74212    0.133629
 0.691177   0.340256   0.160287       0.42779    0.0931274  0.751539
 0.0359316  0.964522   0.0786498   …  0.236624   0.475669   0.0875776
 0.935546   0.398667   0.425559       0.0416409  0.987103   0.0762924
 0.690947   0.501275   0.178543       0.6946     0.399692   0.422608
 0.890112   0.216404   0.901426       0.344381   0.253609   0.915638
 0.944721   0.801658   0.467741       0.0184369  0.727301   0.65053
 0.950199   0.903615   0.49638     …  0.413981   0.212418   0.342246
 0.435201   0.899393   0.0915574      0.969629   0.0473321  0.27384
 0.732873   0.11807    0.46876        0.361749   0.522907   0.201227
 ⋮                                 ⋱             ⋮          
 0

In [36]:
attn(Q,(K,V))

250×512×64 Array{Float64,3}:
[:, :, 1] =
 32596.5  31285.9  31595.5  32794.9  …  31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9  …  31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9  …  31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  32861.4  32516.1
 32596.5  31285.9  31595.5  32794.9     31065.9  32651.7  3

In [37]:
attn(Q,(K,V);linear=true)

250×512×64 Array{Float64,3}:
[:, :, 1] =
 4.10674e6  3.93584e6  3.97222e6  …  4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6  …  4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6  …  4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 4.10674e6  3.93584e6  3.97222e6     4.10755e6  4.12921e6  4.08851e6
 ⋮                                ⋱             ⋮          
 4

In [None]:
function (mem::Memory)(cell)
    permutedims!(val, cell, (1, 3, 2))
    mmul(mem.w, val), val
end

# TODO: Refer comp 542
mmul(w, x) = (
    w == 1 ? x :
    w == 0 ? 0 : reshape(w * reshape(x, size(x, 1), :), (:, size(x)[2:end]...))
);

In [20]:
typeof(K)

Array{Float64,3}