In [4]:
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import torch.fft as fft
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
import scipy.io as sio
# import h5py

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

torch.manual_seed(0)
np.random.seed(0)


In [5]:

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x


class FNO1dComplexTime(nn.Module):
    def __init__(self, modes, width):
        super(FNO1dComplexTime, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .

        input: the solution of the initial condition and location (Re(a(x)), Im(a(x)), x)
        input shape: (batchsize, x=s, c=3)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=2)
        """

        self.modes1 = modes
        self.width = width
        self.fc0 = nn.Linear(4, self.width) # input channel is 3: (Re(a(x)), Im(a(x)), x, t)

        self.conv0 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv1 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv2 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv3 = SpectralConv1d(self.width, self.width, self.modes1)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)


        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x, t):
        # print("INPUT X SHAPE: {} DTYPE: {}".format(x.shape, x.dtype))
        # print("INPUT T SHAPE: {} DTYPE: {}".format(t.shape, t.dtype))
        # print("T: {}".format(t))
        t = t.view(-1, 1, 1).repeat([1, x.shape[1], 1])
        # print("T0: {}".format(t[0]))
        # print("T1: {}".format(t[1]))
        # print("INPUT T SHAPE: {} DTYPE: {}".format(t.shape, t.dtype))
        # o = torch.ones((1,  x.size()[1]), dtype = torch.float)
        # print("INPUT O SHAPE: {} DTYPE: {}".format(o.shape, o.dtype))
        # t_arr = torch.matmul(t,  o)
        # print("T_ARR SHAPE: {}".format(t_arr.shape))
        x = torch.cat([x, t], dim=2)
        # print("X SHAPE: {}".format(x.shape))
        x = self.fc0(x)
        x = x.permute(0, 2, 1)

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.relu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.relu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.relu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return torch.view_as_complex(x)



In [84]:
class NLS_Residual_Loss:
    """
    NLS: i u_t + 1 / 2 * u_xx + |u|^2 u = 0

    """
    def __init__(self, delta_x, delta_t, n_grid_points):
        self.delta_x = delta_x
        self.delta_t = delta_t
        self.n_grid_points = n_grid_points
        # self.batch_size = batch_size
        # self.I = torch.eye(self.batch_size)
        # self.imag = torch.tensor(0+1j, dtype=torch.cfloat).repeat((self.batch_size, self.n_grid_points))


    def time_derivative(self, model, x, t):
        """
        u_t ~= (u(x, t + delta_t) - u(x, t)) / delta_t
        """
        term_1 = model(x,t)
        t_advance = t + self.delta_t
        term_2 = model(x, t_advance)
        out = torch.div(term_2 - term_1, self.delta_t)
        return out


    def spatial_discrete_derivatives(self, u):
        u_shift_right = torch.roll(u, 1, 1)
        u_shift_left = torch.roll(u, -1, 1)
        
        u_xx = (u_shift_left  - 2*u + u_shift_right) / (self.delta_x ** 2)
        return u_xx
        
    def __call__(self, model, x, t):
        # x has shape (batch_size, s, 3)
        # u has shape (batch_size, s, 1)
        return self.NLS_residual(model, x, t)

    def NLS_residual(self, model, x, t):
        u = model(x,t)

        u_abs = torch.mul(u, torch.square(torch.abs(u)))
        u_t = self.time_derivative(model, x, t)
        u_xx = self.spatial_discrete_derivatives(u)

        resid = torch.mul(u_t, 0+1j) + torch.mul(u_xx, 1/2) + u_abs

        return torch.square(resid.abs()).sum()        
        

In [85]:
DATA_FP = '/local/meliao/projects/fourier_neural_operator/data/2021-06-11_NLS_data_02/NLS_data_seed_0.mat'
MODEL_FP = '/local/meliao/projects/fourier_neural_operator/experiments/08_FNO_pretraining/models/00_pretrain_ep_1000'
PLOTS_DIR = '/local/meliao/projects/fourier_neural_operator/experiments/11_PDE_Loss/superresolution_plots'

In [86]:
d = sio.loadmat(DATA_FP)

In [104]:
class FakeModel:
    def __init__(self, X):
        self.X = torch.tensor(X)
    def __call__(self, x, t):
        t_idx = t / 0.001
        return self.X[4, int(t_idx)].view((1,1024))


class ConstantFakeModel:
    def __init__(self, a):
        self.out = torch.mul(torch.ones((1,1024), dtype=torch.cdouble), a)
    def __call__(self,x,t):
        return self.out
def prepare_input(X):
    # X has shape (nbatch, 1, grid_size)
    s = X.shape[-1]
    n_batches = X.shape[0]

    # Convert to tensor
    X_input = torch.view_as_real(torch.tensor(X, dtype=torch.cfloat))

    # FNO code appends the spatial grid to the input as below:
    x_grid = torch.linspace(-np.pi, np.pi, s).view(-1,1)
    X_input = torch.cat((X_input, x_grid.repeat(n_batches, 1, 1)), axis=2)

    return X_input

In [117]:
fake_model_1 = FakeModel(d['output'])
fake_model_2 = ConstantFakeModel(0.)
real_model = torch.load(MODEL_FP, map_location='cpu')
real_model.eval()

FNO1dComplexTime(
  (fc0): Linear(in_features=4, out_features=64, bias=True)
  (conv0): SpectralConv1d()
  (conv1): SpectralConv1d()
  (conv2): SpectralConv1d()
  (conv3): SpectralConv1d()
  (w0): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (w1): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (w2): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (w3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)

In [118]:
loss_obj = NLS_Residual_Loss(2 * np.pi / 1024, 0.001, 1024 )

In [121]:
aa = fake_model_1.X[:1,0]
x_input = prepare_input(aa)
print(x_input.shape)

for i in range(10):
    t = torch.tensor(i * 0.001)
    t0 = default_timer()

    v = loss_obj(real_model, x_input, t)
    v.backward()
    t1 = default_timer()
    print(v.item(), t1 - t0)

torch.Size([1, 1024, 3])
939334.75 0.10543419700115919
935120.75 0.06653134804219007
926235.25 0.08486056001856923
917506.0 0.0825958059867844
908060.0 0.06718615198042244
902314.0 0.07377398200333118
900137.25 0.08174675796180964
906817.0625 0.07045507500879467
910318.0 0.07308840204495937
910994.25 0.08237527508754283
