<a href="https://colab.research.google.com/github/inderpreetsingh01/PyTorch/blob/main/VQ_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [134]:
import numpy as np
import pickle

In [154]:
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once in a notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once in a notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
gdrive = GoogleDrive(gauth)

In [133]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [126]:
import torch 
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T

In [8]:
%%capture
!pip install datasets
from datasets import load_dataset

In [146]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [9]:
rvl_cdip_dataset = load_dataset('rvl_cdip', streaming=True)

Downloading builder script:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.64k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.14k [00:00<?, ?B/s]

In [81]:
train_data = rvl_cdip_dataset['train']
val_data = rvl_cdip_dataset['validation']
test_data = rvl_cdip_dataset['test']

In [82]:
train_data.features

{'image': Image(decode=True, id=None),
 'label': ClassLabel(names=['letter', 'form', 'email', 'handwritten', 'advertisement', 'scientific report', 'scientific publication', 'specification', 'file folder', 'news article', 'budget', 'invoice', 'presentation', 'questionnaire', 'resume', 'memo'], id=None)}

In [83]:
transforms = T.Compose(
    [   T.Lambda(lambda x:np.resize(x, (224,224))),
        T.Lambda(lambda x:np.expand_dims(x, 2)),
        T.ToTensor()
    ]
)

In [128]:
data = list(train_data.take(1))

In [154]:
transforms(np.array(data[0]['image']))

tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 0.3922, 0.3922, 0.3922],
         [0.3922, 0.3922, 0.3686,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 0.7843, 0.0000, 0.6549]]])

In [84]:
def transform_func(examples):
    examples["pixel_values"] = [transforms(np.array(image)) for image in examples["image"]]
    return examples

In [85]:
train_data1 = train_data.map(transform_func, remove_columns=["image"], batched=True)
val_data1 = val_data.map(transform_func, remove_columns=["image"], batched=True)
test_data1 = test_data.map(transform_func, remove_columns=["image"], batched=True)

In [86]:
data_instance = next(iter(train_data1))

In [87]:
data_instance['pixel_values'].shape

torch.Size([1, 224, 224])

In [173]:
train_dataloader = DataLoader(train_data1, batch_size=8)
val_dataloader = DataLoader(val_data1, batch_size=10)

In [109]:
from tqdm import tqdm

In [131]:
i=0
pbar = tqdm(train_dataloader, desc='Training', total=320000/16)
for batch in pbar:
  # print(i)
  i=i+1
  pbar.set_postfix({'batch':pbar.n})
  

Training:   0%|          | 0/20000.0 [00:04<?, ?it/s]


KeyboardInterrupt: ignored

In [71]:
class VectorQuantizer(nn.Module):

  def __init__(self, num_embedding=8192, embedding_dim=512, commitment_cost=0.2):
    super(VectorQuantizer, self).__init__()

    self._num_embeddings = num_embedding
    self._embedding_dim = embedding_dim
    self._embedding = nn.Embedding(num_embedding, embedding_dim)
    self._embedding.weight.data.uniform_(-1.0/num_embedding, 1.0/num_embedding)
    self._commitment_cost = commitment_cost

  def forward(self, inputs):

    B, C, H, W = inputs.shape

    assert C == self._embedding_dim, "Number of channels from encoder output not same as embedding dimension"

    inputs = inputs.permute(0, 2, 3, 1).contiguous()
    input_shape = inputs.shape

    # Flatten input
    flat_input = inputs.view(-1, self._embedding_dim)
    
    # Calculate distances
    distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                + torch.sum(self._embedding.weight**2, dim=1)
                - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
        
    # Encoding
    encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
    encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
    encodings.scatter_(1, encoding_indices, 1)
    
    # Quantize and unflatten
    quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

    # Loss
    e_latent_loss = F.mse_loss(quantized.detach(), inputs)
    q_latent_loss = F.mse_loss(quantized, inputs.detach())
    loss = q_latent_loss + self._commitment_cost * e_latent_loss
    
    quantized = inputs + (quantized - inputs).detach()
    avg_probs = torch.mean(encodings, dim=0)
    perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
    
    # convert quantized from BHWC -> BCHW
    return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [169]:
