In [1]:
import time
import h5py
import torch
import numpy as np
import diffusion_pde as dpde
import matplotlib.pyplot as plt
from pathlib import Path
from omegaconf import OmegaConf
from hydra import initialize, compose

In [2]:
data_dir = dpde.utils.get_repo_root() / "data" 

In [3]:
with h5py.File(data_dir / "heat_no_cond.hdf5", "r") as f:
    data = f["U"][:]  # (N, S, S, 2)
    labels = f["labels"] if "labels" in f else None
    attrs = dict(f.attrs)
print(f"Data shape: {data.shape}, Labels: {labels}")
[print(f"  {k}: {v}") for k, v in attrs.items()];

Data shape: (3000, 1, 64, 64, 2), Labels: None
  Lx: 1.0
  Ly: 1.0
  N: 3000
  S: 64
  T: 0.005
  dx: 0.015873015873015872
  dy: 0.015873015873015872
  notes: Heat equation dataset without conditioning: u_t = u_xx + u_yy, Dirichlet BCs with linear lift.


In [17]:
with initialize(config_path="../conf", version_base=None):
    cfg = compose(
        config_name="train",
        overrides=[
            "dataset=heat_no_cond",
            "model=unet"
        ]
    )
#print(OmegaConf.to_yaml(cfg))

In [19]:
edm = dpde.utils.get_net_from_config(cfg)
print(f"num parameters: {sum(p.numel() for p in edm.parameters() if p.requires_grad)}")

num parameters: 902466


In [20]:
trainloader, valloader = dpde.datasets.get_dataloaders(cfg)

In [21]:
batch = next(iter(trainloader))
X = batch["X"]  # (B, C, H, W)
labels = batch["labels"]  # (B, label_ch) or None

sigma = torch.randn(X.shape[0]) * 0.1 + 0.5  # (B,)

out = edm(X, sigma, labels=labels)