In [47]:
begin
	using Flux
	using Flux: onehot
	using Flux: gradient
	using Flux.Optimise: update!
	using Flux: onecold
	using CUDA
	using Flux: @functor
	using NNlib
	using Tullio
	using Plots 
	using Transformers
	using Transformers.Basic 
	using Transformers.Datasets: batched
	using Transformers: Abstract3DTensor, Container, epsilon, batchedmul, batched_triu!
	using CUDAKernels, KernelAbstractions 
	enable_gpu(true)
end

ErrorException: CUDA not functional

In [48]:
begin
	labels = map(string, 1:10)
	startsym = "11"
	endsym = "12"
	unksym = "0"
	labels = [unksym, startsym, endsym, labels...]
	vocab = Vocabulary(labels, unksym)
end

Vocabulary{String}(13, unk=0)

In [49]:
#function for generate training datas 
sample_data() = (d = map(string, rand(1:10, 10)); (d,d))

sample_data (generic function with 1 method)

In [50]:
#function for adding start & end symbol
preprocess(x) = [startsym, x..., endsym]

preprocess (generic function with 1 method)

In [51]:
begin
    @show sample_ex = preprocess.(sample_data())
    @show encoded_sample_ex = vocab(sample_ex[1]) #use Vocabulary to encode the training data, +3
    end

sample_ex = preprocess.(sample_data()) = (["11", "6", "5", "6", "9", "7", "6", "4", "5", "4", "5", "12"], ["11", "6", "5", "6", "9", "7", "6", "4", "5", "4", "5", "12"])
encoded_sample_ex = vocab(sample_ex[1]) = [2, 9, 8, 9, 12, 10, 9, 7, 8, 7, 8, 3]


12-element Vector{Int64}:
  2
  9
  8
  9
 12
 10
  9
  7
  8
  7
  8
  3

In [52]:
sample = preprocess.(sample_data())

(["11", "6", "9", "2", "6", "7", "7", "9", "9", "9", "2", "12"], ["11", "6", "9", "2", "6", "7", "7", "9", "9", "9", "2", "12"])

In [53]:
encoded_sample = vocab(sample[1])

12-element Vector{Int64}:
  2
  9
 12
  5
  9
 10
 10
 12
 12
 12
  5
  3

In [54]:
#define a Word embedding layer which turn word index to word vector
embed = Embed(512, length(vocab)) |> gpu

Embed(512)

In [55]:
#define a position embedding layer mentioned above
pe = PositionEmbedding(512) |> gpu

PositionEmbedding(512)

In [56]:
#wrapper to get embedding
function embedding(x)
    we = embed(x, inv(sqrt(512)))
    e = we .+ pe(we)
    return e
  end

embedding (generic function with 1 method)

In [57]:
abstract type AbstractAttention end

In [58]:
#=begin
    create_atten_mask1(T::Type, score::AbstractArray, ::Nothing, future) = create_atten_mask1(T, score, fill!(similar(score, size(score,1), size(score, 2), 1), one(T)), future)
    function create_atten_mask1(T::Type, score::AbstractArray, _mask::AbstractArray, future::Bool=false)
      #size(mask) == (q, k, n, b)
    
      # ql, kl = size(mask)
      mask = copy(_mask)
    
      maskval = convert(T, -1e9)
      !future && batched_triu!(mask, 0)
      mask .= (1 .- mask) .* maskval
      return mask
    end
    end

ErrorException: syntax: incomplete: unterminated multi-line comment #= ... =#

In [59]:
#Flux.@nograd create_atten_mask1

In [60]:
begin

	struct MultiheadAttention_Linear{Q<:Dense, K<:Dense, V<:Dense, O<:Dense, DP<:Dropout} <: AbstractAttention
    head::Int
    future::Bool
    iqproj::Q
    ikproj::K
    ivproj::V
    oproj::O
    drop::DP