x = batch['pixel_values']

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=in_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=in_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(in_channels=in_channels,
                      out_channels=in_channels,
                      kernel_size=1, stride=1, bias=False)
        )
    
    def forward(self, x):
        return x + self._block(x)

In [12]:
class EncoderLayer(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(EncoderLayer, self).__init__()
    
    self.conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                                kernel_size=3, stride=2, padding=1, bias=False)
    self.residual_block = ResidualBlock(out_channels)

  def forward(self, input):
    x = self.conv_layer(input)
    x = self.residual_block(x)
    return x

In [34]:
class Encoder(nn.Module):
    def __init__(self, in_channels=1, out_channels=384):
        super(Encoder, self).__init__()
        assert in_channels == 1, 'Number of Input channels not equal to 1'
        self.layers = nn.ModuleList()
        self.layers.append(EncoderLayer(1, 128))
        self.layers.append(EncoderLayer(128, 256))
        self.layers.append(EncoderLayer(256, 384))

    def forward(self, x):
        for layer in self.layers:
          x = layer(x)
        return x

In [30]:
class DecoderLayer(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(DecoderLayer, self).__init__()
    
    self.conv_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, 
                                kernel_size=4, stride=2, padding=1, bias=False)
    self.residual_block = ResidualBlock(out_channels)

  def forward(self, input):
    x = self.conv_layer(input)
    x = self.residual_block(x)
    return x

