In [1]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

import torch.optim as optim


# Data preprocessing utils : 
from utils.acdc_dataset import ACDC_Dataset, One_hot_Transform, load_dataset
from torchvision.transforms import Compose
from torchvision import transforms

from torch.utils.data import DataLoader


# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# my defined model



In [2]:
from torch import nn
import torch
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss

from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor
Tensor = TypeVar('torch.tensor')

from vector_quantize_pytorch import VectorQuantize


###### Hyper Parameters of the Model ######
in_channels = 4 


In [3]:


class ResidualLayer(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                                kernel_size=3, padding=1, bias=False),
                                      nn.ReLU(True),
                                      nn.Conv2d(out_channels, out_channels,
                                                kernel_size=1, bias=False))

    def forward(self, input: Tensor) -> Tensor:
        return input + self.resblock(input)




class VQVAE(nn.Module):

    def __init__(self,
                 in_channels: int,
                 embedding_dim: int,
                 num_embeddings: int,
                #hidden_dims: List = None,
                 downsampling_factor :int = 4,
                 beta: float = 0.25,
                #  embedding: Tensor = None,
                 **kwargs) -> None:
        super(VQVAE, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta

        modules = []
        
        if downsampling_factor < 2 :
            raise Warning("VQVAE can't have a donwsampling factor less than 2")
        elif downsampling_factor ==2 :
            hidden_dims = [64]
        elif downsampling_factor == 4 :
            hidden_dims = [64, 128]
        elif downsampling_factor == 8 :
            hidden_dims = [64, 128, 256]
        else:
            assert("donwsamlping factor must be one of the following values : {2,4,8}")



        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, in_channels,
                          kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU())
        )

        for _ in range(2):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, embedding_dim,
                          kernel_size=1, stride=1),
                nn.LeakyReLU())
        )

        self.encoder = nn.Sequential(*modules)

        self.vq_layer = VectorQuantize(dim = embedding_dim,
                                        codebook_size = num_embeddings,
                                        commitment_weight = self.beta,
                                        decay = 0.8)

        # Build Decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(embedding_dim,
                          hidden_dims[-1],
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.LeakyReLU())
        )

        for _ in range(2):
            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))

        modules.append(nn.LeakyReLU())

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
                    nn.LeakyReLU())
            )

        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hidden_dims[-1],
                                   out_channels=4,
                                   kernel_size=4,
                                   stride=2, padding=1),
                nn.ReLU()
                ))

        self.decoder = nn.Sequential(*modules)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        return [result]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        result = self.decoder(z)
        return result

    def forward(self, inputs: Tensor, **kwargs) -> List[Tensor]:
        encoding = self.encode(inputs)[0]
        encoding = encoding.permute(0, 2, 3, 1)
        quantized_inputs, indices, commitment_loss_beta = self.vq_layer(encoding)
        quantized_inputs = quantized_inputs.permute(0, 3, 1, 2)
        return [self.decode(quantized_inputs), inputs, indices, commitment_loss_beta]

    ## !! update codebook_usage

    # def codebook_usage(self, inputs: Tensor, **kwargs) -> List[Tensor]:
    #     encoding = self.encode(inputs)[0]
    #     quantized_hist = self.vq_layer.quantized_latents_hist(encoding)
    #     return quantized_hist



    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        inputs = args[1]
        indices = args[2]
        commitment_loss_beta = args[3]

        recons_loss = F.cross_entropy(recons,inputs)

        loss = recons_loss + commitment_loss_beta
        return {'loss': loss,
                'Reconstruction_Loss': recons_loss,
                'commitement Loss':commitment_loss_beta}

    # def sample(self,
    #            num_samples: int,
    #            current_device: Union[int, str], **kwargs) -> Tensor:
    #     raise Warning('VQVAE sampler is not implemented.')

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return (self.forward(x)[0] > 0.5 ) # Since we are dealing with binary image.


In [4]:
x = torch.randn(1,32,32,64)



