In [120]:
Project_Root = '/gdrive/MyDrive/CV_Project/'
from google.colab import drive
drive.mount('/gdrive')
%cd -q $Project_Root

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
!ls

checkpoints		     metric.ipynb	 train_pixelCNN
data			     models		 train_vae.ipynb
documents		     __pycache__	 train_vqvae.ipynb
GetData.ipynb		     README.md		 utils.py
hierachical_vae_train.ipynb  requirements.txt	 visualization.py
hierarchical_train_pixelCNN  residualDataset.py  visualize.ipynb
images			     train_hrvae.ipynb


In [70]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
from skimage import data, img_as_float
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
import os
from tqdm import tqdm
from models.vae import VanillaVAE

In [106]:
BATCH_SIZE = 64
transform = torchvision.transforms.ToTensor()

data_dir = './data'

mnist_testset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
print(len(mnist_testset))

10000


In [5]:
class Config:
    # --- Paths ---
    WEIGHT_PATH = "checkpoints/save_3_best.pth"
    DATA_DIR = "./data"
    OUTPUT_DIR = "./checkpoints"

    # --- Training Hyperparameters ---
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 64
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 10
    LOG_INTERVAL = 100

    # --- Model Configuration ---
    NUM_HIERARCHY_LAYERS = 2

    # --- Loss Weights (Adjust to tune training) ---
    LOSS_WEIGHT_SMOOTHNESS = 1.0
    LOSS_WEIGHT_RESIDUAL = 0.1

