In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.transforms.functional as TF
import torchvision.transforms as T

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
# from scipy.signal import savgol_filter

In [3]:
class DogDataset(torch.utils.data.Dataset):
    def __init__(self, data_folder, transform=None) -> None:
        super().__init__()

        self._img_paths = []
        for path, subdirs, files in os.walk(data_folder):
            for name in files:
                self._img_paths.append(os.path.join(path, name))


        # self._img_paths = [os.path.join(data_folder, f)
        #                    for f in os.listdir(data_folder)]
        self._target_img_size = 256
        self._transform = transform

        # self._to_tensor = T.ToTensor()

    def __len__(self):
        return len(self._img_paths)

    def __getitem__(self, index):
        selected_img_path = self._img_paths[index]

        img = cv2.imread(selected_img_path)

        h, w, _ = img.shape

        r = self._target_img_size / min(h, w)
        s = (round(r * h), round(r*w))
        # print(f'New size: {s}')

        img = cv2.resize(img, s, interpolation=cv2.INTER_CUBIC)

        h, w, _ = img.shape

        # Center crop
        x = w/2 - self._target_img_size/2
        y = h/2 - self._target_img_size/2

        crop_img = img[int(y):int(y+self._target_img_size),
                       int(x):int(x+self._target_img_size)]
        
        # crop_img = crop_img.astype(np.float16)

        # img = TF.center_crop(img, output_size=2 * [self._target_img_size])
        # img = torch.unsqueeze(T.ToTensor()(crop_img), 0)

        if self._transform:
            img_tensor = self._transform(crop_img)
        else:
            img_tensor = crop_img

        # img_tensor = self._to_tensor(img_tensor)

        return img_tensor


In [4]:
data_folder = "/home/jaswant/Documents/DiscreteVAE/data/Images"

mean_tensor = (0.4360, 0.4408, 0.4332)
std_tensor = (0.2619, 0.2639, 0.2616)

dog_data_normalized = DogDataset(data_folder=data_folder, transform=T.Compose([T.ToTensor(), T.Normalize(mean=mean_tensor, std=std_tensor)]))

In [5]:
sample_index = 178
sample_dog = dog_data_normalized[sample_index]
sample_dog.shape

torch.Size([3, 256, 256])

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings