In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from RecursiveFNO import RecursiveFNO
from glob import glob
from torch.utils.data import TensorDataset, DataLoader
from neuralop.models.base_model import get_model
from configmypy import ConfigPipeline, YamlConfig, ArgparseConfig


os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.manual_seed(66)
np.random.seed(66)
torch.set_default_dtype(torch.float32)
BATCH_SIZE = 1
time_steps = 100       # Total time steps per simulation
steps = TIME_BATCH_SIZE + 1
model_select = "0pde_1mse"  # 0pde_1mse for FNO, 1pde_1mse for PIFNO

In [None]:
def load_checkpoint(model, optimizer, scheduler, save_dir):
    '''load model and optimizer'''
    checkpoint = torch.load(save_dir)
    model.load_state_dict(checkpoint['model_state_dict'])
    if (not optimizer is None):
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    print('Pretrained model loaded!')
    return model, optimizer, scheduler

In [None]:
config_name = "default"
pipe = ConfigPipeline(
    [
        YamlConfig("burgers_pino_config.yaml", config_name="default", config_folder="config"),
    ]
)
config = pipe.read_conf()
config_name = pipe.steps[-1].config_name
model_fno = get_model(config).cuda()
model = RecursiveFNO(model_fno, time_steps+1).cuda()
model.eval()
model, _, _ = load_checkpoint(model, None, None, "model/checkpoint1000_" + model_select + ".pt")

In [None]:
def read_data(datapath):
    vel_seq_whole = []
    R = []
    a = []
    w = []
    for each_file in glob(datapath):
        vis, aa, ww = os.path.basename(each_file).strip(".npy").split("_")[-3:]
        R.append(int(vis))
        vis = np.float32(vis) / 10000.0
        a.append(int(aa))
        w.append(int(ww))
        sim_data = np.float32(np.load(each_file))
        vis_data_shape = (sim_data.shape[0], sim_data.shape[1], sim_data.shape[2], 1)
        vis_data = np.empty(vis_data_shape)
        vis_data[:, :, :, :] = vis
        run_data = np.concatenate([sim_data, vis_data], axis=-1).transpose(2, 3, 1, 0)
        run_data = np.expand_dims(run_data, axis=0)
        vel_seq_whole.append(run_data)
    vel_seq_whole = np.concatenate(vel_seq_whole, axis=0)
    return vel_seq_whole, np.array(R), np.array(a), np.array(w)

In [None]:
for r_tmp in [100, 500, 3000, 6500, 12500, 15000]:
    file_reg = "data/test_data/burgers_test_%i_*.npy" % r_tmp
    test_seq_clipped, r_list, a_list, w_list = read_data(file_reg)
    test_ic = torch.Tensor(test_seq_clipped[:, 0, :, :, :])
    test_seq = torch.Tensor(test_seq_clipped[:, 1:, :2, :, :])
    test_dataset = TensorDataset(test_ic, test_seq)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    for batch_id, data in enumerate(test_dataloader):
        with torch.no_grad():
            ic, seq = data
            ic = ic.cuda()
            pred = model(ic)
            pred = pred[0, :-1, :2, :, :].transpose(1, 3).detach().cpu().numpy()
        fname = "data/pred_fno/%s/burgers_test_%i_%i_%i.npy" % (model_select,
                                                                r_list[batch_id],
                                                                a_list[batch_id],
                                                                w_list[batch_id])
        np.save(fname, pred)