In [None]:
import os
import wandb
import torch
import h5py
import numpy as np
import diffusion_pde as dpde
import matplotlib.pyplot as plt
import seaborn as sns
from wandb.apis.public.runs import Runs, Run
from wandb.apis.public.artifacts import RunArtifacts
from omegaconf import OmegaConf
from pathlib import Path
from mpl_toolkits.axes_grid1 import make_axes_locatable
from skopt import gp_minimize
from tqdm import tqdm

In [21]:
wandb.login()
API = wandb.Api()

In [22]:
wandb_cfg = OmegaConf.load("../conf/train.yaml").wandb

In [23]:
filters = {}#{"tags": "fine-tune"}

runs = Runs(
    client=API.client,
    entity=wandb_cfg.entity,
    project=wandb_cfg.project,
    filters=filters
)

In [24]:
for i, run in enumerate(runs):
    print(f"Index: {i} - Run ID: {run.id}, Name: {run.name}")

Index: 0 - Run ID: hodjisac, Name: heat-logt/forward/unet-v2
Index: 1 - Run ID: 8nb62ytp, Name: heat-logt/forward/unet-v2
Index: 2 - Run ID: 4labp6a8, Name: heat-logt/joint/unet-v2
Index: 3 - Run ID: lykjcqiu, Name: heat-logt/joint/unet-v2
Index: 4 - Run ID: rot73q78, Name: heat-logt/joint/unet-v2/fine-tune
Index: 5 - Run ID: 48vnbamd, Name: heat-logt/joint/unet-v2/fine-tune
Index: 6 - Run ID: qxoyf4d4, Name: heat-logt/joint/unet-v2/fine-tune
Index: 7 - Run ID: wu23ezqs, Name: heat-logt/joint/unet-v2/fine-tune
Index: 8 - Run ID: orbatwtx, Name: heat-logt/joint/unet-v2/fine-tune
Index: 9 - Run ID: u0316v5n, Name: heat-logt/joint/unet-v2/fine-tune
Index: 10 - Run ID: wwx06zrf, Name: heat-logt/forward/unet-v2/fine-tune
Index: 11 - Run ID: 3hh696qs, Name: heat-logt/forward/unet-v2/fine-tune


In [6]:
idx_to_compare = [3, 9]

In [47]:
run_idx = 3
run_cfg = OmegaConf.create(runs[run_idx].config)

print(f"dataset: {run_cfg.dataset.data.name}")
print(f"model: {run_cfg.model.name}")
print(f"method: {run_cfg.dataset.method}")

dataset: heat_logt
model: unet-v2
method: joint


In [None]:
def get_model(run: Run, api: wandb.Api = API) -> str:
    model_path = Path(f"../pretrained_models/{run.id}/model.pth").resolve()
    if not os.path.isfile(model_path):
        arts = RunArtifacts(client=api.client, run=run)
        _ = arts[0].download(root=f"../pretrained_models/{run.id}/")
    return model_path
    
model_path = get_model(runs[run_idx])
print(f"Model path: {model_path}")

NameError: name 'API' is not defined

In [50]:
edm = dpde.utils.get_net_from_config(run_cfg)
edm.load_state_dict(torch.load(model_path, weights_only=True))

<All keys matched successfully>

In [51]:
with h5py.File(dpde.utils.get_repo_root() / "data" / "heat_logt_validate.hdf5", "r") as f:
    data_A = f["A"][:]
    data_U = f["U"][:]
    data_labels = f["labels"][:]
    t_steps = f["t_steps"][:]
    attrs = dict(f.attrs)

dx = attrs["dx"]

In [52]:
print("Validation data attributes:")
for attr in attrs:
    print(f"  {attr}: {attrs[attr]}")

Validation data attributes:
  Lx: 1.0
  Ly: 1.0
  N: 200
  S: 64
  T: 0.5
  alpha_logrange: [-2.5  0.5]
  description: 2D heat equation with linear Dirichlet BCs, pseudospectral interior DST with lifting. log-spaced time.
  dx: 0.015873015873015872
  dy: 0.015873015873015872
  name: heat_logt_validate
  steps: 100


