In [1]:
import Pkg;
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet

In [None]:
include("../latentplan/models/common.jl")
using .Common: Linear

In [46]:
macro size(e::Union{Symbol, Expr})
    quote
        println("###########")
        println($(string(e)), " = ")
        display($(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 1 method)

# MaxPool1d

In [32]:
struct MaxPool1d
    window;
    stride;
    
    function MaxPool1d(window, stride)
        new(window, stride)
    end
end

(m::MaxPool1d)(x) = begin 
    pool_results = pool(reshape(x, size(x, 1), 1, 1); window=m.window, stride=m.stride)[:,1,1]
    reshape(pool_results, size(pool_results, 1), size(x)[2:end]...) 
end

In [33]:
mp = MaxPool1d(4,2)

MaxPool1d(4, 2)

In [35]:
mp(x)

4×1 Matrix{Int64}:
  4
  6
  8
 10

# MultiHeadAttention

In [37]:
struct ScaledDotProductAttention; end

(s::ScaledDotProductAttention)(q,k,v) = begin
    dk = size(k, 1)
    scores = bmm(permutedims(k, (2,1,3)), q) .* (1 / sqrt(dk))
    att = softmax(scores, dims=1)
    return bmm(v, att)
end

In [38]:
struct MultiHeadAttention
    embed_dim; num_head; linear_q; linear_k; linear_v; linear_o;
    
    function MultiHeadAttention(embed_dim, num_head)
        q = Linear(embed_dim, embed_dim)
        k = Linear(embed_dim, embed_dim)
        v = Linear(embed_dim, embed_dim)
        o = Linear(embed_dim, embed_dim)
        new(embed_dim, num_head, q,k,v,o)
    end 
end

function (m::MultiHeadAttention)(q,k,v)
    q = m.linear_q(q); k = m.linear_k(k); v = m.linear_v(v)
    q = _reshape_to_heads(m, q)
    k = _reshape_to_heads(m, k)
    v = _reshape_to_heads(m, v)
    y = ScaledDotProductAttention()(q,k,v)
    y = _reshape_from_heads(m, y)
    return m.linear_o(y)
end

function _reshape_to_heads(m::MultiHeadAttention, x)
    embed_dim, seq_len, batch_size = size(x)
    head_dim = embed_dim ÷ m.num_head
    x = reshape(x, head_dim, m.num_head, seq_len, batch_size)
    x = permutedims(x, (1, 3, 2, 4))
    return reshape(x, head_dim, seq_len, m.num_head * batch_size)
end

function _reshape_from_heads(m::MultiHeadAttention, x)
    head_dim, seq_len, batch_size = size(x)
    batch_size = batch_size / m.num_head
    out_dim = head_dim * m.num_head
    x = reshape(x, head_dim, seq_len, m.num_head, batch_size)
    x = permute(x, (1, 3, 2, 4))
    return reshape(x, out_dim, seq_len, batch_size)
end
    


_reshape_from_heads (generic function with 1 method)

# Repeat interleave

In [48]:
x = reshape(Array(1:6), 3,2,1)
@size x
y = repeat(x, 2,1,1)
@size y

###########
x = 


3×2×1 Array{Int64, 3}:
[:, :, 1] =
 1  4
 2  5
 3  6

size(x) = (3, 2, 1)
###########
y = 


6×2×1 Array{Int64, 3}:
[:, :, 1] =
 1  4
 2  5
 3  6
 1  4
 2  5
 3  6

size(y) = (6, 2, 1)


In [53]:
using StatsBase
x = reshape(Array(1:6), 3,2,1)
@size x
y = repeat(x, inner=(2,1,1))
@size y

###########
x = 


3×2×1 Array{Int64, 3}:
[:, :, 1] =
 1  4
 2  5
 3  6

size(x) = (3, 2, 1)
###########
y = 


6×2×1 Array{Int64, 3}:
[:, :, 1] =
 1  4
 1  4
 2  5
 2  5
 3  6
 3  6

size(y) = (6, 2, 1)
