In [1]:
%load_ext autoreload
%autoreload 2

In [52]:
import torch
from torch.autograd import Function
from torch import nn
import torch.nn.functional as F
import ipdb

In [30]:
class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        with torch.no_grad():
            embedding_size = codebook.size(1)
            inputs_size = inputs.size()
            inputs_flatten = inputs.view(-1, embedding_size)

            codebook_sqr = torch.sum(codebook ** 2, dim=1)
            inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)

            # Compute the distances to the codebook
            distances = torch.addmm(codebook_sqr + inputs_sqr,
                inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

            _, indices_flatten = torch.min(distances, dim=1)
            indices = indices_flatten.view(*inputs_size[:-1])
            ctx.mark_non_differentiable(indices)

            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError('Trying to call `.grad()` on graph containing '
            '`VectorQuantization`. The function `VectorQuantization` '
            'is not differentiable. Use `VectorQuantizationStraightThrough` '
            'if you want a straight-through estimator of the gradient.')

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        indices = vq(inputs, codebook)
        indices_flatten = indices.view(-1)
        ctx.save_for_backward(indices_flatten, codebook)
        ctx.mark_non_differentiable(indices_flatten)

        codes_flatten = torch.index_select(codebook, dim=0,
            index=indices_flatten)
        codes = codes_flatten.view_as(inputs)

        return (codes, indices_flatten)

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs, grad_codebook = None, None

        if ctx.needs_input_grad[0]:
            # Straight-through estimator
            grad_inputs = grad_output.clone()
        if ctx.needs_input_grad[1]:
            # Gradient wrt. the codebook
            indices, codebook = ctx.saved_tensors
            embedding_size = codebook.size(1)

            grad_output_flatten = (grad_output.contiguous()
                                              .view(-1, embedding_size))
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, indices, grad_output_flatten)

        return (grad_inputs, grad_codebook)

In [31]:
vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply

In [43]:
class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
        z_q_x = z_q_x_.contiguous()

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
            dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar


In [45]:
def main():
    codebook = VQEmbedding(K=8, D=4)
    codebook.requires_grad = True
    print(codebook.embedding.weight)
    
    trajectory_feature = torch.zeros(1,2,4)
    trajectory_feature[0,1,:].fill_(1)
    trajectory_feature.requires_grad = True
    latents_st, latents = codebook.straight_through(trajectory_feature)
    print("Latent code shape:", latents_st.shape)
    print("Latents_st:",  latents_st)
    print("Latents:", latents)
    loss = 2 * latents_st[0,0,0] + latents_st[0,1,1]
    loss.backward()
    print(trajectory_feature.grad)
    print(codebook.embedding.weight.grad)
main()

Parameter containing:
tensor([[-0.1017, -0.0756, -0.0260,  0.0423],
        [ 0.0249, -0.0127, -0.0857, -0.0423],
        [-0.0179, -0.1194, -0.0493,  0.0043],
        [ 0.1053,  0.0444,  0.0938, -0.0371],
        [-0.1228, -0.0994,  0.0341,  0.0482],
        [ 0.0325, -0.0910,  0.0846,  0.0073],
        [ 0.0806,  0.0586, -0.0904, -0.0681],
        [ 0.1024, -0.0948,  0.1128, -0.0157]], requires_grad=True)
Latent code shape: torch.Size([1, 2, 4])
Latents_st: tensor([[[ 0.0249, -0.0127, -0.0857, -0.0423],
         [ 0.1053,  0.0444,  0.0938, -0.0371]]],
       grad_fn=<VectorQuantizationStraightThroughBackward>)
Latents: tensor([[[ 0.0249, -0.0127, -0.0857, -0.0423],
         [ 0.1053,  0.0444,  0.0938, -0.0371]]], grad_fn=<ViewBackward0>)
tensor([[[2., 0., 0., 0.],
         [0., 1., 0., 0.]]])
tensor([[0., 0., 0., 0.],
        [2., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0.,

In [24]:
def vq_test():
    codebook = VQEmbedding(K=4, D=4)
    
    trajectory_feature = torch.zeros(1,1,4)
    trajectory_feature.requires_grad = True
    indices = vq(trajectory_feature, codebook.embedding.weight.detach())
    print(indices)
vq_test()

tensor([[2]])


# VQEmbeddingMovingAverage

In [47]:

class VQEmbeddingMovingAverage(nn.Module):
    def __init__(self, K, D, decay=0.99):
        super().__init__()
        embedding = torch.zeros(K, D)
        embedding.uniform_(-1./K, 1./K)
        self.decay = decay

        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.ones(K))
        self.register_buffer("ema_w", self.embedding.clone())

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        ipdb.set_trace()
        K, D = self.embedding.size()

        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding)
        z_q_x = z_q_x_.contiguous()


        if self.training:
            encodings = F.one_hot(indices, K).float()
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)

            dw = encodings.transpose(1, 0)@z_e_x_.reshape([-1, D])
            self.ema_w = self.decay * self.ema_w + (1 - self.decay) * dw

            self.embedding = self.ema_w / (self.ema_count.unsqueeze(-1))
            self.embedding = self.embedding.detach()
            self.ema_w = self.ema_w.detach()
            self.ema_count = self.ema_count.detach()

        z_q_x_bar_flatten = torch.index_select(self.embedding, dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar

# Grad test

In [None]:
function test_grad