In [8]:
import os
import matplotlib.pyplot as plt
import scipy.io as sio
import torch
import numpy as np
import pandas as pd
import logging

In [2]:
WD = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/'
os.chdir(WD)
print(os.getcwd())

/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals


In [3]:
from train_models import SpectralConv1d, FNO1dComplexTime, TimeDataSetResiduals

In [4]:
TRAIN_DF = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/00_residual_train.txt'
TEST_DF = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/00_residual_test.txt'
DATA_FP = '/local/meliao/projects/fourier_neural_operator/data/2021-06-24_NLS_data_04_test.mat'
MODEL_FP = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/models/00_residual_ep_500'
BASELINE_FP = '/local/meliao/projects/fourier_neural_operator/experiments/08_FNO_pretraining/models/00_pretrain_ep_1000'
PLOTS_DIR = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/plots/'



In [11]:
train_df = pd.read_table(TRAIN_DF)
test_df = pd.read_table(TEST_DF)
fp_train_test = os.path.join(PLOTS_DIR, 'first_train_test.png')

In [19]:
def make_train_test_plot(a_train, a_test, fp=None, title=""):
    fig, ax = plt.subplots(1, 2, sharey=False)
    fig.patch.set_facecolor('white')


    # a_train and a_test are the time-dependent FNO data. They're in the first column
    ax[0].set_title("Train")
    ax[0].plot(a_train.epoch, a_train.MSE, '-', color='red', label='train')
    ax[1].plot(a_test.epoch, a_test.test_mse, '--', color='red', label='test')
    ax[0].set_xlabel("Epoch", fontsize=13)
    ax[1].set_xlabel("Epoch", fontsize=13)
    # ax[0].legend()
    ax[0].set_yscale('log')

    ax[0].set_ylabel("MSE", fontsize=13)

    # b_train and b_test are the time-dependent. They're in the seecond column
    ax[1].set_title("Test")
    # ax[1].plot(b_train.epoch, b_train.MSE, '-', color='blue', label='train')
    # ax[1].plot(b_test.epoch, b_test.test_mse, '--', color='blue', label='test')
    ax[1].set_xlabel("Epoch", fontsize=13)
    ax[1].set_yscale('log')
    # ax[1].legend(fontsize=13)

    fig.suptitle(title)

    plt.tight_layout()

    if fp is not None:
        plt.savefig(fp)
    else:
        plt.show()
    plt.close(fig)

In [20]:
TRAIN_PATTERN = "/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/residuals_lr_exp_{}_train.txt"
TEST_PATTERN = "/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/residuals_lr_exp_{}_test.txt"
for i in [-5, -4, -3, -2.5, -2, -2.5, -1, -0.5, 0, 1]:
    df_test_i = pd.read_table(TEST_PATTERN.format(i))
    df_train_i = pd.read_table(TRAIN_PATTERN.format(i))
    fp_out_i = os.path.join(PLOTS_DIR, 'train_test_{}.png'.format(i))
    t = "Learning Rate: 10e{}".format(i)
    make_train_test_plot(df_train_i, df_test_i, fp=fp_out_i, title=t)
    

In [18]:
make_train_test_plot(train_df, test_df, fp=fp_train_test)

In [12]:
d = sio.loadmat(DATA_FP)
model = torch.load(MODEL_FP, map_location='cpu')
emulator = torch.load(BASELINE_FP, map_location='cpu')

In [13]:
t_dset = TimeDataSetResiduals(d['output'][:, :7], d['t'][:, :7], d['x'], emulator)
t_dloader = torch.utils.data.DataLoader(t_dset, batch_size=1, shuffle=False)

In [19]:
def plot_check(x, y, t, preds, resid, fp=None):

    # X has size (grid_size, 3) with the columns being (Re(u_0), Im(u_0), x)
    fig, ax=plt.subplots(nrows=1, ncols=4)
    fig.set_size_inches(15,10) #15, 20 works well
    fig.patch.set_facecolor('white')

    x_real = x[:, 0].flatten()
    x_imag = x[:, 1].flatten()
    # print("X_REAL:", x_real.shape, "X_IMAG:", x_imag.shape)
    # print("PREDS_REAL:", np.real(preds).shape, "PREDS_IMAG:", np.imag(preds).shape)
    # print("Y_REAL:", np.real(y).shape, "Y_IMAG:", np.imag(y).shape)

    ax[0].set_title("$Re(u)$")
    ax[0].plot(x_real, label='Input')
    ax[0].plot(np.real(y), label='Soln') 
    ax[0].plot(np.real(preds), '--', label='Pred')

    ax[0].legend()  

    ax[1].set_title("Residuals: $Re(u)$")
    ax[1].plot(np.real(y) - np.real(preds), color='red', label='actual')
    ax[1].plot(np.real(resid), color='green', label='predicted')
    ax[1].legend()
    
    ax[2].set_title("$Im(u)$")
    ax[2].plot(x_imag, label='Input')
    ax[2].plot(np.imag(y), label='Soln')
    ax[2].plot(np.imag(preds), '--', label='Pred')
    ax[2].legend()

    ax[3].set_title("Residuals: $Im(u)$")

    ax[3].plot(np.imag(y) - np.imag(preds), color='red', label='actual')
    ax[3].plot(np.imag(resid), color='green', label='predicted')

    ax[3].legend()
    
    plt.tight_layout()
    # plt.title("T = {}".format(t))
    if fp is not None:
        plt.savefig(fp)
    else:
        plt.show()
    plt.clf()

In [20]:
model = torch.load(MODEL_FP, map_location=torch.device('cpu'))

In [21]:
n = 0
for x_i, y_i, t_i, preds_i in t_dloader:
    # x_i, y_i, t_i, preds_i = t_dset[i]
    # print(x_i.shape)
    model_resid = model(x_i, t_i)
    fp_i = os.path.join(PLOTS_DIR, 'test_case_{}.png'.format(n))    
    plot_check(x_i[0], y_i[0], t_i.item(), preds_i[0], model_resid[0].detach().numpy(), fp=fp_i)
    n += 1
    if n >= 5:
        break

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

In [9]:
def load_data(fp):
    logging.info("Loading data from {}".format(fp))
    data = sio.loadmat(os.path.expanduser(fp))
    return data['output'], data['t']

def load_model(fp, device):
    # Model datatypes are loaded from train_models.py
    model = torch.load(fp, map_location=device)
    return model

def l2_normalized_error(pred, actual):
    """Short summary.

    Parameters
    ----------
    pred : type
        Description of parameter `pred`.
    actual : type
        Description of parameter `actual`.

    Returns
    -------
    types
        Description of returned object.

    """
    errors = pred - actual
    error_norms = torch.linalg.norm(torch.tensor(errors), dim=-1, ord=2)
    actual_norms = torch.linalg.norm(torch.tensor(actual), dim=-1, ord=2)
    normalized_errors = torch.divide(error_norms, actual_norms)
    return normalized_errors.detach().numpy()
def prepare_input(X):
    # X has shape (nbatch, 1, grid_size)
    s = X.shape[-1]
    n_batches = X.shape[0]

    # Convert to tensor
    X_input = torch.view_as_real(torch.tensor(X, dtype=torch.cfloat))

    # FNO code appends the spatial grid to the input as below:
    x_grid = torch.linspace(-np.pi, np.pi, 1024).view(-1,1)
    X_input = torch.cat((X_input, x_grid.repeat(n_batches, 1, 1)), axis=2)

    return X_input

