In [14]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
from pyDOE import lhs

import torch
import torch.nn as nn
import torch.autograd as autograd

from torch.optim import Adam, SGD, LBFGS
from torch.utils.data import Dataset, DataLoader

In [15]:
def linear_block(in_channel, out_channel):
    block = nn.Sequential(
        nn.Linear(in_channel, out_channel), 
        nn.Tanh()
    )
    return block

class FCNet(nn.Module):
    def __init__(self, layers=[2, 10, 1]):
        super(FCNet, self).__init__()
        fc_list = [linear_block(in_size, out_size) for in_size, out_size in zip(layers, layers[1:-1])]
        fc_list.append(nn.Linear(layers[-2], layers[-1]))
        self.fc = nn.Sequential(*fc_list)
    
    def forward(self, x):
        return self.fc(x)

In [16]:
class BurgerData(Dataset):
    '''
    members: 
        - t, x, Exact: raw data
        - X, T: meshgrid 
        - X_star, u_star: flattened (x, t), u array
        - lb, ub: lower bound and upper bound vector
        - X_u, u: boundary condition data (x, t), u
    '''
    def __init__(self, datapath):
        data = scipy.io.loadmat(datapath)

        # raw 2D data
        self.t = np.linspace(0,1,101) # (100,1)
        self.x = np.linspace(0,1,128) # (256, 1)
        self.Exact = np.real(data['output'][100,:,:]).T # (100, 256)

        # Flattened sequence
        self.get_flatten_data()
        self.get_boundary_data()

    def __len__(self):
        return self.Exact.shape[0]

    def __getitem__(self, idx):
        return self.X_star[idx], self.u_star[idx]
    
    def get_flatten_data(self):
        X, T = np.meshgrid(self.x, self.t)
        self.X, self.T = X, T
        self.X_star = np.hstack((X.flatten()[:,None], T.flatten()[:,None])) # 
        self.u_star = self.Exact.flatten()[:,None]
        
        self.lb = self.X_star.min(0) # lower bound of (x, t): 2-dimensional vector 
        self.ub = self.X_star.max(0) # upper bound of (x, t): 2-dimensional vector

    def get_boundary_data(self):
        xx1 = np.hstack((self.X[0:1,:].T, self.T[0:1,:].T))
        uu1 = self.Exact[0:1,:].T
        xx2 = np.hstack((self.X[:,0:1], self.T[:,0:1]))
        uu2 = self.Exact[:,0:1]
        xx3 = np.hstack((self.X[:,-1:], self.T[:,-1:]))
        uu3 = self.Exact[:,-1:]
        self.X_u = np.vstack([xx1, xx2, xx3])
        self.u = np.vstack([uu1, uu2, uu3])
    
    def sample_xt(self, N=10000):
        '''
        Sample (x, t) pairs within the boundary
        Return:
            - X_f: (N, 2) array
        '''
        X_f = self.lb + (self.ub-self.lb)*lhs(2, N)
        X_f = np.vstack((X_f, self.X_u))
        return X_f

    def sample_xu(self, N=100):
        '''
        Sample N points from boundary data
        Return: 
            - X_u: (N, 2) array 
            - u: (N, 1) array
        '''
        idx =  np.random.choice(self.X_u.shape[0], N, replace=False)
        X_u = self.X_u[idx, :]
        u = self.u[idx,:]
        return X_u, u

In [17]:
def count_parameters(model):
    count = 0
    for p in model.parameters():
        count += p.numel()
    return count

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def zero_grad(params):
    '''
    set grad field to 0
    '''
    if isinstance(params, torch.Tensor):
        if params.grad is not None:
            params.grad.detach()
            params.grad.zero_()
    else:
        for p in params:
            if p.grad is not None:
                p.grad.detach()
                p.grad.zero_()
                
def mse(X, Y):
    loss = (X - Y) ** 2
    return loss.mean()

