In [None]:
!pip install timm

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[?25l[K     |▉                               | 10 kB 26.3 MB/s eta 0:00:01[K     |█▊                              | 20 kB 8.9 MB/s eta 0:00:01[K     |██▋                             | 30 kB 7.7 MB/s eta 0:00:01[K     |███▌                            | 40 kB 7.3 MB/s eta 0:00:01[K     |████▍                           | 51 kB 5.2 MB/s eta 0:00:01[K     |█████▏                          | 61 kB 5.6 MB/s eta 0:00:01[K     |██████                          | 71 kB 5.5 MB/s eta 0:00:01[K     |███████                         | 81 kB 6.1 MB/s eta 0:00:01[K     |███████▉                        | 92 kB 4.9 MB/s eta 0:00:01[K     |████████▊                       | 102 kB 5.4 MB/s eta 0:00:01[K     |█████████▋                      | 112 kB 5.4 MB/s eta 0:00:01[K     |██████████▍                     | 122 kB 5.4 MB/s eta 0:00:01[K     |███████████▎                    | 133 kB 5.4 MB/s eta 0:00:01[K     |█

In [None]:
import torch
from torch import nn
from torch.nn import functional as F


import numpy as np
import timm
import cv2
import os
import random
from tqdm.notebook import tqdm

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [None]:
class Encoder(nn.Module):
    def __init__(self, backbone = 'resnet34'):
        super(Encoder, self).__init__()
        self.backbone = timm.create_model(backbone, pretrained = True)
        self.List = list(self.backbone.children())[:-4]

    def forward(self,X):
        for i,layer in enumerate(self.List):
            X = layer(X)
        return X


class discretize(nn.Module):
    def __init__(self, n_e, e_dim, beta=0.25):
        super(discretize, self).__init__()
        self.e_dim = e_dim
        self.n_e = n_e
    
        self.beta = beta
        self.code_book = nn.Embedding(n_e, e_dim)
        self.softmax = nn.Softmax(dim=1)
        self.code_book.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

    def forward(self, enc):
        enc = enc.permute(0,2,3,1).contiguous()
        enc_flattened = enc.view(-1, self.e_dim)

        distances = (torch.sum(enc_flattened**2, dim=1, keepdim = True) +
                     torch.sum(self.code_book.weight**2, dim=1)
                        -2*torch.matmul(enc_flattened, self.code_book.weight.t()))

        min_encoding_ids = torch.argmin(distances, dim=1).unsqueeze(1)
        min_encodings_mask = torch.zeros(min_encoding_ids.shape[0], self.n_e, dtype = self.code_book.weight.dtype).to(enc.device)
        min_encodings_mask.scatter_(1, min_encoding_ids, 1)
        latent_reps = torch.matmul(min_encodings_mask, self.code_book.weight).view(enc.shape)
        codebook_loss = F.mse_loss(latent_reps.detach(),enc) + self.beta * F.mse_loss(latent_reps , enc.detach())
        latent_reps = enc + (latent_reps-enc).detach()

        e_mean = torch.mean(min_encodings_mask, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))

        latent_reps = latent_reps.permute(0, 3, 1, 2).contiguous()

        return latent_reps, codebook_loss, perplexity, min_encoding_ids


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
encoder = Encoder().to(DEVICE)
Discrete = discretize(512, 128).to(DEVICE)
encoder.load_state_dict(torch.load("/content/gdrive/My Drive/vqvae/encoder.pth", map_location = DEVICE))
Discrete.load_state_dict(torch.load("/content/gdrive/My Drive/vqvae/discrete.pth", map_location = DEVICE))

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /root/.cache/torch/hub/checkpoints/resnet34-43635321.pth


<All keys matched successfully>

In [None]:
root_pth = "/content/gdrive/MyDrive/vqvae"
root_pth = "/content/gdrive/MyDrive/vqvae"
count = 0
birds285_folders = os.listdir(root_pth + "/285 birds/train")
birds_folders = os.listdir(root_pth + "/birds/train")
birdsrev2_folders = os.listdir(root_pth + "/birds_rev2/train")
img_pths = []
for folder in birds_folders:
    folder_pth = root_pth+"/birds/train/"+folder
    files = os.listdir(folder_pth)
    for f in files:
        file_pth = folder_pth + "/" + f
        img_pths.append(file_pth)
        count+=1
        if(count==6000):
            break
    if(count==6000):
        break

In [None]:
imgs_size_in_bytes = 0
latent_size_in_bytes = 0
for img_pth in tqdm(img_pths):
    imgs_size_in_bytes+=os.path.getsize(img_pth)
    img = np.moveaxis(cv2.imread(img_pth),2 ,0)[np.newaxis, :,:,:]
    img = torch.tensor(img, dtype = torch.float).to(DEVICE)
    img = img/255.0
    enc = encoder(img)
    d_latent,_,_,ids = Discrete(enc)
    ids = ids.detach().cpu().numpy()
    id_file_name = img_pth.split("/")[-2]+"_"+img_pth.split("/")[-1].split(".")[0]
    np.save("/content/gdrive/My Drive/vqvae/latent_data/"+id_file_name+".npy", ids)
    latent_size_in_bytes+=os.path.getsize("/content/gdrive/My Drive/vqvae/latent_data/"+id_file_name+".npy")


  0%|          | 0/6000 [00:00<?, ?it/s]

AttributeError: ignored

In [None]:
print("size of original dataset : {}". format(imgs_size_in_bytes))
print("size of compressed dataset : {}".format(latent_size_in_bytes))

size of original dataset : 127526920
size of compressed dataset : 38400000


In [None]:
compression_ratio = imgs_size_in_bytes/latent_size_in_bytes
print(compression_ratio)

3.321013541666667
