In [1]:
import Pkg
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet
using Statistics
using LinearAlgebra
using Distributions

In [3]:
include("../latentplan/models/common.jl")
using .Common: Linear, LayerNorm, Chain, Dropout, GELU, softmax

In [4]:
macro size(e::Union{Symbol, Expr})
    quote
        println("###########")
        println($(string(e)), " = ")
        display($(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 1 method)

# CausalSelfAttention

In [5]:
struct CausalSelfAttention; 
    key; query; value; proj; mask;
    attn_drop; resid_drop;
    n_head;
    
    function CausalSelfAttention(config)
        key = Linear(config["n_embd"], config["n_embd"])
        query = Linear(config["n_embd"], config["n_embd"])
        value = Linear(config["n_embd"], config["n_embd"])
        proj = Linear(config["n_embd"], config["n_embd"])
        
        mask = Matrix(UpperTriangular(ones(config["block_size"],config["block_size"])))
        if haskey(config, "action_dim")
            joined_dim = config["observation_dim"] + config["action_dim"] + 2
            mask[joined_dim:joined_dim:end,:, :, :] .= 0
        end
        new(key,query,value,proj,mask, config["attn_pdrop"], config["resid_pdrop"], config["n_head"])
    end
end

In [6]:
function (c::CausalSelfAttention)(x)
    C, T, B = size(x)

    k = permutedims(reshape(c.key(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    q = permutedims(reshape(c.query(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    v = permutedims(reshape(c.value(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    
    # (T, hs, nh, B) x (hs, T, nh, B) -> (T, T, nh, B)
    att = bmm(permutedims(k, (2,1,3,4)), q) .* (1 / sqrt(size(k, 1)))
    att[c.mask[1:T,1:T] .== 0, :, :] .= -Inf
    att = softmax(att, dims=1)
    att_drop = dropout(att, c.attn_drop)
    # (hs, T, nh, B) x (T, T, nh, B)  -> (hs, T, nh, B)
    y = bmm(v, att_drop)
    # (C, T, B)
    y = reshape(permutedims(y, (1, 3, 2, 4)), (C, T, B)) # re-assemble all head outputs side by side
    # output projection
    y = dropout(c.proj(y), c.resid_drop)
    return y
end

# CSA

In [7]:
config = Dict("n_embd" => 8, "block_size" => 12, "action_dim" => 1, "observation_dim"=> 2, "attn_pdrop"=>0.1, "resid_pdrop"=>0.1, "n_head"=>2)

Dict{String, Real} with 7 entries:
  "resid_pdrop"     => 0.1
  "attn_pdrop"      => 0.1
  "n_head"          => 2
  "block_size"      => 12
  "action_dim"      => 1
  "observation_dim" => 2
  "n_embd"          => 8

# Block

In [8]:
struct Block
    ln1::LayerNorm;
    ln2::LayerNorm;
    attn::CausalSelfAttention;
    mlp::Chain;

    function Block(config)
        ln1 = LayerNorm(config["n_embd"])
        ln2 = LayerNorm(config["n_embd"])
        attn = CausalSelfAttention(config)
        mlp = Chain(
            Linear(config["n_embd"], 4 * config["n_embd"]), 
            GELU(),
            Linear(4 * config["n_embd"], config["n_embd"]),
            Dropout(config["resid_pdrop"])
        )
        new(ln1,ln2,attn,mlp)
    end
end

function (b::Block)(x)
    x = x .+  b.attn(b.ln1(x))
    x = x .+ b.mlp(b.ln2(x))
    return x
end

In [9]:
block = Block(config)
x = rand(8,2,3)
block(x)

8×2×3 Array{Float64, 3}:
[:, :, 1] =
 -0.252795   0.35683
 -0.217497  -0.0517943
  1.0853     0.232041
 -1.5641    -1.46864
 -0.652647  -0.590191
  1.70763    1.53705
  1.84816    2.09661
  0.198716   0.466664

[:, :, 2] =
  0.922732   0.586288
  0.338404   0.419671
  0.649373   0.776267
  1.17186   -0.0524429
  0.830902   1.05332
  0.419295   0.400618
 -0.369936   0.356866
  0.109032   0.283863

[:, :, 3] =
  0.202959  -0.0982616
 -0.175399   0.269452
  0.599338   1.26965
 -2.35665   -0.942176
 -0.773963   0.204239
  1.07299    0.331395
  2.88936    0.929908
 -1.0094     0.635357