### Setup Environment

In [1]:
Project_Root = '/gdrive/MyDrive/CV_Project/'
from google.colab import drive
drive.mount('/gdrive')
%cd -q $Project_Root

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
!ls

checkpoints		     models		 train_pixelCNN
data			     __pycache__	 train_vae.ipynb
documents		     README.md		 train_vqvae.ipynb
GetData.ipynb		     requirements.txt	 utils.py
hierachical_vae_train.ipynb  residualDataset.py  visualization.py
images			     train_hrvae.ipynb	 visualize.ipynb


In [31]:
# !pip install -r requirements.txt --upgrade

### Setup all the data and hyper-parameters

In [3]:
import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [4]:
data_dir = './data'

BATCH_SIZE = 64
transform = torchvision.transforms.ToTensor()

mnist_trainset = datasets.MNIST(root=data_dir, train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

mnist_testset = datasets.MNIST(root=data_dir, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
print(len(mnist_testset))

10000


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# Assuming these files are present in the environment
from models.decompose import DecomposeVAE

class Config:
    # --- Paths ---
    WEIGHT_PATH = "checkpoints/save_3_best.pth"
    DATA_DIR = "./data"
    OUTPUT_DIR = "./checkpoints"

    # --- Training Hyperparameters ---
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE = 64
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 10
    LOG_INTERVAL = 100

    # --- Model Configuration ---
    NUM_HIERARCHY_LAYERS = 2

    # --- Loss Weights (Adjust to tune training) ---
    LOSS_WEIGHT_SMOOTHNESS = 1.0
    LOSS_WEIGHT_RESIDUAL = 0.1

class ResidualLatentUNet(nn.Module):
    def __init__(self, model_container, device="cpu", num_layers=2):
        super().__init__()
        self.device = torch.device(device)
        self.fullvae = model_container.getFullVAE().to(self.device)
        self.fullvae.eval()

        # Freeze VAE parameters
        for param in self.fullvae.parameters():
            param.requires_grad = False

        # Get latent dimensions from VAE instance
        with torch.no_grad():
            dummy = torch.zeros(1, 1, 28, 28, device=self.device)
            zq, *_ = self.fullvae.quantize(dummy)
            _, latent_ch, latent_h, latent_w = zq.shape

        in_ch = latent_ch * 2  # concatenated (image + residual)
        print(f"  U-Net Input Latent Size: {latent_h}x{latent_w}, {in_ch} channels")

        # --- Encoder Path (Compression) ---
        # Enc1: 7x7 -> 4x4 (Skip 1)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_ch, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        # Enc2: 4x4 -> 2x2 (Skip 2)
        self.enc2 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        # --- Bottleneck ---
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 512, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # --- Decoder Path (Decompression & Skip) ---

        # Dec2 (Innermost): Upsample (512ch) + Skip (512ch) -> 256ch
        self.up2 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(512 + 512, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Dec1 (Outermost): Upsample (256ch) + Skip (256ch) -> 128ch
        self.up1 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(256 + 256, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
        )

        # Final projection: 128ch -> 128ch (in_ch)
        self.final = nn.Conv2d(128, in_ch, 1)

    def forward(self, img_tensor):
        img_tensor = img_tensor.to(self.device)

        with torch.no_grad():
            # Get VAE latents and reconstruction
            zq_img, *_ = self.fullvae.quantize(img_tensor)
            recon = self.fullvae.decoder(zq_img)
            residual = img_tensor - recon

            # Get residual latent
            zq_res, *_ = self.fullvae.quantize(residual)

            # Align spatial dims (Safety check, should match)
            if zq_img.shape[-2:] != zq_res.shape[-2:]:
                zq_res = F.interpolate(zq_res, size=zq_img.shape[-2:], mode='nearest')

        # U-Net Input (7x7, 128ch)
        z_concat = torch.cat([zq_img, zq_res], dim=1)

        # Encoder path
        e1 = self.enc1(z_concat)    # Skip 1 (e1, expected 4x4)
        e2 = self.enc2(e1)          # Skip 2 (e2, expected 2x2)

        # Bottleneck
        b = self.bottleneck(e2)     # (expected 2x2)

        # Decoder 2 (Skip: e2)
        d2_up = self.up2(b)         # Upconvolution (expected 4x4)

        # Interpolation check for Dec 2: Target size is e2.shape (expected 2x2)
        if d2_up.shape[2:] != e2.shape[2:]:
            # Resize upsampled feature to match the skip connection
            d2_up = F.interpolate(d2_up, size=e2.shape[2:], mode='nearest')
        d2 = self.dec2(torch.cat([d2_up, e2], dim=1))

        # Decoder 1 (Skip: e1)
        d1_up = self.up1(d2)        # Upconvolution (expected 8x8 or 6x6)

        # Interpolation check for Dec 1: Target size is z_concat.shape (expected 7x7)
        if d1_up.shape[2:] != z_concat.shape[2:]:
            d1_up = F.interpolate(d1_up, size=z_concat.shape[2:], mode='nearest') # This forces d1_up to 7x7

        # --- FIX: SPATIAL MISMATCH RESOLUTION ---
        # The error occurs because e1 (e.g., 4x4 or 3x3) does not match d1_up (7x7).
        # We must resize e1 to match the target size of d1_up (which is 7x7).
        if d1_up.shape[2:] != e1.shape[2:]:
            e1_resized = F.interpolate(e1, size=d1_up.shape[2:], mode='nearest')
        else:
            e1_resized = e1
        # ----------------------------------------

        d1 = self.dec1(torch.cat([d1_up, e1_resized], dim=1))

        # Final projection
        z_refined = self.final(d1)

        return {
            "z_image": zq_img,
            "z_residual": zq_res,
            "z_concat": z_concat,
            "z_refined": z_refined,
            "recon": recon,
            "residual": residual,
        }

In [55]:
class HierarchicalLoss(nn.Module):
    def __init__(self, smoothness_weight, residual_weight):
        super().__init__()
        self.smoothness_weight = smoothness_weight
        self.residual_weight = residual_weight

    def forward(self, unet_output, original_image):
        z_refined = unet_output["z_refined"]
        z_concat = unet_output["z_concat"]
        smoothness_loss = F.mse_loss(z_refined, z_concat)

        z_residual = unet_output["z_residual"]
        residual_energy_loss = torch.mean(z_residual ** 2)

        total_loss = (self.smoothness_weight * smoothness_loss) + \
                     (self.residual_weight * residual_energy_loss)

        return total_loss, {
            "total": total_loss.item(),
            "smoothness": smoothness_loss.item(),
            "residual": residual_energy_loss.item(),
        }


# ==============================================================================
# 4. TRAINING AND UTILITY FUNCTIONS
# ==============================================================================

def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    total_loss, total_smoothness, total_residual = 0, 0, 0
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{Config.NUM_EPOCHS}')

    for batch_idx, (data, _) in enumerate(pbar):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss, loss_dict = criterion(output, data)
        loss.backward()
        optimizer.step()

        total_loss += loss_dict['total']
        total_smoothness += loss_dict['smoothness']
        total_residual += loss_dict['residual']

        if batch_idx % Config.LOG_INTERVAL == 0:
            pbar.set_postfix({
                'loss': f"{loss_dict['total']:.4f}",
                'smooth': f"{loss_dict['smoothness']:.4f}",
                'res_e': f"{loss_dict['residual']:.4f}"
            })

    num_batches = len(dataloader)
    return {
        'total': total_loss / num_batches,
        'smoothness': total_smoothness / num_batches,
        'residual': total_residual / num_batches
    }


def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_smoothness, total_residual = 0, 0, 0
    with torch.no_grad():
        for data, _ in tqdm(dataloader, desc='Validating', leave=False):
            data = data.to(device)
            output = model(data)
            loss, loss_dict = criterion(output, data)

            total_loss += loss_dict['total']
            total_smoothness += loss_dict['smoothness']
            total_residual += loss_dict['residual']

    num_batches = len(dataloader)
    return {
        'total': total_loss / num_batches,
        'smoothness': total_smoothness / num_batches,
        'residual': total_residual / num_batches
    }

In [56]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs, save_dir='checkpoints'):
    """
    Complete training loop with validation and checkpoint saving
    """

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Initialize best loss tracking
    best_val_loss = float('inf')
    train_history = {'train': [], 'val': []}

    # Training loop
    for epoch in range(num_epochs):
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"{'='*60}")

        # Train for one epoch
        train_metrics = train_epoch(
            model=model,
            dataloader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch
        )

        # Validate
        val_metrics = validate(
            model=model,
            dataloader=val_loader,
            criterion=criterion,
            device=device
        )

        # Store history
        train_history['train'].append(train_metrics)
        train_history['val'].append(val_metrics)

        # Print epoch summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train - Total: {train_metrics['total']:.6f}, "
              f"Smoothness: {train_metrics['smoothness']:.6f}, "
              f"Residual: {train_metrics['residual']:.6f}")
        print(f"  Val   - Total: {val_metrics['total']:.6f}, "
              f"Smoothness: {val_metrics['smoothness']:.6f}, "
              f"Residual: {val_metrics['residual']:.6f}")

        # Save best model
        if val_metrics['total'] < best_val_loss:
            best_val_loss = val_metrics['total']
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'best_val_loss': best_val_loss,
            }, os.path.join(save_dir, 'best_model.pth'))
            print(f"  ✓ Saved best model with val loss: {best_val_loss:.6f}")

        # Save regular checkpoint (every 10 epochs or adjust as needed)
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_history': train_history,
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
            print(f"  ✓ Saved checkpoint at epoch {epoch+1}")

    # Save final model
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_history': train_history,
        'final_val_loss': val_metrics['total'],
    }, os.path.join(save_dir, 'final_model.pth'))

    print(f"\n{'='*60}")
    print("Training completed!")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"{'='*60}")

    return model, train_history

