In [1]:
using Pkg;
Pkg.activate("..")

[32m[1m  Activating[22m[39m project at `~/Projects/comp441/latentplan.jl`


In [2]:
using Knet
using AutoGrad
using Distributions
using Debugger

In [148]:
macro size(e::Symbol)
    quote
        println($(string(e)), " = ", $(esc(e)))
        if applicable(size, $(esc(e)))
            println("size(",$(string(e)), ") = ", size($(esc(e))))
        end
    end
end

@size (macro with 2 methods)

In [149]:
e = (1,1)
b = rand(4,4)
@size e
@size b

e = (1, 1)
b = [0.29190395824649285 0.7836548578519952 0.05678593146376054 0.9141141870559594; 0.20564022831226347 0.6465200482467607 0.2370892245807702 0.9062440094849393; 0.628226901133905 0.10375718841262294 0.8565299802624302 0.844868260779564; 0.799184825700346 0.7126346920678022 0.6071012828079465 0.4085747277304722]
size(b) = (4, 4)


In [198]:
function vq(inputs, codebook)
    embedding_size = size(codebook,1)
    inputs_size = size(inputs)
    inputs_flatten = reshape(inputs, (embedding_size, :))
    
    codebook_sqr = transpose(sum(codebook .^ 2, dims=1))
    inputs_sqr = sum(inputs_flatten .^ 2, dims=1)
    distances = (codebook_sqr .+ inputs_sqr) + -2 * (transpose(codebook) * inputs_flatten)
    indices_cartesian = argmin(distances, dims=1)
    indices_flatten = broadcast(x->x[1], indices_cartesian)
    indices = reshape(indices_flatten, inputs_size[2:end])
    return indices
end

function vq_st(inputs, codebook)
    indices = vq(inputs, codebook)
    indices_flatten = reshape(indices, :)
    codes_flatten = codebook[:, indices]
    codes = reshape(codes_flatten, size(inputs))
    return codes, indices_flatten
end

function vq_st_codebook_backprop(inputs, codebook, output, grad_output)
    _, indices = output
    embedding_size = size(codebook, 1)
    grad_output_flatten = reshape(grad_output, (embedding_size, :))
    grad_codebook = zeros(Float32, size(codebook))
    grad_codebook[:, indices] += grad_output_flatten[:, indices]
    return grad_codebook
end

@primitive vq_st(inputs, codebook),dy,y dy vq_st_codebook_backprop(inputs,codebook, y, dy)  

In [190]:
mutable struct Embedding
    weight::Param

    function Embedding(D, K)
        weight = Param(rand(Uniform(-1/K, 1/K), (D, K)))
        new(weight)
    end
end

function (e::Embedding)(x)
    print("Embedding forward")
    weight * transpose(x)
end

In [187]:
mutable struct VQEmbedding
    embedding::Embedding

    function VQEmbedding(D, K)
        println("Creating VQEmbedding")
        embedding = Embedding(D, K)
        new(embedding)
    end
end

function (v::VQEmbedding)(z_e_x)
    println("VQEmbedding Forward")
    latents = vq(z_e_x, v.embedding.weight)
    return latents
end

function (v::VQEmbedding)(z_e_x, straight_through::Bool)
    z_q_x, indices = vq_st(z_e_x, v.embedding.weight)
    z_q_x_bar = v.embedding.weight[:, indices]
    return z_q_x, z_q_x_bar
end

In [194]:
inputs = Param(zeros(4, 2, 1))
fill!(view(inputs, :, 2, :), 1)
@show inputs
codebook = VQEmbedding(4,8)
@show codebook.embedding.weight

inputs = P(Array{Float64, 3}(4,2,1))
Creating VQEmbedding
codebook.embedding.weight = P(Matrix{Float64}(4,8))


4×8 Param{Matrix{Float64}}:
 -0.112099    0.122443   0.0564891  -0.0733504  …   0.0282702   -0.118773
  0.0688482  -0.106894  -0.0309742   0.0269659      0.00436635  -0.0201115
 -0.0758263  -0.117534   0.112868   -0.0439643     -0.0788019    0.028299
 -0.119934    0.101106  -0.0955869  -0.017539      -0.089427     0.0368782

In [195]:
typeof(codebook)

VQEmbedding

In [196]:
codebook(inputs)

VQEmbedding Forward


2×1 Matrix{Int64}:
 4
 3

In [203]:
codebook(inputs, straight_through=true)[1][1]

indices = [4; 3;;]
size(indices) = (2, 1)
indices_flatten = [4, 3]
size(indices_flatten) = (2,)
codes_flatten = [-0.07335043048776543 0.05648909989422812; 0.02696587974332243 -0.030974164157332063; -0.04396434630864143 0.11286802874038543; -0.017538952213784653 -0.09558689212070395;;;]
size(codes_flatten) = (4, 2, 1)
codes = [-0.07335043048776543 0.05648909989422812; 0.02696587974332243 -0.030974164157332063; -0.04396434630864143 0.11286802874038543; -0.017538952213784653 -0.09558689212070395;;;]
size(codes) = (4, 2, 1)


-0.07335043048776543

In [201]:
latents_st

4×2×1 Array{Float64, 3}:
[:, :, 1] =
 -0.0733504   0.0564891
  0.0269659  -0.0309742
 -0.0439643   0.112868
 -0.017539   -0.0955869

# Test for primitive usage

In [3]:
function test(x,y)
    return sum(x+2*y)
end

test (generic function with 1 method)

In [13]:
x1, x2 = Param([1,2,3]), Param([4,5,6])
y = @diff test(x1, x2) ^ 2

T(1296)

In [14]:
grad(y, x1)

72

In [9]:
@primitive test(x1, x2),dy,y dy dy

In [24]:
y = @diff test(x1, x2)

T(3)

In [28]:
grad(y, x2)

R(3)

# Test for Param indexing

In [65]:
function select_second(x1, x2)
    y = x1[1] ^ 2 * x2[2] * 3 * x1[3]
    return y
end

select_second (generic function with 2 methods)

In [66]:
x1 = Param([1, 2, 3])
x2 = Param([1, 2, 3])
y = @diff select_second(x1, x2)

T(18)