# It trains MFWNO on the MF 1D Poisson's data (time-independent problem).
### HF data size = 50

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer
from pytorch_wavelets import DWT1D, IDWT1D

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# WNO

In [None]:
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet 
        self.dwt_ = DWT1D(wave=self.wavelet, J=self.level, mode='zero')
        dummy_data = torch.randn( 1,1,size ) 
        mode_data, _ = self.dwt_(dummy_data)
        self.modes1 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet     
        dwt = DWT1D(wave=self.wavelet, J=self.level, mode='zero').to(x.device)
        x_ft, x_coeff = dwt(x)
        
        # Multiply the final low pass wavelet coefficients
        out_ft = self.mul1d(x_ft, self.weights1)
        # Multiply the final high pass wavelet coefficients
        x_coeff[-1] = self.mul1d(x_coeff[-1].clone(), self.weights2)
        
        # Reconstruct the signal
        idwt = IDWT1D(wave=self.wavelet, mode='zero').to(x.device)
        x = idwt((out_ft, x_coeff)) 
        return x

In [None]:
class WNO1d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class WNO1d_linear(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d_linear, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class MFWNO(nn.Module):
  def __init__(self, width, level, size, wavelet, in_channel, grid_range):
    super(MFWNO, self).__init__()
    
    self.width = width
    self.level = level 
    self.size = size
    self.wavelet = wavelet 
    self.in_channel = in_channel 
    self.grid_range = grid_range
    
    self.conv1 = WNO1d_linear(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.conv2 = WNO1d(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.fc0 = nn.Linear(1,12)
    self.fc1 = nn.Linear(12,1)

  def forward(self, x):
    x = self.conv1(x) + self.conv2(x)
    x = self.fc0(x)
    x = F.gelu(x)
    x = self.fc1(x)
    return x


# Multifidelity

In [None]:
ntrain = 50
ntest = 40
n_total = ntrain + ntest
last_m = 400
s = 100

batch_size = 5
learning_rate = 0.001

w_decay = 1e-4
epochs = 500
step_size = 50   # weight-decay step size
gamma = 0.5      # weight-decay rate

wavelet = 'db6'  # wavelet basis function
level = 3        # lavel of wavelet decomposition
width = 64       # uplifting dimension
layers = 4       # no of wavelet layers

h = 100           # total grid size divided by the subsampling rate
grid_range = 1
in_channel = 3   # (a(x), x) for this case


In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'][last_m-n_total:last_m]
y_data_l = data['y_low_100'][last_m-n_total:last_m]
y_data_h = data['yhi'][last_m-n_total:last_m]
x_coords = data['xhi'].reshape((s,))

In [None]:
 data['f_stoch'].shape

In [None]:
# read data

x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l

x_train_mf, y_train_mf = x_mf[:ntrain, ...], y_mf[:ntrain, ...]
x_test_mf, y_test_mf = x_mf[-ntest:, ...], y_mf[-ntest:, ...]

x_train_mf = torch.tensor( x_train_mf, dtype=torch.float )
y_train_mf = torch.tensor( y_train_mf, dtype=torch.float ) 
x_test_mf = torch.tensor( x_test_mf, dtype=torch.float ) 
y_test_mf = torch.tensor( y_test_mf, dtype=torch.float ) 

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_mf.shape, y_mf.shape, x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
# model
model_mf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=3, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_mf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    test_mse = 0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.cuda(), y.cuda()

            out = model_mf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            test_mse += tmse.item()

    train_mse /= len(train_loader_mf)
    train_l2 /= ntrain
    test_l2 /= ntest
    test_mse /= len(test_loader_mf)
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}, Test-MSE-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2, test_mse))


Epoch-92, Time-0.4096, Train-MSE-0.0000, Train-L2-0.0764, Test-L2-0.5392, Test-MSE-0.0005
Epoch-93, Time-0.4097, Train-MSE-0.0000, Train-L2-0.0835, Test-L2-0.5352, Test-MSE-0.0005
Epoch-94, Time-0.3971, Train-MSE-0.0000, Train-L2-0.0920, Test-L2-0.5376, Test-MSE-0.0005
Epoch-95, Time-0.3863, Train-MSE-0.0000, Train-L2-0.1037, Test-L2-0.5341, Test-MSE-0.0005
Epoch-96, Time-0.4043, Train-MSE-0.0000, Train-L2-0.0934, Test-L2-0.5321, Test-MSE-0.0005
Epoch-97, Time-0.4085, Train-MSE-0.0000, Train-L2-0.0965, Test-L2-0.5367, Test-MSE-0.0005
Epoch-98, Time-0.3920, Train-MSE-0.0000, Train-L2-0.0832, Test-L2-0.5333, Test-MSE-0.0005
Epoch-99, Time-0.3790, Train-MSE-0.0000, Train-L2-0.0886, Test-L2-0.5240, Test-MSE-0.0005
Epoch-100, Time-0.3953, Train-MSE-0.0000, Train-L2-0.0734, Test-L2-0.5277, Test-MSE-0.0005
Epoch-101, Time-0.3997, Train-MSE-0.0000, Train-L2-0.0618, Test-L2-0.5284, Test-MSE-0.0005
Epoch-102, Time-0.3885, Train-MSE-0.0000, Train-L2-0.0561, Test-L2-0.5308, Test-MSE-0.0005
Epoch-1

