In [46]:
import torch
import h5py
import diffusion_pde as dpde
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from pathlib import Path

In [119]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class UnetBlock(torch.nn.Module):
    def __init__(self, 
        in_ch: int, 
        out_ch: int, 
        emb_ch: int, 
        mode ="", 
        act_fn: torch.nn.Module = torch.nn.SiLU, 
        dropout: float = 0.1,
        skip_scale: float = 2**-0.5
    ):
        super().__init__()
        
        self.skip_scale = skip_scale

        if mode == "down":
            self.conv1 = torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1)
        elif mode == "":
            self.conv1 = torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        elif mode == "up":
            self.conv1 = torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=2, mode='nearest'),
                torch.nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            )
            #self.conv1 = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.conv2 = torch.nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        torch.nn.init.zeros_(self.conv2.weight)
        torch.nn.init.zeros_(self.conv2.bias)

        self.emb_layer = torch.nn.Linear(emb_ch, out_ch)
        self.act_fn = act_fn()

        self.norm1 = torch.nn.GroupNorm(32 if in_ch % 32 == 0 else in_ch, in_ch)
        self.norm2 = torch.nn.GroupNorm(32 if out_ch % 32 == 0 else out_ch, out_ch)

        self.dropout = torch.nn.Dropout(dropout)

        if mode == "" and in_ch == out_ch:
            self.skip = torch.nn.Identity()
        elif mode == "down":
            self.skip = torch.nn.Sequential(
                torch.nn.AvgPool2d(2),
                torch.nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else torch.nn.Identity()
            )
        elif mode == "up":
            self.skip = torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=2, mode='nearest'),
                torch.nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else torch.nn.Identity(),
            )
        else:
            self.skip = torch.nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
        orig = x
        emb = self.emb_layer(emb).unsqueeze(-1).unsqueeze(-1)

        x = self.conv1(self.act_fn(self.norm1(x)))
        x = x + emb
        x = self.conv2(self.dropout(self.act_fn(self.norm2(x))))
        x = x + self.skip(orig)

        return x * self.skip_scale
    
    
class Unet2(torch.nn.Module):
    '''
    Unet taken from deep learning course.
    '''
    
    def __init__(
        self,
        chs: list[int], # list of channels including input channel size: (ch_in, ch_1, ..., ch_n), length n+1
        label_ch: int, # label dimension (class label/ time etc)
        noise_ch: int = 32, # embedding channel size 
        act_fn: torch.nn.Module = torch.nn.SiLU,
        debug: bool = False
    ):
        super().__init__()
        self.act_fn = act_fn
        self.debug = debug

        self.down_blocks = torch.nn.ModuleDict()
        self.up_blocks = torch.nn.ModuleDict()

        # create encoder blocks 
        for i in range(len(chs)-1):
            mode = "down" if i < len(chs)-2 else ""
            in_ch = chs[i] * 2 if i == 0 else chs[i]
            out_ch = chs[i+1]
            self.down_blocks[f"down_{in_ch}->{out_ch}_{mode}"] = UnetBlock(
                in_ch=in_ch,
                out_ch=chs[i+1],
                emb_ch=noise_ch,
                mode=mode,
                act_fn=act_fn,
            )
        

        # create decoder blocks
        for i in range(len(chs)-1, 0, -1):
            mode = "up" if i < len(chs)-1 else ""
            in_ch = chs[i] * 2 if i < len(chs)-1 else chs[i]
            out_ch = chs[i-1]
            self.up_blocks[f"up_{chs[i]}->{chs[i-1]}_{mode}"] = UnetBlock(
                in_ch=in_ch,
                out_ch=out_ch,
                emb_ch=noise_ch,
                mode=mode,
                act_fn=act_fn,
            )


        self.sigma_embedding = PositionalEmbedding(noise_ch)
        self.linear_label = torch.nn.Linear(label_ch, noise_ch)

        

    def forward(self, x, sigma, labels, obs) -> torch.Tensor:
        x = torch.cat([x, obs], dim=1)  # concatenate input and observation along channel dimension
        emb_sigma = self.sigma_embedding(sigma)
        emb_label = self.linear_label(labels)

        skips = []
        for i, down_block in enumerate(self.down_blocks.values()):
            x = down_block(x, emb_sigma + emb_label)
            if i < len(self.down_blocks) - 1:
                skips.append(x)
                if self.debug:
                    print(f"Skip Block {i}: {x.shape}")
            if self.debug:
                print(f"Down Block {i}: {x.shape}")

        for i, up_block in enumerate(self.up_blocks.values()):
            if i > 0:
                skip = skips.pop()
                if self.debug:
                    print(f"Using Skip Block {len(self.down_blocks)-2 - i}: {skip.shape}")
                x = torch.cat([x, skip], dim=1)
            x = up_block(x, emb_sigma + emb_label)
            if self.debug:
                print(f"Up Block {i}: {x.shape}")

        return x