vq = VectorQuantize(
    dim = 64,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

quantized, indices, commit_loss = vq(x)

print(x.shape)

torch.Size([1, 32, 32, 64])


In [7]:
print(commit_loss)

tensor([0.9922], grad_fn=<AddBackward0>)


In [5]:
vq.codebook_size

512

In [55]:
# print(indices.view(-1))
print(indices)

tensor([[[377, 192, 133,  ..., 206, 422, 466],
         [268, 479, 319,  ..., 233, 434,  52],
         [ 10, 282, 348,  ..., 146, 133, 414],
         ...,
         [421, 110, 182,  ...,  16, 225, 441],
         [293, 489, 486,  ...,  55, 463, 342],
         [309, 415,  61,  ..., 152, 308, 348]]])


In [56]:
len(np.unique(indices.view(-1).numpy()))

428

In [57]:
print(torch.bincount(indices.view(-1)))

tensor([2, 8, 5, 2, 2, 0, 0, 3, 3, 4, 4, 6, 0, 1, 3, 1, 3, 0, 2, 1, 2, 2, 0, 0,
        2, 2, 1, 2, 2, 3, 1, 2, 2, 0, 2, 2, 3, 9, 1, 4, 0, 0, 3, 2, 1, 1, 2, 2,
        0, 4, 0, 0, 1, 1, 1, 5, 1, 1, 3, 1, 2, 2, 1, 5, 4, 2, 2, 3, 7, 1, 2, 0,
        3, 1, 0, 3, 0, 3, 1, 0, 2, 0, 1, 3, 2, 2, 1, 3, 4, 2, 3, 2, 2, 1, 1, 0,
        0, 2, 2, 1, 1, 6, 1, 0, 3, 3, 1, 2, 2, 3, 3, 4, 1, 1, 3, 0, 0, 1, 2, 2,
        4, 3, 0, 3, 1, 3, 5, 2, 1, 2, 0, 5, 2, 4, 2, 1, 2, 3, 1, 3, 3, 1, 2, 1,
        1, 1, 5, 0, 5, 2, 1, 3, 6, 1, 2, 1, 3, 1, 2, 3, 1, 7, 2, 3, 2, 5, 1, 1,
        0, 6, 0, 1, 1, 2, 1, 1, 2, 2, 1, 0, 1, 2, 2, 1, 1, 1, 4, 1, 3, 1, 2, 1,
        3, 2, 1, 2, 5, 1, 0, 3, 1, 1, 2, 3, 1, 1, 4, 0, 2, 3, 2, 1, 1, 1, 2, 4,
        3, 2, 2, 4, 1, 1, 0, 0, 1, 7, 2, 0, 0, 0, 1, 0, 5, 2, 3, 0, 4, 1, 4, 1,
        1, 3, 8, 1, 5, 1, 0, 3, 1, 3, 1, 1, 2, 6, 0, 0, 3, 1, 3, 4, 5, 3, 2, 0,
        1, 1, 3, 3, 9, 0, 2, 1, 1, 1, 1, 2, 0, 0, 1, 4, 6, 0, 1, 0, 0, 0, 4, 2,
        0, 0, 3, 3, 3, 2, 3, 2, 6, 1, 4,

----------

In [6]:
K =  512 # num_embeddings
D =  64 # embedding_dim
in_channels = 4 
downsampling_factor = 4

In [7]:
# ACDC_VQVAE = VQVAE(in_channels, D, K, downsampling_factor)
model = VQVAE(in_channels, D, K, downsampling_factor)


In [8]:
input_tesor = torch.randn(16, 4, 128, 128)

In [9]:
encoding = model.encode(input_tesor)[0]
encoding.shape

torch.Size([16, 64, 32, 32])

In [10]:
vq = VectorQuantize(
    dim = 64,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = .25   # the weight on the commitment loss
)

In [11]:
encoding = encoding.permute(0, 2, 3, 1)
print(encoding.shape)
quantized, indices, commit_loss = vq(encoding)

torch.Size([16, 32, 32, 64])


In [12]:
quantized = quantized.permute(0, 3, 1, 2)
quantized.shape

torch.Size([16, 64, 32, 32])

In [13]:
output = model.decode(quantized)
print(output.shape)

torch.Size([16, 4, 128, 128])


---------

In [18]:
input_tesor = torch.randn(16,4,128,128)

In [None]:
output, inputs, indices, commit_loss = model(input_tesor)
# loss = model.loss_function(output, inputs, codebook_loss, commit_loss)['loss']

In [23]:
print(commit_loss)

tensor([0.0006], grad_fn=<AddBackward0>)


In [26]:
model.eval()
output, inputs, indices, commit_loss = model(input_tesor)

SolveValueException: Failed to solve values of expressions. Found contradictory values {16, 1} for equivalent expressions {'1', 'h', '16'}
Input:
    'h [c] d = 1 512 64'
    'h b n = 16 32 32'
    'h b n d = None'
    '1 = 1'
