In [1]:
#import os
#os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import diffusion_pde as dpde
import numpy as np
import matplotlib.pyplot as plt
import torch
import h5py
from torch.utils.data import Subset
from torch.func import jvp

from pathlib import Path

In [5]:
b = torch.ones(4).view(4, 1)
a = torch.arange(4).view(4, 1)

print(torch.cat([b, a], dim=-1))

tensor([[1., 0.],
        [1., 1.],
        [1., 2.],
        [1., 3.]])


In [None]:
def sample_timeseries(
    net,                # EDMWrapper (calls Unet inside)
    device,             # device to run the sampler on  
    sample_shape,       # (B, C, H, W) shape of samples
    t_steps,            # array of time steps to sample over
    loss_fn,            # loss function to compute gradients
    loss_fn_kwargs,     # extra args to pass to loss function
    extra_labels=None,  # (B, label_dim-1) extra conditioning your Unet expects aside from time conditioning
    zeta_a=1.0,         # weight for obs_a loss
    zeta_u=1.0,         # weight for obs_u loss
    zeta_pde=1.0,      # weight for pde loss
    num_steps=18,
    sigma_min=0.002,
    sigma_max=80.0,
    rho=7.0,
    S_churn=0.0,
    S_min=0.0,
    S_max=float('inf'),
    S_noise=1.0,
    generator=None,
    debug=False,
):
    t_steps = torch.tensor(t_steps, dtype=torch.float32)

    for i in range(t_steps.shape[0]):

        labels = torch.full((sample_shape[0], 1), t_steps[i])
        if extra_labels is not None:
            labels = torch.cat([labels, extra_labels], dim=-1)

        samples, losses = dpde.sampling.edm_sampler(
            net=net,
            device=device,
            sample_shape=sample_shape,
            loss_fn=loss_fn,
            loss_fn_kwargs=loss_fn_kwargs,
            labels=labels,
            zeta_a=zeta_a,
            zeta_u=zeta_u,
            zeta_pde=zeta_pde,
            num_steps=num_steps,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
            S_churn=S_churn,
            S_min=S_min,
            S_max=S_max,
            S_noise=S_noise,
            generator=generator,
            debug=debug,
        )