chs = [1, 32, 64, 128, 256]
label_ch = 2
noise_ch = 32
net = Unet2(chs=chs, label_ch=label_ch, noise_ch=noise_ch, debug=False)
print(net)

Unet2(
  (down_blocks): ModuleDict(
    (down_2->32_down): UnetBlock(
      (conv1): Conv2d(2, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (emb_layer): Linear(in_features=32, out_features=32, bias=True)
      (act_fn): SiLU()
      (norm1): GroupNorm(2, 2, eps=1e-05, affine=True)
      (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (skip): Sequential(
        (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (1): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (down_32->64_down): UnetBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (emb_layer): Linear(in_features=32, out_features=64, bias=True)
      (act_fn): SiLU()
      (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
      (no

In [120]:
test_x = torch.randn(10, 1, 64, 64)
test_sigma = torch.randn(10)
test_labels = torch.randn(10, 2)
test_obs = torch.randn(10, 1, 64, 64)

out = net(test_x, test_sigma, test_labels, test_obs)
print("Output shape:", out.shape)

Output shape: torch.Size([10, 1, 64, 64])


In [121]:
class EDMWrapper(torch.nn.Module):
    def __init__(self, 
        unet: torch.nn.Module,
        sigma_data: float = 0.5,
    ):
        super().__init__()
        self.unet = unet
        self.sigma_data = sigma_data

    def forward(self, x, sigma, *args) -> torch.Tensor:
        # x has shape (b, ch, h, w) and sigma has shape (b,)
        # both should be dtype float32 for now
        sigma = torch.reshape(sigma, (-1, 1, 1, 1))

        # weights given by the EDM paper.
        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / torch.sqrt(sigma ** 2 + self.sigma_data ** 2)
        c_in = 1 / torch.sqrt(sigma ** 2 + self.sigma_data ** 2)
        c_noise = torch.flatten(torch.log(sigma) / 4).to(torch.float32)

        F_x = self.unet(c_in * x, c_noise, *args)   # output of u-net
        D_x = c_skip * x + c_out * F_x          # denoised data

        return D_x

In [122]:
edm = EDMWrapper(net, sigma_data=0.5)
out = edm(test_x, test_sigma, test_labels, test_obs)

In [123]:
class DiffusionDataset2(torch.utils.data.Dataset):
    """
    Diffusion dataset compatible with torch DataLoader.
    Each item is a tuple (X, label) where X is the concatenation of the
    initial state and a snapshot of the trajectory at time t. the label
    is a vector containing the time t and any additional labels.
    Note that t is sampled randomly for each item from the t_steps provided.

    Parameters
    ----------
    data : torch.Tensor
        Tensor of shape (N, ch_u, h, w, T) representing the trajectories.
    t_steps : torch.Tensor
        1D tensor of shape (T,) representing the time steps corresponding to the last dimension of `data`.
    labels : Optional[torch.Tensor], optional
        Optional tensor of shape (N, label_dim) representing additional labels, by default None.
    generator : Optional[torch.Generator], optional
        Random number generator for reproducibility, by default None.
    """
    def __init__(self,
        data: np.ndarray, 
        t_steps: np.ndarray, 
        labels: np.ndarray | None = None,
        generator: torch.Generator | None = None
    ) -> None:
        super().__init__()
        # assume data is (N, ch_u, h, w, t)
        assert len(data.shape) == 5, f"Dimensions of 'data' should be (N, ch_u, h, w, t) but got {data.shape}"

        self.data = torch.from_numpy(data).float()
        self.labels = torch.from_numpy(labels).float() if labels is not None else None
        if self.labels is not None and self.labels.ndim == 1:
            self.labels = self.labels.reshape((-1, 1))  # ensure labels is (N, label_dim):
        self.t_steps = torch.from_numpy(t_steps).float()

        self.N, self.T = data.shape[0], data.shape[-1]

        self.g = generator

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        
        # sample random timestep
        t0_idx = torch.randint(0, self.T-1, (1,), generator=self.g).item()
        tf_idx = torch.randint(t0_idx + 1, self.T, (1,), generator=self.g).item()

        # slice data snapshot at timestep
        # get corresponding time value
        X = self.data[idx, ..., t0_idx]  # shape (ch_u, h, w)
        Y = self.data[idx, ..., tf_idx]  # shape (ch_u, h, w)
        label = self.t_steps[tf_idx] - self.t_steps[t0_idx]  # time difference as label
        if self.labels is not None:
            label = torch.cat((torch.tensor([label]), self.labels[idx]), dim=0)
        return X, Y, label

In [124]:
class EDMLoss:
    '''
    taken from "elucidating the design space..." paper:
    https://github.com/NVlabs/edm/blob/main/training/loss.py
    '''
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5, reduce=True, reduce_method="mean"):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.reduce = reduce
        self.reduce_method = reduce_method

    def __call__(self, net, x, *args):
        rnd_normal = torch.randn([x.shape[0], 1, 1, 1], device=x.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        n = torch.randn_like(x) * sigma
        D_yn = net(x + n, sigma.flatten(), *args)
        loss = weight * ((D_yn - x) ** 2)

        if not self.reduce:
            return loss

        if self.reduce_method == "mean":
            return loss.mean(dim=(1,2,3))
        elif self.reduce_method == "sum":
            return loss.sum(dim=(1,2,3))

In [125]:
data_path = Path("/home/s204790/dynamical-pde-diffusion/data/heat_logt.hdf5")

with h5py.File(data_path, "r") as f:
    data = f["U"][:]
    t_steps = f["t_steps"][:]
    labels = f["labels"][:]

In [126]:
dataset = DiffusionDataset2(data=data, t_steps=t_steps, labels=labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

In [127]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
edm = edm.to(device)

In [128]:
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

Number of trainable parameters: 1969491


In [129]:
epochs = 100
print_every = 5

optimizer = torch.optim.Adam(edm.parameters(), lr=1e-4)
loss_fn = EDMLoss()

for epoch in range(epochs):
    for (X, Y, labels) in dataloader:
        X, Y, labels = X.to(device), Y.to(device), labels.to(device)

        optimizer.zero_grad()
        loss = loss_fn(edm, Y, labels, X).sum()
        loss.backward()
        optimizer.step()
    if epoch % print_every == print_every - 1:
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

Epoch 5/100, Loss: 4.7524
Epoch 10/100, Loss: 2.5002
Epoch 15/100, Loss: 3.9297
Epoch 20/100, Loss: 3.1842
Epoch 25/100, Loss: 2.2436
Epoch 30/100, Loss: 2.0965
Epoch 35/100, Loss: 2.2918
Epoch 40/100, Loss: 2.2818
Epoch 45/100, Loss: 2.0736
Epoch 50/100, Loss: 1.6578
Epoch 55/100, Loss: 1.7171
Epoch 60/100, Loss: 1.5671
Epoch 65/100, Loss: 1.5321
Epoch 70/100, Loss: 1.7482
Epoch 75/100, Loss: 1.3385
Epoch 80/100, Loss: 1.5424
Epoch 85/100, Loss: 1.4815
Epoch 90/100, Loss: 1.4001
Epoch 95/100, Loss: 0.8724
Epoch 100/100, Loss: 0.9480
