In [1]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

import torch.optim as optim


# Data preprocessing utils : 
from utils.acdc_dataset import ACDC_Dataset, One_hot_Transform, load_dataset
from torchvision.transforms import Compose
from torchvision import transforms

from torch.utils.data import DataLoader


# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# my defined model


In [2]:

# from vector_quantize_pytorch import VectorQuantize
from vector_quantize_pytorch import ResidualVQ


In [84]:
residual_vq = ResidualVQ(
    dim = 64,
    num_quantizers = 2,
    codebook_size = 512,
    # stochastic_sample_codes = True,
    # sample_codebook_temp = 0.1,         # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
    # shared_codebook = True
    commitment_weight = 1.              # whether to share the codebooks for all quantizers or not
)

residual_vq2 = ResidualVQ(
    dim = 64,
    num_quantizers = 2,
    codebook_size = 512,
    # stochastic_sample_codes = True,
    # sample_codebook_temp = 0.1,         # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
    # shared_codebook = True
    commitment_weight = 0.5              # whether to share the codebooks for all quantizers or not
)


x = torch.randn(16, 32, 32, 64)
quantized, indices, commit_loss = residual_vq(x)
quantized2, indices2, commit_loss2 = residual_vq2(x)



In [181]:
residual_vq.layers[0].dim

64

In [85]:
print(indices.shape)

torch.Size([16, 32, 32, 2])


In [86]:
print(commit_loss)
print(commit_loss2)

tensor([[0.9940, 0.9883]], grad_fn=<StackBackward0>)
tensor([[0.4970, 0.4942]], grad_fn=<StackBackward0>)


In [83]:
torch.sum(commit_loss)

tensor(1.9806, grad_fn=<SumBackward0>)

In [16]:
inputs = torch.randn(16,64,32,32)

inputs = inputs.permute(0,3,2,1)
quantized, indices, commit_loss = residual_vq(inputs)



In [17]:
print(indices.shape)
print(quantized.shape)
print(commit_loss)
print(residual_vq.codebooks.shape)

torch.Size([16, 32, 32, 8])
torch.Size([16, 32, 32, 64])
tensor([[0.8794, 0.7752, 0.6837, 0.6034, 0.5337, 0.4719, 0.4179, 0.3711]],
       grad_fn=<StackBackward0>)
torch.Size([8, 512, 64])


In [19]:
print(residual_vq.codebooks[0])

tensor([[ 2.0659e-01, -4.2754e-01,  5.0350e-03,  ...,  2.7439e-01,
          2.9347e-01,  2.2752e-01],
        [ 1.2426e-04,  2.7030e-03, -7.4224e-03,  ...,  1.2880e-02,
          1.2001e-02, -9.1113e-03],
        [-2.1528e-01,  1.8731e-01, -1.9458e-01,  ...,  2.6617e-01,
          2.7675e-02, -1.1047e-02],
        ...,
        [ 1.3675e-01,  1.2865e-02, -5.1116e-02,  ..., -1.1161e-01,
         -2.2483e-01,  6.3654e-01],
        [-4.7515e-01, -3.0455e-01,  2.7131e-01,  ..., -6.7351e-01,
          1.0754e-01, -9.6724e-02],
        [ 3.0639e-01, -3.2468e-01,  2.5370e-01,  ...,  6.2321e-01,
         -5.6234e-02,  1.1464e-01]])


----------------

In [140]:


class ResidualLayer(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                                kernel_size=3, padding=1, bias=False),
                                      nn.ReLU(True),
                                      nn.Conv2d(out_channels, out_channels,
                                                kernel_size=1, bias=False))

    def forward(self, input: Tensor) -> Tensor:
        return input + self.resblock(input)




