In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/gray_scott_results

/content/drive/.shortcut-targets-by-id/1JV7hS7EmJJsIJBj6IGG6rYBPHje96cuY/gray_scott_results


# 1. 1D problem: the (time-independent) Burgers equation

## 1.1. MWT_1D

In [None]:
from utils_3d import train, test, LpLoss, get_filter, UnitGaussianNormalizer

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from typing import List, Tuple

import sys
import os

import numpy as np
import math
from scipy.io import loadmat, savemat
import matplotlib.pyplot as plt
from scipy.special import eval_legendre
from sympy import Poly, legendre, Symbol
import h5py


import operator
from functools import reduce
from functools import partial
from timeit import default_timer
from scipy.special import eval_legendre, gammaln

In [None]:
# torch.manual_seed(0)
# np.random.seed(0)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def exp_pade_coeff(p, q):
    J = np.arange(p)
    log_num = gammaln(p+q-J+1) + gammaln(p+1) - gammaln(p+q+1) - gammaln(J+1) - gammaln(p-J+1)
    num = np.exp(log_num)
    num[0] = 1
    J = np.arange(q)
    log_dec = gammaln(p+q-J+1) + gammaln(q+1) - gammaln(p+q+1) - gammaln(J+1) - gammaln(q-J+1)
    dec = np.exp(log_dec) * (-1)**(J)
    dec[0] = 1
    return num, dec

In [None]:
def get_initializer(name):
    
    if name == 'xavier_normal':
        init_ = partial(nn.init.xavier_normal_)
    elif name == 'kaiming_uniform':
        init_ = partial(nn.init.kaiming_uniform_)
    elif name == 'kaiming_normal':
        init_ = partial(nn.init.kaiming_normal_)
    return init_

In [None]:
class sparseKernel(nn.Module):
    def __init__(self,
                 k, alpha, c=1, 
                 nl = 1,
                 initializer = None,
                 **kwargs):
        super(sparseKernel,self).__init__()
        
        self.k = k
        self.Li = nn.Linear(c*k, 128)
        self.conv = self.convBlock(c*k, 128)
#         self.Lo = nn.Linear(alpha*k, c*k)
        self.Lo = nn.Linear(128, c*k)
        
    def forward(self, x):
        B, N, c, ich = x.shape # (B, N, c, k)
        x = x.view(B, N, -1)
#         x = F.relu(self.Li(x))
        x = x.permute(0, 2, 1)
        x = self.conv(x)
        x = x.permute(0, 2, 1)
        x = self.Lo(x)
        x = x.view(B, N, c, ich)
        
        return x
        
        
    def convBlock(self, ich, och):
        net = nn.Sequential(
            nn.Conv1d(ich, och, 3, 1, 1),
            nn.ReLU(inplace=True),
#             nn.Conv1d(och, och, 3, 1, 1),
#             nn.ReLU(inplace=True),
        )
        return net 

def compl_mul1d(x, weights):
    # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
    return torch.einsum("bix,iox->box", x, weights)


class sparseKernelFT(nn.Module):
    def __init__(self,
                 k, alpha, c=1, 
                 nl = 1,
                 initializer = None,
                 **kwargs):
        super(sparseKernelFT, self).__init__()        
        
        self.modes1 = alpha
        self.scale = (1 / (c*k*c*k))
        self.weights1 = nn.Parameter(self.scale * torch.rand(c*k, c*k, self.modes1, dtype=torch.cfloat))
        self.k = k
        
    def forward(self, x):
        B, N, ck = x.shape # (B, N, c, k)
        
#         x = x.view(B, N, -1)
        x = x.permute(0, 2, 1)
        x_fft = torch.fft.rfft(x)
        # Multiply relevant Fourier modes
        l = min(self.modes1, N//2+1)
        out_ft = torch.zeros(B, ck, N//2 + 1,  device=x.device, dtype=torch.cfloat)
        
        out_ft[:, :, :l] = compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l])
        
        #Return to physical space
        x = torch.fft.irfft(out_ft, n=N)
        x = x.permute(0, 2, 1)
