In [1]:
import sys, os, time
import torch
sys.path.append("../src/")
from run_pdebench_finetuning import get_args, get_model, build_pdebench_dataset
from engine_for_pdebench_finetuning import get_targets, unnorm_batch
import utils
from einops import rearrange
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
model_dir = 'pdebench_finetuning/k400_b/k400_b_turb_512_sweeps/dcc7rvql/k400_b_turb_512_sweeps_lr_0.008750'
# model_dir = 'pdebench_finetuning/k400_s/k400_s_turb_512_4chan_test_2'

args_json = os.path.join(utils.get_ceph_dir(), model_dir, "args.json")
args = utils.load_args(args_json)

args.num_workers = 1
args.train_split_ratio = 0.8 # will have to be removed for new series of models
args.test_split_ratio = 0.1 # will have to be removed for new series of models
args.device = 'cuda:0'
args.checkpoint = os.path.join(model_dir, 'checkpoint-49')

In [3]:
print("Dataset:\t\t", args.data_set)
print("Fields:\t\t\t", args.fields)
print("Model:\t\t\t", args.model)
print("Checkpoint:\t\t", args.checkpoint)
print("Batch size:\t\t", args.batch_size)
print("Number of workers:\t", args.num_workers)
print("Mask type:\t\t", args.mask_type)
print("Mask ratio:\t\t", args.mask_ratio)
print("Norm target mode:\t", args.norm_target_mode)
print("Num frames:\t\t", args.num_frames)
print("Device:\t\t\t", args.device)

Dataset:		 compNS_turb
Fields:			 ['Vx', 'Vy', 'density', 'pressure']
Model:			 pretrain_videomae_base_patch16_512_4chan
Checkpoint:		 pdebench_finetuning/k400_b/k400_b_turb_512_sweeps/dcc7rvql/k400_b_turb_512_sweeps_lr_0.008750/checkpoint-49
Batch size:		 1
Number of workers:	 1
Mask type:		 last_frame
Mask ratio:		 0.9
Norm target mode:	 last_frame
Num frames:		 16
Device:			 cuda:0


In [4]:
device = torch.device(args.device)

# Load model
model = get_model(args)
model.to(device)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model loaded")
print('number of params: {} M'.format(n_parameters / 1e6))

# Load dataset
dataset = build_pdebench_dataset(args, set_type='test')
data_norm_tf = dataset.transform.transform.transforms[1] # CustomNormalize object to unnormalize data
dataset.random_start = False

# Data loader
# sampler = torch.utils.data.RandomSampler(dataset)
sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
        worker_init_fn=utils.seed_worker)

Creating model: pretrain_videomae_base_patch16_512_4chan
Position interpolate from 8x14x14 to 8x32x32
Position interpolate from 8x14x14 to 8x32x32
Adapting checkpoint for PDEBench
Model loaded
number of params: 94.80128 M
Raw dataset compNS_turb has 1000 samples of shape (512, 512) and 21 timesteps.


In [5]:
def rearrange_ouput(output):
    p0, p1, p2 = 2, args.patch_size[0], args.patch_size[1]
    c = len(args.fields)
    t = 1 # For last frame prediction
    h, w = args.window_size[-2:]
    output = rearrange(output, 'b (t h w) (p0 p1 p2 c) -> b t c p0 (h p1) (w p2)', p0=p0, p1=p1, p2=p2, c=c, t=t, h=h, w=w)
    return output

In [7]:
loss_func_mse = nn.MSELoss()

def loss_func_nmse(input, target, mean_dim=None):
    x = torch.mean(torch.square(input - target), dim=(-1, -2)) / torch.mean(torch.square(target) + 1e-7, dim=(-1, -2))
    return x.mean(dim=mean_dim)

def loss_func_nrmse(input, target, mean_dim=None):
    x = torch.sqrt(torch.mean(torch.square(input - target), dim=(-1, -2)) / torch.mean(torch.square(target) + 1e-7, dim=(-1, -2)))
    return x.mean(dim=mean_dim)

In [8]:
## 1-step predictions

losses_mse = []
losses_nmse = []
losses_nrmse = []
losses_nrmse_per_field = []

model.eval()
for samples, masks in data_loader:
    samples = samples.to(device, non_blocking=True)
    samples_unnorm = data_norm_tf.unnormalize(samples.cpu())
    
    bool_masked_pos = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)

    p0, p1, p2 = 2, args.patch_size[0], args.patch_size[1]
    nchan = samples.shape[1]
    target = get_targets(samples, bool_masked_pos, args.norm_target_mode, p0=p0, p1=p1, p2=p2)
    target_unnorm = samples_unnorm[:, :, -2:, :, :].squeeze()

    with torch.no_grad():
        outputs = model(samples, bool_masked_pos)
        outputs_unnorm = unnorm_batch(outputs,
                                      norm_mode=args.norm_target_mode,
                                      patch_size=(p0, p1, p2),
                                      context=samples,
                                      bool_masked_pos=bool_masked_pos)
        outputs_unnorm = data_norm_tf.unnormalize(rearrange_ouput(outputs_unnorm.cpu())).squeeze()

        # Only keep first frame
        outputs_unnorm = outputs_unnorm[:, :1]
        target_unnorm = target_unnorm[:, :1]

        loss_mse = loss_func_mse(input=outputs_unnorm, target=target_unnorm)
        loss_nmse = loss_func_nmse(input=outputs_unnorm, target=target_unnorm)
        loss_nrmse = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm)
        loss_nrmse_per_field = loss_func_nrmse(input=outputs_unnorm, target=target_unnorm, mean_dim=1)
        
        loss_mse_value = loss_mse.item()
        loss_nmse_value = loss_nmse.item()
        loss_nrmse_value = loss_nrmse.item()
        loss_nrmse_per_field_value = loss_nrmse_per_field.numpy()

        losses_mse.append(loss_mse_value)
        losses_nmse.append(loss_nmse_value)
        losses_nrmse.append(loss_nrmse_value)
        losses_nrmse_per_field.append(loss_nrmse_per_field_value)

losses_mse = np.array(losses_mse)
losses_nmse = np.array(losses_nmse)
losses_nrmse = np.array(losses_nrmse)
losses_nrmse_per_field = np.array(losses_nrmse_per_field)

print(f"MSE: {np.mean(losses_mse):.4f} +/- {np.std(losses_mse):.4f}")
print(f"NMSE: {np.mean(losses_nmse):.4f} +/- {np.std(losses_nmse):.4f}")
print(f"NRMSE: {np.mean(losses_nrmse):.4f} +/- {np.std(losses_nrmse):.4f}")
for i in range(len(args.fields)):
    print(f"NRMSE {args.fields[i]}: {np.mean(losses_nrmse_per_field[:, i]):.4f} +/- {np.std(losses_nrmse_per_field[:, i]):.4f}")

MSE: 0.0853 +/- 0.0120
NMSE: 0.1439 +/- 0.0178
NRMSE: 0.3743 +/- 0.0229
NRMSE Vx: 0.3943 +/- 0.0460
NRMSE Vy: 0.3851 +/- 0.0406
NRMSE density: 0.2912 +/- 0.0203
NRMSE pressure: 0.4265 +/- 0.0312