class RQVAE(nn.Module):

    def __init__(self,
                 in_channels: int,
                 embedding_dim: int,
                 num_embeddings: int,
                 num_quantizers: int,
                 shared_codebook: bool = False,
                #hidden_dims: List = None,
                 downsampling_factor :int = 4,
                 decay : float = 0.8,
                 beta: float = 0.25,
                #  embedding: Tensor = None,
                 **kwargs) -> None:
        super(RQVAE, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta
        self.shared_codebook = shared_codebook
        self.num_quantizers = num_quantizers
        self.decay = decay

        modules = []
        
        if downsampling_factor < 2 :
            raise Warning("VQVAE can't have a donwsampling factor less than 2")
        elif downsampling_factor ==2 :
            hidden_dims = [64]
        elif downsampling_factor == 4 :
            hidden_dims = [64, 128]
        elif downsampling_factor == 8 :
            hidden_dims = [64, 128, 256]
        else:
            assert("donwsamlping factor must be one of the following values : {2,4,8}")



        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, in_channels,
                          kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU())
        )

        for _ in range(2):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, embedding_dim,
                          kernel_size=1, stride=1),
                nn.LeakyReLU())
        )

        self.encoder = nn.Sequential(*modules)

        self.vq_layer = ResidualVQ(dim = embedding_dim,
                                    codebook_size = num_embeddings,
                                    commitment_weight = self.beta,
                                    decay = self.decay,
                                    num_quantizers = self.num_quantizers,
                                    shared_codebook = self.shared_codebook,
                                    )

        # Build Decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(embedding_dim,
                          hidden_dims[-1],
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.LeakyReLU())
        )

        for _ in range(2):
            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

        modules.append(nn.LeakyReLU())

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
                    nn.LeakyReLU())
            )

        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hidden_dims[-1],
                                   out_channels=4,
                                   kernel_size=4,
                                   stride=2, padding=1),
                nn.ReLU()
                ))

        self.decoder = nn.Sequential(*modules)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        return [result]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        result = self.decoder(z)
        return result

    def forward(self, inputs: Tensor, **kwargs) -> List[Tensor]:
        encoding = self.encode(inputs)[0]
        encoding = encoding.permute(0, 2, 3, 1)
        quantized_inputs, indices, commitment_loss_beta = self.vq_layer(encoding)
        quantized_inputs = quantized_inputs.permute(0, 3, 1, 2)
        return [self.decode(quantized_inputs), inputs, indices, commitment_loss_beta]

    ## !! update codebook_usage

    # def codebook_usage(self, inputs: Tensor, **kwargs) -> List[Tensor]:
    #     encoding = self.encode(inputs)[0]
    #     quantized_hist = self.vq_layer.quantized_latents_hist(encoding)
    #     return quantized_hist



    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        inputs = args[1]
        indices = args[2]
        commitment_loss_beta = args[3]

        recons_loss = F.cross_entropy(recons,inputs)

        loss = recons_loss + torch.sum(commitment_loss_beta) # sum over all commitement losses of all codebooks
        return {'loss': loss,
                'Reconstruction_Loss': recons_loss,
                'commitement Loss':commitment_loss_beta}

    # def sample(self,
    #            num_samples: int,
    #            current_device: Union[int, str], **kwargs) -> Tensor:
    #     raise Warning('VQVAE sampler is not implemented.')

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """


    def codebook_usage(self, inputs):
        encoding = self.encode(inputs)[0]
        encoding = encoding.permute(0, 2, 3, 1)
        _, indices, _ = self.vq_layer(encoding)

        num_codebooks = indices.shape[-1]
        embedding_histogram = torch.zeros(num_codebooks,self.vq_layer.codebook_size )

        for i in range(num_codebooks):
            encoding_inds_flat_i = indices[... , i].view(-1)   # [B,H,W] --> [B,H,W]
            embedding_histogram_i = torch.bincount(encoding_inds_flat_i, minlength=self.vq_layer.codebook_size)  # Count occurrences of each embedding
            embedding_histogram[i] = embedding_histogram_i
            
        return embedding_histogram



In [119]:
indices[... , 0].shape

torch.Size([16, 32, 32])

----------

In [141]:
K =  512 # num_embeddings
D =  64 # embedding_dim
in_channels = 4 
num_quantizers = 2
downsampling_factor = 4
shared_codebook = False

In [142]:
# ACDC_VQVAE = VQVAE(in_channels, D, K, downsampling_factor)
model = RQVAE(in_channels= in_channels,
                embedding_dim= D,
                num_embeddings= K,
                num_quantizers= num_quantizers,
                shared_codebook= shared_codebook,
                downsampling_factor= downsampling_factor )


In [143]:
input_tensor = torch.randn(16, 4, 128, 128)

In [144]:
output_tensor, inputs, indices, commit_losss = model(input_tensor)

In [145]:
print(output_tensor.shape)
# print(input_tensor - inputs)
print(indices.shape)
print(torch.sum(commit_losss))

torch.Size([16, 4, 128, 128])
torch.Size([16, 32, 32, 2])
tensor(0.0010, grad_fn=<SumBackward0>)


In [152]:
hist_codebooks = model.codebook_usage(input_tesor)
print(hist_codebooks.shape)

torch.Size([2, 512])


In [87]:
model.eval()
input_tensor = torch.randn(16, 4, 128, 128)
output_tensor, inputs, indices, commit_losss = model(input_tensor)

SolveValueException: Failed to solve values of expressions. Found contradictory values {16, 1} for equivalent expressions {'1', 'h', '16'}
Input:
    'h [c] d = 1 512 64'
    'h b n = 16 32 32'
    'h b n d = None'
    '1 = 1'
