# It trains MFWNO on the MF 1D Poisson's data (time-independent problem).
### HF data size = 20

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer
from pytorch_wavelets import DWT1D, IDWT1D

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# WNO

In [None]:
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet 
        self.dwt_ = DWT1D(wave=self.wavelet, J=self.level, mode='zero')
        dummy_data = torch.randn( 1,1,size ) 
        mode_data, _ = self.dwt_(dummy_data)
        self.modes1 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet     
        dwt = DWT1D(wave=self.wavelet, J=self.level, mode='zero').to(x.device)
        x_ft, x_coeff = dwt(x)
        
        # Multiply the final low pass wavelet coefficients
        out_ft = self.mul1d(x_ft, self.weights1)
        # Multiply the final high pass wavelet coefficients
        x_coeff[-1] = self.mul1d(x_coeff[-1].clone(), self.weights2)
        
        # Reconstruct the signal
        idwt = IDWT1D(wave=self.wavelet, mode='zero').to(x.device)
        x = idwt((out_ft, x_coeff)) 
        return x

In [None]:
class WNO1d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class WNO1d_linear(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d_linear, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class MFWNO(nn.Module):
  def __init__(self, width, level, size, wavelet, in_channel, grid_range):
    super(MFWNO, self).__init__()
    
    self.width = width
    self.level = level 
    self.size = size
    self.wavelet = wavelet 
    self.in_channel = in_channel 
    self.grid_range = grid_range
    
    self.conv1 = WNO1d_linear(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.conv2 = WNO1d(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.fc0 = nn.Linear(1,12)
    self.fc1 = nn.Linear(12,1)

  def forward(self, x):
    x = self.conv1(x) + self.conv2(x)
    x = self.fc0(x)
    x = F.gelu(x)
    x = self.fc1(x)
    return x


# Multifidelity

In [None]:
ntrain = 20
ntest = 40
n_total = ntrain + ntest
last_m = 400
s = 100

batch_size = 5
learning_rate = 0.001

w_decay = 1e-4
epochs = 500
step_size = 50   # weight-decay step size
gamma = 0.5      # weight-decay rate

wavelet = 'db6'  # wavelet basis function
level = 3        # lavel of wavelet decomposition
width = 64       # uplifting dimension
layers = 4       # no of wavelet layers

h = 100           # total grid size divided by the subsampling rate
grid_range = 1
in_channel = 3   # (a(x), x) for this case


In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'][last_m-n_total:last_m]
y_data_l = data['y_low_100'][last_m-n_total:last_m]
y_data_h = data['yhi'][last_m-n_total:last_m]
x_coords = data['xhi'].reshape((s,))

In [None]:
 data['f_stoch'].shape

In [None]:
# read data

x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l

x_train_mf, y_train_mf = x_mf[:ntrain, ...], y_mf[:ntrain, ...]
x_test_mf, y_test_mf = x_mf[-ntest:, ...], y_mf[-ntest:, ...]

x_train_mf = torch.tensor( x_train_mf, dtype=torch.float )
y_train_mf = torch.tensor( y_train_mf, dtype=torch.float ) 
x_test_mf = torch.tensor( x_test_mf, dtype=torch.float ) 
y_test_mf = torch.tensor( y_test_mf, dtype=torch.float ) 

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_mf.shape, y_mf.shape, x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
# model
model_mf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=3, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_mf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    test_mse = 0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.cuda(), y.cuda()

            out = model_mf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            test_mse += tmse.item()

    train_mse /= len(train_loader_mf)
    train_l2 /= ntrain
    test_l2 /= ntest
    test_mse /= len(test_loader_mf)
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}, Test-MSE-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2, test_mse))