In [55]:
ch_a = 1 if run_cfg.dataset.method == "joint" else 0

sample_shape = (16, 2 if ch_a else 1, 64, 64)

#generator = torch.Generator().manual_seed(0)

interior_a = 0.5
interior_u = 0.0
boundary_a = 0.5
boundary_u = 0.0
same_boundary = False


def generate_masks(interior_a, interior_u, boundary_a, boundary_u, same_boundary):
    boundary_obs_a = dpde.validation.random_boundary_mask(sample_shape[2], sample_shape[3], frac_obs=boundary_a)
    if same_boundary:
        boundary_obs_u = boundary_obs_a[:]
    else:
        boundary_obs_u = dpde.validation.random_boundary_mask(sample_shape[2], sample_shape[3], frac_obs=boundary_u)
    interior_obs_a = dpde.validation.random_interior_mask(sample_shape[2], sample_shape[3], frac_obs=interior_a)
    interior_obs_u = dpde.validation.random_interior_mask(sample_shape[2], sample_shape[3], frac_obs=interior_u)

    mask_a = dpde.validation.combine_masks(boundary_obs_a, interior_obs_a)
    mask_u = dpde.validation.combine_masks(boundary_obs_u, interior_obs_u)
    return mask_a, mask_u

In [56]:
sampler = dpde.sampling.EDMHeatSampler(
    net=edm,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    sample_shape=(64, 64),
    num_channels=2 if run_cfg.dataset.method == "joint" else 1,
    num_steps=50,
    sigma_min=0.002,
    sigma_max=80.0,
    rho=7.0,
)

In [57]:
class sampling_context:
    def __init__(self, sampler: dpde.sampling.EDMHeatSampler):
        self.sampler = sampler

    def __enter__(self):
        self.prev_fp32_prec = torch.backends.cudnn.conv.fp32_precision
        torch.backends.cudnn.conv.fp32_precision = 'tf32'
        self.sampler.net.eval()
        self.sampler.net.to(self.sampler.device)
        #return self.sampler
    
    def __exit__(self, exc_type, exc_value, traceback):
        torch.backends.cudnn.conv.fp32_precision = self.prev_fp32_prec
        self.sampler.net.to(torch.device("cpu"))
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [58]:
NUM_TO_EVAL = 20
T = t_steps.shape[0]

rng = np.random.default_rng(seed=42)

sample_idxs = rng.choice(data_U.shape[0], size=NUM_TO_EVAL, replace=False)
t0 = rng.integers(low=0, high=T - 1, size=NUM_TO_EVAL)
remaining = T - t0
dt = rng.integers(low=0, high=remaining)
tf = t0 + dt

print("t0 indices:", t0)
print("tf indices:", tf)
assert np.all(tf < T), "tf indices must be less than T"
assert np.all(t0 >= 0), "t0 indices must be non-negative"
assert np.all(tf >= t0), "tf indices must be greater than or equal t0"

t0 indices: [63 16 75 70 35  6 97 44 89 67 77 75 19 36 46 49  4 54 15 74]
tf indices: [88 94 94 81 98 45 98 95 93 69 88 95 34 66 53 84 50 69 34 89]


In [64]:
perm_data = np.permute_dims(data_U, (0, 4, 1, 2, 3)).squeeze()

sample_data = np.stack([
    perm_data[sample_idxs, t0, ...],
    perm_data[sample_idxs, tf, ...],
], axis=1)

assert np.all(sample_data[:, 0, ...] == perm_data[sample_idxs, t0, ...]), "A data is not matching!"
assert np.all(sample_data[:, 1, ...] == perm_data[sample_idxs, tf, ...]), "U data is not matching!"

print(f"samples shape: {sample_data.shape}")  # (NUM_TO_EVAL, 2, S, S)

sample_alphas = data_labels[sample_idxs, :]
sample_tsteps = (t_steps[tf] - t_steps[t0])[:, np.newaxis]  # (NUM_TO_EVAL,)

sample_labels = np.concatenate([sample_alphas, sample_tsteps], axis=1)  # (NUM_TO_EVAL, 2)
print(f"labels shape: {sample_labels.shape}")  # (NUM_TO_EVAL, 2)

