In [1]:
%load_ext autoreload
%autoreload 2

In [149]:
import os
import sys
sys.path.insert(0, '../src')

import torch
from torch import nn
from torch.nn import functional as F
from models.backbone import build_backbone
from models.pointalign import SmallDecoder

In [150]:
class VectorQuantizer(nn.Module):
    '''
    Represents the VQ-VAE layer.
    Implements the algorithm in 'Generating Diverse High Fidelity Images with VQ-VAE2' by Razavi et al.
    https://arxiv.org/pdf/1906.00446
    
    Adapted from: https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py,
                  https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py

    Attributes:
        num_embed: int, the number of codebook vectors in the latent space
        embed_dim: int, the dimensionality of the tensors in the latent space
        commitment_cost: scalar that controls weighting of commitment loss term (beta in Eq 4 of the VQ-VAE 2 paper)
        decay: float, decay for the moving averages
        eps: small float constant to avoid numerical instability
    '''
    
    def __init__(self, num_embed, embed_dim, decay=0.99, eps=1e-5):
        super().__init__()

        self.num_embed = num_embed
        self.embed_dim = embed_dim
        #self.commitment_cost = commitment_cost
        self.decay = decay
        self.eps = eps

        embed = torch.randn(embed_dim, num_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embed))
        self.register_buffer("dw", embed.clone())
    

    def forward(self, x):
        '''
        Inputs:
            x - Tensor of shape (N, C, H, embed_dim)
        '''
        x_flat = x.reshape(-1, self.embed_dim) # (NCH, embed_dim)
        
        dists = (x_flat**2).sum(dim=1, keepdim=True) - 2 * (x_flat @ self.embed) + (self.embed**2).sum(dim=0, keepdim=True) #(NCH, num_embed)
        
        encoding_inds = dists.argmin(dim=-1) # (NCH,)
        encodings = F.one_hot(encoding_inds, self.num_embed).type(x_flat.dtype) # (NCH, num_embed)
        assert x.shape[-1:][0] == self.embed_dim
        encoding_inds = encoding_inds.view(*x.shape[:-1]) # (N, C, H)
        quantized = self.quantize(encoding_inds) # (N, C, H, embed_dim)
        
        if self.training:
            '''
            encodings_onehot_sum = encodings.sum(0)
            encodings_sum = x_flat.transpose(0,1) @ encodings
            self.cluster_size.data = self.cluster_size * self.decay + (1 - self.decay) * encodings_onehot_sum #N^(t)
            self.dw.data = self.decay * self.dw.data + (1-self.decay) * encodings_sum
            n = self.cluster_size.sum()
            cluster_size = n * (self.cluster_size + self.eps)/(n + self.num_embed * self.eps)
            embed_normalized = self.dw / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)
            '''
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self.cluster_size.data)
            cluster_size = n * (self.cluster_size + self.eps)/(n + self.num_embed * self.eps)
            
            #m^(t) = self.decay * m^(t-1) + (1-self.decay) * \sum_{j}^{n^(t)} E(x)_{i,j}^(t)
            self.dw = self.decay * self.dw + (1 - self.decay) * x_flat.T @ encodings
            
            n = self.cluster_size.sum()
            cluster_size = n * (self.cluster_size + self.eps) / (n + self.eps * self.num_embed)
            self.embed = self.dw / cluster_size.unsqueeze(0) #e^(t) = m^(t) / N^(t)
            
        #else:
            #loss = self.commitment_cost * e_latent_loss
        
        diff = ((quantized.detach() - x)**2).mean()
        quantized = x + (quantized - x).detach()

        return quantized, diff, encoding_inds
    

    def quantize(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

In [151]:
class VectorQuantizer(nn.Module):
    '''
    Represents the VQ-VAE layer with an exponential moving average loss to accelerate training.
    Implements the algorithm in 'Generating Diverse High Fidelity Images with VQ-VAE2' by Razavi et al.
    https://arxiv.org/pdf/1906.00446
    
    Adapted from: https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py,
                  https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py.

    Attributes:
        num_embed: int, the number of codebook vectors in the latent space
        embed_dim: int, the dimensionality of the tensors in the latent space
        commitment_cost: scalar that controls weighting of commitment loss term (beta in Eq 4 of the VQ-VAE 2 paper)
        decay: float, decay for the moving averages
        eps: small float constant to avoid numerical instability
    '''
    
    def __init__(self, num_embed, embed_dim, decay=0.99, eps=1e-5):
        super().__init__()

        self.num_embed = num_embed
        self.embed_dim = embed_dim
        #self.commitment_cost = commitment_cost
        self.decay = decay
        self.eps = eps

        embed = torch.randn(embed_dim, num_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embed))
        self.register_buffer("dw", embed.clone())
    

    def forward(self, x):
        '''
        Inputs:
            x - Tensor of shape (N, C, H, embed_dim)
        
        '''
        x_flat = x.reshape(-1, self.embed_dim) # (NCH, embed_dim)
        
        dists = (x_flat**2).sum(dim=1, keepdim=True) - 2 * (x_flat @ self.embed) + (self.embed**2).sum(dim=0, keepdim=True) #(NCH, num_embed)
        
        encoding_inds = dists.argmin(dim=-1) # (NCH,)
        encodings = F.one_hot(encoding_inds, self.num_embed).type(x_flat.dtype) # (NCH, num_embed)
        assert x.shape[-1:][0] == self.embed_dim
        encoding_inds = encoding_inds.view(*x.shape[:-1]) # (N, C, H)
        quantized = self.quantize(encoding_inds) # (N, C, H, embed_dim)
        
        if self.training:
            # Laplace smoothing of the cluster size
            n = torch.sum(self.cluster_size.data)
            cluster_size = n * (self.cluster_size + self.eps)/(n + self.num_embed * self.eps)
            
            #m^(t) = self.decay * m^(t-1) + (1-self.decay) * \sum_{j}^{n^(t)} E(x)_{i,j}^(t)
            self.dw = self.decay * self.dw + (1 - self.decay) * x_flat.T @ encodings
            
            n = self.cluster_size.sum()
            cluster_size = n * (self.cluster_size + self.eps) / (n + self.eps * self.num_embed)
            self.embed = self.dw / cluster_size.unsqueeze(0) #e^(t) = m^(t) / N^(t)
            
        #else:
            #loss = self.commitment_cost * e_latent_loss
        
        diff = ((quantized.detach() - x)**2).mean()
        quantized = x + (quantized - x).detach()

        return quantized, diff, encoding_inds
    

    def quantize(self, embed_ind):
        return F.embedding(embed_ind, self.embed.T)

