# It performs Reliability analysis of MF Allen-Cahn equation using MFWNO (2D time-dependent reliability analysis)
### HF data size = 200, with 50 times step = 10000

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import seaborn as sns
from pandas import DataFrame as pdf
import matplotlib.pyplot as plt
from utils import *

from timeit import default_timer
from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT)
from pytorch_wavelets import DTCWTForward, DTCWTInverse

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# WNO

In [None]:
class WaveConv2dCwt(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet1, wavelet2):
        super(WaveConv2dCwt, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        !! It is computationally expensive than the discrete "WaveConv2d" !!
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet_level1 = wavelet1
        self.wavelet_level2 = wavelet2        
        dummy_data = torch.randn( 1,1,*size ) 
        dwt_ = DTCWTForward(J=self.level, biort=self.wavelet_level1,
                            qshift=self.wavelet_level2)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        self.modes21 = mode_coef[-1].shape[-3]
        self.modes22 = mode_coef[-1].shape[-2]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights0 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights15r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights15c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """        
        # Compute dual tree continuous Wavelet coefficients 
        cwt = DTCWTForward(J=self.level, biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x_ft, x_coeff = cwt(x)
        
        out_ft = torch.zeros_like(x_ft, device= x.device)
        out_coeff = [torch.zeros_like(coeffs, device= x.device) for coeffs in x_coeff]
        
        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights0)
        # Multiply the final detailed wavelet coefficients        
        out_coeff[-1][:,:,0,:,:,0] = self.mul2d(x_coeff[-1][:,:,0,:,:,0].clone(), self.weights15r)
        out_coeff[-1][:,:,0,:,:,1] = self.mul2d(x_coeff[-1][:,:,0,:,:,1].clone(), self.weights15c)
        out_coeff[-1][:,:,1,:,:,0] = self.mul2d(x_coeff[-1][:,:,1,:,:,0].clone(), self.weights45r)
        out_coeff[-1][:,:,1,:,:,1] = self.mul2d(x_coeff[-1][:,:,1,:,:,1].clone(), self.weights45c)
        out_coeff[-1][:,:,2,:,:,0] = self.mul2d(x_coeff[-1][:,:,2,:,:,0].clone(), self.weights75r)
        out_coeff[-1][:,:,2,:,:,1] = self.mul2d(x_coeff[-1][:,:,2,:,:,1].clone(), self.weights75c)
        out_coeff[-1][:,:,3,:,:,0] = self.mul2d(x_coeff[-1][:,:,3,:,:,0].clone(), self.weights105r)
        out_coeff[-1][:,:,3,:,:,1] = self.mul2d(x_coeff[-1][:,:,3,:,:,1].clone(), self.weights105c)
        out_coeff[-1][:,:,4,:,:,0] = self.mul2d(x_coeff[-1][:,:,4,:,:,0].clone(), self.weights135r)
        out_coeff[-1][:,:,4,:,:,1] = self.mul2d(x_coeff[-1][:,:,4,:,:,1].clone(), self.weights135c)
        out_coeff[-1][:,:,5,:,:,0] = self.mul2d(x_coeff[-1][:,:,5,:,:,0].clone(), self.weights165r)
        out_coeff[-1][:,:,5,:,:,1] = self.mul2d(x_coeff[-1][:,:,5,:,:,1].clone(), self.weights165c)
        
        # Return to physical space        
        icwt = DTCWTInverse(biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x = icwt((out_ft, out_coeff))
        return x


In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet1 = wavelet[0]
        self.wavelet2 = wavelet[1]
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.conv1 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.conv2 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding, 0,self.padding]) 
        
        # pdb.set_trace()
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding, :-self.padding]
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

# Training and Data

In [None]:
ntrain = 4000
ntest = 1000
epochs = 105
lst = 4000
batch_size = 250
side = 65

n_total = ntrain * 50
learning_rate = 0.001

step_size = 20
gamma = 0.5

wavelet = ['near_sym_a', 'qshift_a']  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 32       # uplifting dimension
s = side
grid_range = [1, 1]
in_channel = 4


In [None]:
path = 'data/ac2dlowhighres_1.mat'
reader = MatReader(path)
u_low = np.array(reader.read_field('ulr_nextstep'))
u_high = np.array(reader.read_field('uhr'))

In [None]:
print(u_low.shape, u_high.shape)

In [None]:
x_or_h = u_high[:,:-1,:,:].reshape(-1,s,s,1)
y_or_h = u_high[:,1:,:,:].reshape(-1,s,s)
y_or_l = u_low.reshape(-1,s,s,1)

print(x_or_h.shape, y_or_h.shape, y_or_l.shape)

del u_low, u_high

In [None]:
# Create the input and output (residual) dataset

x_mf = np.concatenate((x_or_h,y_or_l),axis=-1)
y_mf = y_or_h - y_or_l.reshape((n_total,s,s))

x_mf = torch.tensor( x_mf, dtype=torch.float ) 
y_mf = torch.tensor( y_mf, dtype=torch.float ) 

In [None]:
# Define the dataloaders

test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_mf, y_mf),
                                              batch_size=batch_size, shuffle=False)

# MF Model

In [None]:
# %%
""" The MD-WNO model definition """
model_mf = torch.load('model/MF_WNO_AC2D_10000samples')
print(count_params(model_mf))

myloss = LpLoss(size_average=False)


In [None]:
# Prediction:
pred_mf = [] 

with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        x, y = x.to(device), y.to(device)
        
        t1 = default_timer()
        out = model_mf(x).reshape(x.shape[0], s, s)
        t2 = default_timer()
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        
        test_l2 /= x.shape[0]
        if index % 25 == 0:
            print('Batch-{}, Time-{:0.4f}, Test-L2-{:0.4f}'.format(index, t2-t1, test_l2))
        pred_mf.append(out.cpu())
        index += 1

pred_mf = torch.cat(( pred_mf ), dim=0 )

In [None]:
# Add the residual operator to LF-dataset 

real_mf = y_mf + x_mf[..., 1]
output_mf = pred_mf + x_mf[..., 1]

real_mf = real_mf.reshape(ntrain, 50, s, s)
output_mf = output_mf.reshape(ntrain, 50, s, s)


In [None]:
print(real_mf.shape, output_mf.shape)

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()
mse_residual = F.mse_loss(y_mf, pred_mf)

print('MSE-Predicted solution-{:0.4f}, MSE-Residual-{:0.4f}'
      .format(mse_pred, mse_residual))


In [None]:
# Compute error statistics

error = (real_mf - output_mf)**2
mse_mean = torch.mean(error)
mse_std = torch.std(error)

print('MSE_mean-{}, MSE-std-{}'.format(mse_mean, mse_std))

In [None]:
fig10, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig10.subplots_adjust(hspace=0.35, wspace=0.2)

sample = 1000
index = 0 
for i in range(50):
    if i % 10 == 0:
        im = axs[0, index].imshow(real_mf[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_mf[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(real_mf[sample, i, :, :] - output_mf[sample, i, :, :]),
                                    cmap='jet')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

# High Fidelity

In [None]:
class WaveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv2d, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet       
        dummy_data = torch.randn( 1,1,*size )        
        dwt_ = DWT(J=self.level, mode='symmetric', wave=self.wavelet)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet
        dwt = DWT(J=self.level, mode='symmetric', wave=self.wavelet).to(x.device)
        x_ft, x_coeff = dwt(x)

        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft, self.weights1)
        # Multiply the final detailed wavelet coefficients
        x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2)
        x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3)
        x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4)
        
        # Return to physical space        
        idwt = IDWT(mode='symmetric', wave=self.wavelet).to(x.device)
        x = idwt((out_ft, x_coeff))
        return x

In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding, 0,self.padding]) 
        
        # pdb.set_trace()
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding, :-self.padding]
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    

In [None]:
wavelet = 'db4'  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 32       # uplifting dimension
s = side
grid_range = [1, 1]
in_channel = 3

In [None]:
print(x_or_h.shape, y_or_h.shape)

In [None]:
model = torch.load('model/HF_WNO_AC2D_10000samples')
print(count_params(model))

myloss = LpLoss(size_average=False)


In [None]:
# Predict on HF data using HF-WNO
pred_hf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        x, y = x.to(device), y.to(device)
        t1 = default_timer()
        out = model(x[..., 0:1]).reshape(x.shape[0], s, s)
        t2 = default_timer()

        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        
        test_l2 /= x.shape[0]
        if index % 25 == 0:
            print('Batch-{}, Time-{:0.4f}, Test-L2-{:0.4f}'.format(index, t2-t1, test_l2))
        
        pred_hf.append(out.cpu())
        index += 1

pred_hf = torch.cat(( pred_hf ), dim=0 )

# print('Mean mse_hf-{}'.format(F.mse_loss(y_hf, pred_hf).item()))

In [None]:
y_mf.shape

In [None]:
# Get the time shape

real_hf = (y_mf +  x_mf[:,:,:,0]).reshape(ntrain, 50, s, s)
output_hf = pred_hf.reshape(ntrain, 50, s, s)

In [None]:
print('Mean mse_hf-{}'.format(F.mse_loss(real_hf, output_hf).item()))

In [None]:
# Compute error statistics

error = (real_hf - output_hf)**2
mse_mean = torch.mean(error)
mse_std = torch.std(error)

print('MSE_mean-{}, MSE-std-{}'.format(mse_mean, mse_std))

In [None]:
fig5, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig5.subplots_adjust(hspace=0.35, wspace=0.2)

fig5.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
sample = 1000
index = 0 
for i in range(50):
    if i % 10 == 0:
        im = axs[0, index].imshow(real_hf[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_hf[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(real_hf[sample, i, :, :] - output_hf[sample, i, :, :]),
                                    cmap='jet')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

In [None]:
output_mf.shape

In [None]:
# First crossing estimate
T = 50    
eh = 1.01
eh_mcs = np.zeros((ntrain, T))
eh_wno_mf = np.zeros((ntrain, T))
eh_wno_hf = np.zeros((ntrain, T))

for i in range(ntrain):
    for j in range(1,T):   # Neglecting the initial condition
        if torch.sum( torch.abs(real_mf[i, j, :, :]) > eh ) > eh: 
            eh_mcs[i,j] = 1
        else: 
            eh_mcs[i,j] = 0
        
for i in range(ntrain): 
    for j in range(1,T):   # Neglecting the initial condition
        if torch.sum( torch.abs(output_mf[i, j, :, :]) > eh ) > eh: 
            eh_wno_mf[i,j] = 1
        else: 
            eh_wno_mf[i,j] = 0
            
for i in range(ntrain): 
    for j in range(1,T):   # Neglecting the initial condition
        if torch.sum( torch.abs(output_hf[i, j, :, :]) > eh ) > eh: 
            eh_wno_hf[i,j] = 1
        else: 
            eh_wno_hf[i,j] = 0
            

In [None]:
np.where(eh_mcs[i] > 0)

In [None]:
count_mcs = np.zeros(ntrain)
count_wno_mf = np.zeros(ntrain)
count_wno_hf = np.zeros(ntrain)
time_mcs = np.zeros(ntrain)
time_wno_mf = np.zeros(ntrain)
time_wno_hf = np.zeros(ntrain)

for i in range(ntrain):
    if len(np.where(eh_mcs[i] > 0)[0]) == 0:
        time_mcs[i] = T
        count_mcs[i] = 0
    else:
        time_mcs[i] = np.where(eh_mcs[i] > 0)[0][0]
        count_mcs[i] = 1

for i in range(ntrain):
    if len(np.where(eh_wno_mf[i] > 0)[0]) == 0:
        time_wno_mf[i] = T
        count_wno_mf[i] = 0 
    else:
        time_wno_mf[i] = np.where(eh_wno_mf[i] > 0)[0][0]
        count_wno_mf[i] = 1 
        
for i in range(ntrain):
    if len(np.where(eh_wno_hf[i] > 0)[0]) == 0:
        time_wno_hf[i] = T
        count_wno_hf[i] = 0 
    else:
        time_wno_hf[i] = np.where(eh_wno_hf[i] > 0)[0][0]
        count_wno_hf[i] = 1 

pf_mcs = len(np.where(count_mcs!=0)[0])/ntrain
pf_wno_mf = len(np.where(count_wno_mf!=0)[0])/ntrain
pf_wno_hf = len(np.where(count_wno_hf!=0)[0])/ntrain

print('Failure samples, MCS-{}, WNO-MF-{}, WNO-HF-{}'.format(count_mcs, count_wno_mf, count_wno_hf))
print('Prob. of failure, MCS-{}, WNO-MF-{}, WNO-HF-{}'.format(pf_mcs, pf_wno_mf, pf_wno_hf))


In [None]:
labels = ['Failure Probability']*ntrain 
df = pdf(data={'Methods':labels, 'MCS':time_mcs, 'MFWNO':time_wno_mf, 'HFWNO':time_wno_hf})

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 26

fig2, ax = plt.subplots(nrows=1, ncols=1, figsize=(12,8), dpi=600)
sns.kdeplot(data=df, x="MCS", fill=True, common_norm=False, color='cyan', alpha=0.25,
            linewidth=1, multiple="layer", bw_method=0.5, bw_adjust=1.5,
            label='MCS')
sns.kdeplot(data=df, x="MFWNO", fill=False, color='red', linestyle='-.',
            linewidth=4, multiple="layer", bw_method=0.41, bw_adjust=2.0,
            label='MF-WNO')
sns.kdeplot(data=df, x="HFWNO", fill=False, color='green', linestyle='--',
            linewidth=4, multiple="layer", bw_method=0.41, bw_adjust=2.0,
            label='HF-WNO')
ax.set_xlabel('First crossing time (s)')
ax.set_ylabel('PDF')
ax.set_xlim([-60, 120])
plt.legend(labelspacing=0.15)
plt.grid(True, alpha=0.35)

# fig2.savefig('pf_AC2d.pdf', format='pdf', dpi=600, bbox_inches='tight')


In [None]:
# Save the data:
scipy.io.savemat('data/results_MFWNO_AC_10000samples.mat', mdict={'time_mcs':time_mcs,
                                                    'time_wno_mf':time_wno_mf,
                                                    'time_wno_hf':time_wno_hf,
                                                    'count_mcs':count_mcs,
                                                    'count_wno_mf':count_wno_mf,
                                                    'count_wno_hf':count_wno_hf})


In [None]:
scipy.io.savemat('data/MFWNO_Allen_Cahn_n200.mat', mdict={'real_mf':real_mf.cpu().numpy(), 
                                                        'output_hf':output_hf.cpu().numpy(),
                                                        'output_mf':output_mf.cpu().numpy()})


In [None]:
real_mf.shape

In [None]:
plt.rcParams["font.family"] = "Serif"
plt.rcParams['font.size'] = 10

fig6, ax = plt.subplots(nrows=5, ncols=6, figsize=(14, 8))
plt.subplots_adjust(hspace=0.30, wspace=0.7)

sample = 1000
index = 1

im = ax[0,0].imshow(real_mf[sample,0,:,:], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='jet')
ax[0,0].set_title('Step-0')
ax[0,0].set_ylabel('IC-{}'.format(sample), color='r', fontsize=12)
plt.colorbar(im, ax=ax[0,0], orientation="vertical", fraction=0.046, pad=0.05)
ax[1,0].set_ylabel('HFWNO', labelpad=20, color='g', fontsize=12); 
ax[1,0].set(frame_on=False); ax[1,0].get_xaxis().set_ticks([]); ax[1,0].get_yaxis().set_ticks([])
ax[2,0].set_ylabel('Error-HFWNO', labelpad=20, color='g', fontsize=12); 
ax[2,0].set(frame_on=False); ax[2,0].get_xaxis().set_ticks([]); ax[2,0].get_yaxis().set_ticks([])
ax[3,0].set_ylabel('MFWNO', labelpad=20, color='b', fontsize=12); 
ax[3,0].set(frame_on=False); ax[3,0].get_xaxis().set_ticks([]); ax[3,0].get_yaxis().set_ticks([])
ax[4,0].set_ylabel('Error-MFWNO', labelpad=20, color='b', fontsize=12); 
ax[4,0].set(frame_on=False); ax[4,0].get_xaxis().set_ticks([]); ax[4,0].get_yaxis().set_ticks([])
            
for i in range(50):
    if i % 9 == 0 and i != 0:
        im = ax[0,index].imshow(real_mf[sample,i,:,:], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='jet')
        ax[0,index].set_title('Step-{}'.format(i+1));
        plt.colorbar(im, ax=ax[0,index], orientation="vertical", fraction=0.046, pad=0.05)

        im = ax[1,index].imshow(output_hf[sample,i,:,:], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='jet')
        plt.colorbar(im, ax=ax[1,index], orientation="vertical", fraction=0.046, pad=0.05)

        im = ax[2,index].imshow(np.abs(real_mf[sample,i,:,:] - output_hf[sample,i,:,:]), extent=[0,1,0,1],
                                interpolation='Gaussian', vmin=0, vmax=0.05, cmap='jet')
        plt.colorbar(im, ax=ax[2,index], orientation="vertical", fraction=0.046, pad=0.05)

        im = ax[3,index].imshow(output_mf[sample,i,:,:], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='jet')
        plt.colorbar(im, ax=ax[3,index], orientation="vertical", fraction=0.046, pad=0.05)

        im = ax[4,index].imshow(np.abs(real_mf[sample,i,:,:] - output_mf[sample,i,:,:]), extent=[0,1,0,1],
                                interpolation='Gaussian', vmin=0, vmax=0.05, cmap='jet')
        plt.colorbar(im, ax=ax[4,index], orientation="vertical", fraction=0.046, pad=0.05)
        index += 1

# fig6.savefig('Prediction_AC.pdf', format='pdf', dpi=100, bbox_inches='tight')  
