In [37]:
import sys, os, time
import torch
sys.path.append("../src/")
from run_pdebench_finetuning import get_args, get_model, build_pdebench_dataset
from engine_for_pdebench_finetuning import get_targets, unnorm_batch
import utils
from einops import rearrange
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [66]:
model_dir = 'pdebench_finetuning/k400_b/final_runs/k400_b_rand_128_0.08'
# model_dir = 'pdebench_finetuning/k400_s/k400_s_turb_512_4chan_test_2'
args_json = os.path.join(utils.get_ceph_dir(), model_dir, "args.json")
args = utils.load_args(args_json)

args.num_workers = 1
args.device = 'cuda:0'
args.checkpoint = os.path.join(model_dir, 'checkpoint-499')

In [67]:
print("Dataset:\t\t", args.data_set)
print("Fields:\t\t\t", args.fields)
print("Model:\t\t\t", args.model)
print("Checkpoint:\t\t", args.checkpoint)
print("Batch size:\t\t", args.batch_size)
print("Number of workers:\t", args.num_workers)
print("Mask type:\t\t", args.mask_type)
print("Mask ratio:\t\t", args.mask_ratio)
print("Norm target mode:\t", args.norm_target_mode)
print("Num frames:\t\t", args.num_frames)
print("Device:\t\t\t", args.device)

Dataset:		 compNS_rand
Fields:			 ['Vx', 'Vy', 'density', 'pressure']
Model:			 pretrain_videomae_base_patch16_128_4chan_18f
Checkpoint:		 pdebench_finetuning/k400_b/final_runs/k400_b_rand_128_0.08/checkpoint-499
Batch size:		 1
Number of workers:	 1
Mask type:		 last_frame
Mask ratio:		 0.9
Norm target mode:	 last_frame
Num frames:		 18
Device:			 cuda:0


In [68]:
device = torch.device(args.device)

# Load model
model = get_model(args)
model.to(device)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model loaded")
print('number of params: {} M'.format(n_parameters / 1e6))

# Load dataset
dataset = build_pdebench_dataset(args, set_type='test')
data_norm_tf = dataset.transform.transform.transforms[1] # CustomNormalize object to unnormalize data
dataset.timesteps = 21
dataset.random_start = False

# Data loader
# sampler = torch.utils.data.RandomSampler(dataset)
sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        worker_init_fn=utils.seed_worker)

Creating model: pretrain_videomae_base_patch16_128_4chan_18f
Position interpolate from 8x14x14 to 9x8x8
Position interpolate from 8x14x14 to 9x8x8
Adapting checkpoint for PDEBench
Model loaded
number of params: 94.80128 M
Loading dataset file /mnt/home/gkrawezik/ceph/AI_DATASETS/PDEBench/2D/CFD/2D_Train_Rand/2D_CFD_Rand_M0.1_Eta0.01_Zeta0.01_periodic_128_Train.hdf5
Raw dataset compNS_rand has 10000 samples of shape (128, 128) and 21 timesteps.


In [69]:
def rearrange_ouput(output):
    p0, p1, p2 = 2, args.patch_size[0], args.patch_size[1]
    c = len(args.fields)
    t = 1 # For last frame prediction
    h, w = args.window_size[-2:]
    output = rearrange(output, 'b (t h w) (p0 p1 p2 c) -> b t c p0 (h p1) (w p2)', p0=p0, p1=p1, p2=p2, c=c, t=t, h=h, w=w)
    return output

In [70]:
loss_func_mse = nn.MSELoss()

def loss_func_nmse(input, target, mean_dim=None):
    x = torch.mean(torch.square(input - target), dim=(-1, -2)) / torch.mean(torch.square(target) + 1e-7, dim=(-1, -2))
    return x.mean(dim=mean_dim)

def loss_func_nrmse(input, target, mean_dim=None):
    x = torch.sqrt(torch.mean(torch.square(input - target), dim=(-1, -2)) / torch.mean(torch.square(target) + 1e-7, dim=(-1, -2)))
    return x.mean(dim=mean_dim)

In [71]:
# ## 1-step predictions

# losses_mse = []
# losses_nmse = []
# losses_nrmse = []
# losses_nrmse_per_field = []

# model.eval()
# for samples, masks in data_loader:
#     samples = samples.to(device, non_blocking=True)
#     samples_unnorm = data_norm_tf.unnormalize(samples.cpu())
    
#     bool_masked_pos = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)

#     p0, p1, p2 = 2, args.patch_size[0], args.patch_size[1]
#     nchan = samples.shape[1]
#     target = get_targets(samples, bool_masked_pos, args.norm_target_mode, p0=p0, p1=p1, p2=p2)
#     target_unnorm = samples_unnorm[:, :, -2:, :, :].squeeze()