Epoch-183, Time-0.3843, Train-MSE-0.0000, Train-L2-0.0231, Test-L2-0.5308, Test-MSE-0.0005
Epoch-184, Time-0.4004, Train-MSE-0.0000, Train-L2-0.0232, Test-L2-0.5298, Test-MSE-0.0005
Epoch-185, Time-0.3851, Train-MSE-0.0000, Train-L2-0.0226, Test-L2-0.5314, Test-MSE-0.0005
Epoch-186, Time-0.3725, Train-MSE-0.0000, Train-L2-0.0215, Test-L2-0.5311, Test-MSE-0.0005
Epoch-187, Time-0.3624, Train-MSE-0.0000, Train-L2-0.0276, Test-L2-0.5302, Test-MSE-0.0005
Epoch-188, Time-0.3811, Train-MSE-0.0000, Train-L2-0.0263, Test-L2-0.5292, Test-MSE-0.0005
Epoch-189, Time-0.3791, Train-MSE-0.0000, Train-L2-0.0255, Test-L2-0.5324, Test-MSE-0.0005
Epoch-190, Time-0.3959, Train-MSE-0.0000, Train-L2-0.0238, Test-L2-0.5288, Test-MSE-0.0005
Epoch-191, Time-0.3768, Train-MSE-0.0000, Train-L2-0.0239, Test-L2-0.5330, Test-MSE-0.0005
Epoch-192, Time-0.3790, Train-MSE-0.0000, Train-L2-0.0221, Test-L2-0.5272, Test-MSE-0.0005
Epoch-193, Time-0.3743, Train-MSE-0.0000, Train-L2-0.0224, Test-L2-0.5327, Test-MSE-0.0005

Epoch-274, Time-0.3915, Train-MSE-0.0000, Train-L2-0.0097, Test-L2-0.5307, Test-MSE-0.0005
Epoch-275, Time-0.4025, Train-MSE-0.0000, Train-L2-0.0097, Test-L2-0.5302, Test-MSE-0.0005
Epoch-276, Time-0.3775, Train-MSE-0.0000, Train-L2-0.0095, Test-L2-0.5308, Test-MSE-0.0005
Epoch-277, Time-0.3908, Train-MSE-0.0000, Train-L2-0.0094, Test-L2-0.5305, Test-MSE-0.0005
Epoch-278, Time-0.3947, Train-MSE-0.0000, Train-L2-0.0093, Test-L2-0.5306, Test-MSE-0.0005
Epoch-279, Time-0.4010, Train-MSE-0.0000, Train-L2-0.0093, Test-L2-0.5306, Test-MSE-0.0005
Epoch-280, Time-0.3705, Train-MSE-0.0000, Train-L2-0.0092, Test-L2-0.5303, Test-MSE-0.0005
Epoch-281, Time-0.3902, Train-MSE-0.0000, Train-L2-0.0092, Test-L2-0.5305, Test-MSE-0.0005
Epoch-282, Time-0.3951, Train-MSE-0.0000, Train-L2-0.0093, Test-L2-0.5304, Test-MSE-0.0005
Epoch-283, Time-0.3915, Train-MSE-0.0000, Train-L2-0.0093, Test-L2-0.5306, Test-MSE-0.0005
Epoch-284, Time-0.3754, Train-MSE-0.0000, Train-L2-0.0093, Test-L2-0.5302, Test-MSE-0.0005

Epoch-365, Time-0.3707, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5305, Test-MSE-0.0005
Epoch-366, Time-0.3885, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5305, Test-MSE-0.0005
Epoch-367, Time-0.3986, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5306, Test-MSE-0.0005
Epoch-368, Time-0.3784, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5306, Test-MSE-0.0005
Epoch-369, Time-0.3695, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5305, Test-MSE-0.0005
Epoch-370, Time-0.3849, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5306, Test-MSE-0.0005
Epoch-371, Time-0.3763, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5305, Test-MSE-0.0005
Epoch-372, Time-0.3730, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5305, Test-MSE-0.0005
Epoch-373, Time-0.4024, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-0.5306, Test-MSE-0.0005
Epoch-374, Time-0.4010, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-0.5305, Test-MSE-0.0005
Epoch-375, Time-0.3908, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-0.5306, Test-MSE-0.0005

