In [1]:
using Pkg;
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet
using AutoGrad
using AutoGrad: params
using Distributions
using Debugger



In [3]:
macro size(e::Union{Symbol, Expr})
    quote
        println("###########")
        println($(string(e)), " = ")
        display($(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 1 method)

# VQEmbedding and VectorQuantization

In [73]:
function vq(inputs, codebook)
    embedding_size = size(codebook,1)
    inputs_size = size(inputs)
    inputs_flatten = reshape(inputs, (embedding_size, :))
    
    codebook_sqr = transpose(sum(codebook .^ 2, dims=1))
    inputs_sqr = sum(inputs_flatten .^ 2, dims=1)
    distances = (codebook_sqr .+ inputs_sqr) + -2 * (transpose(codebook) * inputs_flatten)
    indices_cartesian = argmin(distances, dims=1)
    indices_flatten = broadcast(x->x[1], indices_cartesian)
    indices = reshape(indices_flatten, inputs_size[2:end])
    return indices
end

function vq_st(inputs, codebook)
    indices = vq(inputs, codebook)
    indices_flatten = reshape(indices, :)
    codes_flatten = codebook[:, indices]
    codes = reshape(codes_flatten, size(inputs))
    return codes, indices_flatten
end

function vq_st_codebook_backprop(inputs, codebook, output, grad_output)
    println("Backprop vq_st")
    _, indices = output
    embedding_size = size(codebook, 1)
    grad_output_flatten = reshape(grad_output[1], (embedding_size, :))
    grad_codebook = zeros(Float32, size(codebook))
    grad_codebook[:, indices] += grad_output_flatten
    return grad_codebook
end

@primitive vq_st(inputs, codebook),dy,y dy[1] vq_st_codebook_backprop(inputs,codebook, y, dy)  

In [5]:
mutable struct Embedding
    weight::Param

    function Embedding(D, K)
        weight = Param(rand(Uniform(-1/K, 1/K), (D, K)))
        new(weight)
    end
end

function (e::Embedding)(x)
    print("Embedding forward")
    weight * transpose(x)
end

In [36]:
mutable struct VQEmbedding
    embedding::Embedding

    function VQEmbedding(D, K)
        println("Creating VQEmbedding")
        embedding = Embedding(D, K)
        new(embedding)
    end
end

function (v::VQEmbedding)(z_e_x)
    println("VQEmbedding Forward")
    latents = vq(z_e_x, v.embedding.weight)
    return latents
end

function straight_through(v::VQEmbedding, z_e_x)
    z_q_x, indices = vq_st(z_e_x, v.embedding.weight)
    z_q_x_bar = v.embedding.weight[:, indices]
    return z_q_x, z_q_x_bar
end

straight_through (generic function with 3 methods)

In [37]:
inputs = Param(zeros(Float32, 4, 2, 1))
fill!(view(inputs, :, 2, :), 1)
@show inputs
codebook = VQEmbedding(4,8)
@show codebook.embedding.weight

inputs = P(Array{Float32, 3}(4,2,1))
Creating VQEmbedding
codebook.embedding.weight = P(Matrix{Float64}(4,8))


4×8 Param{Matrix{Float64}}:
 -0.0013834  -0.0537871  -0.0838884  …  -0.00517189  -0.0313119  -0.0834134
 -0.123116    0.122444   -0.119361      -0.0753161   -0.0719027  -0.0572111
  0.0247497  -0.0676382   0.0190005      0.0334828   -0.0264266   0.0787029
 -0.0171251   0.0138449  -0.0768991     -0.0673715   -0.114724   -0.0218046

In [38]:
codebook(inputs)

VQEmbedding Forward


2×1 Matrix{Int64}:
 5
 2

In [42]:
straight_through(codebook, inputs)[1]

4×2×1 Array{Float64, 3}:
[:, :, 1] =
 -0.0609056  -0.0537871
 -0.0231273   0.122444
  0.0203875  -0.0676382
 -0.0398396   0.0138449

In [67]:
function get_loss(latents)
    return latents[1][3, 2, 1] + 2
end

get_loss (generic function with 1 method)

In [44]:
loss = @diff get_loss(straight_through(codebook, inputs))

T(1.939094423267588)

In [12]:
grad(loss, inputs)

4×2×1 Array{Float64, 3}:
[:, :, 1] =
 1.0  0.0
 0.0  0.0
 0.0  0.0
 0.0  0.0

In [13]:
grad(loss, codebook.embedding.weight)

4×8 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

# One Hot Encoding

In [46]:
function one_hot(Type, indices, class_num)
    onehot = zeros(Type, class_num, size(indices)...)
    for index in CartesianIndices(indices)
        onehot[indices[index], index] = convert(Type, 1)
    end
    onehot
end

@primitive one_hot(Type, indices, class_num),dy,y nothing nothing nothing

# VQEmbeddingMovingAverage

In [82]:
mutable struct VQEmbeddingMovingAverage
    embedding
    decay::Float32
    ema_count
    ema_w

    function VQEmbeddingMovingAverage(D, K; decay=0.99f0)
        embedding = rand(Uniform(-1/K, 1/K), (D, K))
        ema_count = ones(K)
        ema_w = deepcopy(embedding)
        new(embedding, decay, ema_count, ema_w)
    end
end

function (v::VQEmbeddingMovingAverage)(z_e_x)
    vq(z_e_x, v.embedding.weight)
end

function straight_through(v::VQEmbeddingMovingAverage, z_e_x, train::Bool=true)
    D, K = size(v.embedding)
    z_q_x, indices = vq_st(z_e_x, v.embedding)
    
    if train
        encodings = one_hot(Float32, indices, K)
        v.ema_count = v.decay .* v.ema_count + (1 - v.decay) .* sum(encodings, dims=2)[:, 1]
        dw = reshape(z_e_x, (D, :)) * transpose(encodings) 
        v.ema_w = v.decay .* v.ema_w + (1 - v.decay) .* dw
        v.embedding = v.ema_w ./ reshape(v.ema_count, (1, :))
        @size v.embedding
    end

    z_q_x_bar_flatten = v.embedding[:, indices]
    z_q_x_bar = reshape(z_q_x_bar_flatten, size(z_e_x))

    return z_q_x, z_q_x_bar
end

straight_through (generic function with 3 methods)

In [89]:
inputs = Param(zeros(4, 2, 1))
fill!(view(inputs, :, 2, :), 1)
@size inputs
codebook = VQEmbeddingMovingAverage(4,8)
@size codebook.embedding

###########
inputs = 


4×2×1 Param{Array{Float64, 3}}:
[:, :, 1] =
 0.0  1.0
 0.0  1.0
 0.0  1.0
 0.0  1.0

size(inputs) = (4, 2, 1)
###########
codebook.embedding = 


4×8 Matrix{Float64}:
  0.0162179   0.0520691  -0.0154652  …  -0.0657131   0.096129   -0.0960628
 -0.113334    0.017766    0.097886      -0.0482823  -0.0650824  -0.0513131
  0.0347081  -0.0591741   0.0612955      0.0573917  -0.0136534  -0.00142566
  0.0292792   0.122822   -0.109728       0.0514431   0.0166219   0.112911

size(codebook.embedding) = (4, 8)


In [90]:
codebook.embedding

4×8 Matrix{Float64}:
  0.0162179   0.0520691  -0.0154652  …  -0.0657131   0.096129   -0.0960628
 -0.113334    0.017766    0.097886      -0.0482823  -0.0650824  -0.0513131
  0.0347081  -0.0591741   0.0612955      0.0573917  -0.0136534  -0.00142566
  0.0292792   0.122822   -0.109728       0.0514431   0.0166219   0.112911

In [91]:
loss = @diff get_loss(straight_through(codebook, inputs))
grad(loss, inputs)

###########
v.embedding = 


4×8 AutoGrad.Result{Matrix{Float64}}:
  0.0162179   0.0615484  -0.0154652  …  -0.065056    0.096129   -0.0960628
 -0.113334    0.0275883   0.097886      -0.0477994  -0.0650824  -0.0513131
  0.0347081  -0.0485824   0.0612955      0.0568178  -0.0136534  -0.00142566
  0.0292792   0.131593   -0.109728       0.0509286   0.0166219   0.112911

size(v.embedding) = (4, 8)


4×2×1 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  0.0

In [92]:
codebook.embedding

4×8 AutoGrad.Result{Matrix{Float64}}:
  0.0162179   0.0615484  -0.0154652  …  -0.065056    0.096129   -0.0960628
 -0.113334    0.0275883   0.097886      -0.0477994  -0.0650824  -0.0513131
  0.0347081  -0.0485824   0.0612955      0.0568178  -0.0136534  -0.00142566
  0.0292792   0.131593   -0.109728       0.0509286   0.0166219   0.112911

# Grad test