Epoch-92, Time-0.2539, Train-MSE-0.0000, Train-L2-0.1235, Test-L2-0.9298, Test-MSE-0.0014
Epoch-93, Time-0.2429, Train-MSE-0.0000, Train-L2-0.1272, Test-L2-0.9710, Test-MSE-0.0014
Epoch-94, Time-0.2514, Train-MSE-0.0000, Train-L2-0.0910, Test-L2-0.9379, Test-MSE-0.0014
Epoch-95, Time-0.2462, Train-MSE-0.0000, Train-L2-0.0902, Test-L2-0.9602, Test-MSE-0.0014
Epoch-96, Time-0.2387, Train-MSE-0.0000, Train-L2-0.0814, Test-L2-0.9540, Test-MSE-0.0014
Epoch-97, Time-0.2479, Train-MSE-0.0000, Train-L2-0.0735, Test-L2-0.9482, Test-MSE-0.0014
Epoch-98, Time-0.2423, Train-MSE-0.0000, Train-L2-0.0681, Test-L2-0.9591, Test-MSE-0.0014
Epoch-99, Time-0.2559, Train-MSE-0.0000, Train-L2-0.0757, Test-L2-0.9554, Test-MSE-0.0014
Epoch-100, Time-0.2437, Train-MSE-0.0000, Train-L2-0.0643, Test-L2-0.9592, Test-MSE-0.0014
Epoch-101, Time-0.2233, Train-MSE-0.0000, Train-L2-0.0565, Test-L2-0.9599, Test-MSE-0.0014
Epoch-102, Time-0.2187, Train-MSE-0.0000, Train-L2-0.0559, Test-L2-0.9483, Test-MSE-0.0014
Epoch-1

Epoch-183, Time-0.2498, Train-MSE-0.0000, Train-L2-0.0250, Test-L2-0.9675, Test-MSE-0.0014
Epoch-184, Time-0.2357, Train-MSE-0.0000, Train-L2-0.0225, Test-L2-0.9628, Test-MSE-0.0014
Epoch-185, Time-0.2432, Train-MSE-0.0000, Train-L2-0.0226, Test-L2-0.9644, Test-MSE-0.0014
Epoch-186, Time-0.2386, Train-MSE-0.0000, Train-L2-0.0220, Test-L2-0.9663, Test-MSE-0.0014
Epoch-187, Time-0.2373, Train-MSE-0.0000, Train-L2-0.0220, Test-L2-0.9631, Test-MSE-0.0014
Epoch-188, Time-0.2455, Train-MSE-0.0000, Train-L2-0.0236, Test-L2-0.9681, Test-MSE-0.0014
Epoch-189, Time-0.2466, Train-MSE-0.0000, Train-L2-0.0247, Test-L2-0.9639, Test-MSE-0.0014
Epoch-190, Time-0.2461, Train-MSE-0.0000, Train-L2-0.0223, Test-L2-0.9661, Test-MSE-0.0014
Epoch-191, Time-0.2460, Train-MSE-0.0000, Train-L2-0.0231, Test-L2-0.9651, Test-MSE-0.0014
Epoch-192, Time-0.2458, Train-MSE-0.0000, Train-L2-0.0263, Test-L2-0.9635, Test-MSE-0.0014
Epoch-193, Time-0.2476, Train-MSE-0.0000, Train-L2-0.0217, Test-L2-0.9647, Test-MSE-0.0014

