In [1]:
from PARCtorch.PARCv2 import PARCv2
from PARCtorch.differentiator.differentiator import Differentiator
from PARCtorch.differentiator.finitedifference import FiniteDifference
from PARCtorch.integrator.integrator import Integrator
from PARCtorch.integrator.heun import Heun
from PARCtorch.utilities.unet import UNet

In [2]:
import torch
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from glob import glob
from torch.optim import Adam
from tqdm import tqdm

### Model declaration

In [3]:
# Burgers: u, v
# Adv: all vars
# Dif: all vars
n_fe_features = 64
unet_burgers = UNet(
    [64, 64 * 2, 64 * 4],
    3,
    n_fe_features,
    up_block_use_concat=[False, True],
    skip_connection_indices=[0],
).cuda()
right_diff = FiniteDifference(padding_mode="replicate").cuda()
heun_int = Heun().cuda()
diff_burgers = Differentiator(
    1,  # 1 state variables: mu. We always assume 2 velocity being the last 2 channels
    n_fe_features,  # Number of features returned by the feature extraction network: 64
    [1, 2],  # Channel indices to calculate advection: u and v
    [1, 2],  # Channel indices to calculate diffusion: u and v
    unet_burgers,  # Feature extraction network: unet_burgers
    "constant",  # Padding mode: constant padding of zero
    right_diff,  # Finite difference method: replication of image_gradients
).cuda()
burgers_int = Integrator(
    True, [], heun_int, [None, None, None], "constant", right_diff
)
criterion = torch.nn.L1Loss().cuda()
model = PARCv2(diff_burgers, burgers_int, criterion).cuda()

### Data loading

In [4]:
batch_size = 8
npy_patterns = "/home/xc7ts/experiments/burgers_2d/train_data/*.npy"
future_steps = 1
max_epoch = 10

In [5]:
class BurgersSimulation(Dataset):
    def __init__(
        self,
        npy_path,
        future_steps=1,
        max_len=None,
        timestep=None,
        min_val=torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32),
        max_val=torch.tensor([15000.0, 1.0, 1.0], dtype=torch.float32),
        epsilon=1e-9,
    ):
        self.npy_path = npy_path
        self.future_steps = future_steps
        self.min_val = min_val
        self.max_val = max_val
        if max_len is not None:
            self.max_len = max_len
        else:
            tmp = np.load(npy_path, mmap_mode="r")
            self.max_len = tmp.shape[2] - self.future_steps
        if timestep is not None:
            self.timestep = timestep
        else:
            self.timestep = 1.0 / self.max_len
        self.re = self.extract_Re_number(npy_path)
        self.t0 = 0.0
        self.t1 = (
            torch.tensor(range(1, self.future_steps + 1), dtype=torch.float32)
            * self.timestep
        )
        self.epsilon = epsilon

    def extract_Re_number(self, filename):
        base_name = os.path.basename(
            filename
        )  # e.g., 'burgers_train_7500_9_8.npy'
        try:
            parts = base_name.split("_")
            Re_str = parts[2]  # '7500' in 'burgers_train_7500_9_8.npy'
            Re_number = float(Re_str)
        except (IndexError, ValueError) as e:
            raise ValueError(
                f"Filename '{filename}' is not in the expected format 'burgers_train_<Re>_*.npy'."
            ) from e
        return Re_number

    def __len__(self):
        return self.max_len

    def __getitem__(self, i):
        tmp = np.load(self.npy_path, mmap_mode="r")
        all_snaps = torch.tensor(
            tmp[:, :, i : i + self.future_steps + 1, :], dtype=torch.float32
        ).permute(2, 3, 0, 1)
        all_re = (
            torch.ones(
                all_snaps.shape[0],
                1,
                all_snaps.shape[2],
                all_snaps.shape[3],
                dtype=torch.float32,
            )
            * self.re
        )
        sim = torch.cat([all_re, all_snaps], 1)
        # Normalization
        sim = (sim - self.min_val[None, :, None, None]) / (
            self.max_val[None, :, None, None]
            - self.min_val[None, :, None, None]
            + self.epsilon
        )
        return sim[0], self.t0, self.t1, sim[1:]