In [57]:
config = Config()
model_container = DecomposeVAE(config.WEIGHT_PATH, config.DEVICE)
model = ResidualLatentUNet(model_container=model_container, device=config.DEVICE).to(config.DEVICE)
train_params = [params for params in model.parameters()]
optimizer = torch.optim.Adam(train_params, lr=config.LEARNING_RATE)
criterion = HierarchicalLoss(smoothness_weight=config.LOSS_WEIGHT_SMOOTHNESS, residual_weight=config.LOSS_WEIGHT_RESIDUAL)

  U-Net Input Latent Size: 7x7, 128 channels


In [58]:
train_model(model, trainloader, testloader, criterion, optimizer, config.DEVICE, config.NUM_EPOCHS, save_dir='checkpoints')


Epoch 1/10


Epoch 1/10: 100%|██████████| 938/938 [00:43<00:00, 21.68it/s, loss=0.0861, smooth=0.0723, res_e=0.1373]



Epoch 1 Summary:
  Train - Total: 0.204569, Smoothness: 0.190549, Residual: 0.140207
  Val   - Total: 0.087627, Smoothness: 0.073557, Residual: 0.140692
  ✓ Saved best model with val loss: 0.087627

Epoch 2/10


Epoch 2/10: 100%|██████████| 938/938 [00:41<00:00, 22.72it/s, loss=0.0581, smooth=0.0443, res_e=0.1375]



