In [1]:
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.fft as fft
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
import scipy.io as sio
# import h5py

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)

from NLS_Residual_Loss import FNO1dComplexTime, SpectralConv1d, NLS_Residual_Loss

In [3]:

class TimeDataSet(torch.utils.data.Dataset):
    def __init__(self, X, t_grid, x_grid):
        super(TimeDataSet, self).__init__()
        assert X.shape[1] == t_grid.shape[-1]
        self.X = torch.tensor(X, dtype=torch.cfloat)
        self.t = torch.tensor(t_grid.flatten(), dtype=torch.float)
        self.x_grid = torch.tensor(x_grid, dtype=torch.float).view(-1, 1)
        self.n_tsteps = self.t.shape[0] - 1
        self.n_batches = self.X.shape[0]
        self.dataset_len = self.n_tsteps * self.n_batches

    def make_x_train(self, x_in):
        x_in = torch.view_as_real(x_in)
        y = torch.cat([x_in, self.x_grid], axis=1)
        return y

    def __getitem__(self, idx):
        idx_original = idx
        t_idx = int(idx % self.n_tsteps) + 1
        idx = int(idx // self.n_tsteps)
        batch_idx = int(idx % self.n_batches)
        x = self.make_x_train(self.X[batch_idx, 0]) #.reshape(self.output_shape)
        y = self.X[batch_idx, t_idx] #.reshape(self.output_shape)
        t = self.t[t_idx]
        return x,y,t

    def __len__(self):
        return self.dataset_len

    def __repr__(self):
        return "TimeDataSet with length {}, n_tsteps {}, n_batches {}".format(self.dataset_len,
                                                                                            self.n_tsteps,
                                                                                            self.n_batches)


In [2]:
def load_data(fp):
    logging.info("Loading data from {}".format(fp))
    data = sio.loadmat(os.path.expanduser(fp))
    return data['output'], data['t']

def load_model(fp, device):
    # Model datatypes are loaded from train_models.py
    model = torch.load(fp, map_location=device)
    return model

def l2_normalized_error(pred, actual):
    """Short summary.

    Parameters
    ----------
    pred : type
        Description of parameter `pred`.
    actual : type
        Description of parameter `actual`.

    Returns
    -------
    types
        Description of returned object.

    """
    errors = pred - actual
    error_norms = torch.linalg.norm(torch.tensor(errors), dim=-1, ord=2)
    actual_norms = torch.linalg.norm(torch.tensor(actual), dim=-1, ord=2)
    normalized_errors = torch.divide(error_norms, actual_norms)
    return normalized_errors.detach().numpy()
def prepare_input(X):
    # X has shape (nbatch, 1, grid_size)
    s = X.shape[-1]
    n_batches = X.shape[0]

    # Convert to tensor
    X_input = torch.view_as_real(torch.tensor(X, dtype=torch.cfloat))

    # FNO code appends the spatial grid to the input as below:
    x_grid = torch.linspace(-np.pi, np.pi, 1024).view(-1,1)
    X_input = torch.cat((X_input, x_grid.repeat(n_batches, 1, 1)), axis=2)

    return X_input

In [None]:
def test_super_res_predictions(model, t_dset):

    # X has shape (nbatch, n_tsteps, grid_size)
    TEST_KEYS = ['TDLRP', 'FNO', 'TD-FNO']
    solns_dd = {1024: t_dset.X.numpy(),
                512: t_dset.X.numpy()[:,:,::2],
                256: t_dset.X.numpy()[:,:,::4],
                128: t_dset.X.numpy()[:,:,::8]}
    preds_dd = {}
    errors_dd = {}
    for k,v in solns_dd.items():
        assert k == v.shape[-1]
        preds_arr_i = np.zeros_like(v)
        preds_arr_i[:,0] = v[:,0]
        errors_arr_i = np.zeros((v.shape[0], v.shape[1]-1), dtype=np.double)
        

    preds_dd = {i: np.zeros((t_dset.X.shape[0], t_dset.X.shape[1]-1, t_dset.X.shape[2]), dtype=np.cdouble) for i in TEST_KEYS}
    errors_dd = {i: np.zeros((t_dset.X.shape[0], t_dset.X.shape[1]-1), dtype=np.double) for i in TEST_KEYS}

    one_tensor = torch.tensor(1, dtype=torch.float).repeat([t_dset.X.shape[0],1,1])
    IC_input = prepare_input(t_dset.X[:,0,:])

    # First input is given by the initial condition
    comp_input_i = prepare_input(t_dset.X[:,0,:])
    # Iterate along timesteps
    for i in range(t_dset.t.shape[0]-1):
        SOLN_I = torch.tensor(t_dset.X[:,i+1,:])
        # First test: composing the model
        comp_preds_i = fno_model(comp_input_i, one_tensor).detach().numpy()
        # comp_preds_i = emulator(comp_input_i).detach().numpy()
        preds_dd['FNO'][:,i,:] = comp_preds_i
        comp_input_i = prepare_input(comp_preds_i)
        errors_i = l2_normalized_error(torch.tensor(comp_preds_i), SOLN_I)
        errors_dd['FNO'][:,i] = errors_i

        # Second test: prediction from the TDLRP
        # if time_dep_model:
        i_tensor = torch.tensor(t_dset.t[i+1], dtype=torch.float).repeat([t_dset.X.shape[0],1,1])
        preds_k = model(IC_input, i_tensor).detach().numpy() + t_dset.linear_part_arr[:,i+1].numpy()
        preds_dd['TDLRP'][:,i,:] = preds_k
        errors_k = l2_normalized_error(torch.tensor(preds_k), SOLN_I)
        errors_dd['TDLRP'][:,i] = errors_k

        # Third test: prediction from FNO-T model
        preds_j = fno_t_model(IC_input, i_tensor).detach().numpy()
        preds_dd['TD-FNO'][:,i] = preds_j
        errors_dd['TD-FNO'][:,i] = l2_normalized_error(torch.tensor(preds_j), SOLN_I)
    # for k,v in errors_dd.items():
    #     print("{}: {}: {}".format(k, v.shape, v))
    return preds_dd, errors_dd