In [34]:
def test_TDRP_predictions(X, t_grid, model, emulator):

    # X has shape (nbatch, n_tsteps, grid_size)
    TEST_KEYS = ['TDRP', 'FNO']
    preds_dd = {i: np.zeros((X.shape[0], X.shape[1]-1, X.shape[2]), dtype=np.cdouble) for i in TEST_KEYS}
    errors_dd = {i: np.zeros((X.shape[0], X.shape[1]-1), dtype=np.double) for i in TEST_KEYS}

    one_tensor = torch.tensor(1, dtype=torch.float).repeat([X.shape[0],1,1])
    half_tensor = torch.tensor(0.5, dtype=torch.float).repeat([X.shape[0],1,1])
    IC_input = prepare_input(X[:,0,:])

    # First input is given by the initial condition
    comp_input_i = prepare_input(X[:,0,:])
    # Iterate along timesteps
    for i in range(t_grid.shape[1]-1):
        SOLN_I = torch.tensor(X[:,i+1,:])
        # First test: composing the model
        comp_preds_i = emulator(comp_input_i, one_tensor).detach().numpy()
        # comp_preds_i = emulator(comp_input_i).detach().numpy()
        preds_dd['FNO'][:,i,:] = comp_preds_i
        comp_input_i = prepare_input(comp_preds_i)
        errors_i = l2_normalized_error(torch.tensor(comp_preds_i), SOLN_I)
        errors_dd['FNO'][:,i] = errors_i

        # Second test: prediction from the TDRP
        # if time_dep_model:
        i_tensor = torch.tensor(t_grid[0,i+1], dtype=torch.float).repeat([X.shape[0],1,1])
        preds_k = model(IC_input, i_tensor).detach().numpy() + comp_preds_i
        preds_dd['TDRP'][:,i,:] = preds_k
        errors_k = l2_normalized_error(torch.tensor(preds_k), SOLN_I)
        errors_dd['TDRP'][:,i] = errors_k
    # for k,v in errors_dd.items():
    #     print("{}: {}: {}".format(k, v.shape, v))
    return preds_dd, errors_dd



In [35]:
X, t_grid = load_data(DATA_FP)
model = load_model(MODEL_FP, torch.device('cpu'))
emulator = load_model(BASELINE_FP, torch.device('cpu'))

preds_dd, errors_dd = test_TDRP_predictions(X, t_grid, model, emulator)

In [36]:
def plot_time_errors(errors_dd, t_grid, title, fp):

    n_t_steps = t_grid.shape[1]
    x_vals = t_grid.flatten()
    plt.figure().patch.set_facecolor('white')
    for k, v in errors_dd.items():
        print("{}: {}: {}".format(k, v.shape, v))
        v_means = np.mean(v, axis=0)
        v_stds = np.std(v, axis=0)
        plt.plot(x_vals, v_means, label=k, alpha=0.7)
        plt.fill_between(x_vals,
                            v_means + v_stds,
                            v_means - v_stds,
                            alpha=0.3)
    plt.legend()
    plt.xlabel("Time step")
    plt.xticks(ticks=np.arange(0, n_t_steps),
               labels=make_special_ticks(n_t_steps),
              rotation=45,
              ha='right',
              )
    plt.ylabel("$L_2$-Normalized Errors")
    # plt.yscale('log')
    plt.title(title)
    plt.tight_layout()
    plt.savefig(fp)
    plt.clf()


def make_special_ticks(n):
    s = "$t={} \\ \\to  \\ t={}$"
    return [s.format(i, i+1) for i in range(n)]


In [37]:
fp_time_errors = os.path.join(PLOTS_DIR, 'time_errors.png')
errors_dd_i = {k:np.delete(v, [59], axis=0) for k,v in errors_dd.items()}
plot_time_errors(errors_dd_i, t_grid[:,:-1], "FNO vs. TDRP: Test Dataset", fp_time_errors)

TDRP: (99, 20): [[0.13207102 0.12945511 0.13043608 ... 0.51021225 0.60373977 0.92020205]
 [0.10837244 0.11366866 0.11396072 ... 0.23144641 0.36397116 0.36967749]
 [0.1263909  0.11755799 0.12253479 ... 0.34051434 0.39924998 0.41414662]
 ...
 [0.10550525 0.08919719 0.10368662 ... 0.80181372 1.10381983 1.1059068 ]
 [0.11449791 0.10280088 0.13791706 ... 0.56430704 0.63196105 0.59508274]
 [0.15920514 0.14022268 0.15175473 ... 1.05612821 1.24324342 2.03699109]]
FNO: (99, 20): [[0.01243497 0.01598123 0.02207848 ... 0.43628975 0.49258585 0.88523006]
 [0.01481244 0.01785373 0.02005794 ... 0.18829525 0.23918457 0.25206275]
 [0.02054129 0.03006444 0.04092341 ... 0.24530984 0.31185064 0.33924729]
 ...
 [0.017487   0.0261243  0.04565185 ... 0.81696262 0.97635519 1.11683518]
 [0.01515897 0.02387284 0.03646972 ... 0.52365211 0.59921707 0.59611353]
 [0.01916176 0.02558736 0.03506259 ... 0.90327319 1.00157102 1.06147611]]


<Figure size 432x288 with 0 Axes>