end

	Flux.functor(mh::MultiheadAttention_Linear) = (mh.iqproj, mh.ikproj, mh.ivproj, mh.oproj), m -> MultiheadAttention_Linear(mh.head, mh.future, m..., mh.drop)

	MultiheadAttention_Linear(head::Int,
                   is::Int,
                   hs::Int,
                   os::Int;
                   future::Bool=true, pdrop = 0.1) = MultiheadAttention_Linear(head,
                                                                        future,
                                                                        Dense(is, hs*head),
                                                                        Dense(is, hs*head),
                                                                        Dense(is, hs*head),
                                                                        Dense(hs*head, os),
                                                                        Dropout(pdrop),
                                                                        )


function Base.show(io::IO, mh::MultiheadAttention_Linear)
    hs = div(size(mh.iqproj.weight)[1], mh.head)
    is = size(mh.iqproj.weight)[end]
    os = size(mh.oproj.weight)[1]

    print(io, "MultiheadAttention(")
    print(io, "head=$(mh.head), ")
    print(io, "head_size=$(hs), ")
    print(io, "$(is)=>$(os)")

    if Flux.istraining()
        print(io, ", dropout=$(mh.drop.p))")
    else
        print(io, ")")
    end
end

	function (mh::MultiheadAttention_Linear)(query::A1,
    key::A2,
    value::A3;
    mask=nothing) where {T,
                         A1 <: Abstract3DTensor{T},
                         A2 <: Abstract3DTensor{T},
                         A3 <: Abstract3DTensor{T}}
qs = size(query)
ks = size(key)
vs = size(value)

#size(ipq) == (h, q_seq_len, batch)
ipq = @toNd mh.iqproj(query)
ipk = @toNd mh.ikproj(key)
ipv = @toNd mh.ivproj(value)

h = size(ipq, 1)
hs = div(h, mh.head)

#size(ipq) == (hs, q_seq_len, head, batch)
ipq = permutedims(reshape(ipq, hs, mh.head, qs[2], qs[3]), [1, 3, 2, 4])
ipk = permutedims(reshape(ipk, hs, mh.head, ks[2], ks[3]), [1, 3, 2, 4])
ipv = permutedims(reshape(ipv, hs, mh.head, vs[2], vs[3]), [1, 3, 2, 4])

#size(ipq) == (hs, q_seq_len, head * batch)
ipq = reshape(ipq, hs, qs[2], :)
ipk = reshape(ipk, hs, ks[2], :)
ipv = reshape(ipv, hs, vs[2], :)
		
atten = attention1(ipq,ipk,ipv,
mask,
mh.future,
mh.drop,mh.head)

atten = permutedims(reshape(atten, hs, qs[2], mh.head, qs[3]), [1, 3, 2, 4]) #size(atten) == (hs, head, ql, b)
atten = reshape(atten, h, qs[2], qs[3]) #size(atten) == (h, ql, b)

out = @toNd mh.oproj(atten)
out #size(out) == (h, q_seq_len, batch)
end

function (mh::MultiheadAttention_Linear)(query::A1,
    key::A2,
    value::A3;
    mask=nothing) where {T,
                         A1 <: AbstractMatrix{T},
                         A2 <: AbstractMatrix{T},
                         A3 <: AbstractMatrix{T}}

# size(query) == (dims, seq_len)
ipq = mh.iqproj(query)
ipk = mh.ikproj(key)
ipv = mh.ivproj(value)

#ipq = cu(ipq)
#ipk = cu(ipk)
#ipv = cu(ipv)
	
h = size(ipq)[1] #h == hs * head
hs = div(h, mh.head)

#size(hq) == (hs, seq_len, head)
hq = permutedims(reshape(ipq, hs, mh.head, :), [1, 3, 2])
hk = permutedims(reshape(ipk, hs, mh.head, :), [1, 3, 2])
hv = permutedims(reshape(ipv, hs, mh.head, :), [1, 3, 2])

atten = attention1(hq, hk, hv,
mask,
mh.future,
mh.drop,mh.head)

# size(atten) == (head*hs, seq_len)
atten = reshape(permutedims(atten, [1, 3, 2]), h, :)

mh.oproj(atten)
end
end