In [31]:
class Decoder(nn.Module):
  def __init__(self, in_channels=512, out_channels=1):
    super(Decoder, self).__init__()
    self.layers = nn.ModuleList()
    self.layers.append(DecoderLayer(in_channels, in_channels//2))
    self.layers.append(DecoderLayer(in_channels//2, in_channels//4))
    self.layers.append(DecoderLayer(in_channels//4, 1))

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)
    return x

In [72]:
class Model(nn.Module):
  def __init__(self, enc_in_channels=1, enc_out_channels=384, embedding_dim=512,
               num_embedding=8192, commitment_cost=0.2):
    super(Model, self).__init__()
    self.encoder = Encoder(enc_in_channels, enc_out_channels)
    self.decoder = Decoder(embedding_dim, enc_in_channels)
    self.vqlayer = VectorQuantizer(num_embedding, embedding_dim, commitment_cost)
    self.proj_to_embedding = nn.Conv2d(enc_out_channels, embedding_dim, kernel_size=1, stride=1)

  def forward(self, x):
    B,C,H,W = x.shape

    x = self.encoder(x)
    x = self.proj_to_embedding(x)
    loss, q, perplexity, encodings = self.vqlayer(x)
    x_recon = self.decoder(q)
    return loss, x_recon, perplexity

In [179]:
def train(model, optimizer, history=None):
  model.to(device)
  model.train()
  # print('model device', next(model.parameters()).is_cuda)
  d_model = None

  if history:
    vq_vae_model_history = history
  else:
    vq_vae_model_history = dict(loss_lst=[], recon_loss_lst=[], perplexity_lst=[])

  i = 0
  pbar = tqdm(train_dataloader, desc='Training', total=320000/16)
  for batch in pbar:
    batch = batch['pixel_values']
    # print('device 1', batch.device)
    batch = batch.to(device)
    # print('device 2', batch.device)

    optimizer.zero_grad()
    loss, x_recon, perplexity = model(batch)

    recon_loss = F.mse_loss(batch, x_recon)
    loss_final = loss+recon_loss-perplexity
    loss_final.backward()

    optimizer.step()

    vq_vae_model_history['loss_lst'].append(loss.item())
    vq_vae_model_history['recon_loss_lst'].append(recon_loss.item())
    vq_vae_model_history['perplexity_lst'].append(perplexity.item())

    pbar.set_postfix({'loss':loss_final.item()})

    if pbar.n>500*i:
      i=i+1
      MODEL_PATH = 'vq-vae_doc_model.pth'
      torch.save(model, MODEL_PATH)

      with open('vq_vae_model_history', 'wb') as f:
        pickle.dump(vq_vae_model_history, f)

      if d_model:
        d_model.Delete()
        d_history.Delete()

      # Create & upload a file.
      d_model = gdrive.CreateFile({'vq-vae_doc_model.pth': 'vq-vae_doc_model.pth'})
      d_model.SetContentFile('vq-vae_doc_model.pth')
      d_model.Upload()
      print('\nUploaded file with ID {}'.format(d_model.get('id')))

      
      # Create & upload a file.
      d_history = gdrive.CreateFile({'vq_vae_model_history': 'vq_vae_model_history'})
      d_history.SetContentFile('vq_vae_model_history')
      d_history.Upload()
      print('\nUploaded file with ID {}'.format(d_history.get('id')))

In [156]:
model = Model()

In [136]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [180]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
f'Number of parameters: {round(pytorch_total_params/10_00_000, 3)}M'

'Number of parameters: 14.11M'

In [None]:
train(model, optimizer)

Training:   0%|          | 1/20000.0 [00:11<63:55:29, 11.51s/it, loss=-2.24]


Uploaded file with ID 1_9_BvAlT-bJppeJNMYlQuhkf-Eb43kMr


Training:   0%|          | 2/20000.0 [00:14<37:02:02,  6.67s/it, loss=-2.24]


Uploaded file with ID 1ysVcwM8wh05hcErWmMlI_nGcKfe0s23K


Training:   3%|▎         | 501/20000.0 [02:58<14:41:36,  2.71s/it, loss=-1.55]


Uploaded file with ID 1eWR3-ycD-1IBK0BCd38RCM5YPUluYpBp


Training:   3%|▎         | 502/20000.0 [03:02<16:52:26,  3.12s/it, loss=-1.55]


Uploaded file with ID 1Y-tE3RqDLn4I6jU9GXwaDED9reuxUyHf


Training:   5%|▌         | 1001/20000.0 [05:46<13:27:37,  2.55s/it, loss=-1.32]


Uploaded file with ID 1yrHb6bHd5WfZuqKRLPlqq6NLNuXZBKai


Training:   5%|▌         | 1002/20000.0 [05:51<18:03:27,  3.42s/it, loss=-1.32]


Uploaded file with ID 18CSKJt_YCqM7GDx26yA3HujxuNYGSccM


Training:   8%|▊         | 1501/20000.0 [08:36<13:20:58,  2.60s/it, loss=-1.38]


Uploaded file with ID 1OAPKfjyeGcwxcpAAiXI1iozFddu_xeSH


Training:   8%|▊         | 1502/20000.0 [08:41<17:26:27,  3.39s/it, loss=-1.38]


Uploaded file with ID 1zjkyL-uolkpxJBXxBuHZbSNn68rpaLo8


Training:  10%|█         | 2001/20000.0 [11:28<15:21:02,  3.07s/it, loss=-1.63]


Uploaded file with ID 1RNCBYWAUmCu4VX158zY6GEomTc2qC_i3


Training:  10%|█         | 2002/20000.0 [11:33<18:05:22,  3.62s/it, loss=-1.63]


Uploaded file with ID 19_dD78OIcjviaFvojBmluTxh7FKkQLTa


Training:  13%|█▎        | 2501/20000.0 [14:19<14:14:48,  2.93s/it, loss=-1.23]


Uploaded file with ID 1XK9WBb5cZS61is_rByMxlemnvnlWI1so


Training:  13%|█▎        | 2502/20000.0 [14:24<17:47:13,  3.66s/it, loss=-1.23]


Uploaded file with ID 1z5z_UgB-e1_9po9n8-Zzp_4BdQ_gxIU8


Training:  15%|█▌        | 3001/20000.0 [17:08<13:10:28,  2.79s/it, loss=-2.44]


Uploaded file with ID 1u-i1Z46oz-1_e9VK29HehiZ9HOAgTCG1


Training:  15%|█▌        | 3002/20000.0 [17:13<15:38:16,  3.31s/it, loss=-2.44]


Uploaded file with ID 1tbuM8YN1g5nBKDTR3qIMdPk8_YbHU1G3


Training:  18%|█▊        | 3501/20000.0 [19:58<12:35:43,  2.75s/it, loss=-1.09]


Uploaded file with ID 1K3zgzh4SgXi7B9C3b4rJWNMgCpCbpbdP


Training:  18%|█▊        | 3502/20000.0 [20:03<15:39:08,  3.42s/it, loss=-1.09]


Uploaded file with ID 1tWci-COZXtD941cmZjK3f9ht-0I8ZW7v


Training:  20%|██        | 4001/20000.0 [22:47<11:43:51,  2.64s/it, loss=-2.21]


Uploaded file with ID 1l4zUZt0WQ_KNJBN2pUWE8H5-pDsJmxZE


Training:  20%|██        | 4002/20000.0 [22:52<14:59:54,  3.38s/it, loss=-2.21]


Uploaded file with ID 1FMH7sJAv7F2n4ePVmqfOOPatwWGhk8ts


Training:  23%|██▎       | 4501/20000.0 [26:00<41:27:59,  9.63s/it, loss=-.605]


Uploaded file with ID 1p9yvdB7MxDtEA1X5oLfNqkhKM7zJKZj1


Training:  23%|██▎       | 4502/20000.0 [26:05<36:11:08,  8.41s/it, loss=-.605]


Uploaded file with ID 1yZ2JVYxvJG_3RRoveuGD_qNyXOlSkn76


Training:  25%|██▌       | 5001/20000.0 [28:50<11:01:10,  2.64s/it, loss=-2.27]


Uploaded file with ID 1svCYmhTKt9rrl6y1OUEGRCAgqrp7KQ7u


Training:  25%|██▌       | 5002/20000.0 [28:55<13:49:55,  3.32s/it, loss=-2.27]


Uploaded file with ID 1knew2RrGVgYz2M0L6XQjF_eys2Gbxaq1


Training:  28%|██▊       | 5501/20000.0 [31:39<11:31:42,  2.86s/it, loss=-4.28]


Uploaded file with ID 1W3Rd3ZbO05C9QEZC5EtJOdFjIjdAyMVr


Training:  28%|██▊       | 5502/20000.0 [31:43<13:20:11,  3.31s/it, loss=-4.28]


Uploaded file with ID 1r09uS9SvOXhGFiH_K6eJwkDBNtuu9eIN


Training:  30%|███       | 6001/20000.0 [34:27<10:15:34,  2.64s/it, loss=-2.1] 


Uploaded file with ID 1PMPbkB8oH1gesHJA59vlYTeDoguawKKL


Training:  30%|███       | 6002/20000.0 [34:32<12:15:17,  3.15s/it, loss=-2.1]


Uploaded file with ID 14zGaT9b7mJFwCSJCOimG9cAWjbk66m9e


Training:  33%|███▎      | 6501/20000.0 [37:18<9:39:05,  2.57s/it, loss=-3.3] 


Uploaded file with ID 1TrNnHRa6ALdbsfGU3JszNXSluPjoye9J


Training:  33%|███▎      | 6502/20000.0 [37:22<11:46:37,  3.14s/it, loss=-3.3]


Uploaded file with ID 1yu_b4GDQIa9Rgb0xoaroiut8JBcRtggp


Training:  35%|███▌      | 7001/20000.0 [40:08<10:23:11,  2.88s/it, loss=-2.54]


Uploaded file with ID 1Q9-Cvx9pa9znCV5i7aJWFAsnCTlJDkkL


Training:  35%|███▌      | 7002/20000.0 [40:12<12:32:06,  3.47s/it, loss=-2.54]


Uploaded file with ID 1mY3UvyAMRGBiGoopqOi3iT3Iocg8e_Py


Training:  38%|███▊      | 7501/20000.0 [42:57<8:56:56,  2.58s/it, loss=-1.78]


Uploaded file with ID 1WHtojuYgd8c__hVG7Yv-NP6Qd00gkIiB


Training:  38%|███▊      | 7502/20000.0 [43:02<11:43:39,  3.38s/it, loss=-1.78]


Uploaded file with ID 1NEC_2Gxrtol8ClBO40R7o_BrAvR2tDFS


Training:  40%|████      | 8001/20000.0 [45:47<9:07:13,  2.74s/it, loss=-1.23]


Uploaded file with ID 1au3kTFbCcGqms7NMeh_rrIVTS2DGO07i


Training:  40%|████      | 8002/20000.0 [45:52<11:17:47,  3.39s/it, loss=-1.23]


Uploaded file with ID 1DoKKfN-5pKvSni87c4J7qiv0RdKSfdLu


Training:  43%|████▎     | 8501/20000.0 [48:36<9:02:13,  2.83s/it, loss=-3.59]


Uploaded file with ID 112-CGHB-ALNcOE91BRzZ6QPTH1w0yM-l


Training:  43%|████▎     | 8502/20000.0 [48:41<11:22:09,  3.56s/it, loss=-3.59]


Uploaded file with ID 13hp2x_O6EvCSLBGxDxrklTt75hmVnuNY


Training:  45%|████▌     | 9001/20000.0 [51:26<8:38:05,  2.83s/it, loss=-1.42]


Uploaded file with ID 1IQCPZtuiXRUsJ21jsFWBfk7gSGIvapXW


Training:  45%|████▌     | 9002/20000.0 [51:31<10:59:09,  3.60s/it, loss=-1.42]


Uploaded file with ID 1vzzVaTps_h_cV3re70DiHWTEq0X05oHb


Training:  48%|████▊     | 9501/20000.0 [54:18<9:39:14,  3.31s/it, loss=-1.31]


Uploaded file with ID 1FbEuFMlCFqxwCg31mWHWc2oqQL7OP5s9


Training:  48%|████▊     | 9502/20000.0 [54:23<11:14:45,  3.86s/it, loss=-1.31]


Uploaded file with ID 1FGn4mG2xAAOhUrEJyAW9Q2EGHratO7tG


Training:  50%|█████     | 10001/20000.0 [57:09<7:36:32,  2.74s/it, loss=-3.9] 


Uploaded file with ID 1MSV1gYst9Snf3EFm2Fjf4xjIayzZN97P


Training:  50%|█████     | 10002/20000.0 [57:13<9:01:51,  3.25s/it, loss=-3.9]


Uploaded file with ID 1QEHaDYJtoxGq91dkxlCV7PMvrOzDAWDL


Training:  53%|█████▎    | 10501/20000.0 [1:00:00<7:15:39,  2.75s/it, loss=-2.7] 


Uploaded file with ID 1aGkXKD7_iVzCSTC3oEmfxlpemp8ghPUz


Training:  53%|█████▎    | 10502/20000.0 [1:00:05<9:09:43,  3.47s/it, loss=-2.7]


Uploaded file with ID 1U_kqu2MmQvRbPPfz9W8AM0zUp0HmFWVC


Training:  55%|█████▌    | 11001/20000.0 [1:02:49<6:18:21,  2.52s/it, loss=-1.38]


Uploaded file with ID 13rwTbAGU_0MI3gZcnXtyiIvf637Pix7M


Training:  55%|█████▌    | 11002/20000.0 [1:02:53<7:52:22,  3.15s/it, loss=-1.38]


Uploaded file with ID 18j0hFo9h5fOEswU9q2W1B1hv3XoM4YH5


Training:  58%|█████▊    | 11501/20000.0 [1:05:37<6:41:30,  2.83s/it, loss=-1.27]


Uploaded file with ID 1UZSt1zAZxVM2Dw0wXy4stj2XLTi3DiJb


Training:  58%|█████▊    | 11502/20000.0 [1:05:42<8:06:22,  3.43s/it, loss=-1.27]


Uploaded file with ID 17hRxSEP4NyOLHbqKTMD69l8JAtmJkELj


Training:  59%|█████▉    | 11875/20000.0 [1:07:37<36:01,  3.76it/s, loss=-1.44]

In [None]:
model