# It trains MFWNO on MF Darcy data (2D time-independent problem)
### HF data size = 20

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
torch.cuda.empty_cache()
import matplotlib.pyplot as plt
from utils import *

from timeit import default_timer
from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT)
from pytorch_wavelets import DTCWTForward, DTCWTInverse

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# WNO

In [None]:
class WaveConv2dCwt(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet1, wavelet2):
        super(WaveConv2dCwt, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        !! It is computationally expensive than the discrete "WaveConv2d" !!
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet_level1 = wavelet1
        self.wavelet_level2 = wavelet2        
        dummy_data = torch.randn( 1,1,*size ) 
        dwt_ = DTCWTForward(J=self.level, biort=self.wavelet_level1,
                            qshift=self.wavelet_level2)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        self.modes21 = mode_coef[-1].shape[-3]
        self.modes22 = mode_coef[-1].shape[-2]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights0 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights15r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights15c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """        
        # Compute dual tree continuous Wavelet coefficients 
        cwt = DTCWTForward(J=self.level, biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x_ft, x_coeff = cwt(x)
        
        out_ft = torch.zeros_like(x_ft, device= x.device)
        out_coeff = [torch.zeros_like(coeffs, device= x.device) for coeffs in x_coeff]
        
        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights0)
        # Multiply the final detailed wavelet coefficients        
        out_coeff[-1][:,:,0,:,:,0] = self.mul2d(x_coeff[-1][:,:,0,:,:,0].clone(), self.weights15r)
        out_coeff[-1][:,:,0,:,:,1] = self.mul2d(x_coeff[-1][:,:,0,:,:,1].clone(), self.weights15c)
        out_coeff[-1][:,:,1,:,:,0] = self.mul2d(x_coeff[-1][:,:,1,:,:,0].clone(), self.weights45r)
        out_coeff[-1][:,:,1,:,:,1] = self.mul2d(x_coeff[-1][:,:,1,:,:,1].clone(), self.weights45c)
        out_coeff[-1][:,:,2,:,:,0] = self.mul2d(x_coeff[-1][:,:,2,:,:,0].clone(), self.weights75r)
        out_coeff[-1][:,:,2,:,:,1] = self.mul2d(x_coeff[-1][:,:,2,:,:,1].clone(), self.weights75c)
        out_coeff[-1][:,:,3,:,:,0] = self.mul2d(x_coeff[-1][:,:,3,:,:,0].clone(), self.weights105r)
        out_coeff[-1][:,:,3,:,:,1] = self.mul2d(x_coeff[-1][:,:,3,:,:,1].clone(), self.weights105c)
        out_coeff[-1][:,:,4,:,:,0] = self.mul2d(x_coeff[-1][:,:,4,:,:,0].clone(), self.weights135r)
        out_coeff[-1][:,:,4,:,:,1] = self.mul2d(x_coeff[-1][:,:,4,:,:,1].clone(), self.weights135c)
        out_coeff[-1][:,:,5,:,:,0] = self.mul2d(x_coeff[-1][:,:,5,:,:,0].clone(), self.weights165r)
        out_coeff[-1][:,:,5,:,:,1] = self.mul2d(x_coeff[-1][:,:,5,:,:,1].clone(), self.weights165c)
        
        # Return to physical space        
        icwt = DTCWTInverse(biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x = icwt((out_ft, out_coeff))
        return x


In [None]:
class WNO2d_mf(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d_mf, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet1 = wavelet[0]
        self.wavelet2 = wavelet[1]
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv1 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv2 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv3 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        x = F.pad(x, [0,self.padding, 0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding, :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

# Training and Data

In [None]:
ntrain = 20
ntest = 40
epochs = 500
last_m = 600
batch_size = 5

n_total = ntrain + ntest
learning_rate = 0.001

step_size = 50
gamma = 0.75

wavelet = ['near_sym_a', 'qshift_a']  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 64       # uplifting dimension
grid_range = [1, 1]
in_channel = 4

r = 2
h = int(((101 - 1)/r) + 1)
s = h


In [None]:
# %%
""" Read data """
PATH = 'data/Darcy_Triangular_FNO_multifid_hmax018_hmin016.mat'
reader = MatReader(PATH)

x_train = np.array(reader.read_field('boundCoeff')[:,::r,::r][:,:s,:s])
y_train = np.array(reader.read_field('sol')[:,::r,::r][:,:s,:s])
y_train_l = np.array(reader.read_field('lressol')[:,::r,::r][:,:s,:s])
x_or_h = x_train[last_m-n_total:last_m].reshape((n_total,s,s,1))
y_or_h = y_train[last_m-n_total:last_m]
y_or_l = y_train_l[last_m-n_total:last_m].reshape((n_total,s,s,1))

In [None]:
y_or_l.shape

In [None]:
x_mf = np.concatenate((x_or_h,y_or_l),axis=-1)
y_mf = y_or_h - y_or_l.reshape((n_total,s,s))

x_mf = torch.tensor( x_mf, dtype=torch.float ) 
y_mf = torch.tensor( y_mf, dtype=torch.float ) 
    
generator = torch.Generator().manual_seed(453)
dataset = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_mf, y_mf),
                                    [ntrain, ntest], generator=generator)
train_dataset_mf, test_dataset_mf = dataset[0], dataset[1]


In [None]:
# Split the training and testing datasets

x_train_mf, y_train_mf = train_dataset_mf[:][0], train_dataset_mf[:][1]
x_test_mf, y_test_mf = test_dataset_mf[:][0], test_dataset_mf[:][1]

In [None]:
x_test_mf.shape

In [None]:
x_normalizer_mf = UnitGaussianNormalizer(x_train_mf)
x_train_mf = x_normalizer_mf.encode(x_train_mf)
x_test_mf = x_normalizer_mf.encode(x_test_mf)

y_normalizer = UnitGaussianNormalizer(y_train_mf)
y_train_mf = y_normalizer.encode(y_train_mf)

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                             batch_size=batch_size, shuffle=False)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


# MF Model

In [None]:
# %%
""" The MD-WNO model definition """
model_mf = WNO2d_mf(width=width, level=level, size=[s,s], wavelet=wavelet,
              in_channel=in_channel, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
y_normalizer.to(device)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model_mf(x).reshape(x.shape[0], s, s)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)
        
        mse = F.mse_loss(out.view(x.shape[0], -1), y.view(x.shape[0], -1), reduction='mean')
        loss = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1))
        loss.backward()
        optimizer.step()
        
        train_mse += mse.item()
        train_l2 += loss.item()
    
    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.to(device), y.to(device)

            out = model_mf(x).reshape(x.shape[0], s, s)
            out = y_normalizer.decode(out)

            test_l2 += myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()

    train_mse /= len(train_loader_mf)
    train_l2/= ntrain
    test_l2 /= ntest
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))
    

Epoch-113, Time-1.0185, Train-MSE-0.0000, Train-L2-0.0644, Test-L2-0.1384
Epoch-114, Time-1.0669, Train-MSE-0.0000, Train-L2-0.0637, Test-L2-0.1390
Epoch-115, Time-0.9974, Train-MSE-0.0000, Train-L2-0.0630, Test-L2-0.1418
Epoch-116, Time-1.0217, Train-MSE-0.0000, Train-L2-0.0647, Test-L2-0.1402
Epoch-117, Time-1.0760, Train-MSE-0.0000, Train-L2-0.0654, Test-L2-0.1435
Epoch-118, Time-1.0094, Train-MSE-0.0000, Train-L2-0.0673, Test-L2-0.1378
Epoch-119, Time-0.9586, Train-MSE-0.0000, Train-L2-0.0654, Test-L2-0.1396
Epoch-120, Time-0.9423, Train-MSE-0.0000, Train-L2-0.0642, Test-L2-0.1372
Epoch-121, Time-1.0482, Train-MSE-0.0000, Train-L2-0.0634, Test-L2-0.1381
Epoch-122, Time-1.0524, Train-MSE-0.0000, Train-L2-0.0620, Test-L2-0.1413
Epoch-123, Time-0.9351, Train-MSE-0.0000, Train-L2-0.0637, Test-L2-0.1392
Epoch-124, Time-1.0182, Train-MSE-0.0000, Train-L2-0.0641, Test-L2-0.1428
Epoch-125, Time-0.9724, Train-MSE-0.0000, Train-L2-0.0658, Test-L2-0.1368
Epoch-126, Time-0.9925, Train-MSE-0.00

Epoch-224, Time-1.2013, Train-MSE-0.0000, Train-L2-0.0382, Test-L2-0.1277
Epoch-225, Time-1.0531, Train-MSE-0.0000, Train-L2-0.0379, Test-L2-0.1265
Epoch-226, Time-1.1489, Train-MSE-0.0000, Train-L2-0.0380, Test-L2-0.1276
Epoch-227, Time-1.1581, Train-MSE-0.0000, Train-L2-0.0385, Test-L2-0.1278
Epoch-228, Time-1.1069, Train-MSE-0.0000, Train-L2-0.0396, Test-L2-0.1268
Epoch-229, Time-1.0607, Train-MSE-0.0000, Train-L2-0.0395, Test-L2-0.1280
Epoch-230, Time-1.1081, Train-MSE-0.0000, Train-L2-0.0393, Test-L2-0.1261
Epoch-231, Time-1.0104, Train-MSE-0.0000, Train-L2-0.0387, Test-L2-0.1275
Epoch-232, Time-1.1618, Train-MSE-0.0000, Train-L2-0.0382, Test-L2-0.1266
Epoch-233, Time-1.0810, Train-MSE-0.0000, Train-L2-0.0382, Test-L2-0.1272
Epoch-234, Time-1.0731, Train-MSE-0.0000, Train-L2-0.0382, Test-L2-0.1283
Epoch-235, Time-1.1085, Train-MSE-0.0000, Train-L2-0.0391, Test-L2-0.1265
Epoch-236, Time-0.9915, Train-MSE-0.0000, Train-L2-0.0389, Test-L2-0.1285
Epoch-237, Time-0.9748, Train-MSE-0.00

Epoch-335, Time-1.2830, Train-MSE-0.0000, Train-L2-0.0268, Test-L2-0.1218
Epoch-336, Time-1.0939, Train-MSE-0.0000, Train-L2-0.0267, Test-L2-0.1221
Epoch-337, Time-1.1654, Train-MSE-0.0000, Train-L2-0.0266, Test-L2-0.1219
Epoch-338, Time-1.0807, Train-MSE-0.0000, Train-L2-0.0266, Test-L2-0.1220
Epoch-339, Time-0.9107, Train-MSE-0.0000, Train-L2-0.0267, Test-L2-0.1222
Epoch-340, Time-1.0658, Train-MSE-0.0000, Train-L2-0.0269, Test-L2-0.1217
Epoch-341, Time-1.0560, Train-MSE-0.0000, Train-L2-0.0269, Test-L2-0.1223
Epoch-342, Time-1.1809, Train-MSE-0.0000, Train-L2-0.0268, Test-L2-0.1215
Epoch-343, Time-1.2319, Train-MSE-0.0000, Train-L2-0.0267, Test-L2-0.1224
Epoch-344, Time-1.0059, Train-MSE-0.0000, Train-L2-0.0268, Test-L2-0.1218
Epoch-345, Time-1.0270, Train-MSE-0.0000, Train-L2-0.0272, Test-L2-0.1222
Epoch-346, Time-1.1838, Train-MSE-0.0000, Train-L2-0.0274, Test-L2-0.1219
Epoch-347, Time-1.1058, Train-MSE-0.0000, Train-L2-0.0275, Test-L2-0.1217
Epoch-348, Time-1.0967, Train-MSE-0.00

Epoch-446, Time-0.9578, Train-MSE-0.0000, Train-L2-0.0216, Test-L2-0.1198
Epoch-447, Time-0.9956, Train-MSE-0.0000, Train-L2-0.0215, Test-L2-0.1198
Epoch-448, Time-0.9541, Train-MSE-0.0000, Train-L2-0.0215, Test-L2-0.1197
Epoch-449, Time-1.0658, Train-MSE-0.0000, Train-L2-0.0215, Test-L2-0.1197
Epoch-450, Time-1.0007, Train-MSE-0.0000, Train-L2-0.0214, Test-L2-0.1197
Epoch-451, Time-1.0071, Train-MSE-0.0000, Train-L2-0.0214, Test-L2-0.1197
Epoch-452, Time-0.9979, Train-MSE-0.0000, Train-L2-0.0214, Test-L2-0.1197
Epoch-453, Time-0.9343, Train-MSE-0.0000, Train-L2-0.0213, Test-L2-0.1197
Epoch-454, Time-1.0388, Train-MSE-0.0000, Train-L2-0.0213, Test-L2-0.1197
Epoch-455, Time-0.9716, Train-MSE-0.0000, Train-L2-0.0213, Test-L2-0.1197
Epoch-456, Time-0.9567, Train-MSE-0.0000, Train-L2-0.0212, Test-L2-0.1197
Epoch-457, Time-0.9217, Train-MSE-0.0000, Train-L2-0.0212, Test-L2-0.1196
Epoch-458, Time-1.0295, Train-MSE-0.0000, Train-L2-0.0212, Test-L2-0.1196
Epoch-459, Time-0.9361, Train-MSE-0.00

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_Darcy2D_20')

In [None]:
# Prediction:
pred_mf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        x, y = x.to(device), y.to(device)
        
        out = model_mf(x).reshape(x.shape[0], s, s)
        out = y_normalizer.decode(out)
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        
        test_l2 /= x.shape[0]
        print('Batch-{}, Test-L2-{:0.4f}'.format(index, test_l2))
        
        pred_mf.append(out.cpu())
        index += 1

pred_mf = torch.cat(( pred_mf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))


In [None]:
# Add the residual operator to LF-dataset 
input_mf = x_normalizer_mf.decode( x_test_mf.cpu() ) 

real_mf = y_test_mf + input_mf[..., 1]
output_mf = pred_mf + input_mf[..., 1]


In [None]:
print(real_mf.shape, output_mf.shape)

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()
mse_LF = F.mse_loss(real_mf, x_test_mf[..., 1])
mse_residual = F.mse_loss(y_test_mf, pred_mf)

print('MSE-Predicted solution-{:0.4f}, MSE-LF Data-{:0.4f}, MSE-Residual-{:0.4f}'
      .format(mse_pred, mse_LF, mse_residual))


In [None]:
fig1, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig1.subplots_adjust(hspace=0.35, wspace=0.2)

fig1.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
index = 0 
for sample in range(ntest):
    if sample % 9 == 0:
        im = axs[0, index].imshow(real_mf[sample, :, :], cmap='nipy_spectral', origin='lower' )
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_mf[sample, :, :], cmap='nipy_spectral', origin='lower' )
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(real_mf[sample, :, :] - output_mf[sample, :, :]),
                                    cmap='jet', origin='lower')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

# High Fidelity

In [None]:
class WaveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv2d, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet       
        dummy_data = torch.randn( 1,1,*size )        
        dwt_ = DWT(J=self.level, mode='symmetric', wave=self.wavelet)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet
        dwt = DWT(J=self.level, mode='symmetric', wave=self.wavelet).to(x.device)
        x_ft, x_coeff = dwt(x)

        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft, self.weights1)
        # Multiply the final detailed wavelet coefficients
        x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2)
        x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3)
        x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4)
        
        # Return to physical space        
        idwt = IDWT(mode='symmetric', wave=self.wavelet).to(x.device)
        x = idwt((out_ft, x_coeff))
        return x

In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        
        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        x = F.pad(x, [0,self.padding, 0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding, :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    

In [None]:
ntrain = ntrain
ntest = ntest
n_total = ntrain + ntest
batch_size = batch_size
learning_rate = 0.001

wavelet = 'db6'  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 64       # uplifting dimension
grid_range = [1, 1]
in_channel = 3

epochs = 250
step_size = 50
gamma = 0.75

r = 2
h = int(((101 - 1)/r) + 1)
s = h

In [None]:
# Create the input and output (residual) dataset
x_hf = torch.tensor( x_or_h, dtype=torch.float ) 
y_hf = torch.tensor( y_or_h, dtype=torch.float ) 
    
generator_hf = torch.Generator().manual_seed(453)
dataset_hf = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_hf, y_hf),
                                    [ntrain, ntest], generator=generator)
train_data_hf, test_data_hf = dataset_hf[0], dataset_hf[1]

# Split the training and testing datasets
x_train_hf, y_train_hf = train_data_hf[:][0], train_data_hf[:][1]
x_test_hf, y_test_hf = test_data_hf[:][0], test_data_hf[:][1]

x_normalizer_hf = UnitGaussianNormalizer(x_train_hf)
x_train_hf = x_normalizer_hf.encode(x_train_hf)
x_test_hf = x_normalizer_hf.encode(x_test_hf)

y_normalizer_hf = UnitGaussianNormalizer(y_train_hf)
y_train_hf = y_normalizer_hf.encode(y_train_hf)

# Define the dataloaders
train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                             batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, y_test_hf),
                                            batch_size=batch_size, shuffle=False)


In [None]:
model = WNO2d(width=width, level=level, size=[s,s], wavelet=wavelet,
              in_channel=in_channel, grid_range=grid_range).to(device)
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
y_normalizer_hf.to(device)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x).reshape(x.shape[0], s, s)
        out = y_normalizer_hf.decode(out)
        y = y_normalizer_hf.decode(y)
        
        mse = F.mse_loss(out.view(x.shape[0], -1), y.view(x.shape[0], -1), reduction='mean')
        loss = myloss(out.view(x.shape[0],-1), y.view(x.shape[0],-1))
        loss.backward()
        optimizer.step()
        
        train_mse += mse.item()
        train_l2 += loss.item()
    
    scheduler.step()
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.to(device), y.to(device)

            out = model(x).reshape(x.shape[0], s, s)
            out = y_normalizer_hf.decode(out)

            test_l2 += myloss(out.view(x.shape[0],-1), y.view(x.shape[0],-1)).item()

    train_mse /= len(train_loader_hf)
    train_l2/= ntrain
    test_l2 /= ntest
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))
    

Epoch-113, Time-0.7404, Train-MSE-0.0036, Train-L2-0.0879, Test-L2-0.2990
Epoch-114, Time-0.6728, Train-MSE-0.0030, Train-L2-0.0819, Test-L2-0.3121
Epoch-115, Time-0.6540, Train-MSE-0.0033, Train-L2-0.0889, Test-L2-0.2980
Epoch-116, Time-0.6021, Train-MSE-0.0030, Train-L2-0.0814, Test-L2-0.3062
Epoch-117, Time-0.6277, Train-MSE-0.0029, Train-L2-0.0854, Test-L2-0.3025
Epoch-118, Time-0.6134, Train-MSE-0.0035, Train-L2-0.0853, Test-L2-0.3036
Epoch-119, Time-0.6241, Train-MSE-0.0024, Train-L2-0.0759, Test-L2-0.3034
Epoch-120, Time-0.6162, Train-MSE-0.0025, Train-L2-0.0767, Test-L2-0.2976
Epoch-121, Time-0.6809, Train-MSE-0.0020, Train-L2-0.0706, Test-L2-0.2922
Epoch-122, Time-0.7502, Train-MSE-0.0019, Train-L2-0.0645, Test-L2-0.3039
Epoch-123, Time-0.6989, Train-MSE-0.0019, Train-L2-0.0640, Test-L2-0.2913
Epoch-124, Time-0.6091, Train-MSE-0.0022, Train-L2-0.0642, Test-L2-0.2968
Epoch-125, Time-0.6033, Train-MSE-0.0016, Train-L2-0.0590, Test-L2-0.2950
Epoch-126, Time-0.6226, Train-MSE-0.00

Epoch-224, Time-0.6994, Train-MSE-0.0003, Train-L2-0.0282, Test-L2-0.2764
Epoch-225, Time-0.7081, Train-MSE-0.0003, Train-L2-0.0288, Test-L2-0.2736
Epoch-226, Time-0.7253, Train-MSE-0.0004, Train-L2-0.0318, Test-L2-0.2784
Epoch-227, Time-0.7416, Train-MSE-0.0005, Train-L2-0.0340, Test-L2-0.2724
Epoch-228, Time-0.6904, Train-MSE-0.0007, Train-L2-0.0374, Test-L2-0.2786
Epoch-229, Time-0.6369, Train-MSE-0.0004, Train-L2-0.0344, Test-L2-0.2741
Epoch-230, Time-0.6350, Train-MSE-0.0003, Train-L2-0.0288, Test-L2-0.2748
Epoch-231, Time-0.6849, Train-MSE-0.0004, Train-L2-0.0295, Test-L2-0.2718
Epoch-232, Time-0.6508, Train-MSE-0.0004, Train-L2-0.0296, Test-L2-0.2746
Epoch-233, Time-0.6874, Train-MSE-0.0004, Train-L2-0.0310, Test-L2-0.2719
Epoch-234, Time-0.6515, Train-MSE-0.0004, Train-L2-0.0304, Test-L2-0.2760
Epoch-235, Time-0.6398, Train-MSE-0.0005, Train-L2-0.0325, Test-L2-0.2703
Epoch-236, Time-0.5918, Train-MSE-0.0005, Train-L2-0.0347, Test-L2-0.2752
Epoch-237, Time-0.6000, Train-MSE-0.00

In [None]:
# Save the HF-WNO model

torch.save(model, 'model/HF_WNO_Darcy2D_20')

In [None]:
# Predict on HF data using HF-WNO
pred_hf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        x = x_normalizer_mf.decode(x)
        x = x_normalizer_hf.encode(x)
        
        x, y = x.to(device), y.to(device)

        out = model(x[..., 0:1]).reshape(x.shape[0], s, s)
        out = y_normalizer_hf.decode(out)
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        test_l2 /= x.shape[0]
        print('Batch-{}, Test-L2-{:0.4f}'.format(index, test_l2))
        
        pred_hf.append(out.cpu())
        index += 1

pred_hf = torch.cat(( pred_hf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))

In [None]:
mse_pred_hf = F.mse_loss(pred_hf, y_test_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
fig2, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig2.subplots_adjust(hspace=0.35, wspace=0.2)

fig2.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
index = 0 
for sample in range(ntest):
    if sample % 9 == 0:
        im = axs[0, index].imshow(y_test_hf[sample, :, :], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(pred_hf[sample, :, :], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(y_test_hf[sample, :, :] - pred_hf[sample, :, :]),
                                    cmap='jet',origin='lower')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

In [None]:
fig4, axs = plt.subplots(nrows=2, ncols=5, figsize=(10, 4), facecolor='w', edgecolor='k')
fig4.subplots_adjust(hspace=0.35, wspace=0.2)

fig4.suptitle(f'Predictions Error', fontsize=16)
index = 0 
for sample in range(ntest):
    if sample % 9 == 0:
        im = axs[0, index].imshow(pred_hf[sample] - y_hf[sample], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_mf[sample] - y_hf[sample], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[1, index])
        index += 1
        

# Plotting

In [None]:
s = 1
xmax = s
ymax = s-8/51
from matplotlib.patches import Rectangle
plt.rcParams["font.family"] = "Serif"
plt.rcParams['font.size'] = 10

figure1 = plt.figure(constrained_layout=False, figsize = (14, 10))
plt.subplots_adjust(hspace=0.25, wspace=0.3)
index = 0
value = 0

plt.subplot(3,4, index+1)
plt.imshow(output_mf[value,:,:], origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('MFSM');
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax]); plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.5,0),0.01,0.4, facecolor='white'))
      
      #####        
plt.subplot(3,4, index+2)
plt.imshow(pred_hf[value,:,:], origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('HFSM');
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax]); plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.5,0),0.01,0.4, facecolor='white'))
        ###
plt.subplot(3,4, index+3)
# plt.imshow(pred.cpu().detach().numpy()[value,:,:], origin='lower', extent = [0, s, 0, s], interpolation='Gaussian', cmap='nipy_spectral')
plt.imshow(input_mf[value,:,:,1], origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('LFM');
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax]); plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.5,0),0.01,0.41, facecolor='white'))
        #####
plt.subplot(3,4, index+4)
# plt.imshow(pred.cpu().detach().numpy()[value,:,:], origin='lower', extent = [0, s, 0, s], interpolation='Gaussian', cmap='nipy_spectral')
plt.imshow(real_mf[value,:,:], origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('Ground Truth');
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax]); plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.5,0),0.01,0.41, facecolor='white'))
        
        ###
plt.subplot(3,4, index+5)
plt.imshow(np.abs(output_mf[value,:,:]-real_mf[value,:,:]), \
   origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('Error');
plt.colorbar(orientation="horizontal", fraction=0.04, pad=0.2)
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax])
plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.49,0),0.01,0.41, facecolor='white'))

        ###
plt.subplot(3,4, index+6)
plt.imshow(np.abs(pred_hf[value,:,:]-real_mf[value,:,:]), \
   origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('Error');
plt.colorbar(orientation="horizontal", fraction=0.04, pad=0.2)
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax])
plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.49,0),0.01,0.41, facecolor='white'))


        ###
plt.subplot(3,4, index+7)
plt.imshow(np.abs(input_mf[value,:,:,1]-real_mf[value,:,:]), \
   origin='lower', extent = [0, 1, 0, 1], interpolation='Gaussian', cmap='nipy_spectral')
plt.title('Error');
plt.colorbar(orientation="horizontal", fraction=0.04, pad=0.2)
        
xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); plt.fill_between(xf, yf, ymax, color = [1, 1, 1])
xf = np.array([0, xmax])
plt.fill_between(xf, ymax, s, color = [1, 1, 1])        
plt.gca().add_patch(Rectangle((0.49,0),0.01,0.41, facecolor='white'))
