In [41]:
import os
import matplotlib.pyplot as plt
import scipy.io as sio
import torch
import numpy as np

In [42]:
print(os.getcwd())

/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals


In [43]:
WD = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/'
os.chdir(WD)
print(os.getcwd())

/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals


In [44]:
from train_models import SpectralConv1d, FNO1dComplexTime

In [54]:
EMULATOR_FP = '/local/meliao/projects/fourier_neural_operator/experiments/08_FNO_pretraining/models/00_pretrain_ep_1000'

MODEL_FP = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/models/00_residual_ep_500'
DATA_FP = '/local/meliao/projects/fourier_neural_operator/data/2021-06-24_NLS_data_04_train.mat'
PLOTS_DIR = '/local/meliao/projects/fourier_neural_operator/experiments/09_predict_residuals/plots/'


In [46]:
d = sio.loadmat(DATA_FP)
emulator = torch.load(EMULATOR_FP, map_location='cpu')

In [47]:
class TimeDataSetResiduals(torch.utils.data.Dataset):
    def __init__(self, X, t_grid, x_grid, emulator):
        super(TimeDataSetResiduals, self).__init__()
        assert X.shape[1] == t_grid.shape[-1]
        self.X = torch.tensor(X, dtype=torch.cfloat)
        self.t = torch.tensor(t_grid.flatten(), dtype=torch.float)
        self.x_grid = torch.tensor(x_grid, dtype=torch.float).view(-1, 1)
        self.n_tsteps = self.t.shape[0] - 1
        self.n_batches = self.X.shape[0]
        self.dataset_len = self.n_tsteps * self.n_batches
        self.emulator = emulator
        self.make_composed_predictions()

    def make_composed_predictions(self):
        t_interval = self.t[1]
        n_tsteps = self.X.shape[1]
        t_tensor = torch.tensor(t_interval, dtype=torch.float).repeat([self.n_batches, 1,1])
        preds = np.zeros(self.X.shape, dtype=np.cfloat)

        # The IC is at time 0
        preds[:,0] = self.X[:,0]

        comp_input_i = self.make_x_train(self.X[:,0])
        for i in range(1, n_tsteps):
            comp_preds_i = self.emulator(comp_input_i, t_tensor).detach().numpy()
            preds[:,i] = comp_preds_i
            comp_input_i = self.make_x_train(comp_preds_i)
        self.emulator_preds = preds

    def make_x_train(self, X, single_batch=False):
        # X has shape (nbatch, 1, grid_size)
        n_batches = X.shape[0] if len(X.shape) > 1 else 1

        # Convert to tensor
        X_input = torch.view_as_real(torch.tensor(X, dtype=torch.cfloat))


        if single_batch:
            X_input = torch.cat((X_input, self.x_grid), dim=1)
        else:
            x_grid_i = self.x_grid.repeat(n_batches, 1, 1)
            X_input = torch.cat((X_input.view((n_batches, -1, 2)), x_grid_i), axis=2)

        return X_input

    def __getitem__(self, idx):
        idx_original = idx
        t_idx = int(idx % self.n_tsteps) + 1
        idx = int(idx // self.n_tsteps)
        batch_idx = int(idx % self.n_batches)
        x = self.make_x_train(self.X[batch_idx, 0], single_batch=True) #.reshape(self.output_shape)
        y = self.X[batch_idx, t_idx] #.reshape(self.output_shape)
        preds = self.emulator_preds[batch_idx, t_idx]
        t = self.t[t_idx]
        return x,y,t,preds

    def __len__(self):
        return self.dataset_len

    def __repr__(self):
        return "TimeDataSetResiduals with length {}, n_tsteps {}, n_batches {}".format(self.dataset_len,
                                                                                            self.n_tsteps,
                                                                                            self.n_batches)


In [48]:
t_dset = TimeDataSetResiduals(d['output'][:, :7], d['t'][:, :7], d['x'], emulator)
t_dloader = torch.utils.data.DataLoader(t_dset, batch_size=1, shuffle=False)

In [64]:
def plot_check(x, y, t, preds, resid, fp=None):

    # X has size (grid_size, 3) with the columns being (Re(u_0), Im(u_0), x)
    fig, ax=plt.subplots(nrows=1, ncols=4)
    fig.set_size_inches(15,10) #15, 20 works well
    fig.patch.set_facecolor('white')

    x_real = x[:, 0].flatten()
    x_imag = x[:, 1].flatten()
    # print("X_REAL:", x_real.shape, "X_IMAG:", x_imag.shape)
    # print("PREDS_REAL:", np.real(preds).shape, "PREDS_IMAG:", np.imag(preds).shape)
    # print("Y_REAL:", np.real(y).shape, "Y_IMAG:", np.imag(y).shape)

    ax[0].set_title("$Re(u)$")
    ax[0].plot(x_real, label='Input')
    ax[0].plot(np.real(y), label='Soln') 
    ax[0].plot(np.real(preds), '--', label='Pred')

    ax[0].legend()  

    ax[1].set_title("Residuals: $Re(u)$")
    ax[1].plot(np.real(y) - np.real(preds), color='red', label='actual')
    ax[1].plot(np.real(resid), color='green', label='predicted')
    ax[1].legend()
    
    ax[2].set_title("$Im(u)$")
    ax[2].plot(x_imag, label='Input')
    ax[2].plot(np.imag(y), label='Soln')
    ax[2].plot(np.imag(preds), '--', label='Pred')
    ax[2].legend()

    ax[3].set_title("Residuals: $Im(u)$")

    ax[3].plot(np.imag(y) - np.imag(preds), color='red', label='actual')
    ax[3].plot(np.imag(resid), color='green', label='predicted')

    ax[3].legend()
    
    plt.tight_layout()
    plt.title("T = {}".format(t))
    if fp is not None:
        plt.savefig(fp)
    else:
        plt.show()
    plt.clf()

In [65]:
model = torch.load(MODEL_FP, map_location=torch.device('cpu'))

In [66]:
n = 0
for x_i, y_i, t_i, preds_i in t_dloader:
    # x_i, y_i, t_i, preds_i = t_dset[i]
    # print(x_i.shape)
    model_resid = model(x_i, t_i)
    fp_i = os.path.join(PLOTS_DIR, 'test_case_{}.png'.format(n))    
    plot_check(x_i[0], y_i[0], t_i.item(), preds_i[0], model_resid[0].detach().numpy(), fp=fp_i)
    n += 1
    if n >= 5:
        break

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>

<Figure size 1080x720 with 0 Axes>