Epoch-274, Time-0.2465, Train-MSE-0.0000, Train-L2-0.0105, Test-L2-0.9658, Test-MSE-0.0014
Epoch-275, Time-0.2537, Train-MSE-0.0000, Train-L2-0.0104, Test-L2-0.9667, Test-MSE-0.0014
Epoch-276, Time-0.2452, Train-MSE-0.0000, Train-L2-0.0104, Test-L2-0.9655, Test-MSE-0.0014
Epoch-277, Time-0.2411, Train-MSE-0.0000, Train-L2-0.0106, Test-L2-0.9662, Test-MSE-0.0014
Epoch-278, Time-0.2381, Train-MSE-0.0000, Train-L2-0.0106, Test-L2-0.9660, Test-MSE-0.0014
Epoch-279, Time-0.2480, Train-MSE-0.0000, Train-L2-0.0108, Test-L2-0.9666, Test-MSE-0.0014
Epoch-280, Time-0.2446, Train-MSE-0.0000, Train-L2-0.0110, Test-L2-0.9654, Test-MSE-0.0014
Epoch-281, Time-0.2481, Train-MSE-0.0000, Train-L2-0.0108, Test-L2-0.9667, Test-MSE-0.0014
Epoch-282, Time-0.2388, Train-MSE-0.0000, Train-L2-0.0108, Test-L2-0.9657, Test-MSE-0.0014
Epoch-283, Time-0.2393, Train-MSE-0.0000, Train-L2-0.0110, Test-L2-0.9667, Test-MSE-0.0014
Epoch-284, Time-0.2374, Train-MSE-0.0000, Train-L2-0.0108, Test-L2-0.9656, Test-MSE-0.0014

Epoch-365, Time-0.2398, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-366, Time-0.2537, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-367, Time-0.2525, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-368, Time-0.2428, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-369, Time-0.2340, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-370, Time-0.2431, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-371, Time-0.2486, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-372, Time-0.2569, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9668, Test-MSE-0.0014
Epoch-373, Time-0.2496, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-374, Time-0.2579, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014
Epoch-375, Time-0.2515, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.9667, Test-MSE-0.0014

Epoch-456, Time-0.2239, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-457, Time-0.2441, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-458, Time-0.2380, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-459, Time-0.2481, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-460, Time-0.2418, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-461, Time-0.2501, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-462, Time-0.2445, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-463, Time-0.2488, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-464, Time-0.2415, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-465, Time-0.2373, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014
Epoch-466, Time-0.2337, Train-MSE-0.0000, Train-L2-0.0087, Test-L2-0.9669, Test-MSE-0.0014

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_poisson1D_20')

In [None]:
# Prediction:
pred_mf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_mf(x).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_mf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
    

In [None]:
pred_mf.shape

In [None]:
inp_mf  = x_test_mf 
real_mf = y_test_mf + inp_mf[:,:,1]
output_mf  =  pred_mf + inp_mf[:,:,1]

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()

print('MSE-Predicted solution-{:0.6f}'.format(mse_pred))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_mf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_mf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
colormap = plt.cm.jet  
colors2 = [colormap(i) for i in np.linspace(0, 1, 5)]

fig1 = plt.figure(figsize = (10, 4), dpi=300)
fig1.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, inp_mf[i, :, 0], color=colors2[index], label='Forcing-{}'.format(i))
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# High Fidelity

In [None]:
# read data
x_hf = torch.tensor( x_data_h, dtype=torch.float ).unsqueeze(-1)
y_hf = torch.tensor( y_data_h, dtype=torch.float ) 

x_train_hf, y_train_hf = x_hf[:ntrain, ...], y_hf[:ntrain, ...]
x_test_hf, y_test_hf = x_hf[-ntest:, ...], y_hf[-ntest:, ...]

train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                              batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, x_test_hf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_train_hf.shape, y_train_hf.shape, x_test_hf.shape, y_test_hf.shape)

In [None]:
# model
model_hf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=2, grid_range=grid_range).to(device)
print(count_params(model_hf))

optimizer = torch.optim.Adam(model_hf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_hf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_hf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_hf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.cuda(), y.cuda()

            out = model_hf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader_hf)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))