In [61]:
#=begin
    function apply_mask1(score, mask)
        s = size(score)
        ms = size(mask)
        bxn = s[end]
        b = ms[end]
        if bxn == b || b == 1
          return score .+ mask
        else
          return reshape(reshape(score, s[1:end-1]..., :, b) .+
                         reshape(mask, ms[1:end-1]..., 1, b), s)
        end
      end
      
      apply_mask1(score::AbstractArray{T}, ::Nothing, future) where T = future ? score : apply_mask1(score, create_atten_mask1(T, score, nothing, future))
      apply_mask1(score::AbstractArray{T}, mask, future) where T = apply_mask(score, create_atten_mask1(T, score, mask, future))
    end

ErrorException: syntax: incomplete: unterminated multi-line comment #= ... =#

In [62]:
# of type float (to allow for integer inputs)
begin
    function oftf(x, y)
        oftype(float(x), y)
    end
    function nelu(x, α=1)
        ifelse(x ≥ 0, float(x)+1, @fastmath oftf(x, α) * (exp(x) - 1)+1)
    end
    end

nelu (generic function with 2 methods)

In [63]:
x = embedding(encoded_sample)

512×12 Matrix{Float32}:
  0.518426   -0.360292    -0.906279   -0.628539   …   0.0295304   0.837439
  0.848443    0.90189      0.176757   -0.628054      -0.89795    -0.838942
  0.601532   -0.318341    -0.957628   -0.708553      -0.329715    0.520984
  0.764399    0.926681     0.302408   -0.438326      -0.615051   -0.959725
  0.570712   -0.335453    -0.961979   -0.861065      -0.713385    0.216894
  0.798447    0.921795     0.489404   -0.457053   …  -0.457616   -0.869925
  0.662358   -0.242714    -0.891689   -0.869533      -0.869262   -0.254222
  0.780816    0.954626     0.48597    -0.38974       -0.173683   -0.781652
  0.629271   -0.120007    -0.857341   -0.915282      -0.961683   -0.632615
  0.811738    0.963074     0.588408   -0.167424       0.264717   -0.597176
  ⋮                                               ⋱   ⋮          
 -0.0152957  -0.00441111   0.0492921   0.0180724      0.0188807  -0.0406102
  1.07367     0.938749     0.997868    0.944274       0.944273    0.960846
 -0.02516

In [64]:
function splitHeads(x, batch_size, head, depth)
    x = reshape(x, (batch_size, :, head, depth))
end

splitHeads (generic function with 1 method)

In [65]:
opt = ADAM(1e-4) 