def PDELoss(model, x, t, nu, equation='B'):
    '''
    Compute the residual of PDE: 
        residual = u_t + u * u_x - nu * u_{xx} : (N,1)
    
    Params: 
        - model 
        - x, t: (x, t) pairs, (N, 2) tensor
        - nu: constant of PDE
        - equation: 'A' stands for Allen Cahn equation, 'B' stands for Burger's equation
    Return: 
        - mean of residual : scalar 
    '''
    u = model(torch.cat([x, t], dim=1))
    # First backward to compute u_x (shape: N x 1), u_t (shape: N x 1)
    grad_x, grad_t= autograd.grad(outputs=[u.sum()], inputs=[x, t], create_graph=True)
    # grad_x = grad_xt[:, 0]
    # grad_t = grad_xt[:, 1]

    # Second backward to compute u_{xx} (shape N x 1)
    
    gradgrad_x, = autograd.grad(outputs=[grad_x.sum()], inputs=[x], create_graph=True)
    # gradgrad_x = gradgrad[:, 0]
    if equation == 'B':
        residual = grad_t + u * grad_x - nu * gradgrad_x
    elif equation == 'A':
        residual = grad_t + 5 * u ** 3 - 5 * u - nu * gradgrad_x 
    return residual

In [18]:
def train(model, X_u, u, X_f, gamma=1.0,
          nu=1.0, num_epoch=100, equation='B',
          device=torch.device('cpu'), optim='LBFGS'):
    model.to(device)
    model.train()
    optimizer = LBFGS(model.parameters(), 
                      lr=1.0, 
                      max_iter=50000, 
                      max_eval=50000, 
                      history_size=50,
                      tolerance_grad=1e-5, 
                      tolerance_change=1.0 * np.finfo(float).eps,
                      line_search_fn="strong_wolfe")
    # mse = nn.MSELoss()
    # training stage
    xts = torch.from_numpy(X_u).float().to(device)
    us = torch.from_numpy(u).float().to(device)

    xs = torch.from_numpy(X_f[:, 0:1]).float().to(device)
    ts = torch.from_numpy(X_f[:, 1:2]).float().to(device)
    xs.requires_grad = True
    ts.requires_grad = True
    iter = 0

    def loss_closure():
        nonlocal iter
        iter = iter + 1

        optimizer.zero_grad()
        
        zero_grad(xs)
        zero_grad(ts)
        # print(xs.grad)
        # MSE loss of prediction error
        # print(xts.shape)
        pred_u = model(xts)
        mse_u = mse(pred_u, us)
        
        # MSE loss of PDE constraint
        f = PDELoss(model, xs, ts, nu, equation=equation)

        mse_f = torch.mean(f ** 2)
        loss = gamma * mse_u + mse_f
        loss.backward()
        
        if iter % 100==0:
            print('Iter: {}, total loss: {}, mse_u: {}, mse_f: {}'.
                format(iter, loss.item(), mse_u.item(), mse_f.item()))
        return loss
    
    optimizer.step(loss_closure)

    return model

In [19]:
def predict(model, x, t, nu):
    '''
    Params:
        - model: model
        - xt: (N, 2) tensor
    Return: 
        - u: (N, 1) tensor
        - residual: (N, 1) tensor
    '''
    model.eval()
    
    x.requires_grad = True
    t.requires_grad = True

    u = model(torch.cat([x, t], dim=1))

    grad_x, grad_t = autograd.grad(outputs=u.sum(), inputs=[x, t], 
                                   create_graph=True)
    gradgrad_x, = autograd.grad(outputs=grad_x.sum(), inputs=[x])

    residual = grad_t + u * grad_x - nu * gradgrad_x
    return u.detach(), residual.detach()

In [20]:
class LpLoss(object):
    '''
    loss function with rel/abs Lp loss
    '''
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

In [21]:
def eval_plot(model, trainset, tag='default'):
    myloss = LpLoss(size_average=True)
    # Compute prediction and residual of network 
    X_star = trainset.X_star
    target_u = torch.tensor(trainset.u_star).float().to(device).squeeze()

    eval_x = torch.tensor(X_star[:, 0:1]).float().to(device)
    eval_t = torch.tensor(X_star[:, 1:2]).float().to(device)

    pred_u, residual = predict(model, eval_x, eval_t, nu)
#     pred_u, residual = pred_u.squeeze(), residual.squeeze()
    
#     relative_error = torch.norm(pred_u - target_u) / torch.norm(target_u)
    relative_error = myloss(pred_u, target_u)
    print('Relative error ||pred_u - target_u||_2 / ||target_u||_2: {}'.format(relative_error.item()))
    print('Averaged PDE L1 residual error: {}'.format(torch.abs(residual).mean().item()))

