In [1]:
import Pkg
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet
using Statistics
using LinearAlgebra

In [3]:
include("../latentplan/models/common.jl")
using .Common: Linear, LayerNorm

In [34]:
macro size(e::Union{Symbol, Expr})
    quote
        println("###########")
        println($(string(e)), " = ")
        display($(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 1 method)

In [5]:
ln = LayerNorm(8)

LayerNorm(P(Vector{Float32}(8)), P(Vector{Float32}(8)), 1.0e-5)

# CausalSelfAttention

In [6]:
struct CausalSelfAttention; 
    key; query; value; proj; mask;
    attn_drop; resid_drop;
    n_head;
    
    function CausalSelfAttention(config)
        key = Linear(config["n_embd"], config["n_embd"])
        query = Linear(config["n_embd"], config["n_embd"])
        value = Linear(config["n_embd"], config["n_embd"])
        proj = Linear(config["n_embd"], config["n_embd"])
        
        mask = Matrix(UpperTriangular(ones(config["block_size"],config["block_size"])))
        if haskey(config, "action_dim")
            joined_dim = config["observation_dim"] + config["action_dim"] + 2
            mask[joined_dim:joined_dim:end,:, :, :] .= 0
        end
        new(key,query,value,proj,mask, config["attn_drop"], config["resid_drop"], config["n_head"])
    end
end

In [35]:
function (c::CausalSelfAttention)(x)
    C, T, B = size(x)

    k = permutedims(reshape(c.key(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    q = permutedims(reshape(c.query(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    v = permutedims(reshape(c.value(x), (C ÷ c.n_head, c.n_head, T, B)), (1, 3, 2, 4)) # hs, T, nh, B
    
    # (T, hs, nh, B) x (hs, T, nh, B) -> (T, T, nh, B)
    att = bmm(permutedims(k, (2,1,3,4)), q) .* (1 / sqrt(size(k, 1)))
    att[c.mask[1:T,1:T] .== 0, :, :] .= -Inf
    att = softmax(att, dims=1)
    att_drop = dropout(att, c.attn_drop)
    # (hs, T, nh, B) x (T, T, nh, B)  -> (hs, T, nh, B)
    y = bmm(v, att_drop)
    # (C, T, B)
    y = reshape(permutedims(y, (1, 3, 2, 4)), (C, T, B)) # re-assemble all head outputs side by side
    # output projection
    y = dropout(c.proj(y), c.resid_drop)
    return y
end

# Softmax

In [8]:
function softmax(w; dims::Int)
    probs = exp.(w)
    return probs ./ sum(probs, dims=dims)
end

softmax (generic function with 1 method)

# BMM Broadcast

In [32]:
rand(4,4)[1:2, :]

2×4 Matrix{Float64}:
 0.815119  0.216852  0.0891273  0.480011
 0.495547  0.537869  0.281774   0.20078

In [45]:
w = rand(4, 4)
x = rand(4, 3)
display(w)
display(x)

4×4 Matrix{Float64}:
 0.256093   0.604766   0.102011  0.156603
 0.0398432  0.948473   0.965657  0.60073
 0.420349   0.963705   0.238991  0.472495
 0.115347   0.0669753  0.577598  0.818909

4×3 Matrix{Float64}:
 0.762107  0.660366  0.166072
 0.449898  0.928151  0.559783
 0.404001  0.930401  0.262188
 0.888096  0.542342  0.669676

In [46]:
w * x[:,:]

4×3 Matrix{Float64}:
 0.647544  0.910273  0.512687
 1.38071   2.13089   1.19303
 1.27009   1.65066   0.988353
 1.07866   1.11986   0.756491

In [47]:
reshape(w * reshape(x, size(x)[1], :), size(w)[1], size(x)[2:end]...)

4×3 Matrix{Float64}:
 0.647544  0.910273  0.512687
 1.38071   2.13089   1.19303
 1.27009   1.65066   0.988353
 1.07866   1.11986   0.756491

# Indexing multi-array

In [20]:
input1 = rand(2,2,4,3)
mask = rand(Bool, 2,2,1,1)
input1[mask] .= 0

28-element view(::Vector{Float64}, [1, 2, 3, 4, 5, 7, 8, 10, 13, 15  …  33, 35, 37, 41, 42, 43, 45, 46, 47, 48]) with eltype Float64:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 ⋮
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

# CSA

In [36]:
config = Dict("n_embd" => 8, "block_size" => 12, "action_dim" => 1, "observation_dim"=> 2, "attn_drop"=>0.1, "resid_drop"=>0.1, "n_head"=>2)

Dict{String, Real} with 7 entries:
  "attn_drop"       => 0.1
  "n_head"          => 2
  "resid_drop"      => 0.1
  "block_size"      => 12
  "action_dim"      => 1
  "observation_dim" => 2
  "n_embd"          => 8

In [37]:
csa = CausalSelfAttention(config)

CausalSelfAttention(Linear(P(Matrix{Float32}(8,8)), P(Vector{Float32}(8)), 0), Linear(P(Matrix{Float32}(8,8)), P(Vector{Float32}(8)), 0), Linear(P(Matrix{Float32}(8,8)), P(Vector{Float32}(8)), 0), Linear(P(Matrix{Float32}(8,8)), P(Vector{Float32}(8)), 0), [1.0 1.0 … 1.0 1.0; 0.0 1.0 … 1.0 1.0; … ; 0.0 0.0 … 1.0 1.0; 0.0 0.0 … 0.0 1.0;;;;], 0.1, 0.1, 2)

In [39]:
csa.mask

12×12×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0

In [38]:
x = rand(8, 2, 3)
csa(x)

(C, T, B) = (8, 2, 3)
###########
c.key.w = 


8×8 Param{Matrix{Float32}}:
  0.502183   -0.500032   -0.595933    0.129906    …  -0.426692   -0.344343
  0.242065   -0.0392931   0.0632212   0.330281        0.535651   -0.103906
  0.5475     -0.298522    0.324135   -0.100263       -0.24116     0.0266227
  0.0197597   0.596061   -0.188364   -0.196794        0.173626    0.173176
 -0.103439   -0.169256   -0.0271037   0.00676222      0.0225345   0.533635
  0.432706   -0.148344    0.251621   -0.142082    …  -0.341953   -0.104039
 -0.278676   -0.260532    0.522699    0.080805        0.20958     0.000613149
 -0.454027    0.289852   -0.273026   -0.127625       -0.0446155  -0.231273

size(c.key.w) = (8, 8)
###########
x = 


8×2×3 Array{Float64, 3}:
[:, :, 1] =
 0.0189905  0.962656
 0.306001   0.370853
 0.153435   0.742593
 0.056983   0.560434
 0.295498   0.107924
 0.20769    0.0254402
 0.579207   0.209565
 0.737062   0.474647

[:, :, 2] =
 0.247378  0.434695
 0.533476  0.366239
 0.718791  0.621385
 0.592072  0.769893
 0.890814  0.881763
 0.40257   0.196232
 0.672078  0.108103
 0.579842  0.918691

[:, :, 3] =
 0.867609   0.54565
 0.0486885  0.117704
 0.486755   0.706388
 0.0608208  0.050485
 0.810148   0.170732
 0.610376   0.645981
 0.603534   0.548095
 0.62576    0.638079

size(x) = (8, 2, 3)
###########
k = 


4×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
 -0.632553  -0.318889
  0.354462   0.52161
 -0.34383    0.498197
  0.418085   0.127053

[:, :, 2, 1] =
  0.199957     0.0395763
 -0.319975     0.321777
  0.0470139    0.0742407
 -0.00491168  -0.692805

[:, :, 1, 2] =
 -0.821262  -0.557665
  0.755005   0.407118
 -0.544646  -0.277001
  0.436264   0.288172

[:, :, 2, 2] =
 -0.190776    0.052688
 -0.334312   -0.166456
  0.081929   -0.123649
 -0.0359094  -0.33648

[:, :, 1, 3] =
 -0.0570783  -0.29161
  0.813516    0.768314
 -0.0318093   0.228034
  0.284063    0.157159

[:, :, 2, 3] =
 -0.183318    0.0190612
  0.114448    0.236777
 -0.0636734   0.380295
 -0.229011   -0.195464

size(k) = (4, 2, 2, 3)
###########
q = 


4×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
  0.261256   0.335778
 -0.431432  -0.73105
 -0.128323  -0.447127
 -0.438559  -0.0928266

[:, :, 2, 1] =
  0.0355718  -0.0878083
 -0.406727    0.848734
  0.622597    0.756347
 -0.036315   -0.446327

[:, :, 1, 2] =
  0.204248   0.101607
 -0.517881  -0.612863
 -0.392984  -0.123933
 -0.367919  -0.281589

[:, :, 2, 2] =
  0.00487568  -0.153825
 -0.16533      0.220364
  0.872089     1.03187
 -0.701363    -0.659936

[:, :, 1, 3] =
  0.601092   0.413607
 -0.52999   -0.695777
  0.109218  -0.211396
 -0.307688   0.115212

[:, :, 2, 3] =
 -0.326472   0.292594
 -0.296382  -0.0574531
  1.19247    0.708477
 -0.700664  -0.163649

size(q) = (4, 2, 2, 3)
###########
v = 


4×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
  0.237304    0.747353
 -0.476313   -0.81039
 -0.0597774   0.122288
 -0.491176   -0.726192

[:, :, 2, 1] =
  0.66521   0.395959
 -0.125585  0.00468848
  1.08621   1.24962
  0.27952   0.432295

[:, :, 1, 2] =
  0.389731   0.811612
 -0.370353  -0.3498
  0.27309    0.331682
 -0.598401  -0.832704

[:, :, 2, 2] =
  0.640869   0.562456
 -0.708061  -0.531824
  1.8196     1.74365
  0.542511   0.188617

[:, :, 1, 3] =
  0.438712    0.146307
 -0.473129   -0.659843
  0.0635641  -0.261509
 -0.894591   -0.437657

[:, :, 2, 3] =
  0.787341  0.752244
 -0.105512  0.0356304
  1.38208   1.05293
  0.835584  0.748904

size(v) = (4, 2, 2, 3)
###########
att = 


2×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
 -0.228709  -0.1783
 -0.214     -0.361475

[:, :, 2, 1] =
  0.0833522  -0.12569
 -0.0290432   0.317499

[:, :, 1, 2] =
 -0.252609  -0.300754
 -0.160954  -0.176493

[:, :, 2, 2] =
 0.0754882  0.0319571
 0.0779693  0.02484

[:, :, 1, 3] =
 -0.278171  -0.275091
 -0.302967  -0.342643

[:, :, 2, 3] =
 0.0552297  -0.0339235
 0.257022    0.146696

size(att) = (2, 2, 2, 3)
###########
att = 


2×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
  -0.228709  -0.1783
 -Inf        -0.361475

[:, :, 2, 1] =
   0.0833522  -0.12569
 -Inf          0.317499

[:, :, 1, 2] =
  -0.252609  -0.300754
 -Inf        -0.176493

[:, :, 2, 2] =
   0.0754882  0.0319571
 -Inf         0.02484

[:, :, 1, 3] =
  -0.278171  -0.275091
 -Inf        -0.342643

[:, :, 2, 3] =
   0.0552297  -0.0339235
 -Inf          0.146696

size(att) = (2, 2, 2, 3)
###########
att = 


2×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
 1.0  0.545666
 0.0  0.454334

[:, :, 2, 1] =
 1.0  0.390981
 0.0  0.609019

[:, :, 1, 2] =
 1.0  0.468975
 0.0  0.531025

[:, :, 2, 2] =
 1.0  0.501779
 0.0  0.498221

[:, :, 1, 3] =
 1.0  0.516881
 0.0  0.483119

[:, :, 2, 3] =
 1.0  0.454968
 0.0  0.545032

size(att) = (2, 2, 2, 3)
###########
y = 


4×2×2×3 Array{Float64, 4}:
[:, :, 1, 1] =
  0.237304    0.469036
 -0.476313   -0.628096
 -0.0597774   0.0229413
 -0.491176   -0.597952

[:, :, 2, 1] =
  0.66521    0.501231
 -0.125585  -0.046246
  1.08621    1.18573
  0.27952    0.372563

[:, :, 1, 2] =
  0.389731   0.61376
 -0.370353  -0.359439
  0.27309    0.304204
 -0.598401  -0.722822

[:, :, 2, 2] =
  0.640869   0.601802
 -0.708061  -0.620256
  1.8196     1.78176
  0.542511   0.366194

[:, :, 1, 3] =
  0.438712    0.297446
 -0.473129   -0.563334
  0.0635641  -0.0934848
 -0.894591   -0.673838

[:, :, 2, 3] =
  0.787341   0.768212
 -0.105512  -0.0285846
  1.38208    1.20268
  0.835584   0.78834

size(y) = (4, 2, 2, 3)
###########
y = 


8×2×3 Array{Float64, 3}:
[:, :, 1] =
  0.237304    0.469036
 -0.476313   -0.628096
 -0.0597774   0.0229413
 -0.491176   -0.597952
  0.66521     0.501231
 -0.125585   -0.046246
  1.08621     1.18573
  0.27952     0.372563

[:, :, 2] =
  0.389731   0.61376
 -0.370353  -0.359439
  0.27309    0.304204
 -0.598401  -0.722822
  0.640869   0.601802
 -0.708061  -0.620256
  1.8196     1.78176
  0.542511   0.366194

[:, :, 3] =
  0.438712    0.297446
 -0.473129   -0.563334
  0.0635641  -0.0934848
 -0.894591   -0.673838
  0.787341    0.768212
 -0.105512   -0.0285846
  1.38208     1.20268
  0.835584    0.78834

size(y) = (8, 2, 3)
###########
y = 


8×2×3 Array{Float64, 3}:
[:, :, 1] =
  0.029056  -0.155081
 -0.839923  -0.737924
  0.277747   0.200378
 -0.206205  -0.365214
 -0.981166  -1.0262
  0.948935   1.16492
 -0.542494  -0.491691
 -0.616478  -0.639192

[:, :, 2] =
 -0.426091   -0.532622
 -1.09355    -1.06898
  0.0975944   0.103989
 -0.556545   -0.698734
 -1.26863    -1.07909
  1.31767     1.45332
 -0.393841   -0.271048
 -0.816986   -0.71484

[:, :, 3] =
 -0.161317    0.0608505
 -0.67516    -0.650905
  0.0858993   0.149302
 -0.445579   -0.230135
 -1.39184    -1.38273
  1.67554     1.37706
 -0.547968   -0.697065
 -0.804798   -0.841862

size(y) = (8, 2, 3)


8×2×3 Array{Float64, 3}:
[:, :, 1] =
  0.029056  -0.155081
 -0.839923  -0.737924
  0.277747   0.200378
 -0.206205  -0.365214
 -0.981166  -1.0262
  0.948935   1.16492
 -0.542494  -0.491691
 -0.616478  -0.639192

[:, :, 2] =
 -0.426091   -0.532622
 -1.09355    -1.06898
  0.0975944   0.103989
 -0.556545   -0.698734
 -1.26863    -1.07909
  1.31767     1.45332
 -0.393841   -0.271048
 -0.816986   -0.71484

[:, :, 3] =
 -0.161317    0.0608505
 -0.67516    -0.650905
  0.0858993   0.149302
 -0.445579   -0.230135
 -1.39184    -1.38273
  1.67554     1.37706
 -0.547968   -0.697065
 -0.804798   -0.841862