#         x = x.view(B, N, c, k)
        return x
    

class pade_exponential(nn.Module):
    def __init__(self, 
                k, alpha, c=1,
                p = 3, q = 4,
                initializer = None,
                **kwargs):
        super(pade_exponential, self).__init__()
        
        self.p = p
        self.q = q
        Pp, Pq = exp_pade_coeff(p, q)
        
        self.LinOperator = sparseKernelFT(k, alpha, c)
        self.Linear = nn.Linear(c*k, c*k)
        
        self.register_buffer('Pp', torch.Tensor(Pp))
        self.register_buffer('Pq', torch.Tensor(Pq))
        
    def forward(self, x):
        B, N, c, k = x.shape
        
        x = x.view(B, N, -1)
        aggr_q = self.Pq[0] * x
        for i in range(1, self.q):
            x = self.LinOperator(x)
            aggr_q += self.Pq[i] * x
            
        aggr_q = self.Linear(aggr_q)
        aggr_q = F.relu(aggr_q)
        
        x = self.Pp[0] * aggr_q
        for i in range(1, self.p):
            aggr_q = self.LinOperator(aggr_q)
            x += self.Pp[i] * aggr_q
        
        return x.view(B, N, c, k)

    
class MWT_CZ(nn.Module):
    def __init__(self,
                 k = 3, alpha = 5, 
                 L = 0, c = 1,
                 p = 3, q = 4,
                 base = 'legendre',
                 initializer = None,
                 **kwargs):
        super(MWT_CZ, self).__init__()
        
        self.k = k
        self.L = L
        H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
        H0r = H0@PHI0
        G0r = G0@PHI0
        H1r = H1@PHI1
        G1r = G1@PHI1
        
        H0r[np.abs(H0r)<1e-8]=0
        H1r[np.abs(H1r)<1e-8]=0
        G0r[np.abs(G0r)<1e-8]=0
        G1r[np.abs(G1r)<1e-8]=0
        
        self.A = pade_exponential(k, alpha, c, p, q)
        self.B = pade_exponential(k, alpha, c, p, q)
        self.C = pade_exponential(k, alpha, c, p, q)
        
        self.T0 = nn.Linear(k, k)

        self.register_buffer('ec_s', torch.Tensor(
            np.concatenate((H0.T, H1.T), axis=0)))
        self.register_buffer('ec_d', torch.Tensor(
            np.concatenate((G0.T, G1.T), axis=0)))
        
        self.register_buffer('rc_e', torch.Tensor(
            np.concatenate((H0r, G0r), axis=0)))
        self.register_buffer('rc_o', torch.Tensor(
            np.concatenate((H1r, G1r), axis=0)))
        
        
    def forward(self, x):
        
        B, N, c, ich = x.shape # (B, N, k)
        ns = math.floor(np.log2(N))

        Ud = torch.jit.annotate(List[Tensor], [])
        Us = torch.jit.annotate(List[Tensor], [])
#         decompose
        for i in range(ns-self.L):
            d, x = self.wavelet_transform(x)
            Ud += [self.A(d) + self.B(x)]
            Us += [self.C(d)]
        x = self.T0(x) # coarsest scale transform

#        reconstruct            
        for i in range(ns-1-self.L,-1,-1):
            x = x + Us[i]
            x = torch.cat((x, Ud[i]), -1)
            x = self.evenOdd(x)
        return x

    
    def wavelet_transform(self, x):
        xa = torch.cat([x[:, ::2, :, :], 
                        x[:, 1::2, :, :], 
                       ], -1)
        d = torch.matmul(xa, self.ec_d)
        s = torch.matmul(xa, self.ec_s)
        return d, s
        
        
    def evenOdd(self, x):
        
        B, N, c, ich = x.shape # (B, N, c, k)
        assert ich == 2*self.k
        x_e = torch.matmul(x, self.rc_e)
        x_o = torch.matmul(x, self.rc_o)
        
        x = torch.zeros(B, N*2, c, self.k, 
            device = x.device)
        x[..., ::2, :, :] = x_e
        x[..., 1::2, :, :] = x_o
        return x
    
    