Epoch 2 Summary:
  Train - Total: 0.069092, Smoothness: 0.055094, Residual: 0.139979
  Val   - Total: 0.061541, Smoothness: 0.047521, Residual: 0.140198
  ✓ Saved best model with val loss: 0.061541

Epoch 3/10


Epoch 3/10: 100%|██████████| 938/938 [00:42<00:00, 21.87it/s, loss=0.0475, smooth=0.0339, res_e=0.1369]



Epoch 3 Summary:
  Train - Total: 0.052538, Smoothness: 0.038560, Residual: 0.139782
  Val   - Total: 0.049551, Smoothness: 0.035529, Residual: 0.140223
  ✓ Saved best model with val loss: 0.049551

Epoch 4/10


Epoch 4/10: 100%|██████████| 938/938 [00:43<00:00, 21.67it/s, loss=0.0404, smooth=0.0269, res_e=0.1355]



Epoch 4 Summary:
  Train - Total: 0.044218, Smoothness: 0.030244, Residual: 0.139743
  Val   - Total: 0.042053, Smoothness: 0.028031, Residual: 0.140222
  ✓ Saved best model with val loss: 0.042053

Epoch 5/10


Epoch 5/10: 100%|██████████| 938/938 [00:42<00:00, 21.93it/s, loss=0.0361, smooth=0.0224, res_e=0.1363]



Epoch 5 Summary:
  Train - Total: 0.038943, Smoothness: 0.024963, Residual: 0.139794
  Val   - Total: 0.037010, Smoothness: 0.022993, Residual: 0.140171
  ✓ Saved best model with val loss: 0.037010