class ResidualLatentUNet(nn.Module):
    def __init__(self, model_container, device="cpu", num_layers=2):
        super().__init__()
        self.device = torch.device(device)
        self.fullvae = model_container.getFullVAE().to(self.device)
        self.fullvae.eval()

        # Freeze VAE parameters
        for param in self.fullvae.parameters():
            param.requires_grad = False

        # Get latent dimensions from VAE instance
        with torch.no_grad():
            dummy = torch.zeros(1, 1, 28, 28, device=self.device)
            zq, *_ = self.fullvae.quantize(dummy)
            _, latent_ch, latent_h, latent_w = zq.shape

        in_ch = latent_ch * 2  # concatenated (image + residual)
        print(f"  U-Net Input Latent Size: {latent_h}x{latent_w}, {in_ch} channels")

        # --- Encoder Path (Compression) ---
        # Enc1: 7x7 -> 4x4 (Skip 1)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        # Enc2: 4x4 -> 2x2 (Skip 2)
        self.enc2 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        # --- Bottleneck ---
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # --- Decoder Path (Decompression & Skip) ---

        # Dec2 (Innermost): Upsample (512ch) + Skip (512ch) -> 256ch
        self.up2 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(512 + 512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Dec1 (Outermost): Upsample (256ch) + Skip (256ch) -> 128ch
        self.up1 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(256 + 256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Final projection: 128ch -> 128ch (in_ch)
        self.final = nn.Conv2d(128, in_ch, 1)

    def forward(self, img_tensor):
        img_tensor = img_tensor.to(self.device)

        with torch.no_grad():
            # Get VAE latents and reconstruction
            zq_img, *_ = self.fullvae.quantize(img_tensor)
            recon = self.fullvae.decoder(zq_img)
            residual = img_tensor - recon

            # Get residual latent
            zq_res, *_ = self.fullvae.quantize(residual)

            # Align spatial dims (Safety check, should match)
            if zq_img.shape[-2:] != zq_res.shape[-2:]:
                zq_res = F.interpolate(zq_res, size=zq_img.shape[-2:], mode='nearest')

        # U-Net Input (7x7, 128ch)
        z_concat = torch.cat([zq_img, zq_res], dim=1)

        # Encoder path
        e1 = self.enc1(z_concat)    # Skip 1 (e1, expected 4x4)
        e2 = self.enc2(e1)          # Skip 2 (e2, expected 2x2)

        # Bottleneck
        b = self.bottleneck(e2)     # (expected 2x2)

        # Decoder 2 (Skip: e2)
        d2_up = self.up2(b)         # Upconvolution (expected 4x4)

        # Interpolation check for Dec 2: Target size is e2.shape (expected 2x2)
        if d2_up.shape[2:] != e2.shape[2:]:
            # Resize upsampled feature to match the skip connection
            d2_up = F.interpolate(d2_up, size=e2.shape[2:], mode='nearest')
        d2 = self.dec2(torch.cat([d2_up, e2], dim=1))

        # Decoder 1 (Skip: e1)
        d1_up = self.up1(d2)        # Upconvolution (expected 8x8 or 6x6)

        # Interpolation check for Dec 1: Target size is z_concat.shape (expected 7x7)
        if d1_up.shape[2:] != z_concat.shape[2:]:
            d1_up = F.interpolate(d1_up, size=z_concat.shape[2:], mode='nearest') # This forces d1_up to 7x7

        # --- FIX: SPATIAL MISMATCH RESOLUTION ---
        # The error occurs because e1 (e.g., 4x4 or 3x3) does not match d1_up (7x7).
        # We must resize e1 to match the target size of d1_up (which is 7x7).
        if d1_up.shape[2:] != e1.shape[2:]:
            e1_resized = F.interpolate(e1, size=d1_up.shape[2:], mode='nearest')
        else:
            e1_resized = e1
        # ----------------------------------------

        d1 = self.dec1(torch.cat([d1_up, e1_resized], dim=1))

        # Final projection
        z_refined = self.final(d1)

        return {
            "z_image": zq_img,
            "z_residual": zq_res,
            "z_concat": z_concat,
            "z_refined": z_refined,
            "recon": recon,
            "residual": residual,
        }

In [6]:
from models.vqvae import VQVAE
from models.decompose import DecomposeVAE

In [7]:
image_vae_path = "checkpoints/save_3_best.pth"
device = "cuda:0"

model_container = DecomposeVAE(weight_path=image_vae_path, device = device) # this runs in the datsaet so should be on cpu
fullvae = model_container.getFullVAE()

In [64]:
## Hierarchical Residual VAE
config = Config()
model_container = DecomposeVAE(config.WEIGHT_PATH, config.DEVICE)
hrvae = ResidualLatentUNet(model_container=model_container, device=config.DEVICE).to(config.DEVICE)
state_dict= torch.load("checkpoints/final_model.pth")["model_state_dict"]
hrvae.load_state_dict(state_dict)
hrvae.eval()

  U-Net Input Latent Size: 7x7, 128 channels


ResidualLatentUNet(
  (fullvae): VQVAE(
    (encoder): Encoder(
      (conv): Sequential(
        (down0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu0): ReLU()
        (down1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu1): ReLU()
        (final_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (residual_stack): ResidualStack(
        (layers): ModuleList(
          (0-1): 2 x Sequential(
            (0): ReLU()
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
    (pre_vq_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (vq): VectorQuantizer(
      (N_i_ts): SonnetExponentialMovingAverage()
      (m_i_ts): SonnetExponentialMovingAverage()
    )
    (decoder): Decoder(
      (conv): Conv2d(64, 128, kerne

In [75]:
checkpoint = torch.load("checkpoints/vae_save1_best.pth")
vae = VanillaVAE(in_channels = 1, latent_dim = 128).to(device)
vae.load_state_dict(checkpoint["model_state_dict"])
vae.eval()

VanillaVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
  )
  (fc_m

### Reconstruction Metrics

In [54]:
from skimage.metrics import peak_signal_noise_ratio as psnr

In [81]:
def metrics(model, testloader, device="cuda:0"):
    mse_hist = []
    ssim_hist = []
    psnr_hist = []

    for item in tqdm(testloader):
        img, label = item

        if model == "VQ-VAE":
            pred = fullvae(img.to(device))["x_recon"]
        elif model == "VAE":
            pred, _, mu, logvar = vae(img.to(device))
        else:
            pred = hrvae(img.to(device))["recon"]

        pred_numpy = pred.cpu().detach().squeeze().numpy()
        img_numpy = img.squeeze().numpy()

        mse = mean_squared_error(img_numpy.reshape(-1), pred_numpy.reshape(-1))
        mse_hist.append(mse)

        batch_ssim = []
        batch_psnr = []
        for i in range(len(img)):
            ssim_val = ssim(img_numpy[i], pred_numpy[i],
                          data_range=1.0,
                          win_size=7,
                          gaussian_weights=True)
            batch_ssim.append(ssim_val)
            batch_psnr.append(psnr(img_numpy[i], pred_numpy[i]))

        ssim_hist.append(np.mean(batch_ssim))
        psnr_hist.append(np.mean(batch_psnr))

    return mse_hist, ssim_hist, psnr_hist

In [86]:
mse_hist, ssim_hist, psnr_hist = metrics("VQ-VAE", testloader, device = "cuda:0")

100%|██████████| 10000/10000 [00:56<00:00, 176.96it/s]


In [87]:
mse_mn1, mse_stdd1 = np.mean(mse_hist), np.std(mse_hist)
print(f"MSE: {mse_mn1} +- {mse_stdd1}")
ssim_mn1, ssim_stdd1 = np.mean(ssim_hist), np.std(ssim_hist)
print(f"SSIM: {ssim_mn1} +- {ssim_stdd1}")
psnr_mn1, psnr_stdd1 = np.mean(psnr_hist), np.std(psnr_hist)
print(f"PSNR: {psnr_mn1} +- {psnr_stdd1}")

MSE: 0.0016477555109603325 +- 0.0007014837833603243
SSIM: 0.9558864325842032 +- 0.02242286187122203
PSNR: 50.47285651026326 +- 1.6234415458649374


In [88]:
mse_hist1, ssim_hist1, psnr_hist1 = metrics("HR-VAE", testloader, device = "cuda:0")

100%|██████████| 10000/10000 [01:29<00:00, 111.75it/s]


In [89]:
mse_mn2, mse_stdd2 = np.mean(mse_hist1), np.std(mse_hist1)
print(f"MSE: {mse_mn2} +- {mse_stdd2}")
ssim_mn2, ssim_stdd2 = np.mean(ssim_hist1), np.std(ssim_hist1)
print(f"SSIM: {ssim_mn2} +- {ssim_stdd2}")
psnr_mn2, psnr_stdd2 = np.mean(psnr_hist1), np.std(psnr_hist1)
print(f"PSNR: {psnr_mn2} +- {psnr_stdd2}")

MSE: 0.0014575542728746899 +- 0.0006315229046148367
SSIM: 0.95929028042509 +- 0.02058485633848428
PSNR: 50.70523145790703 +- 1.5850760627531841


In [78]:
BATCH_SIZE = 1
transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(2),
    torchvision.transforms.ToTensor()
])

data_dir = './data'

mnist_testset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
print(len(mnist_testset))

10000


In [83]:
mse_hist2, ssim_hist2, psnr_hist2 = metrics("VAE", testloader, device = "cuda:0")

100%|██████████| 10000/10000 [00:47<00:00, 210.89it/s]


In [84]:
mse_mn3, mse_stdd3 = np.mean(mse_hist2), np.std(mse_hist2)
print(f"MSE: {mse_mn3} +- {mse_stdd3}")
ssim_mn3, ssim_stdd3 = np.mean(ssim_hist2), np.std(ssim_hist2)
print(f"SSIM: {ssim_mn3} +- {ssim_stdd3}")
psnr_mn3, psnr_stdd3 = np.mean(psnr_hist2), np.std(psnr_hist2)
print(f"PSNR: {psnr_mn3} +- {psnr_stdd3}")

MSE: 0.002700993245734508 +- 0.0011634623116972328
SSIM: 0.9462470612784327 +- 0.03392890303367568
PSNR: 48.941994357349536 +- 3.03701663739415


### Generaton Metrics


In [136]:
!pip install torch-fidelity
!pip install prdc

Collecting prdc
  Downloading prdc-0.2-py3-none-any.whl.metadata (8.7 kB)
Downloading prdc-0.2-py3-none-any.whl (6.0 kB)
Installing collected packages: prdc
Successfully installed prdc-0.2


In [138]:
from models.pixel_cnn import PixelCNN
from torch.autograd import Variable
from torchvision import transforms
from prdc import compute_prdc
import torch_fidelity

net = PixelCNN(input_dim=1,hidden_dim=64,output_dim=512).to(device)
net.load_state_dict(torch.load("checkpoints/best_pixel_cnn.pth"))
net.eval()

codebook = model_container.getCodeBook()
decoder = model_container.getDecoder().eval()

In [146]:
class VQVAEGenerator(torch.nn.Module):
    def __init__(self, net, codebook, decoder, device='cuda'):
        super().__init__()
        self.net = net
        self.codebook = codebook
        self.decoder = decoder
        self.device = device

    def forward(self, batch_size = 64):
        num_gen = batch_size
        sample = torch.zeros(num_gen, 1, 7, 7, device=self.device)
        self.net.eval()
        with torch.no_grad():
            for i in range(7):
                for j in range(7):
                    out = self.net(Variable(sample))
                    probs = F.softmax(out[:, :, i, j], dim=1)
                    sample[:, :, i, j] = torch.multinomial(probs, 1).float()

        sample_viewed = sample.view(sample.shape[0], -1).long()
        sample_after_codebook = self.codebook[:, sample_viewed]
        sample_after_codebook_reshape = sample_after_codebook.permute((1, 0, 2)).reshape(num_gen, 64, 7, 7)
        decoded_image = self.decoder(sample_after_codebook_reshape)
        return decoded_image

generator = VQVAEGenerator(net, codebook, decoder, device=device)

In [151]:
prec = []
recall = []
for real, _ in testloader:
  generated_image = generator()

  real = (real - real.min()) / (real.max() - real.min())
  generated_image = (generated_image - generated_image.min()) / (generated_image.max() - generated_image.min())

  real_flat = real.view(real.shape[0], -1).cpu().numpy()
  gen_flat = generated_image.view(generated_image.shape[0], -1).cpu().numpy()

  prdc_metrics = compute_prdc(real_features=real_flat, fake_features=gen_flat, nearest_k=5)
  recall.append(prdc_metrics["precision"].item())
  prec.append(prdc_metrics["recall"].item())

Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64 Num fake: 64
Num real: 64

In [152]:
print(np.mean(recall), np.mean(prec))
print(np.std(recall), np.std(prec))

0.9716361464968153 0.12161624203821655
0.06276867153091269 0.06871738675594134


In [153]:
net1 = PixelCNN(input_dim=1,hidden_dim=64,output_dim=256).to(device)
net1.load_state_dict(torch.load("checkpoints/hier_best_pixel_cnn.pth"))
net1.eval()

PixelCNN(
  (net): Sequential(
    (0): MaskedConv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): MaskedConv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (13): BatchNorm2d(64, eps=1e-0

In [155]:
def genHRVAE(batch_size = 64):
  num_gen = batch_size
  sample = torch.zeros(num_gen, 1, 28, 28).to(device)
  sample.fill_(0)
  net.train(False)
  with torch.no_grad():
    for i in range(28):
        for j in range(28):
            out = net1(Variable(sample, volatile=True))
            probs = F.softmax(out[:, :, i, j]).data
            sample[:, :, i, j] = torch.multinomial(probs, 1).float()
  return sample

In [158]:
prec1 = []
recall1 = []
for idx, (real, _) in enumerate(testloader):
  if(idx>10):
    continue
  generated_image = genHRVAE()

  real = (real - real.min()) / (real.max() - real.min())
  generated_image = (generated_image - generated_image.min()) / (generated_image.max() - generated_image.min())

  real_flat = real.view(real.shape[0], -1).cpu().numpy()
  gen_flat = generated_image.view(generated_image.shape[0], -1).cpu().numpy()

  prdc_metrics = compute_prdc(real_features=real_flat, fake_features=gen_flat, nearest_k=5)
  recall1.append(prdc_metrics["precision"].item())
  prec1.append(prdc_metrics["recall"].item())

  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


  out = net1(Variable(sample, volatile=True))
  probs = F.softmax(out[:, :, i, j]).data


Num real: 64 Num fake: 64


In [159]:
print(np.mean(recall1), np.mean(prec1))
print(np.std(recall1), np.std(prec1))

0.9417613636363636 0.5625
0.04111964439241514 0.051607676490298154