Epoch-113, Time-0.2045, Train-MSE-0.0394, Train-L2-0.0988, Test-L2-1.5043
Epoch-114, Time-0.2359, Train-MSE-0.0391, Train-L2-0.0936, Test-L2-1.5144
Epoch-115, Time-0.2501, Train-MSE-0.0397, Train-L2-0.0950, Test-L2-1.5020
Epoch-116, Time-0.2511, Train-MSE-0.0394, Train-L2-0.0983, Test-L2-1.5154
Epoch-117, Time-0.2568, Train-MSE-0.0382, Train-L2-0.0943, Test-L2-1.5141
Epoch-118, Time-0.2571, Train-MSE-0.0385, Train-L2-0.0902, Test-L2-1.4995
Epoch-119, Time-0.2529, Train-MSE-0.0388, Train-L2-0.0969, Test-L2-1.5162
Epoch-120, Time-0.2360, Train-MSE-0.0380, Train-L2-0.0888, Test-L2-1.5149
Epoch-121, Time-0.2488, Train-MSE-0.0367, Train-L2-0.0896, Test-L2-1.5114
Epoch-122, Time-0.2481, Train-MSE-0.0376, Train-L2-0.0869, Test-L2-1.5197
Epoch-123, Time-0.2553, Train-MSE-0.0369, Train-L2-0.0842, Test-L2-1.5150
Epoch-124, Time-0.2443, Train-MSE-0.0360, Train-L2-0.0868, Test-L2-1.5188
Epoch-125, Time-0.2465, Train-MSE-0.0365, Train-L2-0.0821, Test-L2-1.5224
Epoch-126, Time-0.2394, Train-MSE-0.03

Epoch-224, Time-0.2341, Train-MSE-0.0245, Train-L2-0.0540, Test-L2-1.5601
Epoch-225, Time-0.2019, Train-MSE-0.0242, Train-L2-0.0547, Test-L2-1.5604
Epoch-226, Time-0.2110, Train-MSE-0.0242, Train-L2-0.0540, Test-L2-1.5594
Epoch-227, Time-0.2078, Train-MSE-0.0244, Train-L2-0.0546, Test-L2-1.5586
Epoch-228, Time-0.2073, Train-MSE-0.0241, Train-L2-0.0552, Test-L2-1.5612
Epoch-229, Time-0.2144, Train-MSE-0.0241, Train-L2-0.0555, Test-L2-1.5601
Epoch-230, Time-0.2048, Train-MSE-0.0241, Train-L2-0.0541, Test-L2-1.5599
Epoch-231, Time-0.2326, Train-MSE-0.0242, Train-L2-0.0553, Test-L2-1.5624
Epoch-232, Time-0.2598, Train-MSE-0.0238, Train-L2-0.0543, Test-L2-1.5620
Epoch-233, Time-0.2583, Train-MSE-0.0240, Train-L2-0.0549, Test-L2-1.5620
Epoch-234, Time-0.2564, Train-MSE-0.0238, Train-L2-0.0545, Test-L2-1.5609
Epoch-235, Time-0.2585, Train-MSE-0.0241, Train-L2-0.0537, Test-L2-1.5611
Epoch-236, Time-0.2562, Train-MSE-0.0238, Train-L2-0.0545, Test-L2-1.5625
Epoch-237, Time-0.2609, Train-MSE-0.02

Epoch-335, Time-0.2362, Train-MSE-0.0213, Train-L2-0.0469, Test-L2-1.5720
Epoch-336, Time-0.2401, Train-MSE-0.0213, Train-L2-0.0469, Test-L2-1.5720
Epoch-337, Time-0.2467, Train-MSE-0.0212, Train-L2-0.0469, Test-L2-1.5724
Epoch-338, Time-0.2483, Train-MSE-0.0212, Train-L2-0.0468, Test-L2-1.5720
Epoch-339, Time-0.2490, Train-MSE-0.0212, Train-L2-0.0468, Test-L2-1.5724
Epoch-340, Time-0.2382, Train-MSE-0.0212, Train-L2-0.0468, Test-L2-1.5724
Epoch-341, Time-0.2470, Train-MSE-0.0212, Train-L2-0.0467, Test-L2-1.5724
Epoch-342, Time-0.2433, Train-MSE-0.0212, Train-L2-0.0467, Test-L2-1.5725
Epoch-343, Time-0.2492, Train-MSE-0.0212, Train-L2-0.0467, Test-L2-1.5726
Epoch-344, Time-0.2488, Train-MSE-0.0211, Train-L2-0.0467, Test-L2-1.5726
Epoch-345, Time-0.2596, Train-MSE-0.0211, Train-L2-0.0466, Test-L2-1.5728
Epoch-346, Time-0.2471, Train-MSE-0.0211, Train-L2-0.0466, Test-L2-1.5726
Epoch-347, Time-0.2421, Train-MSE-0.0211, Train-L2-0.0466, Test-L2-1.5728
Epoch-348, Time-0.2496, Train-MSE-0.02

