In [12]:
import wandb
import torch
import h5py
import numpy as np
import diffusion_pde as dpde
from wandb.apis.public.runs import Runs
from wandb.apis.public.artifacts import Artifacts, RunArtifacts, ArtifactFiles
#from hydra import initialize, compose
from omegaconf import OmegaConf
from copy import deepcopy

In [2]:
wandb.login()
api = wandb.Api()

[34m[1mwandb[0m: Currently logged in as: [33mphiliphohwy[0m ([33mphiliphohwy-danmarks-tekniske-universitet-dtu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
#with initialize(config_path="../conf", version_base=None):
#    cfg = compose(
#        config_name="train", 
#        overrides=[
#            "dataset=heat_eq_logt_joint",
#            "model=unetv2"
#        ]
#    )

wandb_cfg = OmegaConf.load("../conf/train.yaml").wandb

In [4]:
filters = {}

runs = Runs(
    client=api.client,
    entity=wandb_cfg.entity,
    project=wandb_cfg.project,
    filters=filters
)

In [5]:
for run in runs:
    print(f"Run ID: {run.id}, Name: {run.name}")

Run ID: hodjisac, Name: heat-logt/forward/unet-v2
Run ID: 8nb62ytp, Name: heat-logt/forward/unet-v2
Run ID: 4labp6a8, Name: heat-logt/joint/unet-v2
Run ID: lykjcqiu, Name: heat-logt/joint/unet-v2


In [6]:
run_idx = 2
run_cfg = OmegaConf.create(runs[run_idx].config)

artifacts = RunArtifacts(
    client=api.client,
    run=runs[run_idx]
    )

for artifact in artifacts:
    print(f"Artifact Name: {artifact.name}, Type: {artifact.type}")

Artifact Name: heat-logt-joint-unet-v2:v0, Type: model
Artifact Name: run-4labp6a8-history:v0, Type: wandb-history


In [7]:
print(OmegaConf.to_yaml(run_cfg))

lr: 0.0001
model:
  name: unet-v2
  emb_ch: 256
  obs_ch: 0
  base_ch: 64
  dropout: 0
  ch_mults:
  - 1
  - 2
  - 2
  noise_ch: 64
  sigma_data: 0.5
  n_res_blocks: 2
wandb:
  dir: /home/s204790/dynamical-pde-diffusion/logs
  name: None
  entity: philiphohwy-danmarks-tekniske-universitet-dtu
  project: dynamical-pde-diffusion-final-final
epochs: 1000
dataset:
  net:
    in_ch: 2
    label_ch: 2
  data:
    pde: heat
    name: heat_logt
    datapath: data/heat_logt.hdf5
  loss: edm
  method: joint
  sampling:
    ch_a: 1
    zeta_a: 50
    zeta_u: 40
    zeta_pde: 0.4
    loss_func: diffusion_pde.sampling.heat_loss
    num_steps: 50
    batch_size: 32
    sample_shape:
    - 2
    - 64
    - 64
  training:
    shuffle: true
    batch_size: 64
    num_epochs: 1000
    physics_loss: false
    weight_decay: 0
    learning_rate: 0.0001
    physics_loss_coeff: 1
  start_at_t0: true
run_name: heat-logt/joint/unet-v2
weight_decay: 0
model_save_path: /home/s204790/dynamical-pde-diffusion/pretr

In [8]:
model_path = artifacts[0].download(root=f"../pretrained_models/{runs[run_idx].id}/")

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [9]:
print(model_path)

../pretrained_models/4labp6a8/


In [10]:
edm = dpde.utils.get_net_from_config(run_cfg)
edm.load_state_dict(torch.load(f"{model_path}/model.pth", weights_only=True))

<All keys matched successfully>

In [11]:
def edm_sampler2(
    net,            # EDMWrapper (calls Unet inside)
    device,         # device to run the sampler on  
    sample_shape,   # (B, C, H, W) shape of samples
    loss_fn,        # loss function to compute gradients
    loss_fn_kwargs, # extra args to pass to loss function
    labels,         # (B, label_dim) extra conditioning your Unet expects; use zeros if None
    zeta_a=1.0,     # weight for obs_a loss
    zeta_u=1.0,     # weight for obs_u loss
    zeta_pde=1.0,   # weight for pde loss
    num_steps=18,
    sigma_min=0.002,
    sigma_max=80.0,
    rho=7.0,
    to_cpu=True,
    generator=None,
    return_losses=False,
):
    dtype_f = torch.float32     # net runs in fp32
    dtype_t = torch.float64     # keep time grid in fp64 for stability, as in EDM

    B = sample_shape[0]
    
    net.to(device=device)
        
    labels = labels.to(device=device, dtype=dtype_f)    # move labels to correct device and dtype

    if generator is None:
        generator = torch.Generator(device=device)

    # Initial sample at sigma_max
    latents = torch.randn(sample_shape, device=device, generator=generator)
    
    # Move loss function kwargs to correct device and dtype
    loss_kwargs = loss_fn_kwargs.copy()
    for key, val in loss_fn_kwargs.items():
        if isinstance(val, torch.Tensor):
            loss_kwargs[key] = val.clone().to(device=device, dtype=dtype_t)

    # Discretize sigmas per EDM (Karras et al. 2022), t_N = 0 appended.
    step_idx = torch.arange(num_steps, dtype=dtype_t, device=device)
    sigmas = (sigma_max**(1.0/rho) + step_idx/(num_steps-1) * (sigma_min**(1.0/rho) - sigma_max**(1.0/rho)))**rho
    sigmas = getattr(net, "round_sigma", lambda x: x)(sigmas)
    sigmas = torch.cat([sigmas, torch.zeros_like(sigmas[:1])])  # length N+1, last = 0

    # Initialize x at sigma_0
    x_next = (latents.to(dtype_t) * sigmas[0])

    losses = torch.zeros((num_steps, 4))  # for debugging
    
    for i, (sigma_cur, sigma_next) in enumerate(zip(sigmas[:-1], sigmas[1:])):  # i = 0..N-1
        x_cur = x_next.detach().clone()
        x_cur.requires_grad = True
        # Euler step to t_next
        x_N, dxdt = dpde.sampling.X_and_dXdt_fd(net, x_cur.to(dtype_f), torch.full((B,), sigma_cur, device=device, dtype=dtype_f), labels)
        x_N, dxdt = x_N.to(dtype_t), dxdt.to(dtype_t)
        
        d_cur = (x_cur - x_N) / sigma_cur
        x_next = x_cur + (sigma_next - sigma_cur) * d_cur
        # Heun (2nd-order) correction unless final step
        if i < num_steps - 1:
            x_N, dxdt = dpde.sampling.X_and_dXdt_fd(net, x_next.to(dtype_f), torch.full((B,), sigma_next, device=device, dtype=dtype_f), labels)
            x_N, dxdt = x_N.to(dtype_t), dxdt.to(dtype_t)
            d_prime = (x_next - x_N) / sigma_next
            x_next = x_cur + (sigma_next - sigma_cur) * (0.5 * d_cur + 0.5 * d_prime)

        # Compute losses

        loss_pde, loss_a, loss_obs_u = loss_fn(x_N, dxdt, **loss_kwargs)
        
        if i <= 0.8 * num_steps:
            w_a, w_u, w_pde = zeta_a, zeta_u, zeta_pde
        else:
            w_a, w_u, w_pde = 0.1 * zeta_a, 0.1 * zeta_u, zeta_pde
        
        loss_comb = w_a * loss_a + w_u * loss_obs_u + w_pde * loss_pde
        grad_x = torch.autograd.grad(loss_comb, x_cur, retain_graph=False)[0]
        x_next = x_next - grad_x
        
        losses[i] = torch.tensor([loss_a.item(), loss_obs_u.item(), loss_pde.item(), loss_comb.item()])

        del x_cur, x_N, dxdt, d_cur, loss_pde, loss_a, loss_obs_u, loss_comb, grad_x
    

    # Return at sigma=0 in fp32
    x = x_next.to(dtype_f).detach()
    if to_cpu:
        x = x.cpu()

    losses = losses.detach().cpu().numpy() if return_losses else None

    net.to(device=torch.device("cpu"))
    
    del x_next, loss_kwargs

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    return x, losses

In [None]:
obs_a = torch.tensor(A)
obs_u = torch.tensor(U)

dx = 1. / (A.shape[-1]-1)
ch_a = 1

sample_shape = (16, 2, 64, 64)

zeta_a = 30.0
zeta_u = 10.0
zeta_pde = 0.5

num_steps = 40

t_cond = torch.ones(sample_shape[0]) * T
alpha_cond = torch.ones_like(t_cond) * alpha
labels = torch.stack([t_cond, alpha_cond], dim=-1)

loss_fn_kwargs = {
    "obs_a": obs_a,
    "obs_u": obs_u,
    "mask_a": mask_a,
    "mask_u": mask_u,
    "dx": dx,
    "dy": dx,
    "ch_a": ch_a,
    "labels": alpha_cond,
}

torch.backends.cudnn.conv.fp32_precision = 'tf32'

#for i in range(10):
samples, losses = edm_sampler2(
    net=edm,
    device=device,
    sample_shape=sample_shape,
    loss_fn=dpde.sampling.heat_loss,
    loss_fn_kwargs=loss_fn_kwargs,
    labels=labels,
    zeta_a=zeta_a,
    zeta_u=zeta_u,
    zeta_pde=zeta_pde,
    num_steps=num_steps,
    to_cpu=True,
    return_losses=True,
)
print(f"Final total loss: {losses[-1, 3]:.4f}")