class MWT_exp(nn.Module):
    def __init__(self,
                 ich = 1, k = 3, alpha = 2, c = 1,
                 p = 3, q = 4,
                 nCZ = 3,
                 L = 0,
                 base = 'legendre',
                 initializer = None,
                 **kwargs):
        super(MWT_exp,self).__init__()
        
        self.k = k
        self.c = c
        self.L = L
        self.nCZ = nCZ
        self.Lk = nn.Linear(ich, c*k)
        
        self.MWT_CZ = nn.ModuleList(
            [MWT_CZ(k, alpha, L, c, 
                p, q, base, 
                initializer) 
                for _ in range(nCZ)]
        )
        self.Lc0 = nn.Linear(c*k, 128)
        self.Lc1 = nn.Linear(128, 1)
        
        if initializer is not None:
            self.reset_parameters(initializer)
        
    def forward(self, x):
        
        B, N, ich = x.shape # (B, N, d)
        ns = math.floor(np.log2(N))
        x = self.Lk(x)
        x = x.view(B, N, self.c, self.k)
    
        for i in range(self.nCZ):
            x = self.MWT_CZ[i](x)
#             
            if i < self.nCZ-1:
#                 x = torch.tanh(x)
                x = F.relu(x)
#             x = F.leaky_relu(x)

        x = x.view(B, N, -1) # collapse c and k
        x = self.Lc0(x)
        x = F.relu(x)
        x = self.Lc1(x)
        return x.squeeze()
    
    def reset_parameters(self, initializer):
        initializer(self.Lc0.weight)
        initializer(self.Lc1.weight)    

In [None]:
# Coupled Data

ntrain = 1000
ntest = 200

sub = 2**0 #subsampling rate
h = 2**8 // sub #total grid size divided by the subsampling rate
s = h


batch_size = 20

rw_u = loadmat('/content/drive/MyDrive/gray_scott_results/Coupled_PDE_data/kernel1Drho_t0_1.mat')
x_data = rw_u['rho_t0'].astype(np.float32)
y_data = rw_u['rho_t02'].astype(np.float32)
print(x_data.shape)

x_train_u = x_data[:ntrain,::sub]
y_train_u = y_data[:ntrain,::sub]
x_test_u = x_data[-ntest:,::sub]
y_test_u = y_data[-ntest:,::sub]

x_train_u = torch.from_numpy(x_train_u)
x_test_u = torch.from_numpy(x_test_u)
y_train_u = torch.from_numpy(y_train_u)
y_test_u = torch.from_numpy(y_test_u)

x_train_u = x_train_u.unsqueeze(-1)
x_test_u = x_test_u.unsqueeze(-1)


rw_v = loadmat('/content/drive/MyDrive/gray_scott_results/Coupled_PDE_data/kernel1Dphi_t0_1.mat')
x_data = rw_v['phi_t0'].astype(np.float32)
y_data = rw_v['phi_t02'].astype(np.float32)

x_train_v = x_data[:ntrain,::sub]
y_train_v = y_data[:ntrain,::sub]
x_test_v = x_data[-ntest:,::sub]
y_test_v = y_data[-ntest:,::sub]

x_train_v = torch.from_numpy(x_train_v)
x_test_v = torch.from_numpy(x_test_v)
y_train_v = torch.from_numpy(y_train_v)
y_test_v = torch.from_numpy(y_test_v)
print(y_test_u.shape)

x_train_v = x_train_v.unsqueeze(-1)
x_test_v = x_test_v.unsqueeze(-1)

x_train = torch.cat([x_train_u.reshape(ntrain,s,-1), x_train_v.reshape(ntrain,s,-1)], dim=1)
x_test = torch.cat([x_test_u.reshape(ntest,s,-1), x_test_v.reshape(ntest,s,-1)], dim=1)