def simulation_collate_fn(batch):
    ic, t0, t1, target = zip(*batch)
    # Stack the initial conditions into a tensor
    ic = torch.stack(ic, dim=0)  # (batch_size, 3, 64, 64)
    # Since t0 is always 0.0, return a single scalar tensor
    t0 = torch.tensor(0.0, dtype=torch.float32)  # Scalar tensor
    # Since t1 is consistent across all samples, take the first one
    t1 = t1[0]  # (future_steps,)
    # Stack targets into a tensor and permute to match desired shape
    target = torch.stack(target, dim=0).permute(
        1, 0, 2, 3, 4
    )  # (future_steps, batch_size, 3, 64, 64)
    return ic, t0, t1, target


list_burgers_train_dataset = []
for each_npy in glob(npy_patterns):
    list_burgers_train_dataset.append(BurgersSimulation(each_npy))
burgers_train_dataset = ConcatDataset(list_burgers_train_dataset)
train_dataloader = DataLoader(
    burgers_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    collate_fn=simulation_collate_fn,
)

In [6]:
len(train_dataloader)

1250

In [7]:
for batch in train_dataloader:
    ic, t0, t1, target = batch
    print("ic shape:", ic.shape)  # (batch_size, 3, 64, 64)
    print("t0:", t0)  # 0.0
    print("t1:", t1)  # (future_steps,)
    print("target shape:", target.shape)
    break

ic shape: torch.Size([8, 3, 64, 64])
t0: tensor(0.)
t1: tensor([0.0100])
target shape: torch.Size([1, 8, 3, 64, 64])


### Training

In [8]:
model.train()
optimizer = Adam(model.parameters(), lr=1e-5)

In [9]:
for epoch in range(max_epoch):
    epoch_loss = 0.0
    for each_train_data in tqdm(train_dataloader, total=len(train_dataloader)):
        ic, t0, t1, gt = each_train_data
        optimizer.zero_grad()
        pred = model(ic.cuda(), t0.cuda(), t1.cuda())
        loss = criterion(pred[:, :, 1:, :, :], gt.cuda()[:, :, 1:, :, :])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= len(train_dataloader)
    print(f"Epoch [{epoch+1}/{max_epoch}], Loss: {epoch_loss:.6f}")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:46<00:00, 26.83it/s]


Epoch [1/10], Loss: 0.000815


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 28.02it/s]


Epoch [2/10], Loss: 0.000169


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 27.92it/s]


Epoch [3/10], Loss: 0.000096


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:45<00:00, 27.67it/s]


Epoch [4/10], Loss: 0.000073


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 28.12it/s]


Epoch [5/10], Loss: 0.000063


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:45<00:00, 27.70it/s]


Epoch [6/10], Loss: 0.000055


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 27.78it/s]


Epoch [7/10], Loss: 0.000051


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 27.93it/s]


Epoch [8/10], Loss: 0.000046


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:44<00:00, 27.96it/s]


Epoch [9/10], Loss: 0.000043


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1250/1250 [00:45<00:00, 27.70it/s]

Epoch [10/10], Loss: 0.000040





### Prediction

In [10]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation


vmin = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
vmax = torch.tensor([15000.0, 1.0, 1.0], dtype=torch.float32)

t0 = torch.tensor(0.0, dtype=torch.float32).cuda()
t1 = torch.tensor(range(1, 101), dtype=torch.float32) * 0.01
t1 = t1.cuda()