#     with torch.no_grad():
#         outputs = model(samples, bool_masked_pos)
#         outputs_unnorm = unnorm_batch(outputs,
#                                       norm_mode=args.norm_target_mode,
#                                       patch_size=(p0, p1, p2),
#                                       context=samples,
#                                       bool_masked_pos=bool_masked_pos)
#         outputs_unnorm = data_norm_tf.unnormalize(rearrange_ouput(outputs_unnorm.cpu())).squeeze()

#         # Only keep first frame
#         outputs_unnorm = outputs_unnorm[:, :1]
#         target_unnorm = target_unnorm[:, :1]

#         loss_mse = loss_func_mse(input=outputs_unnorm, target=target_unnorm)
#         loss_nmse = loss_func_nmse(input=outputs_unnorm, target=target_unnorm)
#         loss_nrmse = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm)
#         loss_nrmse_per_field = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm, mean_dim=1)
        
#         loss_mse_value = loss_mse.item()
#         loss_nmse_value = loss_nmse.item()
#         loss_nrmse_value = loss_nrmse.item()
#         loss_nrmse_per_field_value = loss_nrmse_per_field.numpy()

#         losses_mse.append(loss_mse_value)
#         losses_nmse.append(loss_nmse_value)
#         losses_nrmse.append(loss_nrmse_value)
#         losses_nrmse_per_field.append(loss_nrmse_per_field_value)

# losses_mse = np.array(losses_mse)
# losses_nmse = np.array(losses_nmse)
# losses_nrmse = np.array(losses_nrmse)
# losses_nrmse_per_field = np.array(losses_nrmse_per_field)

# print(f"MSE: {np.mean(losses_mse):.4f} +/- {np.std(losses_mse):.4f}")
# print(f"NMSE: {np.mean(losses_nmse):.4f} +/- {np.std(losses_nmse):.4f}")
# print(f"NRMSE: {np.mean(losses_nrmse):.4f} +/- {np.std(losses_nrmse):.4f}")
# for i in range(len(args.fields)):
#     print(f"NRMSE {args.fields[i]}: {np.mean(losses_nrmse_per_field[:, i]):.4f} +/- {np.std(losses_nrmse_per_field[:, i]):.4f}")

In [72]:
## 1 and 5 step prediction

loss_func = nn.MSELoss()

n_pred_frames = 5

losses_mse_1_step = []
losses_mse_5_step = []
losses_nmse_1_step = []
losses_nmse_5_step = []
losses_nrmse_1_step = []
losses_nrmse_5_step = []
losses_nrmse_per_field_1_step = []
losses_nrmse_per_field_5_step = []

model.eval()
for samples_base, masks in data_loader:
    samples_base = samples_base.to(device, non_blocking=True)
    samples_truth_unnorm = data_norm_tf.unnormalize(samples_base.cpu())

    bool_masked_pos = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)

    p0, p1, p2 = 2, args.patch_size[0], args.patch_size[1]
    h, w = args.input_size // p1, args.input_size // p2
    nchan = samples_base.shape[1]
    
    # Extend samples_base
    samples_base = torch.cat([samples_base,
                              samples_base[:, :, -1:].repeat(1, 1, 1, 1, 1)], dim=2)

    for i in range(args.num_frames, args.num_frames + n_pred_frames):
        samples = samples_base[:, :, i-args.num_frames:i]

        with torch.no_grad():
            outputs = model(samples, bool_masked_pos)
            outputs_unnorm = unnorm_batch(outputs,
                                        norm_mode=args.norm_target_mode,
                                        patch_size=(p0, p1, p2),
                                        context=samples,
                                        bool_masked_pos=bool_masked_pos)
            
            recon_full = rearrange(samples, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=p0, p1=p1, p2=p2)
            recon_full[bool_masked_pos] = outputs_unnorm.flatten(start_dim=0, end_dim=1)
            recon_full = rearrange(recon_full, 'b (t h w) (p0 p1 p2 c) -> b c (t p0) (h p1) (w p2)', p0=p0, p1=p1, p2=p2, h=h, w=w)
            samples_base[:, :, i-2] = recon_full[:, :, -2]
        
    samples_base = samples_base[:, :, :-1]

    target_unnorm = samples_truth_unnorm[:, :, -n_pred_frames:].squeeze()
    outputs_unnorm = samples_base[:, :, -n_pred_frames:]
    outputs_unnorm = data_norm_tf.unnormalize(outputs_unnorm.cpu()).squeeze()
    
    loss_mse_1_step = loss_func_mse(input=outputs_unnorm[:, :1], target=target_unnorm[:, :1])
    loss_mse_5_step = loss_func_mse(input=outputs_unnorm, target=target_unnorm)
    loss_nmse_1_step = loss_func_nmse(input=outputs_unnorm[:, :1], target=target_unnorm[:, :1])
    loss_nmse_5_step = loss_func_nmse(input=outputs_unnorm, target=target_unnorm)
    loss_nrmse_1_step = loss_func_nrmse(input=outputs_unnorm[:, :1], target=target_unnorm[:, :1])
    loss_nrmse_5_step = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm)
    loss_nrmse_per_field_1_step = loss_func_nrmse(input=outputs_unnorm[:, :1], target=target_unnorm[:, :1], mean_dim=1)
    loss_nrmse_per_field_5_step = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm, mean_dim=1)

    loss_mse_1_step_value = loss_mse_1_step.item()
    loss_mse_5_step_value = loss_mse_5_step.item()
    loss_nmse_1_step_value = loss_nmse_1_step.item()
    loss_nmse_5_step_value = loss_nmse_5_step.item()
    loss_nrmse_1_step_value = loss_nrmse_1_step.item()
    loss_nrmse_5_step_value = loss_nrmse_5_step.item()
    loss_nrmse_per_field_1_step_value = loss_nrmse_per_field_1_step.numpy()
    loss_nrmse_per_field_5_step_value = loss_nrmse_per_field_5_step.numpy()

    losses_mse_1_step.append(loss_mse_1_step_value)
    losses_mse_5_step.append(loss_mse_5_step_value)
    losses_nmse_1_step.append(loss_nmse_1_step_value)
    losses_nmse_5_step.append(loss_nmse_5_step_value)
    losses_nrmse_1_step.append(loss_nrmse_1_step_value)
    losses_nrmse_5_step.append(loss_nrmse_5_step_value)
    losses_nrmse_per_field_1_step.append(loss_nrmse_per_field_1_step_value)
    losses_nrmse_per_field_5_step.append(loss_nrmse_per_field_5_step_value)