ADAM(0.0001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [66]:
function attention1(query::A1,
    key::A2,
    value::A3,
    mask, future::Bool,
    dropout, head) where {T,
                    A1 <: Abstract3DTensor{T},
                    A2 <: Abstract3DTensor{T},
                    A3 <: Abstract3DTensor{T}}
#size(query) == (dims, {q,k}_seq_len, batch) == size(key) == size(value)
#size(ipq) == (hs, q_seq_len, head * batch)
#size(score) == (k_seq_len, q_seq_len, batch)

#size(query) == (batch, {q,k}_seq_len, dims) == size(key) == size(value)

query = permutedims(query,(3,2,1))
key = permutedims(key,(3,2,1))
value = permutedims(value,(3,2,1))

qs = size(query)
#batch = 8, seq_len = 12, d_model = 64
ks = size(key)
vs = size(value)

d_model = qs[3]

#depth = div(d_model, n_heads)
depth = div(d_model, head)   

# size(key) == (batch_size,  seq_len_k, num_heads, depth_k) (m,j,h,d)
query = nelu.(splitHeads(query, qs[1], head, depth))
key = nelu.(splitHeads(key, ks[1], head, depth))
value = splitHeads(value, vs[1], head, depth)

# size(k_v) == (batch_size, depth_k, depth_v, seq_len_v)
@tullio k_v[m,d,e,h] := key[m,j,h,d]*value[m,j,h,e]


#padmask = padding_mask(key)
#key = key .* padmask

k_reduced = dropdims(sum(key, dims=2), dims=2) .+ 1e-8

#k_reduced = k_reduced .+ 1e-8

@tullio z_1[m,l,h] := query[m,l,h,d]*k_reduced[m,h,d]

# size(z) == (batch_size, num_heads, seq_len_q)
z = 1 ./ z_1 # ...

# size(output) == (batch_size,len_q, heads, depth_v)
@tullio output[m,l,h,e] := query[m,l,h,d]*k_v[m,d,e,h]*z[m,l,h]

#output = reshape(output, (qs[1], qs[2], mh.head*depth))
# size(output) == (batch_size,len_q, d_model)
output = reshape(output, (qs[1], :, head*depth))

output = permutedims(output,(3,2,1))
#(dims, q_seq_len, batch)
end

attention1 (generic function with 1 method)

In [67]:
abstract type AbstractTransformer end

In [68]:
begin
    struct PwFFN{Di<:Dense, Do<:Dense}
        din::Di
        dout::Do
    end
    
    @functor PwFFN
    
    
    "just a wrapper for two dense layer."
    function PwFFN(size::Int, h::Int, act = nelu)
        PwFFN(
        Dense(size, h, act),
        Dense(h, size)
    )
    end
    function (pw::PwFFN)(x::AbstractMatrix)
      #size(x) == (dims, seq_len)
      pw.dout(pw.din(x))
    end
    
    function (pw::PwFFN)(x::A) where {T, N, A<:AbstractArray{T, N}}
      new_x = reshape(x, size(x, 1), :)
      y = pw(new_x)
      return reshape(y, Base.setindex(size(x), size(y, 1), 1))
    end
    end

In [69]:
begin
	struct Transformer1{MA<:MultiheadAttention_Linear, LA<:LayerNorm, P<:PwFFN, LP<:LayerNorm, DP<:Dropout} <: AbstractTransformer
    mh::MA
    mhn::LA
    pw::P
    pwn::LP
    drop::DP
end

@functor Transformer1


"""
    Transformer(size::Int, head::Int, ps::Int;
                future::Bool = true, act = relu, pdrop = 0.1)
    Transformer(size::Int, head::Int, hs::Int, ps::Int;
                future::Bool = true, act = relu, pdrop = 0.1)  

Transformer layer.

`size` is the input size. if `hs` is not specify, use `div(size, head)` as the hidden size of multi-head attention. 
`ps` is the hidden size & `act` is the activation function of the positionwise feedforward layer. 
When `future` is `false`, the k-th token can't see the j-th tokens where j > k. `pdrop` is the dropout rate.
"""

function Transformer1(size::Int, head::Int, ps::Int; future::Bool = true, act = relu, pdrop = 0.1)  
    rem(size, head) != 0 && error("size not divisible by head")
    Transformer1(size, head, div(size, head), ps;future=future, act=act, pdrop=pdrop)
end

Transformer1(size::Int, head::Int, hs::Int, ps::Int; future::Bool = true, act = relu, pdrop = 0.1) = Transformer1(
    MultiheadAttention_Linear(head, size, hs, size; future=future, pdrop=pdrop),
    LayerNorm(size),
    PwFFN(size, ps, act),
    LayerNorm(size),
    Dropout(pdrop),     
)

function (t::Transformer1)(x::A, mask=nothing) where {T, N, A<:AbstractArray{T, N}}
    dropout = t.drop
    a = t.mh(x, x, x; mask=mask)
    a = dropout(a)
    res_a = x + a
    res_a = t.mhn(res_a)
    pwffn = t.pw(res_a)
    pwffn = dropout(pwffn)
    res_pwffn = res_a + pwffn
    res_pwffn = t.pwn(res_pwffn)
    res_pwffn
end

function Base.show(io::IO, t::Transformer1) 
    hs = div(size(t.mh.iqproj.weight)[1], t.mh.head)
    h, ps = size(t.pw.dout.weight)

    print(io, "Transformer(")
    print(io, "head=$(t.mh.head), ")
    print(io, "head_size=$(hs), ")
    print(io, "pwffn_size=$(ps), ")
    print(io, "size=$(h)")
    if Flux.istraining()
        print(io, ", dropout=$(t.drop.p))")
    else
        print(io, ")")
    end
end
end

In [70]:
begin
	struct TransformerDecoder1{MA<:MultiheadAttention_Linear, LA<:LayerNorm,
        IMA<:MultiheadAttention_Linear, ILA<:LayerNorm,
        P<:PwFFN, LP<:LayerNorm, DP<:Dropout} <: AbstractTransformer
                mh::MA
                mhn::LA
                imh::IMA
                imhn::ILA
                pw::P
                pwn::LP
                drop::DP
end

@functor TransformerDecoder1

"""
TransformerDecoder(size::Int, head::Int, ps::Int; act = relu, pdrop = 0.1)
TransformerDecoder(size::Int, head::Int, hs::Int, ps::Int; act = relu, pdrop = 0.1)

TransformerDecoder layer. Decode the value from a Encoder.

`size` is the input size. if `hs` is not specify, use `div(size, head)` as the hidden size of multi-head attention. 
`ps` is the hidden size & `act` is the activation function of the positionwise feedforward layer. 
`pdrop` is the dropout rate.
"""
function TransformerDecoder1(size::Int, head::Int, ps::Int; act = relu, pdrop = 0.1)
rem(size, head) != 0 && error("size not divisible by head")
TransformerDecoder1(size, head, div(size, head), ps; act=act, pdrop=pdrop)
end

TransformerDecoder1(size::Int, head::Int, hs::Int, ps::Int; act = relu, pdrop = 0.1) = TransformerDecoder1(
MultiheadAttention_Linear(head, size, hs, size; future=false, pdrop=pdrop),
LayerNorm(size),
MultiheadAttention_Linear(head, size, hs, size; future=true, pdrop=pdrop), 
LayerNorm(size),
PwFFN(size, ps, act),
LayerNorm(size),
Dropout(pdrop),
)

function (td::TransformerDecoder1)(x::AbstractArray{T,N}, m, mask=nothing) where {T,N}
dropout = td.drop
a = td.mh(x,x,x)
a = dropout(a)
res_a = x + a
res_a = td.mhn(res_a)

ia = td.imh(res_a, m, m, mask=mask)
ia = dropout(ia)
res_ia = res_a + ia
res_ia = td.imhn(res_ia)

pwffn = td.pw(res_ia)
pwffn = dropout(pwffn)
res_pwffn = res_ia + pwffn
res_pwffn = td.pwn(res_pwffn)
res_pwffn
end

function Base.show(io::IO, td::TransformerDecoder1)
hs = div(size(td.imh.iqproj.weight)[1], td.imh.head)
h, ps = size(td.pw.dout.weight)

print(io, "TransformerDecoder(")
print(io, "head=$(td.mh.head), ")
print(io, "head_size=$(hs), ")
print(io, "pwffn_size=$(ps), ")
print(io, "size=$(h)")
if Flux.istraining()
print(io, ", dropout=$(td.drop.p))")
else
print(io, ")")
end
end
end

In [71]:
#define 4 layer of transformer

encode_t1 = Transformer1(512, 8, 64, 1024) |> gpu

Transformer(head=8, head_size=64, pwffn_size=1024, size=512)

In [72]:
encode_t2 = Transformer1(512, 8, 64, 1024) |> gpu

Transformer(head=8, head_size=64, pwffn_size=1024, size=512)

In [73]:
encode_t3 = Transformer1(512, 8, 64, 1024) |> gpu

Transformer(head=8, head_size=64, pwffn_size=1024, size=512)

In [74]:
encode_t4 = Transformer1(512, 8, 64, 1024) |> gpu

Transformer(head=8, head_size=64, pwffn_size=1024, size=512)

In [75]:
#define 4 layer of transformer decoder

decode_t1 = TransformerDecoder1(512, 8, 64, 1024, act=nelu) |> gpu

TransformerDecoder(head=8, head_size=64, pwffn_size=1024, size=512)

In [76]:
decode_t2 = TransformerDecoder1(512, 8, 64, 1024, act=nelu) |> gpu

TransformerDecoder(head=8, head_size=64, pwffn_size=1024, size=512)

In [77]:
decode_t3 = TransformerDecoder1(512, 8, 64, 1024, act=nelu) |> gpu

TransformerDecoder(head=8, head_size=64, pwffn_size=1024, size=512)

In [78]:
decode_t4 = TransformerDecoder1(512, 8, 64, 1024, act=nelu) |> gpu

TransformerDecoder(head=8, head_size=64, pwffn_size=1024, size=512)

In [79]:
#define the layer to get the final output probabilities

linear = Positionwise(Dense(512, length(vocab)), logsoftmax) |> gpu

Positionwise(Dense(512, 13), logsoftmax)

In [80]:
function encoder_forward(x)
    e = embedding(x)
    t1 = encode_t1(e)
    t2 = encode_t2(t1)
    t3 = encode_t2(t2)
    t4 = encode_t2(t3)
    return t2
end

encoder_forward (generic function with 1 method)

In [81]:
function decoder_forward(x, m)
    e = embedding(x)
    t1 = decode_t1(e, m)
    t2 = decode_t2(t1, m)
    t3 = decode_t3(t2, m)
    t4 = decode_t4(t3, m)
    p = linear(t4)
    return p
  end

decoder_forward (generic function with 1 method)

In [82]:
enc = encoder_forward(encoded_sample);

In [83]:
probs = decoder_forward(encoded_sample, enc);

In [84]:
function smooth(et)
    sm = fill!(similar(et, Float32), 1e-6/size(embed, 2))
    p = sm .* (1 .+ -et)
    label = p .+ et .* (1 - convert(Float32, 1e-6))
    label
end

smooth (generic function with 1 method)

In [85]:
Flux.@nograd smooth

In [86]:
#define loss function

function loss(x, y)
    label = onehot(vocab, y) #turn the index to one-hot encoding
    label = smooth(label) #perform label smoothing
    enc = encoder_forward(x)
    probs = decoder_forward(y, enc)
    # logcrossentropy used by Katharopoulos et al.
    l = logcrossentropy(label[:, 2:end, :], probs[:, 1:end-1, :])  
    return l
  end

loss (generic function with 1 method)

In [87]:
#collect all the parameters

ps = params(embed, pe, encode_t1, encode_t2, encode_t3, encode_t4, decode_t1, decode_t2, decode_t3, decode_t4, linear);

In [88]:
#define training loop

function train!()
    @info "start training"
    losses =[]
    for i = 1:3000
      data = batched([sample_data() for i = 1:32]) #create 32 random sample and batched
      x, y = preprocess.(data[1]), preprocess.(data[2])
      x, y = vocab(x), vocab(y) #encode the data
      x, y = todevice(x, y) #move to gpu
      grad = gradient(()->loss(x, y), ps)
      if i % 8 == 0
          l = loss(x, y)
          println("loss = $l")
      end
      update!(opt, ps, grad)
      push!(losses, loss(x,y))
    end
    return losses
  end

train! (generic function with 1 method)

In [43]:
losses = train!()

InterruptException: InterruptException:

In [44]:
plot(vcat(losses[2:3000]...), label="Losses") # plot losses

┌ Info: start training
└ @ Main /Users/johannes/Documents/GitHub/DM2022-LinearTransformers/test/Pluto_GPU.ipynb:4


UndefVarError: UndefVarError: losses not defined

In [45]:
function translate(x)
    ix = todevice(vocab(preprocess(x)))
    seq = [startsym]

    enc = encoder_forward(ix)

    len = length(ix)
    for i = 1:2len
        trg = todevice(vocab(seq))
        dec = decoder_forward(trg, enc)
        #move back to gpu due to argmax wrong result on CuArrays
        ntok = onecold(collect(dec), labels)
        push!(seq, ntok[end])
        ntok[end] == endsym && break
    end
  seq[2:end-1]
end

translate (generic function with 1 method)

In [46]:
translate(map(string, [5,5,6,6,1,12,3,4,6,]))

21-element Vector{String}:
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 ⋮
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"
 "3"