In [1]:
#import os
#os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import diffusion_pde as dpde
from diffusion_pde.pdes import generate_heat
import numpy as np
import matplotlib.pyplot as plt
import torch
import h5py
from torch.utils.data import Subset
from torch.func import jvp

from pathlib import Path    

In [2]:
data_name = "logt"
model_path = Path().cwd().parent / "pretrained_models" / f"heat_{data_name}.pth"
data_path = Path().cwd().parent / "data" / f"heat_{data_name}.hdf5"
model_path.exists(), data_path.exists()

(True, True)

In [3]:
with h5py.File(data_path, "r") as f:
    attrs = dict(f.attrs)
    data_A = f["test/A"][:]  # (N, 1, S, S)
    data_U = f["test/U"][:]  # (N, 1, S, S, steps+1)
    data_labels = f["test/labels"][:]  # (N,)
    t_steps = f["t_steps"][:]  # (steps+1,)
print("A shape: ", data_A.shape)  # (N, 1, S, S)
print("U shape: ", data_U.shape)  # (N, 1, S, S, steps+1)
print("Labels shape: ", data_labels.shape)  # (N,)
print("t_steps shape: ", t_steps.shape)  # (steps+1,)

[print(k, ":", v) for k, v in attrs.items()];

A shape:  (200, 1, 64, 64)
U shape:  (200, 1, 64, 64, 65)
Labels shape:  (200, 1)
t_steps shape:  (65,)
Lx : 1.0
Ly : 1.0
S : 64
T : 0.5
alpha_logrange : [-2.  0.]
description : 2D heat equation with linear Dirichlet BCs, data generated with sine-pseudospectral method with lifting. Time steps in log-scale.
dx : 0.015873015873015872
name : heat_logt
num_test : 200
num_train : 800
steps : 64


In [10]:
T = attrs["T"]
S = attrs["S"]
Lx = attrs["Lx"]
Ly = attrs["Ly"]

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


  return torch._C._cuda_getDeviceCount() > 0


In [12]:
NUM_TEST_SAMPLES = 50
T_STEPS = 100
SEED = 42

t_steps = torch.linspace(0, T, T_STEPS + 1)
dt = t_steps[1:] - t_steps[:-1]



In [14]:
U, A, t_steps, labels = generate_heat(
    N=NUM_TEST_SAMPLES,
    B=50,
    S=S,
    steps=T_STEPS,
    dt=dt,
    Lx=Lx,
    Ly=Ly,
    device="cuda",
    ic_seed=SEED
)

RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW

In [None]:
def validate_timeseries(
    data_test,          # test data shape: (N, C_in + C_out, H, W)
    net,                # EDMWrapper (calls Unet inside)
    device,             # device to run the sampler on  
    sample_shape,       # (B, C_in + C_out, H, W) shape of samples
    t_steps,            # array of time steps to sample over
    loss_fn,            # loss function to compute gradients
    loss_fn_kwargs,     # extra args to pass to loss function
    extra_labels=None,  # (B, label_dim-1) extra conditioning your Unet expects aside from time conditioning
    zeta_a=1.0,         # weight for obs_a loss
    zeta_u=1.0,         # weight for obs_u loss
    zeta_pde=1.0,      # weight for pde loss
    num_steps=18,
    sigma_min=0.002,
    sigma_max=80.0,
    rho=7.0,
    S_churn=0.0,
    S_min=0.0,
    S_max=float('inf'),
    S_noise=1.0,
    generator=None,
    debug=False,
):  
    errors = np.zeros((t_steps.shape[0], sample_shape[0]))
    t_steps = torch.tensor(t_steps, dtype=torch.float32)

    for i in range(t_steps.shape[0]):

        labels = torch.full((sample_shape[0], 1), t_steps[i])
        if extra_labels is not None:
            labels = torch.cat([labels, extra_labels], dim=-1)

        samples, _ = dpde.sampling.edm_sampler(
            net=net,
            device=device,
            sample_shape=sample_shape,
            loss_fn=loss_fn,
            loss_fn_kwargs=loss_fn_kwargs,
            labels=labels,
            zeta_a=zeta_a,
            zeta_u=zeta_u,
            zeta_pde=zeta_pde,
            num_steps=num_steps,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise,
            generator=generator,
            debug=debug,
        )

        errors[i, :] = torch.norm(samples - data_test[i, ...], dim=(1,2,3)).detach().cpu().numpy()

    return errors