In [None]:
Project_Root = '/gdrive/MyDrive/CV_Project/'
from google.colab import drive
drive.mount('/gdrive')
%cd -q $Project_Root

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [None]:
!ls

checkpoints    models		   train_hrvae.ipynb  visualization.py
data	       __pycache__	   train_pixelCNN     visualize.ipynb
documents      README.md	   train_vae.ipynb
GetData.ipynb  requirements.txt    train_vqvae.ipynb
images	       residualDataset.py  utils.py


In [None]:
!pip install -r requirements.txt --upgrade



In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from models.decompose import DecomposeVAE
import torch.nn.functional as F

In [None]:
weight_path = "checkpoints/save_3_best.pth"
device = "cuda:0"

model_container = DecomposeVAE(weight_path=weight_path, device = device)

encoder = model_container.getEncoder().eval()
quantizer = model_container.getQuantizer().eval()
codebook = model_container.getCodeBook()
decoder = model_container.getDecoder().eval()
fullvae = model_container.getFullVAE().eval()

In [None]:
data_dir = './data'

transform = torchvision.transforms.ToTensor()
mnist_trainset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)
mnist_testset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)

trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=128, shuffle=True, num_workers=1)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=128, shuffle=True, num_workers=1)

# Get 1 image from every class
class_single = {}
seen = set()
for img, label in mnist_testset:
    if label not in seen:
        seen.add(label)
        class_single[label] = img
    if len(class_single.keys()) == 10:
        break

sorted_list = sorted(class_single.items())
single_batch = [item for _, item in sorted_list]

if(len(single_batch) < 10):
    print("Not all classes are present")

single_batch = torch.stack(single_batch).to(device)

In [None]:
z = encoder(single_batch)
(z_quantized, dictionary_loss, commitment_loss, encoding_indices) = quantizer(z)

print(z_quantized.shape)
print(encoding_indices.shape)

# print(z_quantized[0])
# print(encoding_indices[0])

torch.Size([10, 64, 7, 7])
torch.Size([10, 49])


In [None]:
from models.pixel_cnn import PixelCNN

In [None]:
net = PixelCNN(input_dim=1,hidden_dim=64,output_dim=512).to(device)
optimizer = torch.optim.Adam(net.parameters())
best_performance = np.inf
for epoch in range(100):
  # train
  err_tr = []
  net.train()
  for input, _ in trainloader:
      input = Variable(input.to(device))
      z = encoder(input)
      (_, _, _, target) = quantizer(z)
      H = W = int(target.shape[1] ** 0.5)
      target = target.view(target.shape[0], 1, H, W).long()
      loss = F.cross_entropy(net(target.float()), target.squeeze())
      err_tr.append(loss.item())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

  # compute error on test set
  err_te = []
  net.eval()
  for input, _ in testloader:
      input = Variable(input.to(device))
      z = encoder(input)
      (_, _, _, target) = quantizer(z)
      H = W = int(target.shape[1] ** 0.5)
      target = target.view(target.shape[0], 1, H, W).long()
      loss = F.cross_entropy(net(target.float()), target.squeeze())
      err_te.append(loss.item())

  if(epoch%5 == 0):
    torch.save(net.state_dict(), "checkpoints/pixel_cnn.pth")
  if(np.mean(err_te) < best_performance):
    best_performance = np.mean(err_te)
    torch.save(net.state_dict(), "checkpoints/best_pixel_cnn.pth")

  # #
  # sample
  # sample.fill_(0)
  # net.train(False)
  # for i in range(28):
  #     for j in range(28):
  #         out = net(Variable(sample, volatile=True))
  #         probs = F.softmax(out[:, :, i, j]).data
  #         sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.

  print('epoch={}; nll_train={:.7f}; nll_test={:.7f}'.format(
      epoch, np.mean(err_tr), np.mean(err_te)))


epoch=0; nll_train=4.7882872; nll_test=4.1144854
epoch=1; nll_train=4.0010654; nll_test=3.9343831
epoch=2; nll_train=3.9023382; nll_test=3.8934677
epoch=3; nll_train=3.8647943; nll_test=3.8539963
epoch=4; nll_train=3.8264385; nll_test=3.8467280
epoch=5; nll_train=3.7658563; nll_test=3.7219917
epoch=6; nll_train=3.6973830; nll_test=3.6774906
epoch=7; nll_train=3.6352258; nll_test=3.6778157
epoch=8; nll_train=3.5757275; nll_test=3.5377940
epoch=9; nll_train=3.5302673; nll_test=3.5154070
epoch=10; nll_train=3.4881761; nll_test=3.4486524
epoch=11; nll_train=3.4489390; nll_test=3.4675242
epoch=12; nll_train=3.4094684; nll_test=3.3794234
epoch=13; nll_train=3.3705598; nll_test=3.3477763
epoch=14; nll_train=3.3375274; nll_test=3.3416043
epoch=15; nll_train=3.3138925; nll_test=3.2998762
epoch=16; nll_train=3.2770421; nll_test=3.2609803
epoch=17; nll_train=3.2538407; nll_test=3.2416480
epoch=18; nll_train=3.2316283; nll_test=3.1957616
epoch=19; nll_train=3.2073814; nll_test=3.1721684
epoch=20; 

### Latent space Indicies Generation

In [None]:
net = PixelCNN(input_dim=1,hidden_dim=64,output_dim=512).to(device)
net.load_state_dict(torch.load("checkpoints/best_pixel_cnn.pth"))
net.eval()

PixelCNN(
  (net): Sequential(
    (0): MaskedConv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (13): BatchNorm2d(64, eps=1e-0

In [None]:
# Generates the codebook indices
num_gen = 36
sample = torch.zeros(num_gen, 1, 7, 7).to(device)
sample.fill_(0)
net.train(False)
with torch.no_grad():
  for i in range(7):
      for j in range(7):
          out = net(Variable(sample, volatile=True))
          probs = F.softmax(out[:, :, i, j]).data
          sample[:, :, i, j] = torch.multinomial(probs, 1).float()

  out = net(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


In [None]:
print(codebook.shape)
sample_viewed = sample.view(sample.shape[0], -1).long()
print(sample_viewed.shape)
sample_after_codebook = codebook[:,sample_viewed]
print(sample_after_codebook.shape)
sample_after_codebook_reshape = sample_after_codebook.permute((1,0,2)).reshape(num_gen, 64, 7, 7)
print(sample_after_codebook_reshape.shape)

torch.Size([64, 512])
torch.Size([36, 49])
torch.Size([64, 36, 49])
torch.Size([36, 64, 7, 7])


In [None]:
decoded_image = decoder(sample_after_codebook_reshape).cpu().detach()
print(decoded_image.shape)

torch.Size([36, 1, 28, 28])


In [None]:
from utils import save_img_tensors_as_grid

In [None]:
# def show_as_image(binary_image, save, figsize=(10, 5)):
#     plt.figure(figsize=figsize)
#     plt.imshow(binary_image, cmap='gray')
#     if save:
#       plt.savefig(save)
#     plt.xticks([]); plt.yticks([])

# def batch_images_to_one(batches_images):
#     n_square_elements = int(np.sqrt(batches_images.shape[0]))
#     rows_images = np.split(np.squeeze(batches_images), n_square_elements)
#     return np.vstack([np.hstack(row_images) for row_images in rows_images])

In [None]:
save_img_tensors_as_grid(decoded_image, nrows=5, f="images/generated_imgs")

1