In [1]:
import torch
import h5py
import yaml
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.functional import interpolate
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import pytorch_lightning as pl

from ldm.util import instantiate_from_config

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class VQVAEDataset(torch.utils.data.Dataset):
    """
    Dataset for VQ-VAE training.
    Only needs high-resolution images for reconstruction.
    """
    
    def __init__(self, hr_arrays):
        """
        Args:
            hr_arrays: torch.Tensor of shape (N, C, H, W)
        """
        assert isinstance(hr_arrays, torch.Tensor), "hr_arrays must be a torch.Tensor"
        assert hr_arrays.ndim == 4, f"hr_arrays must be 4D tensor, got shape {hr_arrays.shape}"
        
        self.hr = hr_arrays
        self.num_samples = hr_arrays.shape[0]
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """
        Returns dictionary with 'image' key.
        VQModelInterface expects shape (H, W, C) for input.
        """
        hr_sample = self.hr[idx]  # Shape: (C, H, W)
        hr_sample = hr_sample.permute(1, 2, 0)  # Convert to (H, W, C)
        
        return {'image': hr_sample}

In [3]:
# Load JAX-CFD normalized data (range [-1,1])
path = 'data/data_normalized.h5'
hr_data = torch.from_numpy(h5py.File(path, 'r')['hr'][:])

# Create RGB tensor from data
hr_tensor = torch.stack([hr_data for _ in range(3)], 1)

# Split into train and validation
train_size = int(0.9 * len(hr_tensor))
val_size = len(hr_tensor) - train_size

hr_train = hr_tensor[:train_size]
hr_val = hr_tensor[train_size:]

train_dataset = VQVAEDataset(hr_train)
val_dataset = VQVAEDataset(hr_val)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"HR tensor shape: {hr_tensor.shape}")

Train samples: 460
Val samples: 52
HR tensor shape: torch.Size([512, 3, 256, 256])


In [4]:
# Set PyTorch DataLoaders
batch_size = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train batches: 15
Val batches: 2




In [5]:
# Create VQ-VAE config
# Based on the VQModelInterface structure from autoencoder.py
vqvae_config = OmegaConf.create({
    'target': 'ldm.autoencoder.VQModelInterface',
    'params': {
        'embed_dim': 3,
        'n_embed': 8192,
        'ddconfig': {
            'double_z': False,
            'z_channels': 3,
            'resolution': 256,
            'in_channels': 3,
            'out_ch': 3,
            'ch': 128,
            'ch_mult': [1, 2, 4],
            'num_res_blocks': 2,
            'attn_resolutions': [],
            'dropout': 0.0
        },
        'lossconfig': {
            'target': 'torch.nn.Identity',
        },
    }
})

print("VQ-VAE config created")
print(OmegaConf.to_yaml(vqvae_config))

VQ-VAE config created
target: ldm.autoencoder.VQModelInterface
params:
  embed_dim: 3
  n_embed: 8192
  ddconfig:
    double_z: false
    z_channels: 3
    resolution: 256
    in_channels: 3
    out_ch: 3
    ch: 128
    ch_mult:
    - 1
    - 2
    - 4
    num_res_blocks: 2
    attn_resolutions: []
    dropout: 0.0
  lossconfig:
    target: torch.nn.Identity



In [6]:
# Instantiate VQModelInterface from config
vqvae_model = instantiate_from_config(vqvae_config)

# Set learning rate (VQModel expects this attribute)
vqvae_model.learning_rate = 4.5e-6

# Optional: Load pretrained weights
load_pretrained = False
if load_pretrained:
    ckpt_path = "model.ckpt"
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    
    # Extract only VQ-VAE weights (first_stage_model)
    vqvae_sd = {k.replace('first_stage_model.', ''): v 
                for k, v in sd.items() 
                if k.startswith('first_stage_model.')}
    
    vqvae_model.load_state_dict(vqvae_sd, strict=False)
    print("Loaded pretrained VQ-VAE weights")
else:
    print("Training VQ-VAE from scratch")

print(f"\nVQ-VAE model type: {type(vqvae_model)}")
print(f"Parameters: {sum(p.numel() for p in vqvae_model.parameters()) / 1e6:.2f}M")
print(f"Trainable parameters: {sum(p.numel() for p in vqvae_model.parameters() if p.requires_grad) / 1e6:.2f}M")

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels
Training VQ-VAE from scratch