model.load_state_dict(
    torch.load("parcv2_burger/checkpoints/epoch_000490.pt", weights_only=True)
)
model.eval()
os.makedirs("parcv2_burger/test_animation", exist_ok=True)
for each in glob("/home/xc7ts/experiments/burgers_2d/test_data/*.npy"):
    # Re
    base_name = os.path.basename(each)
    parts = base_name.split("_")
    re = float(parts[2])
    # Initial condition
    test_case = (
        torch.tensor(np.load(each), dtype=torch.float32)
        .permute(2, 3, 0, 1)
        .unsqueeze(1)
    )
    test_re = (
        torch.ones(
            test_case.shape[0],
            test_case.shape[1],
            1,
            test_case.shape[3],
            test_case.shape[4],
            dtype=torch.float32,
        )
        * re
    )
    test_gt = torch.cat([test_re, test_case], 2)
    test_gt = (test_gt - vmin[None, None, :, None, None]) / (
        vmax[None, :, None, None] - vmin[None, :, None, None] + 1e-9
    )
    test_ic = test_gt[0, :, :, :, :].cuda()
    with torch.no_grad():
        pred = model(test_ic, t0, t1)
    vel_mag = torch.sqrt(
        pred[:, 0, 1, :, :] * pred[:, 0, 1, :, :]
        + pred[:, 0, 2, :, :] * pred[:, 0, 2, :, :]
    )
    vel_mag = vel_mag.detach().cpu().numpy()
    # Animation
    vel_mag_gt = (
        torch.sqrt(
            test_gt[1:, 0, 1, :, :] * test_gt[1:, 0, 1, :, :]
            + test_gt[1:, 0, 2, :, :] * test_gt[1:, 0, 2, :, :]
        )
        .detach()
        .cpu()
        .numpy()
    )
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(18, 9))
    p_max = np.max(vel_mag_gt)
    im0 = ax0.imshow(vel_mag_gt[0, :, :], vmin=0.0, vmax=p_max)
    ax0.set_title("GT")
    im1 = ax1.imshow(vel_mag[0, :, :], vmin=0.0, vmax=p_max)
    ax1.set_title("PARCtorch")

    def animate(i):
        im0.set_array(vel_mag_gt[i, :, :])
        im1.set_array(vel_mag[i, :, :])
        return im0, im1

    anim_path = (
        "parcv2_burger/test_animation/" + base_name.strip(".npy") + ".gif"
    )
    anim = animation.FuncAnimation(fig, animate, frames=100)
    anim.save(anim_path, fps=5)
    plt.close(fig)

