In [1]:
using Pkg;
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet
using AutoGrad
using AutoGrad: params
using Distributions
using Debugger



In [3]:
macro size(e::Union{Symbol, Expr})
    quote
        println("###########")
        println($(string(e)), " = ")
        display($(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 1 method)

# VQEmbedding and VectorQuantization

In [4]:
# VectorQuantization
function vq(inputs::atype, codebook::atype)
    embedding_size = size(codebook,1)
    inputs_size = size(inputs)
    inputs_flatten = reshape(inputs, (embedding_size, :))
    
    codebook_sqr = dropdims(sum(codebook .^ 2, dims=1), dims=1)
    inputs_sqr = sum(inputs_flatten .^ 2, dims=1)
    distances = (codebook_sqr .+ inputs_sqr) + -2 * (transpose(codebook) * inputs_flatten)
    indices_cartesian = argmin(distances, dims=1)
    indices_flatten = broadcast(x->x[1], indices_cartesian)
    indices = reshape(indices_flatten, inputs_size[2:end])
    return indices
end

# VectorQuantizationStraightThrough
function vq_st(inputs::atype, codebook::atype)
    indices = vq(inputs, codebook)
    indices_flatten = reshape(indices, :)
    codes_flatten = codebook[:, indices]
    codes = reshape(codes_flatten, size(inputs))
    return codes, indices_flatten
end

# VectorQuantizationStraightThrough Backwards gradient calculation
function vq_st_codebook_backprop(codebook, output, grad_output)
    _, indices = output
    embedding_size = size(codebook, 1)
    grad_output_flatten = reshape(grad_output[1], (embedding_size, :))
    grad_codebook = atype(zeros(Float32, size(codebook)))
    grad_codebook[:, indices] += grad_output_flatten
    return grad_codebook
end

# gradient definition for straight through estimation
@primitive vq_st(inputs, codebook),dy,y dy[1] vq_st_codebook_backprop(codebook, y, dy)  

# One Hot Encoding

In [5]:
function get_loss(latents)
    return latents[1][3, 2, 1] + 2
end

get_loss (generic function with 1 method)

In [6]:
function one_hot(Type, indices, class_num)
    onehot = zeros(Type, class_num, size(indices)...)
    for index in CartesianIndices(indices)
        onehot[indices[index], index] = convert(Type, 1)
    end
    onehot
end

@primitive one_hot(Type, indices, class_num),dy,y nothing nothing nothing


# VQEmbeddingMovingAverage

In [7]:
mutable struct VQEmbeddingMovingAverage
    embedding
    decay
    ema_count
    ema_w

    function VQEmbeddingMovingAverage(D, K; decay=0.99f0)
        embedding = atype(Float32.(rand(Uniform(-1/K, 1/K), (D, K))))
        ema_count = atype(ones(Float32, K))
        ema_w = deepcopy(embedding)
        new(embedding, decay, ema_count, ema_w)
    end
end

function (v::VQEmbeddingMovingAverage)(z_e_x)
    vq(z_e_x, v.embedding.weight)
end

function straight_through(v::VQEmbeddingMovingAverage, z_e_x, train::Bool=true)
    D, K = size(v.embedding)
    z_q_x, indices = vq_st(z_e_x, v.embedding)
    
    if train
        encodings = one_hot(Float32, indices, K)
        v.ema_count = v.decay .* v.ema_count + (1 - v.decay) .* sum(encodings, dims=2)[:, 1]
        dw = reshape(z_e_x, (D, :)) * transpose(encodings) 
        v.ema_w = v.decay .* v.ema_w + (1 - v.decay) .* dw
        v.embedding = v.ema_w ./ reshape(v.ema_count, (1, :))
    end

    z_q_x_bar_flatten = v.embedding[:, indices]
    z_q_x_bar = reshape(z_q_x_bar_flatten, size(z_e_x))

    return z_q_x, z_q_x_bar
end

straight_through (generic function with 2 methods)

In [8]:
inputs = Param(zeros(4, 2, 1))
fill!(view(inputs, :, 2, :), 1)
@size inputs
codebook = VQEmbeddingMovingAverage(4,8)
@size codebook.embedding

4×2×1 Param{Array{Float64, 3}}:
[:, :, 1] =
 0.0  1.0
 0.0  1.0
 0.0  1.0
 0.0  1.0

###########
inputs = 
size(

inputs) = (4, 2, 1)


4×8 Matrix{Float64}:
  0.0885293  -0.116828     0.0383815  …   0.00128087   0.0333134  -0.0529688
 -0.020553    0.00260281  -0.0304915      0.0481597    0.12287    -0.0887993
  0.0827049  -0.108817    -0.117196      -0.091772     0.0761856   0.105826
  0.0461754   0.00643331  -0.064942       0.0872004   -0.0710666  -0.0769491

###########
codebook.embedding = 
size(

codebook.embedding) = (4, 8)


In [9]:
loss = @diff get_loss(straight_through(codebook, inputs))
grad(loss, inputs)

4×8 AutoGrad.Result{Matrix{Float64}}:
  0.0966772  -0.116828     0.0383815  …   0.00128087   0.0333134  -0.0529688
 -0.010245    0.00260281  -0.0304915      0.0481597    0.12287    -0.0887993
  0.0909681  -0.108817    -0.117196      -0.091772     0.0761856   0.105826
  0.055162    0.00643331  -0.064942       0.0872004   -0.0710666  -0.0769491

###########
v.embedding = 
size(

v.embedding) = (4, 8)


4×2×1 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0
 0.0  0.0
 0.0  1.0
 0.0  0.0