data = torch.tensor(sample_data, dtype=torch.float32)
labels = torch.tensor(sample_labels, dtype=torch.float32)

samples shape: (20, 2, 64, 64)
labels shape: (20, 2)


In [None]:
def sampler_obj_fun(params, sampler, method, data, labels, batch_size, dx, num_steps):
    N, C, H, W = data.shape
    zeta_a = zeta_u = zeta_pde = None
    if method == "joint":
        zeta_a, zeta_u, zeta_pde = params
    elif method == "forward":
        zeta_u, zeta_pde = params

    total_mse = 0.0
    with sampling_context(sampler):
        for i in tqdm(range(data.shape[0])):
            mask_a, mask_u = generate_masks(
                interior_a=interior_a,
                interior_u=interior_u,
                boundary_a=boundary_a,
                boundary_u=boundary_u,
                same_boundary=same_boundary,
            )
            obs_a = data[i, 0].expand(batch_size, 1, H, W) if method == "joint" else None
            obs_u = data[i, 1].expand(batch_size, 1, H, W)
            net_obs = data[i, 0].expand(batch_size, 1, H, W) if method == "joint" else None
            lbls = labels[i].expand(batch_size, -1)

            samples, losses = sampler.sample_conditional(
                labels=lbls,
                dx=dx,
                net_obs=net_obs,
                obs_a=obs_a,
                obs_u=obs_u,
                mask_a=mask_a,
                mask_u=mask_u,
                zeta_a=zeta_a,
                zeta_u=zeta_u,
                zeta_pde=zeta_pde,
                num_steps=num_steps,
                return_losses=False,
            )

            if method == "joint":
                total_mse += torch.mean((samples - torch.stack([obs_a, obs_u], dim=0).unsqueeze(0))**2).item()
            elif method == "forward":
                total_mse += torch.mean((samples - obs_u.unsqueeze(0).unsqueeze(0))**2).item()
            
    return total_mse / N


bounds = [
    (100.0, 20000.0),   # bounds for zeta_a, comment out if method is "forward"
    (100.0, 20000.0),   # bounds for zeta_u
    (1.0, 100.0),    # bounds for zeta_pde
]

wrapper_kwargs = {
    "sampler": sampler,
    "method": run_cfg.dataset.method,
    "data": data,
    "labels": labels,
    "batch_size": 16,
    "dx": dx,
    "num_steps": 50,
}

In [68]:
res = gp_minimize(
    func=lambda params: sampler_obj_fun(params, **wrapper_kwargs),
    dimensions=bounds,
    n_calls=30,
    n_initial_points=10,
    random_state=42,
)



In [73]:
res.keys()

dict_keys(['x', 'fun', 'func_vals', 'x_iters', 'models', 'space', 'random_state', 'specs'])

In [75]:
for k, v in res.items():
    print(f"{k}: {v}")

x: [100.0, 13722.394220050628, 1.0]
fun: 0.013466795653221198
func_vals: [1.84583912e+01 1.47831648e+01 1.16305858e+01 1.46394261e+01
 2.03273044e+01 1.46152473e+01 2.50619684e-01 7.88899866e-01
 2.47976245e+00 1.96080255e+01 4.20600257e-02 1.45072287e-02
 4.62257431e-02 1.45387571e-02 1.56014476e+01 5.06286515e-02
 4.72320216e-02 1.34667957e-02 1.36177138e-02 2.27590450e-02
 1.98527799e-02 3.96008874e-02 1.49996944e-02 6.21410084e+00
 1.49611009e-02 1.67604788e-02 1.38789765e-02 1.38827184e-02
 4.76937738e-02 9.17000175e+00]
x_iters: [[15951.205438518638, 3750.35231833666, 78.18940902700417], [11977.318143135093, 8972.071781786466, 10.897516665982288], [9239.052950120758, 6740.801361666536, 15.14381497427214], [13052.680611682175, 1222.5904226392954, 72.47787845441566], [18777.198909413433, 115.49744023618514, 99.22894436983056], [12387.882041591562, 12271.897893716792, 1.6995642167520235], [558.9422583241736, 10543.015739141947, 40.5862361998103], [1028.646697950947, 19477.7348249450