In [152]:
num_embed = 512
embed_dim = 64
x = torch.rand(128, 3, 64, 64)

vqvae = VectorQuantizer(num_embed, embed_dim, decay=0.99, eps=1e-5)
quantized, diff, encoding_inds = vqvae.forward(x)
print(quantized.shape)
print(diff.shape)
print(encoding_inds.shape)

torch.Size([128, 3, 64, 64])
torch.Size([])
torch.Size([128, 3, 64])


In [153]:
class VQVAE(nn.Module):

    def __init__(self, points=None):
        super().__init__()

        self.encoder, feat_dims = build_backbone('resnet18', pretrained=True)
        self.quantize = VectorQuantizer(num_embed=512, embed_dim=64)
        self.decoder = SmallDecoder(points, feat_dims[-1], 128)

    def forward(self, images, P=None):
        z = self.encoder(images)[-1]

        quant, diff, encoding_id = self.quantize(z.permute(0, 2, 3, 1))
        quant = quant.permute(0, 3, 1, 2)
        diff = diff.unsqueeze(0)

        ptclds = self.decoder(quant, P)

        return ptclds

In [154]:
model = VQVAE()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/carolinechoi/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [62]:
x = torch.tensor([4, 3, 2, 1, 0])
print(x)
F.one_hot(x, num_classes=6)

tensor([4, 3, 2, 1, 0])


tensor([[0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0]])

In [104]:
encoding_inds = torch.tensor([[1,0,0],[0,1,0]])
print(encoding_inds)
F.one_hot(encoding_inds, 6).type(encoding_inds.dtype)

tensor([[1, 0, 0],
        [0, 1, 0]])


tensor([[[0, 1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0]],

        [[1, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0]]])

In [81]:
encodings = torch.zeros(3, 6, device=encoding_inds.device)
encodings.scatter_(1, encoding_inds, 1)

tensor([[1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.]])

In [90]:
encodings = torch.zeros(3, 6, device=encoding_inds.device)
encodings.scatter_(1, encoding_inds, 1)

tensor([[1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.]])