# VQ-UNET Experiment
This notebook implements the VQ-UNET architecture for image reconstruction on the PASCAL VOC dataset. It includes data loading, model training, and evaluation steps.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

from src.dataset import VOCReconstructionDataset, get_transforms # Updated import
from src.vq_unet_model import VQUNet
from src.trainer import Trainer
from configs.experiment_config import config # Correct way to import

# Display config
config.display()

# Set device
device = torch.device(config.DEVICE)
print(f"Using device: {device}")

# Seed for reproducibility
torch.manual_seed(config.SEED)
if device == 'cuda':
    torch.cuda.manual_seed_all(config.SEED)
np.random.seed(config.SEED)

# Data loading and transformation
transform = get_transforms(image_size=config.IMAGE_SIZE)

print(f"Loading PASCAL VOC 2012 train set from {config.DATA_ROOT}...")
train_dataset = VOCReconstructionDataset(root=config.DATA_ROOT, year='2012', image_set='train', download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

print(f"Loading PASCAL VOC 2012 val set from {config.DATA_ROOT}...")
val_dataset = VOCReconstructionDataset(root=config.DATA_ROOT, year='2012', image_set='val', download=True, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
print("Datasets loaded.")

# Initialize model
# Define encoder channel dimensions (example, can be configured)
encoder_channel_dims = [64, 128, 256, 512] 
model = VQUNet(
    in_channels=3, 
    out_channels=3, 
    codebook_size=config.CODEBOOK_SIZE,
    encoder_channel_dims=encoder_channel_dims,
    commitment_cost=config.COMMITMENT_COST
).to(device)


print("Model initialized.")

# Optimizer and Criterion
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
criterion = nn.MSELoss() # For reconstruction loss
print("Optimizer and criterion initialized.")

# Initialize trainer
trainer = Trainer(
    model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    optimizer=optimizer, 
    criterion=criterion, 
    device=device, 
    num_epochs=config.NUM_EPOCHS, 
    checkpoint_dir=config.CHECKPOINT_DIR,
    vq_loss_weight=config.VQ_LOSS_WEIGHT
)
print("Trainer initialized.")
# Load checkpoint if exists
trainer.load_checkpoint("checkpoints/checkpoint_epoch_6.pth") # Example checkpoint path

# Start training
print("Starting training...")
trainer.train()
print("Training finished.")


Experiment Configuration:
DATA_ROOT: ./data
BATCH_SIZE: 16
IMAGE_SIZE: (128, 128)
CODEBOOK_SIZE: 512
COMMITMENT_COST: 0.25
LEARNING_RATE: 0.0001
NUM_EPOCHS: 25
WEIGHT_DECAY: 1e-05
VQ_LOSS_WEIGHT: 1.0
LOG_INTERVAL: 10
CHECKPOINT_DIR: checkpoints/
USE_CUDA: True
DEVICE: cpu
SEED: 42
Using device: cpu
Loading PASCAL VOC 2012 train set from ./data...
Loading PASCAL VOC 2012 val set from ./data...
Datasets loaded.
Model initialized.
Optimizer and criterion initialized.
Trainer initialized.
Checkpoint loaded: checkpoints/checkpoint_epoch_6.pth, starting from epoch 6
Starting training...


Training Epoch 1/25:   3%|▎         | 12/358 [00:28<13:52,  2.41s/it, Recon Loss=0.0882, VQ Loss=1.4199, Total Loss=1.5081]


KeyboardInterrupt: 

## Visualize Results
After training, we can visualize some of the results from the model.

In [None]:
# Visualize some results
def visualize_results(model, data_loader, device, num_images=8):
    model.eval()
    images, _ = next(iter(data_loader)) # Get a batch of images
    images = images.to(device)
    
    with torch.no_grad():
        reconstructions, vq_loss = model(images) # Model returns reconstructions and VQ loss
    
    images = images.cpu()
    reconstructions = reconstructions.cpu()

    # Denormalize for visualization if necessary (assuming Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
    def denorm(tensor_img):
        return tensor_img * 0.5 + 0.5

    fig, ax = plt.subplots(2, num_images, figsize=(16, 4))
    plt.suptitle(f'VQ Loss on this batch: {vq_loss.item():.4f}', fontsize=16)
    for i in range(num_images):
        if i < images.shape[0]: # Check if enough images in batch
            original_img = denorm(images[i]).permute(1, 2, 0).numpy().clip(0,1)
            reconstructed_img = denorm(reconstructions[i]).permute(1, 2, 0).numpy().clip(0,1)
            
            ax[0, i].imshow(original_img)
            ax[0, i].set_title('Original')
            ax[0, i].axis('off')
            
            ax[1, i].imshow(reconstructed_img)
            ax[1, i].set_title('Reconstructed')
            ax[1, i].axis('off')
        else:
            ax[0, i].axis('off')
            ax[1, i].axis('off')
    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle
    plt.show()


print("Visualizing results from validation set...")
visualize_results(model, val_loader, device)