losses_mse_1_step = np.array(losses_mse_1_step)
losses_mse_5_step = np.array(losses_mse_5_step)
losses_nmse_1_step = np.array(losses_nmse_1_step)
losses_nmse_5_step = np.array(losses_nmse_5_step)
losses_nrmse_1_step = np.array(losses_nrmse_1_step)
losses_nrmse_5_step = np.array(losses_nrmse_5_step)
losses_nrmse_per_field_1_step = np.array(losses_nrmse_per_field_1_step)
losses_nrmse_per_field_5_step = np.array(losses_nrmse_per_field_5_step)

print(f"MSE 1 step: {np.mean(losses_mse_1_step):.4f} +/- {np.std(losses_mse_1_step):.4f}")
print(f"MSE 5 step: {np.mean(losses_mse_5_step):.4f} +/- {np.std(losses_mse_5_step):.4f}")
print(f"NMSE 1 step: {np.mean(losses_nmse_1_step):.4f} +/- {np.std(losses_nmse_1_step):.4f}")
print(f"NMSE 5 step: {np.mean(losses_nmse_5_step):.4f} +/- {np.std(losses_nmse_5_step):.4f}")
print(f"NRMSE 1 step: {np.mean(losses_nrmse_1_step):.4f} +/- {np.std(losses_nrmse_1_step):.4f}")
print(f"NRMSE 5 step: {np.mean(losses_nrmse_5_step):.4f} +/- {np.std(losses_nrmse_5_step):.4f}")
for i in range(len(args.fields)):
    print(f"NRMSE 1 step {args.fields[i]}: {np.mean(losses_nrmse_per_field_1_step[:, i]):.4f} +/- {np.std(losses_nrmse_per_field_1_step[:, i]):.4f}")
    print(f"NRMSE 5 step {args.fields[i]}: {np.mean(losses_nrmse_per_field_5_step[:, i]):.4f} +/- {np.std(losses_nrmse_per_field_5_step[:, i]):.4f}")

MSE 1 step: 0.0014 +/- 0.0020
MSE 5 step: 0.0023 +/- 0.0030
NMSE 1 step: 0.0023 +/- 0.0037
NMSE 5 step: 0.0070 +/- 0.0182
NRMSE 1 step: 0.0330 +/- 0.0163
NRMSE 5 step: 0.0495 +/- 0.0291
NRMSE 1 step Vx: 0.0576 +/- 0.0290
NRMSE 5 step Vx: 0.0892 +/- 0.0543
NRMSE 1 step Vy: 0.0605 +/- 0.0339
NRMSE 5 step Vy: 0.0921 +/- 0.0576
NRMSE 1 step density: 0.0122 +/- 0.0085
NRMSE 5 step density: 0.0150 +/- 0.0097
NRMSE 1 step pressure: 0.0015 +/- 0.0011
NRMSE 5 step pressure: 0.0018 +/- 0.0015
