In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import torch
from torch.autograd import Function
from torch import nn
import torch.nn.functional as F
import ipdb
import numpy as np

In [4]:
class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        with torch.no_grad():
            embedding_size = codebook.size(1)
            inputs_size = inputs.size()
            inputs_flatten = inputs.view(-1, embedding_size)

            codebook_sqr = torch.sum(codebook ** 2, dim=1)
            inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)

            # Compute the distances to the codebook
            distances = torch.addmm(codebook_sqr + inputs_sqr,
                inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

            _, indices_flatten = torch.min(distances, dim=1)
            indices = indices_flatten.view(*inputs_size[:-1])
            ctx.mark_non_differentiable(indices)

            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError('Trying to call `.grad()` on graph containing '
            '`VectorQuantization`. The function `VectorQuantization` '
            'is not differentiable. Use `VectorQuantizationStraightThrough` '
            'if you want a straight-through estimator of the gradient.')

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        indices = vq(inputs, codebook)
        indices_flatten = indices.view(-1)
        ctx.save_for_backward(indices_flatten, codebook)
        ctx.mark_non_differentiable(indices_flatten)

        codes_flatten = torch.index_select(codebook, dim=0,
            index=indices_flatten)
        codes = codes_flatten.view_as(inputs)

        return (codes, indices_flatten)

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs, grad_codebook = None, None

        if ctx.needs_input_grad[0]:
            # Straight-through estimator
            grad_inputs = grad_output.clone()
        if ctx.needs_input_grad[1]:
            # Gradient wrt. the codebook
            indices, codebook = ctx.saved_tensors
            embedding_size = codebook.size(1)

            grad_output_flatten = (grad_output.contiguous()
                                              .view(-1, embedding_size))
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, indices, grad_output_flatten)

        return (grad_inputs, grad_codebook)

In [5]:
vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply

In [6]:
class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
        z_q_x = z_q_x_.contiguous()

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
            dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar


In [7]:

class VQEmbeddingMovingAverage(nn.Module):
    def __init__(self, K, D, decay=0.99):
        super().__init__()
        embedding = torch.zeros(K, D)
        embedding.uniform_(-1./K, 1./K)
        self.decay = decay

        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.ones(K))
        self.register_buffer("ema_w", self.embedding.clone())

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        ipdb.set_trace()
        K, D = self.embedding.size()

        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding)
        z_q_x = z_q_x_.contiguous()


        if self.training:
            encodings = F.one_hot(indices, K).float()
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)

            dw = encodings.transpose(1, 0)@z_e_x_.reshape([-1, D])
            self.ema_w = self.decay * self.ema_w + (1 - self.decay) * dw

            self.embedding = self.ema_w / (self.ema_count.unsqueeze(-1))
            self.embedding = self.embedding.detach()
            self.ema_w = self.ema_w.detach()
            self.ema_count = self.ema_count.detach()

        z_q_x_bar_flatten = torch.index_select(self.embedding, dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar

In [8]:
def main():
    weights = torch.load("../../latentplan.jl/latentplan/test/files/gpt_weights.pt")
    codebook = VQEmbeddingMovingAverage(K=512, D=512)
    codebook.embedding = weights["model.codebook.embedding"]
    codebook.ema_count = weights["model.codebook.ema_count"]
    codebook.ema_w = weights["model.codebook.ema_w"]

    codebook.requires_grad = True
    print(codebook.embedding)
    
    trajectory_feature = torch.tensor(np.load("../../latentplan.jl/latentplan/test/files/trajectory_feature.npy"))
    trajectory_feature.requires_grad = True

    latents_st, latents = codebook.straight_through(trajectory_feature)
    loss = 2 * latents_st[0,0,0] ** 2 + latents_st[0,1,1] ** 2 + latents_st[0,2,2] ** 2
    loss.backward()
    print(loss)
    np.save("trajectory_feature_straight_through_grad.npy", trajectory_feature.grad.cpu().detach().numpy())
    
    print(trajectory_feature.grad)
    print(codebook.embedding.grad)
main()

tensor([[ 0.0019, -0.0002,  0.0009,  ...,  0.0015,  0.0009,  0.0014],
        [ 0.0007, -0.0005, -0.0015,  ..., -0.0017, -0.0005, -0.0014],
        [ 0.0016,  0.0018,  0.0009,  ..., -0.0018,  0.0018,  0.0007],
        ...,
        [ 0.0004, -0.0017, -0.0016,  ..., -0.0012,  0.0002, -0.0004],
        [-0.0015,  0.0014, -0.0008,  ..., -0.0007,  0.0016, -0.0016],
        [ 0.0014, -0.0012,  0.0005,  ...,  0.0006, -0.0003,  0.0015]])
> [0;32m/var/folders/js/rf57f6s5077gn5pslm2v3lsh0000gn/T/ipykernel_92234/6531829.py[0m(19)[0;36mstraight_through[0;34m()[0m
[0;32m     18 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m        [0mK[0m[0;34m,[0m [0mD[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0membedding[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m
tensor(1.1151e-06, grad_fn=<AddBackward0>)
tensor([[[-0.0013,  0.0000,  0

In [9]:
test = np.load("trajectory_feature_straight_through_grad.npy")

In [15]:
test[0,:,2]

array([ 0.        ,  0.        , -0.00183442,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ], dtype=float32)

In [14]:
def vq_test():
    codebook = VQEmbedding(K=4, D=4)
    
    trajectory_feature = torch.zeros(1,1,4)
    trajectory_feature.requires_grad = True
    indices = vq(trajectory_feature, codebook.embedding.weight.detach())
    print(indices)
vq_test()

tensor([[0]])


# VQEmbeddingMovingAverage