RuntimeError: Error(s) in loading state_dict for PARCv2:
	Missing key(s) in state_dict: "differentiator.list_adv.1.cdiff.dy_filter", "differentiator.list_adv.1.cdiff.dx_filter", "differentiator.list_dif.1.cdiff.dy_filter", "differentiator.list_dif.1.cdiff.dx_filter", "differentiator.list_mar.1.spade.spade1.initial_conv.0.weight", "differentiator.list_mar.1.spade.spade1.initial_conv.0.bias", "differentiator.list_mar.1.spade.spade1.gamma_conv.weight", "differentiator.list_mar.1.spade.spade1.gamma_conv.bias", "differentiator.list_mar.1.spade.spade1.beta_conv.weight", "differentiator.list_mar.1.spade.spade1.beta_conv.bias", "differentiator.list_mar.1.spade.conv1.weight", "differentiator.list_mar.1.spade.conv1.bias", "differentiator.list_mar.1.spade.spade2.initial_conv.0.weight", "differentiator.list_mar.1.spade.spade2.initial_conv.0.bias", "differentiator.list_mar.1.spade.spade2.gamma_conv.weight", "differentiator.list_mar.1.spade.spade2.gamma_conv.bias", "differentiator.list_mar.1.spade.spade2.beta_conv.weight", "differentiator.list_mar.1.spade.spade2.beta_conv.bias", "differentiator.list_mar.1.spade.conv2.weight", "differentiator.list_mar.1.spade.conv2.bias", "differentiator.list_mar.1.spade.spade_skip.initial_conv.0.weight", "differentiator.list_mar.1.spade.spade_skip.initial_conv.0.bias", "differentiator.list_mar.1.spade.spade_skip.gamma_conv.weight", "differentiator.list_mar.1.spade.spade_skip.gamma_conv.bias", "differentiator.list_mar.1.spade.spade_skip.beta_conv.weight", "differentiator.list_mar.1.spade.spade_skip.beta_conv.bias", "differentiator.list_mar.1.spade.conv_skip.weight", "differentiator.list_mar.1.spade.conv_skip.bias", "differentiator.list_mar.1.resnet.conv1.0.weight", "differentiator.list_mar.1.resnet.conv1.0.bias", "differentiator.list_mar.1.resnet.conv2.0.weight", "differentiator.list_mar.1.resnet.conv2.0.bias", "differentiator.list_mar.1.resnet.path.0.conv1.0.weight", "differentiator.list_mar.1.resnet.path.0.conv1.0.bias", "differentiator.list_mar.1.resnet.path.0.conv2.0.weight", "differentiator.list_mar.1.resnet.path.0.conv2.0.bias", "differentiator.list_mar.1.resnet.path.1.conv1.0.weight", "differentiator.list_mar.1.resnet.path.1.conv1.0.bias", "differentiator.list_mar.1.resnet.path.1.conv2.0.weight", "differentiator.list_mar.1.resnet.path.1.conv2.0.bias", "differentiator.list_mar.1.conv_out.weight", "differentiator.list_mar.1.conv_out.bias". 
	Unexpected key(s) in state_dict: "differentiator.list_adv.3.cdiff.dy_filter", "differentiator.list_adv.3.cdiff.dx_filter", "differentiator.list_dif.3.cdiff.dy_filter", "differentiator.list_dif.3.cdiff.dx_filter", "differentiator.list_mar.2.spade.spade1.initial_conv.0.weight", "differentiator.list_mar.2.spade.spade1.initial_conv.0.bias", "differentiator.list_mar.2.spade.spade1.gamma_conv.weight", "differentiator.list_mar.2.spade.spade1.gamma_conv.bias", "differentiator.list_mar.2.spade.spade1.beta_conv.weight", "differentiator.list_mar.2.spade.spade1.beta_conv.bias", "differentiator.list_mar.2.spade.conv1.weight", "differentiator.list_mar.2.spade.conv1.bias", "differentiator.list_mar.2.spade.spade2.initial_conv.0.weight", "differentiator.list_mar.2.spade.spade2.initial_conv.0.bias", "differentiator.list_mar.2.spade.spade2.gamma_conv.weight", "differentiator.list_mar.2.spade.spade2.gamma_conv.bias", "differentiator.list_mar.2.spade.spade2.beta_conv.weight", "differentiator.list_mar.2.spade.spade2.beta_conv.bias", "differentiator.list_mar.2.spade.conv2.weight", "differentiator.list_mar.2.spade.conv2.bias", "differentiator.list_mar.2.spade.spade_skip.initial_conv.0.weight", "differentiator.list_mar.2.spade.spade_skip.initial_conv.0.bias", "differentiator.list_mar.2.spade.spade_skip.gamma_conv.weight", "differentiator.list_mar.2.spade.spade_skip.gamma_conv.bias", "differentiator.list_mar.2.spade.spade_skip.beta_conv.weight", "differentiator.list_mar.2.spade.spade_skip.beta_conv.bias", "differentiator.list_mar.2.spade.conv_skip.weight", "differentiator.list_mar.2.spade.conv_skip.bias", "differentiator.list_mar.2.resnet.conv1.0.weight", "differentiator.list_mar.2.resnet.conv1.0.bias", "differentiator.list_mar.2.resnet.conv2.0.weight", "differentiator.list_mar.2.resnet.conv2.0.bias", "differentiator.list_mar.2.resnet.path.0.conv1.0.weight", "differentiator.list_mar.2.resnet.path.0.conv1.0.bias", "differentiator.list_mar.2.resnet.path.0.conv2.0.weight", "differentiator.list_mar.2.resnet.path.0.conv2.0.bias", "differentiator.list_mar.2.resnet.path.1.conv1.0.weight", "differentiator.list_mar.2.resnet.path.1.conv1.0.bias", "differentiator.list_mar.2.resnet.path.1.conv2.0.weight", "differentiator.list_mar.2.resnet.path.1.conv2.0.bias", "differentiator.list_mar.2.conv_out.weight", "differentiator.list_mar.2.conv_out.bias", "differentiator.feature_extraction.downBlocks.2.doubleConv.0.weight", "differentiator.feature_extraction.downBlocks.2.doubleConv.0.bias", "differentiator.feature_extraction.downBlocks.2.doubleConv.2.weight", "differentiator.feature_extraction.downBlocks.2.doubleConv.2.bias", "differentiator.feature_extraction.downBlocks.3.doubleConv.0.weight", "differentiator.feature_extraction.downBlocks.3.doubleConv.0.bias", "differentiator.feature_extraction.downBlocks.3.doubleConv.2.weight", "differentiator.feature_extraction.downBlocks.3.doubleConv.2.bias", "differentiator.feature_extraction.upBlocks.2.doubleConv.0.weight", "differentiator.feature_extraction.upBlocks.2.doubleConv.0.bias", "differentiator.feature_extraction.upBlocks.2.doubleConv.2.weight", "differentiator.feature_extraction.upBlocks.2.doubleConv.2.bias", "differentiator.feature_extraction.upBlocks.3.doubleConv.0.weight", "differentiator.feature_extraction.upBlocks.3.doubleConv.0.bias", "differentiator.feature_extraction.upBlocks.3.doubleConv.2.weight", "differentiator.feature_extraction.upBlocks.3.doubleConv.2.bias", "integrator.list_poi.0.poisson.cdiff.dy_filter", "integrator.list_poi.0.poisson.cdiff.dx_filter", "integrator.list_poi.0.conv.0.weight", "integrator.list_poi.0.conv.0.bias", "integrator.list_poi.0.conv.2.weight", "integrator.list_poi.0.conv.2.bias", "integrator.list_poi.0.conv.4.conv1.0.weight", "integrator.list_poi.0.conv.4.conv1.0.bias", "integrator.list_poi.0.conv.4.conv2.0.weight", "integrator.list_poi.0.conv.4.conv2.0.bias", "integrator.list_poi.0.conv.4.path.0.conv1.0.weight", "integrator.list_poi.0.conv.4.path.0.conv1.0.bias", "integrator.list_poi.0.conv.4.path.0.conv2.0.weight", "integrator.list_poi.0.conv.4.path.0.conv2.0.bias", "integrator.list_poi.0.conv.4.path.1.conv1.0.weight", "integrator.list_poi.0.conv.4.path.1.conv1.0.bias", "integrator.list_poi.0.conv.4.path.1.conv2.0.weight", "integrator.list_poi.0.conv.4.path.1.conv2.0.bias", "integrator.list_poi.0.conv.5.weight", "integrator.list_poi.0.conv.5.bias". 
	size mismatch for differentiator.feature_extraction.doubleConv.0.weight: copying a param with shape torch.Size([64, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
	size mismatch for differentiator.feature_extraction.upBlocks.0.doubleConv.0.weight: copying a param with shape torch.Size([512, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for differentiator.feature_extraction.upBlocks.0.doubleConv.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for differentiator.feature_extraction.upBlocks.0.doubleConv.2.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
	size mismatch for differentiator.feature_extraction.upBlocks.0.doubleConv.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for differentiator.feature_extraction.upBlocks.1.doubleConv.0.weight: copying a param with shape torch.Size([256, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 192, 3, 3]).
	size mismatch for differentiator.feature_extraction.upBlocks.1.doubleConv.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for differentiator.feature_extraction.upBlocks.1.doubleConv.2.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for differentiator.feature_extraction.upBlocks.1.doubleConv.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for differentiator.feature_extraction.finalConv.0.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for differentiator.feature_extraction.finalConv.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for differentiator.feature_extraction.finalConv.2.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for differentiator.feature_extraction.finalConv.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).