Epoch-456, Time-0.3996, Train-MSE-0.0000, Train-L2-0.0072, Test-L2-0.5306, Test-MSE-0.0005
Epoch-457, Time-0.3962, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-458, Time-0.4108, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-459, Time-0.4035, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-460, Time-0.4128, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-461, Time-0.3948, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-462, Time-0.4189, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-463, Time-0.4043, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-464, Time-0.4075, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-465, Time-0.4158, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005
Epoch-466, Time-0.4050, Train-MSE-0.0000, Train-L2-0.0071, Test-L2-0.5306, Test-MSE-0.0005

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_poisson1D_50')

In [None]:
# Prediction:
pred_mf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_mf(x).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_mf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
    

In [None]:
pred_mf.shape

In [None]:
inp_mf  = x_test_mf 
real_mf = y_test_mf + inp_mf[:,:,1]
output_mf  =  pred_mf + inp_mf[:,:,1]

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()

print('MSE-Predicted solution-{:0.6f}'.format(mse_pred))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_mf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_mf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
colormap = plt.cm.jet  
colors2 = [colormap(i) for i in np.linspace(0, 1, 5)]

fig1 = plt.figure(figsize = (10, 4), dpi=300)
fig1.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, inp_mf[i, :, 0], color=colors2[index], label='Forcing-{}'.format(i))
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# High Fidelity

In [None]:
# read data
x_hf = torch.tensor( x_data_h, dtype=torch.float ).unsqueeze(-1)
y_hf = torch.tensor( y_data_h, dtype=torch.float ) 

x_train_hf, y_train_hf = x_hf[:ntrain, ...], y_hf[:ntrain, ...]
x_test_hf, y_test_hf = x_hf[-ntest:, ...], y_hf[-ntest:, ...]

train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                              batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, x_test_hf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_train_hf.shape, y_train_hf.shape, x_test_hf.shape, y_test_hf.shape)

In [None]:
# model
model_hf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=2, grid_range=grid_range).to(device)
print(count_params(model_hf))

optimizer = torch.optim.Adam(model_hf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_hf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_hf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_hf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.cuda(), y.cuda()

            out = model_hf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader_hf)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))


Epoch-113, Time-0.4042, Train-MSE-0.0017, Train-L2-0.0450, Test-L2-1.5263
Epoch-114, Time-0.3791, Train-MSE-0.0015, Train-L2-0.0417, Test-L2-1.5394
Epoch-115, Time-0.4040, Train-MSE-0.0012, Train-L2-0.0342, Test-L2-1.5465
Epoch-116, Time-0.4003, Train-MSE-0.0011, Train-L2-0.0340, Test-L2-1.5406
Epoch-117, Time-0.3992, Train-MSE-0.0012, Train-L2-0.0415, Test-L2-1.5569
Epoch-118, Time-0.4027, Train-MSE-0.0014, Train-L2-0.0474, Test-L2-1.5358
Epoch-119, Time-0.3988, Train-MSE-0.0016, Train-L2-0.0484, Test-L2-1.5553
Epoch-120, Time-0.4033, Train-MSE-0.0016, Train-L2-0.0476, Test-L2-1.5487
Epoch-121, Time-0.4037, Train-MSE-0.0013, Train-L2-0.0431, Test-L2-1.5336
Epoch-122, Time-0.4123, Train-MSE-0.0012, Train-L2-0.0385, Test-L2-1.5431
Epoch-123, Time-0.4031, Train-MSE-0.0011, Train-L2-0.0347, Test-L2-1.5376
Epoch-124, Time-0.4142, Train-MSE-0.0014, Train-L2-0.0417, Test-L2-1.5397
Epoch-125, Time-0.4132, Train-MSE-0.0012, Train-L2-0.0450, Test-L2-1.5394
Epoch-126, Time-0.3953, Train-MSE-0.00

