In [1]:
import torch
import h5py
import yaml
import torchvision.transforms  as T
import numpy                   as np
import matplotlib.pyplot       as plt
from   torch.nn.functional import interpolate
from   omegaconf           import OmegaConf
from   pytorch_lightning   import Trainer

from   ldm.util  import instantiate_from_config

# Instantiate model from config YAML file
config_path = "ldm/yaml_config.yaml"
config      = OmegaConf.load(config_path)
model       = instantiate_from_config(config.model)

# Load checkpoint
ckpt_path = "model.ckpt" # Unzip from https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
sd        = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model.load_state_dict(sd, strict=False)

print(model)

  from .autonotebook import tqdm as notebook_tqdm


LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 113.62 M params.
Keeping EMAs of 308.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels
LatentDiffusion(
  (model): DiffusionWrapper(
    (diffusion_model): UNetModel(
      (time_embed): Sequential(
        (0): Linear(in_features=160, out_features=640, bias=True)
        (1): SiLU()
        (2): Linear(in_features=640, out_features=640, bias=True)
      )
      (input_blocks): ModuleList(
        (0): TimestepEmbedSequential(
          (0): Conv2d(6, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1-2): 2 x TimestepEmbedSequential(
          (0): ResBlock(
            (in_layers): Sequential(
              (0): GroupNorm32(32, 160, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    

In [2]:
device       = "cuda" if torch.cuda.is_available() else "cpu"
model        = model.to(device)
model.logvar = model.logvar.to(model.device)
print(f"model.device: {model.device}")

model.device: cuda:0


In [3]:
class SRDataset(torch.utils.data.Dataset):
    """
    Dataset for super-resolution using pre-loaded tensors of HR and LR arrays.
    Expects two tensors: 
      hr_arrays: shape (N, C, H_hr, W_hr) - e.g., (1000, 3, 256, 256)
      lr_arrays: shape (N, C, H_lr, W_lr) - e.g., (1000, 3, 64, 64)
    """
    
    def __init__(self, hr_arrays, lr_arrays):
        """
        Args:
            hr_arrays: torch.Tensor of shape (N, C, H_hr, W_hr)
            lr_arrays: torch.Tensor of shape (N, C, H_lr, W_lr)
        """
        # Validate inputs
        assert isinstance(hr_arrays, torch.Tensor), "hr_arrays must be a torch.Tensor"
        assert isinstance(lr_arrays, torch.Tensor), "lr_arrays must be a torch.Tensor"
        assert hr_arrays.ndim     == 4, f"hr_arrays must be 4D tensor, got shape {hr_arrays.shape}"
        assert lr_arrays.ndim     == 4, f"lr_arrays must be 4D tensor, got shape {lr_arrays.shape}"
        assert hr_arrays.shape[0] == lr_arrays.shape[0], f"Batch size mismatch: HR {hr_arrays.shape[0]}, LR {lr_arrays.shape[0]}"
        assert hr_arrays.shape[1] == lr_arrays.shape[1], f"Channel mismatch: HR {hr_arrays.shape[1]}, LR {lr_arrays.shape[1]}"
        
        # Store references to the tensors
        self.hr = hr_arrays
        self.lr = lr_arrays
        
        # Store dataset size
        self.num_samples = hr_arrays.shape[0]
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """
        Returns a dictionary with keys 'image' (HR) and 'LR_image' (LR).
        Note: The returned tensors are views, not copies, for memory efficiency.
        """
        # Index into the batch dimension
        hr_sample = self.hr[idx]  # Shape: (C, H_hr, W_hr)
        lr_sample = self.lr[idx]  # Shape: (C, H_lr, W_lr)
        
        # Permute from (C, H, W) to (H, W, C)
        hr_sample = hr_sample.permute(1, 2, 0)
        lr_sample = lr_sample.permute(1, 2, 0)
        
        return {'image': hr_sample, 'LR_image': lr_sample}

In [4]:
# Load JAX-CFD normalized data (range [-1,1])
# See 'data/generate.ipynb'
path          = 'data/data_normalized.h5'
hr_data       = torch.from_numpy(h5py.File(path, 'r')['hr'][:])
lr_data       = torch.from_numpy(h5py.File(path, 'r')['lr'][:])
hr_tensor     = torch.stack([hr_data for _ in range(3)], 1)     # create RGB tensor from data
lr_tensor     = torch.stack([lr_data for _ in range(3)], 1)     # create RGB tensor from data
train_dataset = SRDataset(hr_tensor, lr_tensor)

In [5]:
# Set PyTorch DataLoader for training
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
)



In [6]:
# Train model
trainer = Trainer(
    accelerator='auto',
    max_epochs=128, # or your desired number
    logger=True,    # TensorBoard by default
    callbacks=[],   # Add ModelCheckpoint etc.
)
trainer.fit(model, train_loader)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/scratch/coop/drozda/torch-env/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/scratch/coop/drozda/torch-env/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_da

Epoch 127: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 16/16 [00:08<00:00,  1.97it/s, v_num=12, train/loss_simple_step=0.0185, train/loss_vlb_step=0.000161, train/loss_step=0.0185, global_step=2047.0, train/loss_simple_epoch=0.0232, train/loss_vlb_epoch=0.000433, train/loss_epoch=0.0232]

`Trainer.fit` stopped: `max_epochs=128` reached.


Epoch 127: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 16/16 [00:11<00:00,  1.35it/s, v_num=12, train/loss_simple_step=0.0185, train/loss_vlb_step=0.000161, train/loss_step=0.0185, global_step=2047.0, train/loss_simple_epoch=0.0232, train/loss_vlb_epoch=0.000433, train/loss_epoch=0.0232]


In [7]:
# Save model checkpoint
ckpt_tuned_path = "model_tuned.ckpt"
trainer.save_checkpoint(ckpt_tuned_path)
print(f"Model saved to {ckpt_tuned_path}")

`weights_only` was not set, defaulting to `False`.


Model saved to model_tuned.ckpt


In [8]:
# Reload model
# model_tuned = instantiate_from_config(config.model)
# sd_tuned    = torch.load(ckpt_tuned_path, map_location="cpu")["state_dict"]
# model.load_state_dict(sd_tuned, strict=False)

# print(model)