### Visualization

#### Setup Environment

In [1]:
Project_Root = '/gdrive/MyDrive/CV_Project/'

In [2]:
from google.colab import drive
drive.mount('/gdrive')
%cd -q $Project_Root

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [3]:
!ls

checkpoints   documents      __pycache__       train.ipynb	vqvae.py
data	      GetData.ipynb  README.md	       utils.py
decompose.py  images	     requirements.txt  visualize.ipynb


In [4]:
!pip install -r requirements.txt --upgrade



In [5]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from decompose import DecomposeVAE
import torchvision.datasets as datasets
from utils import computeResidual, save_img_tensors_as_grid
import os

#### Visualize the residuals

In [6]:
weight_path = "checkpoints/save_2_best.pth"
device = "cuda:0"

model_container = DecomposeVAE(weight_path=weight_path, device = device)
fullvae = model_container.getFullVAE()

In [7]:
data_dir = './data'

transform = torchvision.transforms.ToTensor()
mnist_testset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)

# Get 1 image from every class
class_single = {}
seen = set()
for img, label in mnist_testset:
    if label not in seen:
        seen.add(label)
        class_single[label] = img
    if len(class_single.keys()) == 10:
        break

sorted_list = sorted(class_single.items())
single_batch = [item for _, item in sorted_list]

if(len(single_batch) < 10):
    print("Not all classes are present")

single_batch = torch.stack(single_batch).to(device)

In [8]:
residual_save = "images/"

if(not os.path.isdir(residual_save)):
    os.makedirs(residual_save)

fullvae.eval()
with torch.no_grad():
    pred = fullvae(single_batch)["x_recon"]
    residual = computeResidual(pred, single_batch)
    print(residual.shape)
    print(single_batch.shape)
    save_img_tensors_as_grid(residual, 1,  f"{residual_save}residual")
    save_img_tensors_as_grid(single_batch, 1,  f"{residual_save}residual_ground_truth")

torch.Size([10, 1, 28, 28])
torch.Size([10, 1, 28, 28])


#### Visualize the Latent Space Residual