In [None]:
import torch
import h5py
import yaml
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.functional import interpolate
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import pytorch_lightning as pl

from ldm.util import instantiate_from_config

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class VQVAEDataset(torch.utils.data.Dataset):
    """
    Dataset for VQ-VAE training.
    Only needs high-resolution images for reconstruction.
    """
    
    def __init__(self, hr_arrays):
        """
        Args:
            hr_arrays: torch.Tensor of shape (N, C, H, W)
        """
        assert isinstance(hr_arrays, torch.Tensor), "hr_arrays must be a torch.Tensor"
        assert hr_arrays.ndim == 4, f"hr_arrays must be 4D tensor, got shape {hr_arrays.shape}"
        
        self.hr = hr_arrays
        self.num_samples = hr_arrays.shape[0]
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """
        Returns dictionary with 'image' key.
        VQModelInterface expects shape (H, W, C) for input.
        """
        hr_sample = self.hr[idx]  # Shape: (C, H, W)
        hr_sample = hr_sample.permute(1, 2, 0)  # Convert to (H, W, C)
        
        return {'image': hr_sample}

In [None]:
# Load JAX-CFD normalized data (range [-1,1])
path = 'data/data_normalized.h5'
hr_data = torch.from_numpy(h5py.File(path, 'r')['hr'][:])

# Create RGB tensor from data
hr_tensor = torch.stack([hr_data for _ in range(3)], 1)

# Split into train and validation
train_size = int(0.9 * len(hr_tensor))
val_size = len(hr_tensor) - train_size

hr_train = hr_tensor[:train_size]
hr_val = hr_tensor[train_size:]

train_dataset = VQVAEDataset(hr_train)
val_dataset = VQVAEDataset(hr_val)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"HR tensor shape: {hr_tensor.shape}")

In [None]:
# Set PyTorch DataLoaders
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
class ManualVQVAE(pl.LightningModule):
    """
    Wrapper for VQModelInterface that uses manual optimization.
    Required for PyTorch Lightning >= 2.0 which doesn't support optimizer_idx.
    """
    def __init__(self, vqvae_model):
        super().__init__()
        self.vqvae = vqvae_model
        # Enable manual optimization
        self.automatic_optimization = False
        
    def forward(self, x):
        return self.vqvae(x)
    
    def get_input(self, batch, k):
        return self.vqvae.get_input(batch, k)
    
    def training_step(self, batch, batch_idx):
        opt_ae, opt_disc = self.optimizers()
        
        x = self.get_input(batch, self.vqvae.image_key)
        xrec, qloss, ind = self.vqvae(x, return_pred_indices=True)
        
        # Optimize autoencoder
        aeloss, log_dict_ae = self.vqvae.loss(
            qloss, x, xrec, 0, self.global_step,
            last_layer=self.vqvae.get_last_layer(),
            split="train",
            predicted_indices=ind
        )
        
        opt_ae.zero_grad()
        self.manual_backward(aeloss)
        opt_ae.step()
        
        # Optimize discriminator
        discloss, log_dict_disc = self.vqvae.loss(
            qloss, x, xrec, 1, self.global_step,
            last_layer=self.vqvae.get_last_layer(),
            split="train"
        )
        
        opt_disc.zero_grad()
        self.manual_backward(discloss)
        opt_disc.step()
        
        # Logging
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        
    def validation_step(self, batch, batch_idx):
        x = self.get_input(batch, self.vqvae.image_key)
        xrec, qloss, ind = self.vqvae(x, return_pred_indices=True)
        
        aeloss, log_dict_ae = self.vqvae.loss(
            qloss, x, xrec, 0, self.global_step,
            last_layer=self.vqvae.get_last_layer(),
            split="val",
            predicted_indices=ind
        )
        
        discloss, log_dict_disc = self.vqvae.loss(
            qloss, x, xrec, 1, self.global_step,
            last_layer=self.vqvae.get_last_layer(),
            split="val",
            predicted_indices=ind
        )
        
        self.log("val/rec_loss", log_dict_ae["val/rec_loss"], prog_bar=True, sync_dist=True)
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        
    def configure_optimizers(self):
        lr = self.vqvae.learning_rate
        
        opt_ae = torch.optim.Adam(
            list(self.vqvae.encoder.parameters()) +
            list(self.vqvae.decoder.parameters()) +
            list(self.vqvae.quantize.parameters()) +
            list(self.vqvae.quant_conv.parameters()) +
            list(self.vqvae.post_quant_conv.parameters()),
            lr=lr, betas=(0.5, 0.9)
        )
        
        opt_disc = torch.optim.Adam(
            self.vqvae.loss.discriminator.parameters(),
            lr=lr, betas=(0.5, 0.9)
        )
        
        return [opt_ae, opt_disc]