VQ-VAE model type: <class 'ldm.autoencoder.VQModelInterface'>
Parameters: 55.32M
Trainable parameters: 55.32M


In [11]:
# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints/vqvae',
    filename='vqvae-{epoch:02d}-{val/rec_loss:.4f}',
    save_top_k=3,
    monitor='val/rec_loss',
    mode='min',
    save_last=True,
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Setup trainer
# Note: VQModel uses 2 optimizers (autoencoder + discriminator)
trainer = Trainer(
    accelerator='auto',
    devices=1,
    max_epochs=1,
    logger=True,
    callbacks=[checkpoint_callback, lr_monitor],
    log_every_n_steps=10,
    gradient_clip_val=1.0,
)

print("Trainer configured")
print(f"Max epochs: {trainer.max_epochs}")
print(f"Gradient clip: {trainer.gradient_clip_val}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


Trainer configured
Max epochs: 1
Gradient clip: 1.0


In [12]:
# Train VQ-VAE using VQModelInterface
print("Starting VQ-VAE training...")
trainer.fit(vqvae_model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type             | Params | Mode | FLOPs
--------------------------------------------------------------------
0 | encoder         | Encoder          | 22.3 M | eval | 0    
1 | decoder         | Decoder          | 33.0 M | eval | 0    
2 | loss            | Identity         | 0      | eval | 0    
3 | quantize        | VectorQuantizer2 | 24.6 K | eval | 0    
4 | quant_conv      | Conv2d           | 12     | eval | 0    
5 | post_quant_conv | Conv2d           | 12     | eval | 0    
--------------------------------------------------------------------
55.3 M    Trainable params
0         Non-trainable params
55.3 M    Total params
221.291   Total estimated model params size (MB)
0         Modules in train mode
173       Modules in eval mode
0         Total Flops
SLURM auto-requeueing enabled. Setting signal handlers.


Starting VQ-VAE training...
lr_d 4.5e-06
lr_g 4.5e-06
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

ValueError: too many values to unpack (expected 3)

In [None]:
# Save final VQ-VAE model
vqvae_path = "vqvae_trained.ckpt"
torch.save({
    'state_dict': vqvae_model.state_dict(),
    'config': vqvae_config,
}, vqvae_path)

print(f"VQ-VAE saved to {vqvae_path}")

In [None]:
# Visualize reconstructions
vqvae_model.eval()
device = next(vqvae_model.parameters()).device

# Get a batch from validation set
batch = next(iter(val_loader))
hr_images = vqvae_model.get_input(batch, 'image').to(device)[:8]

with torch.no_grad():
    reconstructed, _ = vqvae_model(hr_images)

# Plot original vs reconstructed
fig, axes = plt.subplots(2, 8, figsize=(20, 5))

for i in range(8):
    # Original
    axes[0, i].imshow(hr_images[i, 0].cpu().numpy(), cmap='viridis')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original', fontsize=12)
    
    # Reconstructed
    axes[1, i].imshow(reconstructed[i, 0].cpu().numpy(), cmap='viridis')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontsize=12)

plt.tight_layout()
plt.savefig('vqvae_reconstructions.png', dpi=150, bbox_inches='tight')
plt.show()

print("Reconstruction visualization saved")

In [None]:
# Test encoding and decoding separately
print("Testing VQ-VAE encode/decode pipeline...")

with torch.no_grad():
    # Get one sample
    test_img = hr_images[:1]
    print(f"Input shape: {test_img.shape}")
    
    # Encode (returns continuous latent)
    latent = vqvae_model.encode(test_img)
    print(f"Encoded latent shape: {latent.shape}")
    
    # Decode (includes quantization)
    reconstructed = vqvae_model.decode(latent)
    print(f"Reconstructed shape: {reconstructed.shape}")
    
    # Compute reconstruction error
    rec_error = torch.nn.functional.l1_loss(reconstructed, test_img)
    print(f"Reconstruction L1 error: {rec_error.item():.6f}")

print("\nVQ-VAE is ready for use in latent diffusion!")