Epoch-224, Time-0.3995, Train-MSE-0.0002, Train-L2-0.0113, Test-L2-1.5466
Epoch-225, Time-0.4111, Train-MSE-0.0002, Train-L2-0.0114, Test-L2-1.5464
Epoch-226, Time-0.3956, Train-MSE-0.0002, Train-L2-0.0112, Test-L2-1.5478
Epoch-227, Time-0.3909, Train-MSE-0.0002, Train-L2-0.0118, Test-L2-1.5489
Epoch-228, Time-0.3820, Train-MSE-0.0002, Train-L2-0.0126, Test-L2-1.5444
Epoch-229, Time-0.3782, Train-MSE-0.0002, Train-L2-0.0132, Test-L2-1.5481
Epoch-230, Time-0.3898, Train-MSE-0.0002, Train-L2-0.0162, Test-L2-1.5438
Epoch-231, Time-0.3686, Train-MSE-0.0002, Train-L2-0.0151, Test-L2-1.5497
Epoch-232, Time-0.3726, Train-MSE-0.0002, Train-L2-0.0140, Test-L2-1.5463
Epoch-233, Time-0.3718, Train-MSE-0.0002, Train-L2-0.0126, Test-L2-1.5459
Epoch-234, Time-0.3692, Train-MSE-0.0002, Train-L2-0.0136, Test-L2-1.5470
Epoch-235, Time-0.3981, Train-MSE-0.0002, Train-L2-0.0128, Test-L2-1.5465
Epoch-236, Time-0.3437, Train-MSE-0.0002, Train-L2-0.0129, Test-L2-1.5455
Epoch-237, Time-0.3361, Train-MSE-0.00

Epoch-335, Time-0.4042, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5471
Epoch-336, Time-0.4022, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5471
Epoch-337, Time-0.4082, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5471
Epoch-338, Time-0.4019, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5469
Epoch-339, Time-0.4126, Train-MSE-0.0001, Train-L2-0.0079, Test-L2-1.5471
Epoch-340, Time-0.4154, Train-MSE-0.0001, Train-L2-0.0079, Test-L2-1.5473
Epoch-341, Time-0.4315, Train-MSE-0.0001, Train-L2-0.0080, Test-L2-1.5466
Epoch-342, Time-0.4036, Train-MSE-0.0001, Train-L2-0.0080, Test-L2-1.5474
Epoch-343, Time-0.3925, Train-MSE-0.0001, Train-L2-0.0079, Test-L2-1.5469
Epoch-344, Time-0.4056, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5471
Epoch-345, Time-0.4170, Train-MSE-0.0001, Train-L2-0.0077, Test-L2-1.5471
Epoch-346, Time-0.4101, Train-MSE-0.0001, Train-L2-0.0078, Test-L2-1.5472
Epoch-347, Time-0.4152, Train-MSE-0.0001, Train-L2-0.0077, Test-L2-1.5470
Epoch-348, Time-0.4158, Train-MSE-0.00

Epoch-446, Time-0.4012, Train-MSE-0.0001, Train-L2-0.0072, Test-L2-1.5471
Epoch-447, Time-0.4144, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-448, Time-0.3824, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5471
Epoch-449, Time-0.3856, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-450, Time-0.4274, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5471
Epoch-451, Time-0.4238, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-452, Time-0.4119, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-453, Time-0.4127, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5471
Epoch-454, Time-0.3836, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-455, Time-0.3922, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-456, Time-0.4035, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5471
Epoch-457, Time-0.3999, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-458, Time-0.3903, Train-MSE-0.0001, Train-L2-0.0071, Test-L2-1.5472
Epoch-459, Time-0.3764, Train-MSE-0.00

In [None]:
# Save the HF-WNO model

torch.save(model_hf, 'model/HF_WNO_poisson1D_50')

In [None]:
# Prediction:
pred_hf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_hf(x[:,:,0:1]).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_hf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_hf = torch.cat(( pred_hf ))
# print('Mean mse_mf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))


In [None]:
inp = x_test_mf
real_hf = y_test_mf + inp[:,:,1] 
output_hf = pred_hf

In [None]:
mse_pred_hf = F.mse_loss(output_hf, real_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_hf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_hf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
fig5, axs = plt.subplots(2, 2,figsize=(18,10), dpi=100)
plt.subplots_adjust(wspace=0.25)

n0 = 15
axs[0, 0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0, 0].plot(x_coords,inp_mf[n0,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0, 0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 0].margins(0)
axs[0, 0].grid(True, alpha=0.3)

n1 = 28
axs[0, 1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[0, 1].plot(x_coords,inp_mf[n1,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[0, 1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[0, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 1].margins(0)
axs[0, 1].grid(True, alpha=0.3)

n2 = 37
axs[1, 0].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[1, 0].plot(x_coords,inp_mf[n2,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 0].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[1, 0].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[1, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 0].margins(0)
axs[1, 0].grid(True, alpha=0.3)

n3 = 38
axs[1, 1].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[1, 1].plot(x_coords,inp_mf[n3,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 1].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[1, 1].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[1, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 1].margins(0)
axs[1, 1].grid(True, alpha=0.3)