In [None]:
# Create VQ-VAE config
# Based on the VQModelInterface structure from autoencoder.py
vqvae_config = OmegaConf.create({
    'target': 'ldm.autoencoder.VQModelInterface',
    'params': {
        'embed_dim': 3,
        'n_embed': 8192,
        'ddconfig': {
            'double_z': False,
            'z_channels': 3,
            'resolution': 256,
            'in_channels': 3,
            'out_ch': 3,
            'ch': 128,
            'ch_mult': [1, 2, 4],
            'num_res_blocks': 2,
            'attn_resolutions': [],
            'dropout': 0.0
        },
        'lossconfig': {
            'target': 'ldm.vqperceptual.VQLPIPSWithDiscriminator',
            'params': {
                'disc_conditional': False,
                'disc_in_channels': 3,
                'disc_start': 10000,
                'disc_weight': 0.8,
                'codebook_weight': 1.0
            }
        },
        'image_key': 'image',
    }
})

print("VQ-VAE config created")
print(OmegaConf.to_yaml(vqvae_config))

In [None]:
# Instantiate VQModelInterface from config
vqvae_model = instantiate_from_config(vqvae_config)

# Set learning rate (VQModel expects this attribute)
vqvae_model.learning_rate = 4.5e-6

# Optional: Load pretrained weights
load_pretrained = False
if load_pretrained:
    ckpt_path = "model.ckpt"
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    
    # Extract only VQ-VAE weights (first_stage_model)
    vqvae_sd = {k.replace('first_stage_model.', ''): v 
                for k, v in sd.items() 
                if k.startswith('first_stage_model.')}
    
    vqvae_model.load_state_dict(vqvae_sd, strict=False)
    print("Loaded pretrained VQ-VAE weights")
else:
    print("Training VQ-VAE from scratch")

print(f"\nVQ-VAE model type: {type(vqvae_model)}")
print(f"Parameters: {sum(p.numel() for p in vqvae_model.parameters()) / 1e6:.2f}M")
print(f"Trainable parameters: {sum(p.numel() for p in vqvae_model.parameters() if p.requires_grad) / 1e6:.2f}M")

In [None]:
# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/vqvae',
    filename='vqvae-{epoch:02d}-{val/rec_loss:.4f}',
    save_top_k=3,
    monitor='val/rec_loss',
    mode='min',
    save_last=True,
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Setup trainer
# Note: VQModel uses 2 optimizers (autoencoder + discriminator)
trainer = Trainer(
    accelerator='auto',
    devices=1,
    max_epochs=100,
    logger=True,
    callbacks=[checkpoint_callback, lr_monitor],
    log_every_n_steps=10,
    gradient_clip_val=1.0,
)

print("Trainer configured")
print(f"Max epochs: {trainer.max_epochs}")
print(f"Gradient clip: {trainer.gradient_clip_val}")

In [None]:
# Train VQ-VAE using VQModelInterface
print("Starting VQ-VAE training...")
print("Note: VQModel uses 2 optimizers (autoencoder + discriminator)")
trainer.fit(vqvae_model, train_loader, val_loader)

In [None]:
# Save final VQ-VAE model
vqvae_path = "vqvae_trained.ckpt"
torch.save({
    'state_dict': vqvae_model.state_dict(),
    'config': vqvae_config,
}, vqvae_path)

print(f"VQ-VAE saved to {vqvae_path}")

In [None]:
# Visualize reconstructions
vqvae_model.eval()
device = next(vqvae_model.parameters()).device

# Get a batch from validation set
batch = next(iter(val_loader))
hr_images = vqvae_model.get_input(batch, 'image').to(device)[:8]

with torch.no_grad():
    reconstructed, _ = vqvae_model(hr_images)

# Plot original vs reconstructed
fig, axes = plt.subplots(2, 8, figsize=(20, 5))

for i in range(8):
    # Original
    axes[0, i].imshow(hr_images[i, 0].cpu().numpy(), cmap='viridis')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original', fontsize=12)
    
    # Reconstructed
    axes[1, i].imshow(reconstructed[i, 0].cpu().numpy(), cmap='viridis')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontsize=12)

plt.tight_layout()
plt.savefig('vqvae_reconstructions.png', dpi=150, bbox_inches='tight')
plt.show()

print("Reconstruction visualization saved")

In [None]:
# Test encoding and decoding separately
print("Testing VQ-VAE encode/decode pipeline...")

with torch.no_grad():
    # Get one sample
    test_img = hr_images[:1]
    print(f"Input shape: {test_img.shape}")
    
    # Encode (returns continuous latent)
    latent = vqvae_model.encode(test_img)
    print(f"Encoded latent shape: {latent.shape}")
    
    # Decode (includes quantization)
    reconstructed = vqvae_model.decode(latent)
    print(f"Reconstructed shape: {reconstructed.shape}")
    
    # Compute reconstruction error
    rec_error = torch.nn.functional.l1_loss(reconstructed, test_img)
    print(f"Reconstruction L1 error: {rec_error.item():.6f}")

print("\nVQ-VAE is ready for use in latent diffusion!")