Epoch-446, Time-0.1792, Train-MSE-0.0205, Train-L2-0.0452, Test-L2-1.5754
Epoch-447, Time-0.1801, Train-MSE-0.0205, Train-L2-0.0452, Test-L2-1.5754
Epoch-448, Time-0.1770, Train-MSE-0.0205, Train-L2-0.0452, Test-L2-1.5755
Epoch-449, Time-0.1807, Train-MSE-0.0205, Train-L2-0.0452, Test-L2-1.5754
Epoch-450, Time-0.1820, Train-MSE-0.0205, Train-L2-0.0451, Test-L2-1.5754
Epoch-451, Time-0.1846, Train-MSE-0.0205, Train-L2-0.0451, Test-L2-1.5754
Epoch-452, Time-0.1775, Train-MSE-0.0205, Train-L2-0.0451, Test-L2-1.5754
Epoch-453, Time-0.1833, Train-MSE-0.0205, Train-L2-0.0451, Test-L2-1.5754
Epoch-454, Time-0.1779, Train-MSE-0.0205, Train-L2-0.0451, Test-L2-1.5755
Epoch-455, Time-0.1791, Train-MSE-0.0204, Train-L2-0.0451, Test-L2-1.5755
Epoch-456, Time-0.1809, Train-MSE-0.0204, Train-L2-0.0451, Test-L2-1.5755
Epoch-457, Time-0.1812, Train-MSE-0.0204, Train-L2-0.0451, Test-L2-1.5755
Epoch-458, Time-0.1833, Train-MSE-0.0204, Train-L2-0.0451, Test-L2-1.5755
Epoch-459, Time-0.1831, Train-MSE-0.02

In [None]:
# Save the HF-WNO model

torch.save(model_hf, 'model/HF_WNO_poisson1D_20')

In [None]:
# Prediction:
pred_hf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_hf(x[:,:,0:1]).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_hf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_hf = torch.cat(( pred_hf ))
# print('Mean mse_mf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))


In [None]:
inp = x_test_mf
real_hf = y_test_mf + inp[:,:,1] 
output_hf = pred_hf

In [None]:
mse_pred_hf = F.mse_loss(output_hf, real_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_hf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_hf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
fig5, axs = plt.subplots(2, 2,figsize=(18,10), dpi=100)
plt.subplots_adjust(wspace=0.25)

n0 = 15
axs[0, 0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0, 0].plot(x_coords,inp_mf[n0,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0, 0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 0].margins(0)
axs[0, 0].grid(True, alpha=0.3)

n1 = 28
axs[0, 1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[0, 1].plot(x_coords,inp_mf[n1,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[0, 1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[0, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 1].margins(0)
axs[0, 1].grid(True, alpha=0.3)

n2 = 37
axs[1, 0].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[1, 0].plot(x_coords,inp_mf[n2,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 0].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[1, 0].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[1, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 0].margins(0)
axs[1, 0].grid(True, alpha=0.3)

n3 = 38
axs[1, 1].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[1, 1].plot(x_coords,inp_mf[n3,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 1].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[1, 1].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[1, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 1].margins(0)
axs[1, 1].grid(True, alpha=0.3)
