# It performs Reliability analysis of MF Darcy equation using MFWNO (2D time-independent reliability analysis)
### HF data size = 50

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
torch.cuda.empty_cache()
import matplotlib.pyplot as plt
from utils import *

from timeit import default_timer
from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT)
from pytorch_wavelets import DTCWTForward, DTCWTInverse

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# WNO

In [None]:
class WaveConv2dCwt(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet1, wavelet2):
        super(WaveConv2dCwt, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        !! It is computationally expensive than the discrete "WaveConv2d" !!
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet_level1 = wavelet1
        self.wavelet_level2 = wavelet2        
        dummy_data = torch.randn( 1,1,*size ) 
        dwt_ = DTCWTForward(J=self.level, biort=self.wavelet_level1,
                            qshift=self.wavelet_level2)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        self.modes21 = mode_coef[-1].shape[-3]
        self.modes22 = mode_coef[-1].shape[-2]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights0 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights15r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights15c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """        
        # Compute dual tree continuous Wavelet coefficients 
        cwt = DTCWTForward(J=self.level, biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x_ft, x_coeff = cwt(x)
        
        out_ft = torch.zeros_like(x_ft, device= x.device)
        out_coeff = [torch.zeros_like(coeffs, device= x.device) for coeffs in x_coeff]
        
        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights0)
        # Multiply the final detailed wavelet coefficients        
        out_coeff[-1][:,:,0,:,:,0] = self.mul2d(x_coeff[-1][:,:,0,:,:,0].clone(), self.weights15r)
        out_coeff[-1][:,:,0,:,:,1] = self.mul2d(x_coeff[-1][:,:,0,:,:,1].clone(), self.weights15c)
        out_coeff[-1][:,:,1,:,:,0] = self.mul2d(x_coeff[-1][:,:,1,:,:,0].clone(), self.weights45r)
        out_coeff[-1][:,:,1,:,:,1] = self.mul2d(x_coeff[-1][:,:,1,:,:,1].clone(), self.weights45c)
        out_coeff[-1][:,:,2,:,:,0] = self.mul2d(x_coeff[-1][:,:,2,:,:,0].clone(), self.weights75r)
        out_coeff[-1][:,:,2,:,:,1] = self.mul2d(x_coeff[-1][:,:,2,:,:,1].clone(), self.weights75c)
        out_coeff[-1][:,:,3,:,:,0] = self.mul2d(x_coeff[-1][:,:,3,:,:,0].clone(), self.weights105r)
        out_coeff[-1][:,:,3,:,:,1] = self.mul2d(x_coeff[-1][:,:,3,:,:,1].clone(), self.weights105c)
        out_coeff[-1][:,:,4,:,:,0] = self.mul2d(x_coeff[-1][:,:,4,:,:,0].clone(), self.weights135r)
        out_coeff[-1][:,:,4,:,:,1] = self.mul2d(x_coeff[-1][:,:,4,:,:,1].clone(), self.weights135c)
        out_coeff[-1][:,:,5,:,:,0] = self.mul2d(x_coeff[-1][:,:,5,:,:,0].clone(), self.weights165r)
        out_coeff[-1][:,:,5,:,:,1] = self.mul2d(x_coeff[-1][:,:,5,:,:,1].clone(), self.weights165c)
        
        # Return to physical space        
        icwt = DTCWTInverse(biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x = icwt((out_ft, out_coeff))
        return x


In [None]:
class WNO2d_mf(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d_mf, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet1 = wavelet[0]
        self.wavelet2 = wavelet[1]
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv1 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv2 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.conv3 = WaveConv2dCwt(self.width, self.width, self.level, self.size, self.wavelet1, self.wavelet2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        x = F.pad(x, [0,self.padding, 0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding, :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

# Training and Data

In [None]:
ntrain = 50
ntest = 40
nreliability = 2000
epochs = 500
last_m = 600
batch_size = 5

n_total = ntrain + ntest
learning_rate = 0.001

step_size = 50
gamma = 0.75

wavelet = ['near_sym_a', 'qshift_a']  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 64       # uplifting dimension
grid_range = [1, 1]
in_channel = 4

r = 2
h = int(((101 - 1)/r) + 1)
s = h


In [None]:
# %%
""" Read data """
PATH = 'data/Darcy_Triangular_FNO_multifid_hmax018_hmin016.mat'
reader = MatReader(PATH)

x_train = np.array(reader.read_field('boundCoeff')[:,::r,::r][:,:s,:s])
y_train = np.array(reader.read_field('sol')[:,::r,::r][:,:s,:s])
y_train_l = np.array(reader.read_field('lressol')[:,::r,::r][:,:s,:s])
x_or_h = x_train[last_m-n_total:last_m].reshape((n_total,s,s,1))
y_or_h = y_train[last_m-n_total:last_m]
y_or_l = y_train_l[last_m-n_total:last_m].reshape((n_total,s,s,1))

x_train_rel = np.array(reader.read_field('boundCoeff')[:,::r,::r][:,:s,:s])
y_train_rel = np.array(reader.read_field('sol')[:,::r,::r][:,:s,:s])
y_train_l_rel = np.array(reader.read_field('lressol')[:,::r,::r][:,:s,:s])
x_or_h_rel = x_train[:nreliability, ...].reshape((nreliability,s,s,1))
y_or_h_rel = y_train[:nreliability, ...]
y_or_l_rel = y_train_l[:nreliability, ...].reshape((nreliability,s,s,1))

In [None]:
print(y_or_l.shape, y_or_l_rel.shape)

In [None]:
x_mf = np.concatenate((x_or_h,y_or_l),axis=-1)
y_mf = y_or_h - y_or_l.reshape((n_total,s,s))

x_mf = torch.tensor( x_mf, dtype=torch.float ) 
y_mf = torch.tensor( y_mf, dtype=torch.float ) 

x_mf_rel = np.concatenate((x_or_h_rel,y_or_l_rel),axis=-1)
y_mf_rel = y_or_h_rel - y_or_l_rel.reshape((nreliability,s,s))

x_mf_rel = torch.tensor( x_mf_rel, dtype=torch.float ) 
y_mf_rel = torch.tensor( y_mf_rel, dtype=torch.float ) 
    
generator = torch.Generator().manual_seed(453)
dataset = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_mf, y_mf),
                                    [ntrain, ntest], generator=generator)
train_dataset_mf, test_dataset_mf = dataset[0], dataset[1]


In [None]:
# Split the training and testing datasets

x_train_mf, y_train_mf = train_dataset_mf[:][0], train_dataset_mf[:][1]
x_test_mf, y_test_mf = test_dataset_mf[:][0], test_dataset_mf[:][1]

In [None]:
y_test_mf.shape

In [None]:
x_normalizer_mf = UnitGaussianNormalizer(x_train_mf)
x_train_mf = x_normalizer_mf.encode(x_train_mf)
x_test_mf = x_normalizer_mf.encode(x_test_mf)
x_test_mf_rel = x_normalizer_mf.encode(x_mf_rel)

y_normalizer = UnitGaussianNormalizer(y_train_mf)
y_train_mf = y_normalizer.encode(y_train_mf)

test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)
test_loader_mf_rel = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf_rel, y_mf_rel),
                                             batch_size=batch_size, shuffle=False)


# MF Model

In [None]:
# %%
""" The MD-WNO model definition """
model_mf = torch.load('model/MF_WNO_Darcy2D_50', map_location=device)
print(count_params(model_mf))

myloss = LpLoss(size_average=False)
y_normalizer.to(device)

In [None]:
# Prediction:
pred_mf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf_rel:
        x, y = x.to(device), y.to(device)
        
        t1 = default_timer()
        out = model_mf(x).reshape(x.shape[0], s, s)
        out = y_normalizer.decode(out)
        t2 = default_timer()
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        
        test_l2 /= x.shape[0]
        if index % 25 == 0:
            print('Batch-{}, Time-{}, Test-L2-{:0.4f}'.format(index, t2-t1, test_l2))
        
        pred_mf.append(out.cpu())
        index += 1

pred_mf = torch.cat(( pred_mf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_mf_rel, pred_mf).item()))


In [None]:
# Add the residual operator to LF-dataset 
input_mf = x_normalizer_mf.decode( x_test_mf_rel.cpu() ) 

real_mf = y_mf_rel + input_mf[..., 1]
output_mf = pred_mf + input_mf[..., 1]


In [None]:
print(real_mf.shape, output_mf.shape)

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()
mse_LF = F.mse_loss(real_mf, x_test_mf_rel[..., 1])
mse_residual = F.mse_loss(y_mf_rel, pred_mf)

print('MSE-Predicted solution-{:0.4f}, MSE-LF Data-{:0.4f}, MSE-Residual-{:0.4f}'
      .format(mse_pred, mse_LF, mse_residual))


In [None]:
# Compute error statistics

error = (real_mf - output_mf)**2
mse_mean = torch.mean(error)
mse_std = torch.std(error)

print('MSE_mean-{}, MSE-std-{}'.format(mse_mean, mse_std))


In [None]:
fig1, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig1.subplots_adjust(hspace=0.35, wspace=0.2)

fig1.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
index = 0 
for sample in range(ntest):
    if sample % 9 == 0:
        im = axs[0, index].imshow(real_mf[sample, :, :], cmap='nipy_spectral', origin='lower' )
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_mf[sample, :, :], cmap='nipy_spectral', origin='lower' )
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(real_mf[sample, :, :] - output_mf[sample, :, :]),
                                    cmap='jet', origin='lower')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

# High Fidelity

In [None]:
class WaveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv2d, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet       
        dummy_data = torch.randn( 1,1,*size )        
        dwt_ = DWT(J=self.level, mode='symmetric', wave=self.wavelet)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet
        dwt = DWT(J=self.level, mode='symmetric', wave=self.wavelet).to(x.device)
        x_ft, x_coeff = dwt(x)

        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft, self.weights1)
        # Multiply the final detailed wavelet coefficients
        x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2)
        x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3)
        x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4)
        
        # Return to physical space        
        idwt = IDWT(mode='symmetric', wave=self.wavelet).to(x.device)
        x = idwt((out_ft, x_coeff))
        return x

In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=4)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        
        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        x = F.pad(x, [0,self.padding, 0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x[..., :-self.padding, :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    

In [None]:
ntrain = ntrain
ntest = ntest
n_total = ntrain + ntest
batch_size = batch_size
learning_rate = 0.001

wavelet = 'db6'  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 64       # uplifting dimension
grid_range = [1, 1]
in_channel = 3

epochs = 250
step_size = 50
gamma = 0.75

r = 2
h = int(((101 - 1)/r) + 1)
s = h

In [None]:
# Create the input and output (residual) dataset
x_hf = torch.tensor( x_or_h, dtype=torch.float ) 
y_hf = torch.tensor( y_or_h, dtype=torch.float ) 

x_hf_rel = torch.tensor( x_or_h_rel, dtype=torch.float ) 
y_hf_rel = torch.tensor( y_or_h_rel, dtype=torch.float ) 
    
generator_hf = torch.Generator().manual_seed(453)
dataset_hf = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_hf, y_hf),
                                    [ntrain, ntest], generator=generator)
train_data_hf, test_data_hf = dataset_hf[0], dataset_hf[1]

# Split the training and testing datasets
x_train_hf, y_train_hf = train_data_hf[:][0], train_data_hf[:][1]
x_test_hf, y_test_hf = test_data_hf[:][0], test_data_hf[:][1]

x_normalizer_hf = UnitGaussianNormalizer(x_train_hf)
x_train_hf = x_normalizer_hf.encode(x_train_hf)
x_test_hf = x_normalizer_hf.encode(x_test_hf)
x_test_hf_rel = x_normalizer_hf.encode(x_hf_rel)

y_normalizer_hf = UnitGaussianNormalizer(y_train_hf)
y_train_hf = y_normalizer_hf.encode(y_train_hf)

# Define the dataloaders
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, y_test_hf),
                                            batch_size=batch_size, shuffle=False)
test_loader_hf_rel = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf_rel, y_hf_rel),
                                            batch_size=batch_size, shuffle=False)


In [None]:
model = torch.load('model/HF_WNO_Darcy2D_50', map_location=device)
print(count_params(model))

myloss = LpLoss(size_average=False)
y_normalizer_hf.to(device)

In [None]:
# Predict on HF data using HF-WNO
pred_hf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_hf_rel:        
        x, y = x.to(device), y.to(device)
        t1 = default_timer()
        out = model(x).reshape(x.shape[0], s, s)
        out = y_normalizer_hf.decode(out)
        t2 = default_timer()
        
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        test_l2 /= x.shape[0]
        if index % 25 == 0:
            print('Batch-{}, Time-{}, Test-L2-{:0.4f}'.format(index, t2-t1, test_l2))
        
        pred_hf.append(out.cpu())
        index += 1

pred_hf = torch.cat(( pred_hf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_hf_rel, pred_hf).item()))

In [None]:
mse_pred_hf = F.mse_loss(pred_hf, y_hf_rel).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
# Compute error statistics

error = (real_mf - pred_hf)**2
mse_mean = torch.mean(error)
mse_std = torch.std(error)

print('MSE_mean-{}, MSE-std-{}'.format(mse_mean, mse_std))


In [None]:
fig2, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig2.subplots_adjust(hspace=0.35, wspace=0.2)

fig2.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
index = 0 
for sample in range(ntest):
    if sample % 9 == 0:
        im = axs[0, index].imshow(y_hf_rel[sample, :, :], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(pred_hf[sample, :, :], cmap='nipy_spectral',origin='lower')
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(y_hf_rel[sample, :, :] - pred_hf[sample, :, :]),
                                    cmap='jet',origin='lower')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

In [None]:
print(torch.max(real_mf[0]), torch.max(output_mf[0]), torch.max(pred_hf[0]))

In [None]:
print(real_mf.shape, output_mf.shape, pred_hf.shape)

In [None]:
# %%
eh = 2.0
eh_mcs = np.zeros(ntest)
eh_wno_mf = np.zeros(ntest)
eh_wno_hf = np.zeros(ntest)

for i in range(ntest):
    if len( np.where( real_mf[i, ...] > eh )[0] ) == 0:
        eh_mcs[i] = 0
    else:
        eh_mcs[i] = 1
        
for i in range(ntest):
    if len( np.where( output_mf[i, ...] > eh )[0] ) == 0:
        eh_wno_mf[i] = 0
    else:
        eh_wno_mf[i] = 1
        
for i in range(ntest):
    if len( np.where( pred_hf[i, ...] > eh )[0] ) == 0:
        eh_wno_hf[i] = 0
    else:
        eh_wno_hf[i] = 1

pf_wno_mf = len(np.where(eh_wno_mf!=0)[0])/ntest
pf_wno_hf = len(np.where(eh_wno_hf!=0)[0])/ntest
pf_mcs = len(np.where(eh_mcs!=0)[0])/ntest
print('Prob. of failure, MFWNO-{}, HFWNO-{}, MCS-{}'.format(pf_wno_mf, pf_wno_hf, pf_mcs))


# Plotting

In [None]:
real_mf.shape

In [None]:
s = 1
xmax = s
ymax = s-8/51
from matplotlib.patches import Rectangle
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['font.size'] = 12

figure1, ax = plt.subplots(nrows=5, ncols=6, figsize=(14, 8), dpi=100)
plt.subplots_adjust(hspace=0.30, wspace=0.30)

index = 0
for value in range(nreliability):
    if value % 340 == 0:

        ax[0,index].imshow(real_mf[value, ...], origin='lower', extent=[0,1,0,1], interpolation='Gaussian', cmap='nipy_spectral')
        xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); 
        ax[0,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); 
        ax[0,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([0, xmax]); ax[0,index].fill_between(xf, ymax, s, color = [1, 1, 1])        
        ax[0,index].add_patch(Rectangle((0.5,0),0.02,0.4, facecolor='white'))
        ax[0,index].set_title('Sample-{}'.format(value), color='r')
        if index == 0:
            ax[0,index].set_ylabel('Inrerpolated BC', color='r')
            
        ax[1,index].imshow(pred_hf[value, ...], origin='lower', extent=[0,1,0,1], interpolation='Gaussian', cmap='nipy_spectral')
        xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); 
        ax[1,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); 
        ax[1,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([0, xmax]); ax[1,index].fill_between(xf, ymax, s, color = [1, 1, 1])        
        ax[1,index].add_patch(Rectangle((0.5,0),0.02,0.4, facecolor='white'))
        if index == 0:
            ax[1,index].set_ylabel('HFWNO', color='green')
        
        im = ax[2,index].imshow(np.abs(real_mf[value, ...] - pred_hf[value, ...]), origin='lower',
                           extent=[0,1,0,1], interpolation='Gaussian', cmap='nipy_spectral', vmin=0, vmax=0.1)
        xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); 
        ax[2,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); 
        ax[2,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([0, xmax]); ax[2,index].fill_between(xf, ymax, s, color = [1, 1, 1])        
        ax[2,index].add_patch(Rectangle((0.5,0),0.02,0.4, facecolor='white'))
        plt.colorbar(im, ax=ax[2,index], fraction=0.2)
        if index == 0:
            ax[2,index].set_ylabel('Error-HFWNO', color='green')
            
        ax[3,index].imshow(output_mf[value, ...], origin='lower', extent=[0,1,0,1], interpolation='Gaussian', cmap='nipy_spectral')
        xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); 
        ax[3,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); 
        ax[3,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([0, xmax]); ax[3,index].fill_between(xf, ymax, s, color = [1, 1, 1])        
        ax[3,index].add_patch(Rectangle((0.5,0),0.02,0.4, facecolor='white'))
        if index == 0:
            ax[3,index].set_ylabel('MFWNO', color='blue')

        im = ax[4,index].imshow(np.abs(real_mf[value, ...] - output_mf[value, ...]), origin='lower',
                           extent=[0,1,0,1], interpolation='Gaussian', cmap='nipy_spectral', vmin=0, vmax=0.1)
        xf = np.array([0., xmax/2]); yf = xf*(ymax/(xmax/2)); 
        ax[4,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([xmax/2, xmax]); yf = (xf-xmax)*(ymax/((xmax/2)-xmax)); 
        ax[4,index].fill_between(xf, yf, ymax, color = [1, 1, 1])
        xf = np.array([0, xmax]); ax[4,index].fill_between(xf, ymax, s, color = [1, 1, 1])        
        ax[4,index].add_patch(Rectangle((0.5,0),0.02,0.4, facecolor='white'))
        plt.colorbar(im, ax=ax[4,index], fraction=0.2)
        if index == 0:
            ax[4,index].set_ylabel('Error-HFWNO', color='blue')
            
        index += 1

figure1.savefig('Prediction_Darcy.pdf', format='pdf', dpi=200, bbox_inches='tight')  


In [None]:
scipy.io.savemat('data/mfwno_darcy_n50.mat', mdict={'real_mf':real_mf.cpu().numpy(), 
                                                    'pred_hf':pred_hf.cpu().numpy(),
                                                    'output_mf':output_mf.cpu().numpy()})