#     #==== reshape pred_u to meshgrid for visualization =====
#     u_2d = griddata(X_star, pred_u.cpu().numpy(), (trainset.X, trainset.T), method='cubic')
#     error_2d = trainset.Exact - u_2d
#     t = trainset.t
#     x = trainset.x

#     # absolute error plot
#     plt.imshow(np.abs(error_2d.T), interpolation='nearest', cmap='Blues', 
#               extent=[t.min(), t.max(), x.min(), x.max()], 
#               origin='lower', aspect='auto')
#     plt.xlabel('$t$', fontsize=15)
#     plt.ylabel('$x$', fontsize=15)
#     plt.title('Absolute approximation error at each point')
#     plt.colorbar()
#     plt.savefig('abs_err-%s.png' % tag)
#     plt.show()

#     # Approximation at specific time slice
#     tvalues = [25, 50, 75]
#     fig, axes = plt.subplots(len(tvalues), 1, sharex=True)
#     fig.set_size_inches(12, 8)
#     for ax, tstep in zip(axes, tvalues):
#         ax.plot(x, u_2d[tstep, :], 'r--', linewidth=3, label='Prediction', alpha=0.9)
#         ax.plot(x, trainset.Exact[tstep, :], 'b-', linewidth=3, label='Exact', alpha=0.7)
#         ax.set_title('t=0.%d' % tstep, fontsize=15)
#         ax.set_ylabel('$u(x,t)$', fontsize=15)
#     ax.set_xlabel('$x$', fontsize=15)
#     ax.legend(loc='right')
#     plt.savefig('timeslice-%s.png' % tag)

#     fig, ax = plt.subplots()
#     # residual error plot
#     res_2d = griddata(X_star, residual.cpu().numpy(), (trainset.X, trainset.T), method='cubic')
#     plt.imshow(res_2d.T, interpolation='nearest', cmap='RdBu', 
#                 extent=[t.min(), t.max(), x.min(), x.max()], 
#                 origin='lower', aspect='auto')
#     plt.xlabel('$t$', fontsize=15)
#     plt.ylabel('$x$', fontsize=15)
#     plt.colorbar()
#     plt.savefig('pde-residual-%s.png' % tag)
#     plt.show()

In [22]:
N_u = 100
N_f = 10000
datapath = 'data/burgers_pino.mat'

trainset = BurgerData(datapath)
uxs, us = trainset.sample_xu(N_u)
X_f = trainset.sample_xt(N_f)

In [23]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
nu = 0.01 / np.pi
# config for FCN
layers = [2, 20, 20, 20, 20, 20, 20, 20, 20, 1]
# config for Fourier net
modes = 12
width = 32

In [24]:
fcn = FCNet(layers)
print('Number of parameters: ', count_parameters(fcn))
fcn = train(fcn, X_u=uxs, u=us, X_f=X_f, nu=nu, num_epoch=1, device=device)

Number of parameters:  3021
Iter: 100, total loss: 0.0010602191323414445, mse_u: 0.0009893907699733973, mse_f: 7.082833326421678e-05
Iter: 200, total loss: 0.0008900378597900271, mse_u: 0.0007325236219912767, mse_f: 0.00015751420869491994
Iter: 300, total loss: 0.0008082585409283638, mse_u: 0.000623744388576597, mse_f: 0.0001845141377998516
Iter: 400, total loss: 0.0007019194308668375, mse_u: 0.0004730932996608317, mse_f: 0.0002288261312060058
Iter: 500, total loss: 0.00046493549598380923, mse_u: 0.0003211726143490523, mse_f: 0.00014376288163475692
Iter: 600, total loss: 0.0003071909595746547, mse_u: 0.0002118288102792576, mse_f: 9.53621492953971e-05
Iter: 700, total loss: 0.00023853438324294984, mse_u: 0.00017714808927848935, mse_f: 6.13863012404181e-05
Iter: 800, total loss: 0.00021138555894140154, mse_u: 0.00015765933494549245, mse_f: 5.372622035793029e-05
Iter: 900, total loss: 0.00018969639495480806, mse_u: 0.00014686949725728482, mse_f: 4.282690133550204e-05
Iter: 1000, total los

In [25]:
eval_plot(fcn, trainset, tag='FCN')

Relative error ||pred_u - target_u||_2 / ||target_u||_2: 3.1248230934143066
Averaged PDE L1 residual error: 0.002849587006494403
