In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy

from autoencoder import LayerNorm, FeedForward, clones, Encoder, Decoder, VectorQuantizer

In [2]:
d_model, N, head_num, d_ff = 10, 3, 2, 20  # N: number of layer, head_num: number of head
encoder = Encoder(d_model, N, head_num, d_ff)
decoder = Decoder(d_model, N, head_num, d_ff)
vq = VectorQuantizer(d_model, 10, 0.25)

In [3]:
batch_size, seq_len = 2, 64
x = torch.rand(batch_size, seq_len, d_model)
mask = torch.ones(batch_size, 1, seq_len)

In [4]:
memory = encoder(x, mask)

In [5]:
memory.shape

torch.Size([2, 8, 10])

In [6]:
outputs = vq(memory)

In [8]:
output = decoder(memory, torch.ones(batch_size, 1, memory.shape[1]))

In [9]:
output.shape

torch.Size([2, 64, 10])

In [30]:
class VectorQuantizer(nn.Module):
    def __init__(self, dim_embedding, num_embeddings, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.dim_embedding = dim_embedding
        self.num_embeddings = num_embeddings
        self._commitment_cost = commitment_cost
        
        self.embed = nn.Embedding(num_embeddings, dim_embedding)
        
    def forward(self, inputs):
        """
        Args:
            inputs: [..., dim_embedding]
        """
        assert inputs.shape[-1] == self.dim_embedding
        flat_inputs = inputs.view(-1, inputs.shape[-1])
        
        # distance
        w = self.embed.weight.detach()
        with torch.no_grad():
            distances = ((flat_inputs**2).sum(dim=1, keepdim=True)
                         - 2 * torch.matmul(flat_inputs, w.T)  # distances: [N, num_embeddings]
                         + ((w.T)**2).sum(dim=0, keepdim=True))
            encoding_indices = torch.argmax(-distances, 1)  # [N]
            encodings = F.one_hot(encoding_indices, self.num_embeddings)
            encoding_indices = encoding_indices.view(inputs.shape[:-1])
        
        # get quantized vectors
        quantized = self.embed(encoding_indices)
        
        # compute loss
        q_latent_loss = torch.mean((inputs.detach() - quantized) ** 2)
        e_latent_loss = torch.mean((inputs - quantized.detach()) ** 2)
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized.detach() - inputs.detach())

        return {'quantize': quantized,
                'loss': loss,
                'encodings': encodings,
                'encoding_indices': encoding_indices,}

In [31]:
vq = VectorQuantizer(3, 10, 1)

In [32]:
inputs = torch.tensor([[1., 2., 3.],
                       [1., 2., 3.]], requires_grad=True)

In [33]:
inputs.grad

In [34]:
outputs = vq(inputs)

In [35]:
outputs['quantize'].shape

torch.Size([2, 3])

In [36]:
loss = outputs['loss']

In [37]:
loss

tensor(1.8320, grad_fn=<AddBackward0>)

In [38]:
loss.backward()

In [39]:
inputs.grad

tensor([[0.0605, 0.4485, 0.3170],
        [0.0605, 0.4485, 0.3170]])

In [41]:
inputs.grad_fn

In [42]:
from random import randint

In [43]:
randint(0, 10)

1