In [1]:
import os
import matplotlib.pyplot as plt
import scipy.io as sio
import torch
import numpy as np
import pandas as pd

In [2]:
WD = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/'
os.chdir(WD)
print(os.getcwd())

/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals


In [3]:
from train_models import SpectralConv1d, FNO1dComplexTime, TimeDataSetResiduals

In [4]:
TRAIN_DF = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/00_residual_train.txt'
TEST_DF = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/results/00_residual_test.txt'
DATA_FP = '/local/meliao/projects/fourier_neural_operator/data/2021-06-24_NLS_data_04_test.mat'
MODEL_FP = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/models/00_residual_ep_500'
BASELINE_FP = '/local/meliao/projects/fourier_neural_operator/experiments/08_FNO_pretraining/models/00_pretrain_ep_1000'
PLOTS_DIR = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/plots/'



In [8]:
train_df = pd.read_table(TRAIN_DF)
test_df = pd.read_table(TEST_DF)
fp_train_test = os.path.join(PLOTS_DIR, 'first_train_test.png')

In [17]:
def make_train_test_plot(a_train, a_test, fp=None):
    fig, ax = plt.subplots(1, 2, sharey=False)
    fig.patch.set_facecolor('white')


    # a_train and a_test are the time-dependent FNO data. They're in the first column
    ax[0].set_title("Train")
    ax[0].plot(a_train.epoch, a_train.MSE, '-', color='red', label='train')
    ax[1].plot(a_test.epoch, a_test.test_mse, '--', color='red', label='test')
    ax[0].set_xlabel("Epoch", fontsize=13)
    ax[1].set_xlabel("Epoch", fontsize=13)
    # ax[0].legend()
    # ax[0].set_yscale('log')

    ax[0].set_ylabel("MSE", fontsize=13)

    # b_train and b_test are the time-dependent. They're in the seecond column
    ax[1].set_title("Test")
    # ax[1].plot(b_train.epoch, b_train.MSE, '-', color='blue', label='train')
    # ax[1].plot(b_test.epoch, b_test.test_mse, '--', color='blue', label='test')
    ax[1].set_xlabel("Epoch", fontsize=13)
    ax[1].set_yscale('log')
    # ax[1].legend(fontsize=13)


    plt.tight_layout()

    if fp is not None:
        plt.savefig(fp)
    else:
        plt.show()
    plt.close(fig)

In [18]:
make_train_test_plot(train_df, test_df, fp=fp_train_test)

In [12]:
d = sio.loadmat(DATA_FP)
model = torch.load(MODEL_FP, map_location='cpu')
emulator = torch.load(BASELINE_FP, map_location='cpu')

In [13]:
t_dset = TimeDataSetResiduals(d['output'][:, :7], d['t'][:, :7], d['x'], emulator)
t_dloader = torch.utils.data.DataLoader(t_dset, batch_size=1, shuffle=False)

In [19]:
def plot_check(x, y, t, preds, resid, fp=None):

    # X has size (grid_size, 3) with the columns being (Re(u_0), Im(u_0), x)
    fig, ax=plt.subplots(nrows=1, ncols=4)
    fig.set_size_inches(15,10) #15, 20 works well
    fig.patch.set_facecolor('white')

    x_real = x[:, 0].flatten()
    x_imag = x[:, 1].flatten()
    # print("X_REAL:", x_real.shape, "X_IMAG:", x_imag.shape)
    # print("PREDS_REAL:", np.real(preds).shape, "PREDS_IMAG:", np.imag(preds).shape)
    # print("Y_REAL:", np.real(y).shape, "Y_IMAG:", np.imag(y).shape)

    ax[0].set_title("$Re(u)$")
    ax[0].plot(x_real, label='Input')
    ax[0].plot(np.real(y), label='Soln') 
    ax[0].plot(np.real(preds), '--', label='Pred')

    ax[0].legend()  

    ax[1].set_title("Residuals: $Re(u)$")
    ax[1].plot(np.real(y) - np.real(preds), color='red', label='actual')
    ax[1].plot(np.real(resid), color='green', label='predicted')
    ax[1].legend()
    
    ax[2].set_title("$Im(u)$")
    ax[2].plot(x_imag, label='Input')
    ax[2].plot(np.imag(y), label='Soln')
    ax[2].plot(np.imag(preds), '--', label='Pred')
    ax[2].legend()

    ax[3].set_title("Residuals: $Im(u)$")

    ax[3].plot(np.imag(y) - np.imag(preds), color='red', label='actual')
    ax[3].plot(np.imag(resid), color='green', label='predicted')

    ax[3].legend()
    
    plt.tight_layout()
    # plt.title("T = {}".format(t))
    if fp is not None:
        plt.savefig(fp)
    else:
        plt.show()
    plt.clf()

In [20]:
model = torch.load(MODEL_FP, map_location=torch.device('cpu'))

In [21]:
n = 0
for x_i, y_i, t_i, preds_i in t_dloader:
    # x_i, y_i, t_i, preds_i = t_dset[i]
    # print(x_i.shape)
    model_resid = model(x_i, t_i)
    fp_i = os.path.join(PLOTS_DIR, 'test_case_{}.png'.format(n))    
    plot_check(x_i[0], y_i[0], t_i.item(), preds_i[0], model_resid[0].detach().numpy(), fp=fp_i)
    n += 1
    if n >= 5:
        break

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>