y_train = torch.cat([y_train_u.reshape(ntrain,s,-1), y_train_v.reshape(ntrain,s,-1)], dim=1)
y_test = torch.cat([y_test_u.reshape(ntest,s,-1), y_test_v.reshape(ntest,s,-1)], dim=1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

(1200, 1024)
torch.Size([200, 1024])


In [None]:
# New model
ich = 1
initializer = get_initializer('xavier_normal') # xavier_normal, kaiming_normal, kaiming_uniform
torch.manual_seed(0)
np.random.seed(0)
model = MWT_exp(ich,
            alpha = 10,
            c = 4*4,
            k = 4,
            p = 4,
            q = 2,
            base = 'legendre',
            nCZ = 1,
            initializer = initializer,
            ).to(device)
learning_rate = 0.001

epochs = 500
step_size = 100
gamma = 0.5

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

maeloss = nn.L1Loss()
myloss = LpLoss(size_average=False)

error_u = []
error_v = []

mae_error_u = []
mae_error_v = []

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x)
        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        # mse.backward()
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model.eval()
    test_l2 = 0.0
    test_l2_u = 0.0
    test_l2_v = 0.0
    test_mae = 0.0
    test_mae_u = 0.0
    test_mae_v = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            test_l2_u += myloss(out[:,:s].view(batch_size, -1), y[:,:s,:].view(batch_size, -1)).item()
            test_l2_v += myloss(out[:,s:].view(batch_size, -1), y[:,s:,:].view(batch_size, -1)).item()
            test_mae += maeloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            test_mae_u += maeloss(out[:,:s].view(batch_size, -1), y[:,:s,:].view(batch_size, -1)).item()
            test_mae_v += maeloss(out[:,s:].view(batch_size, -1), y[:,s:,:].view(batch_size, -1)).item()           

    train_mse /= len(train_loader)
    train_l2 /= ntrain
    test_l2 /= ntest
    test_l2_u /= ntest
    test_l2_v /= ntest
    test_mae /= ntest
    test_mae_u /= ntest
    test_mae_v /= ntest
    error_u.append(test_l2_u)
    error_v.append(test_l2_v)
    mae_error_u.append(test_mae_u)
    mae_error_v.append(test_mae_v)

    t2 = default_timer()
    print(ep, t2-t1, train_mse,test_mae, test_mae_u,test_mae_v,test_l2_u, test_l2_v)

# np.save('/content/drive/MyDrive/gray_scott_results/Gray_scott_u_l1l1_ugrf_vgrf_Pade_cat_error',error_u)
# np.save('/content/drive/MyDrive/gray_scott_results/Gray_scott_v_l1l1_ugrf_vgrf_Pade_cat_error',error_v)
# np.save('/content/drive/MyDrive/gray_scott_results/Gray_scott_u_l1l1_ugrf_vgrf_Pade_cat_mae',mae_error_u)
# np.save('/content/drive/MyDrive/gray_scott_results/Gray_scott_v_l1l1_ugrf_vgrf_Pade_cat_mae',mae_error_v)




0 12.05906741399997 0.015114312984514981 0.0014255493320524692 0.0014888045284897088 0.0013622941076755523 0.35833069801330564 0.16866528034210204
1 7.417544961999965 0.0010060708702076227 0.00101664075627923 0.0010988858249038457 0.0009343957807868719 0.26447882890701296 0.11562112629413605
2 7.3718223710000075 0.0010267835791455582 0.0008344144513830542 0.0008731253072619438 0.0007957036048173904 0.20807263135910034 0.09790749549865722
3 7.429636240000036 0.0005076444527367129 0.00072191605810076 0.0007865834468975664 0.0006572486506775021 0.1838601279258728 0.07970373868942261
4 7.696058691000076 0.00040474608656950297 0.0007785776536911726 0.0008521331660449505 0.0007050221553072334 0.20152315020561218 0.08526223719120025
5 7.466100253999912 0.00045204564870800825 0.0008132560318335891 0.0007028093514963985 0.0009237027727067471 0.16053449988365173 0.10801330029964447
6 7.413696809000044 0.0003574660752201453 0.0005773278605192899 0.0006149974325671792 0.0005396582977846265 0.14063

KeyboardInterrupt: ignored