Epoch 6/10


Epoch 6/10: 100%|██████████| 938/938 [00:42<00:00, 21.97it/s, loss=0.0328, smooth=0.0191, res_e=0.1365]



Epoch 6 Summary:
  Train - Total: 0.035175, Smoothness: 0.021207, Residual: 0.139683
  Val   - Total: 0.035540, Smoothness: 0.021529, Residual: 0.140108
  ✓ Saved best model with val loss: 0.035540

Epoch 7/10


Epoch 7/10: 100%|██████████| 938/938 [00:42<00:00, 22.21it/s, loss=0.0296, smooth=0.0161, res_e=0.1355]



Epoch 7 Summary:
  Train - Total: 0.032286, Smoothness: 0.018328, Residual: 0.139583
  Val   - Total: 0.032282, Smoothness: 0.018280, Residual: 0.140015
  ✓ Saved best model with val loss: 0.032282

Epoch 8/10


Epoch 8/10: 100%|██████████| 938/938 [00:42<00:00, 21.93it/s, loss=0.0280, smooth=0.0143, res_e=0.1373]



Epoch 8 Summary:
  Train - Total: 0.030002, Smoothness: 0.016050, Residual: 0.139523
  Val   - Total: 0.030066, Smoothness: 0.016062, Residual: 0.140033
  ✓ Saved best model with val loss: 0.030066

Epoch 9/10


Epoch 9/10: 100%|██████████| 938/938 [00:42<00:00, 22.21it/s, loss=0.0263, smooth=0.0127, res_e=0.1358]



Epoch 9 Summary:
  Train - Total: 0.028121, Smoothness: 0.014173, Residual: 0.139476
  Val   - Total: 0.028044, Smoothness: 0.014041, Residual: 0.140022
  ✓ Saved best model with val loss: 0.028044

Epoch 10/10


Epoch 10/10: 100%|██████████| 938/938 [00:41<00:00, 22.37it/s, loss=0.0257, smooth=0.0119, res_e=0.1381]



Epoch 10 Summary:
  Train - Total: 0.026649, Smoothness: 0.012705, Residual: 0.139438
  Val   - Total: 0.028349, Smoothness: 0.014352, Residual: 0.139966
  ✓ Saved checkpoint at epoch 10

Training completed!
Best validation loss: 0.028044


(ResidualLatentUNet(
   (fullvae): VQVAE(
     (encoder): Encoder(
       (conv): Sequential(
         (down0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
         (relu0): ReLU()
         (down1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
         (relu1): ReLU()
         (final_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       )
       (residual_stack): ResidualStack(
         (layers): ModuleList(
           (0-1): 2 x Sequential(
             (0): ReLU()
             (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (2): ReLU()
             (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
           )
         )
       )
     )
     (pre_vq_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
     (vq): VectorQuantizer(
       (N_i_ts): SonnetExponentialMovingAverage()
       (m_i_ts): SonnetExponentialMovingAverage()
     )
     (decoder): Decoder(
       (

In [6]:
from utils import save_img_tensors_as_grid, plot_image_batch

In [8]:
config = Config()
model_container = DecomposeVAE(config.WEIGHT_PATH, config.DEVICE)
model = ResidualLatentUNet(model_container=model_container, device=config.DEVICE).to(config.DEVICE)
state_dict= torch.load("checkpoints/final_model.pth")["model_state_dict"]
model.load_state_dict(state_dict)
model.eval()

  U-Net Input Latent Size: 7x7, 128 channels


ResidualLatentUNet(
  (fullvae): VQVAE(
    (encoder): Encoder(
      (conv): Sequential(
        (down0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu0): ReLU()
        (down1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu1): ReLU()
        (final_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (residual_stack): ResidualStack(
        (layers): ModuleList(
          (0-1): 2 x Sequential(
            (0): ReLU()
            (1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
    (pre_vq_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
    (vq): VectorQuantizer(
      (N_i_ts): SonnetExponentialMovingAverage()
      (m_i_ts): SonnetExponentialMovingAverage()
    )
    (decoder): Decoder(
      (conv): Conv2d(64, 128, kerne

In [9]:
ipt, lbl = next(iter(testloader))

In [12]:
pred = model(ipt)["recon"]
save_img_tensors_as_grid(pred, nrows = 6, f="images/unet_recon_output")

1