# It performs Reliability analysis of 1D Stochastic Poisson's equation using MFWNO (time-independent reliability).
### HF data size = 50

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer
from pytorch_wavelets import DWT1D, IDWT1D

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# WNO

In [None]:
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet 
        self.dwt_ = DWT1D(wave=self.wavelet, J=self.level, mode='zero')
        dummy_data = torch.randn( 1,1,size ) 
        mode_data, _ = self.dwt_(dummy_data)
        self.modes1 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet     
        dwt = DWT1D(wave=self.wavelet, J=self.level, mode='zero').to(x.device)
        x_ft, x_coeff = dwt(x)
        
        # Multiply the final low pass wavelet coefficients
        out_ft = self.mul1d(x_ft, self.weights1)
        # Multiply the final high pass wavelet coefficients
        x_coeff[-1] = self.mul1d(x_coeff[-1].clone(), self.weights2)
        
        # Reconstruct the signal
        idwt = IDWT1D(wave=self.wavelet, mode='zero').to(x.device)
        x = idwt((out_ft, x_coeff)) 
        return x

In [None]:
class WNO1d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class WNO1d_linear(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d_linear, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class MFWNO(nn.Module):
  def __init__(self, width, level, size, wavelet, in_channel, grid_range):
    super(MFWNO, self).__init__()
    
    self.width = width
    self.level = level 
    self.size = size
    self.wavelet = wavelet 
    self.in_channel = in_channel 
    self.grid_range = grid_range
    
    self.conv1 = WNO1d_linear(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.conv2 = WNO1d(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.fc0 = nn.Linear(1,12)
    self.fc1 = nn.Linear(12,1)

  def forward(self, x):
    x = self.conv1(x) + self.conv2(x)
    x = self.fc0(x)
    x = F.gelu(x)
    x = self.fc1(x)
    return x


# Multifidelity

In [None]:
ntrain = 2000
ntest = 2000
n_total = ntrain + ntest
last_m = 400
s = 100

batch_size = 100
learning_rate = 0.001

w_decay = 1e-4
epochs = 500
step_size = 50   # weight-decay step size
gamma = 0.5      # weight-decay rate

wavelet = 'db6'  # wavelet basis function
level = 3        # lavel of wavelet decomposition
width = 64       # uplifting dimension
layers = 4       # no of wavelet layers

h = 100           # total grid size divided by the subsampling rate
grid_range = 1
in_channel = 3   # (a(x), x) for this case


In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'][:ntest, ...]
y_data_l = data['y_low_100'][:ntest, ...]
y_data_h = data['yhi'][:ntest, ...]
x_coords = data['xhi'].reshape((s,))

In [None]:
# read data

x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l

x_train_mf, y_train_mf = x_mf[:ntrain, ...], y_mf[:ntrain, ...]
x_test_mf, y_test_mf = x_mf[:ntest, ...], y_mf[:ntest, ...]

x_train_mf = torch.tensor( x_train_mf, dtype=torch.float )
y_train_mf = torch.tensor( y_train_mf, dtype=torch.float ) 
x_test_mf = torch.tensor( x_test_mf, dtype=torch.float ) 
y_test_mf = torch.tensor( y_test_mf, dtype=torch.float ) 

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_mf.shape, y_mf.shape, x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
# model
model_mf = torch.load('model/MF_WNO_poisson1D_50', map_location=device)
print(count_params(model_mf))

myloss = LpLoss(size_average=False)


In [None]:
# Prediction:
pred_mf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)
        t1 = default_timer()
        out = model_mf(x).squeeze(-1)
        t2 = default_timer()
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_mf.append( out.cpu() )
        print("Batch-{}, Time={:0.4f}, Test-loss-{:0.6f}".format( index, t2-t1, test_l2 ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
    

In [None]:
pred_mf.shape

In [None]:
inp_mf  = x_test_mf 
real_mf = y_test_mf + inp_mf[:,:,1]
output_mf  =  pred_mf + inp_mf[:,:,1]

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()
error = (output_mf - real_mf)**2
error_mean = torch.mean(error)
error_std = torch.std(error)

print('MSE-Predicted solution-{:0.6f}, mean-{:0.4f}, std-{:0.4f}'.format(mse_pred, error_mean, error_std))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 500 == 0:
        plt.plot(x_coords, real_mf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_mf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# High Fidelity

In [None]:
# read data
x_hf = torch.tensor( x_data_h, dtype=torch.float ).unsqueeze(-1)
y_hf = torch.tensor( y_data_h, dtype=torch.float ) 

x_train_hf, y_train_hf = x_hf[:ntrain, ...], y_hf[:ntrain, ...]
x_test_hf, y_test_hf = x_hf[-ntest:, ...], y_hf[-ntest:, ...]

train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                              batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, x_test_hf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_train_hf.shape, y_train_hf.shape, x_test_hf.shape, y_test_hf.shape)

In [None]:
# model
model_hf = torch.load('model/HF_WNO_poisson1D_50', map_location=device)
print(count_params(model_hf))

myloss = LpLoss(size_average=False)


In [None]:
# Prediction:
pred_hf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)
        t1 = default_timer()
        out = model_hf(x[:,:,0:1]).squeeze(-1)
        t2 = default_timer()
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_hf.append( out.cpu() )
        print("Batch-{}, Time-{:0.4f}, Test-loss-{:0.6f}".format( index, t2-t1, test_l2 ))
        index += 1

pred_hf = torch.cat(( pred_hf ))
# print('Mean mse_mf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))


In [None]:
inp = x_test_mf
real_hf = y_test_mf + inp[:,:,1] 
output_hf = pred_hf

In [None]:
mse_pred_hf = F.mse_loss(output_hf, real_hf).item()
error = (output_hf - real_hf)**2
error_mean = torch.mean(error)
error_std = torch.std(error)

print('MSE-Predicted solution-{:0.6f}, mean-{:0.4f}, std-{:0.4f}'.format(mse_pred_hf, error_mean, error_std))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 500 == 0:
        plt.plot(x_coords, real_hf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_hf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
colormap = plt.cm.jet  
colors2 = [colormap(i) for i in np.linspace(0, 1, 5)]

fig1 = plt.figure(figsize = (10, 4), dpi=300)
fig1.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 500 == 0:
        plt.plot(x_coords, x_data_h[i, :], color=colors2[index], label='Forcing-{}'.format(i))
        index += 1
plt.legend(ncol=5, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# First passage failure

In [None]:
# %%
eh = 4.5
eh_mcs = np.zeros(ntest)
eh_wno_mf = np.zeros(ntest)
eh_wno_hf = np.zeros(ntest)

for i in range(ntest):
    if len( np.where( real_mf[i, ...] > eh )[0] ) == 0:
        eh_mcs[i] = 0
    else:
        eh_mcs[i] = 1
        
for i in range(ntest):
    if len( np.where( output_mf[i, ...] > eh )[0] ) == 0:
        eh_wno_mf[i] = 0
    else:
        eh_wno_mf[i] = 1
        
for i in range(ntest):
    if len( np.where( pred_hf[i, ...] > eh )[0] ) == 0:
        eh_wno_hf[i] = 0
    else:
        eh_wno_hf[i] = 1

pf_wno_mf = len(np.where(eh_wno_mf!=0)[0])/ntest
pf_wno_hf = len(np.where(eh_wno_hf!=0)[0])/ntest
pf_mcs = len(np.where(eh_mcs!=0)[0])/ntest
print('Prob. of failure, MFWNO-{}, HFWNO-{}, MCS-{}'.format(pf_wno_mf, pf_wno_hf, pf_mcs))


# Plotting

In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

fig5, axs = plt.subplots(2, 2,figsize=(14,8), dpi=100)
plt.subplots_adjust(hspace=0.35, wspace=0.15)
axs = axs.flatten()

n0 = 50
axs[0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0].plot(x_coords,inp_mf[n0,:,0], linestyle=':', color='blue', lw=2)
axs[0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0].legend(['HF-Truth ($u(x)$)','Force ($g(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{L}(x)$)'],
              columnspacing=0.4, handletextpad=0.5, ncol=2)
axs[0].margins(0)
axs[0].grid(True, alpha=0.3)
axs[0].set_ylabel('$g(x)$ / $u(x)$')
axs[0].set_xlabel('Space ($x$)')
axs[0].set_title('(a) Sample-{}'.format(n0))

n1 = 500
axs[1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[1].plot(x_coords,inp_mf[n1,:,0], linestyle=':', color='blue', lw=2)
axs[1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[1].legend(['HF-Truth ($u(x)$)','Force ($g(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{L}(x)$)'],
             columnspacing=0.4, handletextpad=0.5, ncol=2, loc=4)
axs[1].margins(0)
axs[1].grid(True, alpha=0.3)
axs[1].set_ylabel('$g(x)$ / $u(x)$')
axs[1].set_xlabel('Space ($x$)')
axs[1].set_title('(b) Sample-{}'.format(n1))

n2 = 1000
axs[2].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[2].plot(x_coords,inp_mf[n2,:,0], linestyle=':', color='blue', lw=2)
axs[2].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[2].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[2].legend(['HF-Truth ($u(x)$)','Force ($g(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{L}(x)$)'], 
              columnspacing=0.4, handletextpad=0.5, ncol=2, loc=4)
axs[2].margins(0)
axs[2].grid(True, alpha=0.3)
axs[2].set_ylabel('$g(x)$ / $u(x)$')
axs[2].set_xlabel('Space ($x$)')
axs[2].set_title('(c) Sample-{}'.format(n2))

n3 = 1500
axs[3].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[3].plot(x_coords,inp_mf[n3,:,0], linestyle=':', color='blue', lw=2)
axs[3].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[3].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[3].legend(['HF-Truth ($u(x)$)','Force ($g(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{L}(x)$)'], 
              columnspacing=0.4, handletextpad=0.5, ncol=2, loc=1)
axs[3].margins(0)
axs[3].grid(True, alpha=0.3)
axs[3].set_ylabel('$g(x)$ / $u(x)$')
axs[3].set_xlabel('Space ($x$)')
axs[3].set_title('(d) Sample-{}'.format(n3))

# fig5.savefig('pred_poisson.pdf', format='pdf', dpi=600, bbox_inches='tight')


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

fig6, axs = plt.subplots(2, 2,figsize=(14,8), dpi=100)
plt.subplots_adjust(hspace=0.35, wspace=0.15)
axs = axs.flatten()

n0 = 50
axs[0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0].legend(['HF-Truth ($u(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{'r'\mathrm{MF}}(x)$)'],
              columnspacing=0.4, handletextpad=0.5, ncol=1, loc=2)
axs[0].margins(0)
axs[0].grid(True, alpha=0.3)
axs[0].set_ylabel('$g(x)$ / $u(x)$')
axs[0].set_xlabel('Space ($x$)')
axs[0].set_title('(a) Sample-{}'.format(n0))

n1 = 499
axs[1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[1].legend(['HF-Truth ($u(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{'r'\mathrm{MF}}(x)$)'],
             columnspacing=0.4, handletextpad=0.5, ncol=1, loc=1)
axs[1].margins(0)
axs[1].grid(True, alpha=0.3)
axs[1].set_ylabel('$g(x)$ / $u(x)$')
axs[1].set_xlabel('Space ($x$)')
axs[1].set_title('(b) Sample-{}'.format(n1+1))

n2 = 999
axs[2].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[2].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[2].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[2].legend(['HF-Truth ($u(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{'r'\mathrm{MF}}(x)$)'], 
              columnspacing=0.4, handletextpad=0.5, ncol=1)
axs[2].margins(0)
axs[2].grid(True, alpha=0.3)
axs[2].set_ylabel('$g(x)$ / $u(x)$')
axs[2].set_xlabel('Space ($x$)')
axs[2].set_title('(c) Sample-{}'.format(n2+1))

n3 = 1499
axs[3].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[3].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[3].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[3].legend(['HF-Truth ($u(x)$)','HF-WNO ($u^{H}(x)$)','MF-WNO ($u^{'r'\mathrm{MF}}(x)$)'], 
              columnspacing=0.4, handletextpad=0.5, ncol=1, loc=4)
axs[3].margins(0)
axs[3].grid(True, alpha=0.3)
axs[3].set_ylabel('$g(x)$ / $u(x)$')
axs[3].set_xlabel('Space ($x$)')
axs[3].set_title('(d) Sample-{}'.format(n3+1))

fig6.savefig('pred_poisson.pdf', format='pdf', dpi=600, bbox_inches='tight')


In [None]:
scipy.io.savemat('data/mfwno_poissons_n50.mat', mdict={'real_mf':real_mf.cpu().numpy(), 
                                                       'output_hf':output_hf.cpu().numpy(),
                                                       'output_mf':output_mf.cpu().numpy(),
                                                       'x_coords':x_coords})
