# It trains MFWNO on the MF 1D Poisson's data (time-independent problem).
### HF data size = 30

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer
from pytorch_wavelets import DWT1D, IDWT1D

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# WNO

In [None]:
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet 
        self.dwt_ = DWT1D(wave=self.wavelet, J=self.level, mode='zero')
        dummy_data = torch.randn( 1,1,size ) 
        mode_data, _ = self.dwt_(dummy_data)
        self.modes1 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet     
        dwt = DWT1D(wave=self.wavelet, J=self.level, mode='zero').to(x.device)
        x_ft, x_coeff = dwt(x)
        
        # Multiply the final low pass wavelet coefficients
        out_ft = self.mul1d(x_ft, self.weights1)
        # Multiply the final high pass wavelet coefficients
        x_coeff[-1] = self.mul1d(x_coeff[-1].clone(), self.weights2)
        
        # Reconstruct the signal
        idwt = IDWT1D(wave=self.wavelet, mode='zero').to(x.device)
        x = idwt((out_ft, x_coeff)) 
        return x

In [None]:
class WNO1d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class WNO1d_linear(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d_linear, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class MFWNO(nn.Module):
  def __init__(self, width, level, size, wavelet, in_channel, grid_range):
    super(MFWNO, self).__init__()
    
    self.width = width
    self.level = level 
    self.size = size
    self.wavelet = wavelet 
    self.in_channel = in_channel 
    self.grid_range = grid_range
    
    self.conv1 = WNO1d_linear(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.conv2 = WNO1d(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.fc0 = nn.Linear(1,12)
    self.fc1 = nn.Linear(12,1)

  def forward(self, x):
    x = self.conv1(x) + self.conv2(x)
    x = self.fc0(x)
    x = F.gelu(x)
    x = self.fc1(x)
    return x


# Multifidelity

In [None]:
ntrain = 30
ntest = 40
n_total = ntrain + ntest
last_m = 400
s = 100

batch_size = 5
learning_rate = 0.001

w_decay = 1e-4
epochs = 500
step_size = 50   # weight-decay step size
gamma = 0.5      # weight-decay rate

wavelet = 'db6'  # wavelet basis function
level = 3        # lavel of wavelet decomposition
width = 64       # uplifting dimension
layers = 4       # no of wavelet layers

h = 100           # total grid size divided by the subsampling rate
grid_range = 1
in_channel = 3   # (a(x), x) for this case


In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'][last_m-n_total:last_m]
y_data_l = data['y_low_100'][last_m-n_total:last_m]
y_data_h = data['yhi'][last_m-n_total:last_m]
x_coords = data['xhi'].reshape((s,))

In [None]:
 data['f_stoch'].shape

In [None]:
# read data

x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l

x_train_mf, y_train_mf = x_mf[:ntrain, ...], y_mf[:ntrain, ...]
x_test_mf, y_test_mf = x_mf[-ntest:, ...], y_mf[-ntest:, ...]

x_train_mf = torch.tensor( x_train_mf, dtype=torch.float )
y_train_mf = torch.tensor( y_train_mf, dtype=torch.float ) 
x_test_mf = torch.tensor( x_test_mf, dtype=torch.float ) 
y_test_mf = torch.tensor( y_test_mf, dtype=torch.float ) 

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_mf.shape, y_mf.shape, x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
# model
model_mf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=3, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_mf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    test_mse = 0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.cuda(), y.cuda()

            out = model_mf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            test_mse += tmse.item()

    train_mse /= len(train_loader_mf)
    train_l2 /= ntrain
    test_l2 /= ntest
    test_mse /= len(test_loader_mf)
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}, Test-MSE-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2, test_mse))


Epoch-92, Time-0.2585, Train-MSE-0.0000, Train-L2-0.1034, Test-L2-0.7560, Test-MSE-0.0009
Epoch-93, Time-0.2769, Train-MSE-0.0000, Train-L2-0.0723, Test-L2-0.7478, Test-MSE-0.0009
Epoch-94, Time-0.2671, Train-MSE-0.0000, Train-L2-0.0789, Test-L2-0.7605, Test-MSE-0.0009
Epoch-95, Time-0.2685, Train-MSE-0.0000, Train-L2-0.0763, Test-L2-0.7494, Test-MSE-0.0009
Epoch-96, Time-0.2752, Train-MSE-0.0000, Train-L2-0.0873, Test-L2-0.7534, Test-MSE-0.0009
Epoch-97, Time-0.2545, Train-MSE-0.0000, Train-L2-0.0903, Test-L2-0.7501, Test-MSE-0.0009
Epoch-98, Time-0.2632, Train-MSE-0.0000, Train-L2-0.0796, Test-L2-0.7607, Test-MSE-0.0009
Epoch-99, Time-0.2504, Train-MSE-0.0000, Train-L2-0.0967, Test-L2-0.7624, Test-MSE-0.0009
Epoch-100, Time-0.2436, Train-MSE-0.0000, Train-L2-0.0822, Test-L2-0.7488, Test-MSE-0.0009
Epoch-101, Time-0.2346, Train-MSE-0.0000, Train-L2-0.0816, Test-L2-0.7572, Test-MSE-0.0009
Epoch-102, Time-0.2389, Train-MSE-0.0000, Train-L2-0.0740, Test-L2-0.7567, Test-MSE-0.0009
Epoch-1

Epoch-183, Time-0.2696, Train-MSE-0.0000, Train-L2-0.0270, Test-L2-0.7581, Test-MSE-0.0009
Epoch-184, Time-0.2653, Train-MSE-0.0000, Train-L2-0.0242, Test-L2-0.7545, Test-MSE-0.0009
Epoch-185, Time-0.2856, Train-MSE-0.0000, Train-L2-0.0245, Test-L2-0.7542, Test-MSE-0.0009
Epoch-186, Time-0.2888, Train-MSE-0.0000, Train-L2-0.0225, Test-L2-0.7577, Test-MSE-0.0009
Epoch-187, Time-0.2757, Train-MSE-0.0000, Train-L2-0.0227, Test-L2-0.7555, Test-MSE-0.0009
Epoch-188, Time-0.2933, Train-MSE-0.0000, Train-L2-0.0213, Test-L2-0.7553, Test-MSE-0.0009
Epoch-189, Time-0.2947, Train-MSE-0.0000, Train-L2-0.0251, Test-L2-0.7559, Test-MSE-0.0009
Epoch-190, Time-0.2995, Train-MSE-0.0000, Train-L2-0.0245, Test-L2-0.7568, Test-MSE-0.0009
Epoch-191, Time-0.3042, Train-MSE-0.0000, Train-L2-0.0227, Test-L2-0.7570, Test-MSE-0.0009
Epoch-192, Time-0.3092, Train-MSE-0.0000, Train-L2-0.0240, Test-L2-0.7555, Test-MSE-0.0009
Epoch-193, Time-0.3052, Train-MSE-0.0000, Train-L2-0.0224, Test-L2-0.7554, Test-MSE-0.0009

Epoch-274, Time-0.2661, Train-MSE-0.0000, Train-L2-0.0104, Test-L2-0.7571, Test-MSE-0.0009
Epoch-275, Time-0.2648, Train-MSE-0.0000, Train-L2-0.0106, Test-L2-0.7566, Test-MSE-0.0009
Epoch-276, Time-0.2662, Train-MSE-0.0000, Train-L2-0.0104, Test-L2-0.7568, Test-MSE-0.0009
Epoch-277, Time-0.2655, Train-MSE-0.0000, Train-L2-0.0104, Test-L2-0.7568, Test-MSE-0.0009
Epoch-278, Time-0.2657, Train-MSE-0.0000, Train-L2-0.0103, Test-L2-0.7569, Test-MSE-0.0009
Epoch-279, Time-0.2649, Train-MSE-0.0000, Train-L2-0.0105, Test-L2-0.7566, Test-MSE-0.0009
Epoch-280, Time-0.2675, Train-MSE-0.0000, Train-L2-0.0107, Test-L2-0.7570, Test-MSE-0.0009
Epoch-281, Time-0.2703, Train-MSE-0.0000, Train-L2-0.0106, Test-L2-0.7567, Test-MSE-0.0009
Epoch-282, Time-0.2729, Train-MSE-0.0000, Train-L2-0.0107, Test-L2-0.7569, Test-MSE-0.0009
Epoch-283, Time-0.2684, Train-MSE-0.0000, Train-L2-0.0106, Test-L2-0.7570, Test-MSE-0.0009
Epoch-284, Time-0.2675, Train-MSE-0.0000, Train-L2-0.0107, Test-L2-0.7569, Test-MSE-0.0009

Epoch-365, Time-0.3048, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7571, Test-MSE-0.0009
Epoch-366, Time-0.2965, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7570, Test-MSE-0.0009
Epoch-367, Time-0.2900, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7570, Test-MSE-0.0009
Epoch-368, Time-0.2936, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7571, Test-MSE-0.0009
Epoch-369, Time-0.3087, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7571, Test-MSE-0.0009
Epoch-370, Time-0.3029, Train-MSE-0.0000, Train-L2-0.0090, Test-L2-0.7570, Test-MSE-0.0009
Epoch-371, Time-0.2966, Train-MSE-0.0000, Train-L2-0.0089, Test-L2-0.7571, Test-MSE-0.0009
Epoch-372, Time-0.3119, Train-MSE-0.0000, Train-L2-0.0089, Test-L2-0.7571, Test-MSE-0.0009
Epoch-373, Time-0.3181, Train-MSE-0.0000, Train-L2-0.0089, Test-L2-0.7571, Test-MSE-0.0009
Epoch-374, Time-0.3009, Train-MSE-0.0000, Train-L2-0.0089, Test-L2-0.7570, Test-MSE-0.0009
Epoch-375, Time-0.3010, Train-MSE-0.0000, Train-L2-0.0089, Test-L2-0.7571, Test-MSE-0.0009

Epoch-456, Time-0.3055, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-457, Time-0.3093, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-458, Time-0.2995, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-459, Time-0.3090, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-460, Time-0.3086, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-461, Time-0.3089, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-462, Time-0.3064, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-463, Time-0.3066, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-464, Time-0.3087, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-465, Time-0.3256, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009
Epoch-466, Time-0.2847, Train-MSE-0.0000, Train-L2-0.0086, Test-L2-0.7571, Test-MSE-0.0009

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_poisson1D_30')

In [None]:
# Prediction:
pred_mf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_mf(x).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_mf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
    

In [None]:
pred_mf.shape

In [None]:
inp_mf  = x_test_mf 
real_mf = y_test_mf + inp_mf[:,:,1]
output_mf  =  pred_mf + inp_mf[:,:,1]

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()

print('MSE-Predicted solution-{:0.6f}'.format(mse_pred))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_mf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_mf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
colormap = plt.cm.jet  
colors2 = [colormap(i) for i in np.linspace(0, 1, 5)]

fig1 = plt.figure(figsize = (10, 4), dpi=300)
fig1.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, inp_mf[i, :, 0], color=colors2[index], label='Forcing-{}'.format(i))
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# High Fidelity

In [None]:
# read data
x_hf = torch.tensor( x_data_h, dtype=torch.float ).unsqueeze(-1)
y_hf = torch.tensor( y_data_h, dtype=torch.float ) 

x_train_hf, y_train_hf = x_hf[:ntrain, ...], y_hf[:ntrain, ...]
x_test_hf, y_test_hf = x_hf[-ntest:, ...], y_hf[-ntest:, ...]

train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                              batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, x_test_hf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_train_hf.shape, y_train_hf.shape, x_test_hf.shape, y_test_hf.shape)

In [None]:
# model
model_hf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=2, grid_range=grid_range).to(device)
print(count_params(model_hf))

optimizer = torch.optim.Adam(model_hf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_hf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_hf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_hf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.cuda(), y.cuda()

            out = model_hf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader_hf)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))


Epoch-113, Time-0.3040, Train-MSE-0.0214, Train-L2-0.0936, Test-L2-1.5716
Epoch-114, Time-0.3016, Train-MSE-0.0217, Train-L2-0.0930, Test-L2-1.5987
Epoch-115, Time-0.3182, Train-MSE-0.0222, Train-L2-0.0952, Test-L2-1.5748
Epoch-116, Time-0.3038, Train-MSE-0.0215, Train-L2-0.0951, Test-L2-1.6041
Epoch-117, Time-0.3100, Train-MSE-0.0190, Train-L2-0.1013, Test-L2-1.5706
Epoch-118, Time-0.3050, Train-MSE-0.0195, Train-L2-0.0842, Test-L2-1.5894
Epoch-119, Time-0.2956, Train-MSE-0.0181, Train-L2-0.0885, Test-L2-1.5819
Epoch-120, Time-0.3044, Train-MSE-0.0194, Train-L2-0.0861, Test-L2-1.5924
Epoch-121, Time-0.3107, Train-MSE-0.0169, Train-L2-0.0857, Test-L2-1.5975
Epoch-122, Time-0.3054, Train-MSE-0.0189, Train-L2-0.0826, Test-L2-1.5811
Epoch-123, Time-0.3002, Train-MSE-0.0153, Train-L2-0.0811, Test-L2-1.5979
Epoch-124, Time-0.3150, Train-MSE-0.0162, Train-L2-0.0806, Test-L2-1.5888
Epoch-125, Time-0.2918, Train-MSE-0.0165, Train-L2-0.0802, Test-L2-1.6065
Epoch-126, Time-0.2967, Train-MSE-0.01

Epoch-224, Time-0.2991, Train-MSE-0.0031, Train-L2-0.0359, Test-L2-1.5954
Epoch-225, Time-0.3067, Train-MSE-0.0031, Train-L2-0.0366, Test-L2-1.5974
Epoch-226, Time-0.2994, Train-MSE-0.0029, Train-L2-0.0338, Test-L2-1.5968
Epoch-227, Time-0.3031, Train-MSE-0.0032, Train-L2-0.0365, Test-L2-1.5998
Epoch-228, Time-0.2912, Train-MSE-0.0028, Train-L2-0.0367, Test-L2-1.5962
Epoch-229, Time-0.2863, Train-MSE-0.0029, Train-L2-0.0319, Test-L2-1.5974
Epoch-230, Time-0.2868, Train-MSE-0.0028, Train-L2-0.0313, Test-L2-1.5964
Epoch-231, Time-0.2747, Train-MSE-0.0029, Train-L2-0.0317, Test-L2-1.5989
Epoch-232, Time-0.3004, Train-MSE-0.0029, Train-L2-0.0309, Test-L2-1.5971
Epoch-233, Time-0.3017, Train-MSE-0.0029, Train-L2-0.0325, Test-L2-1.5968
Epoch-234, Time-0.2858, Train-MSE-0.0029, Train-L2-0.0330, Test-L2-1.5960
Epoch-235, Time-0.3116, Train-MSE-0.0030, Train-L2-0.0366, Test-L2-1.6011
Epoch-236, Time-0.2988, Train-MSE-0.0027, Train-L2-0.0335, Test-L2-1.5936
Epoch-237, Time-0.3007, Train-MSE-0.00

Epoch-335, Time-0.2899, Train-MSE-0.0020, Train-L2-0.0211, Test-L2-1.5971
Epoch-336, Time-0.2995, Train-MSE-0.0020, Train-L2-0.0212, Test-L2-1.5974
Epoch-337, Time-0.2994, Train-MSE-0.0020, Train-L2-0.0213, Test-L2-1.5964
Epoch-338, Time-0.3011, Train-MSE-0.0020, Train-L2-0.0214, Test-L2-1.5975
Epoch-339, Time-0.2964, Train-MSE-0.0020, Train-L2-0.0216, Test-L2-1.5965
Epoch-340, Time-0.3070, Train-MSE-0.0020, Train-L2-0.0216, Test-L2-1.5974
Epoch-341, Time-0.3065, Train-MSE-0.0020, Train-L2-0.0215, Test-L2-1.5970
Epoch-342, Time-0.2971, Train-MSE-0.0020, Train-L2-0.0214, Test-L2-1.5968
Epoch-343, Time-0.3004, Train-MSE-0.0020, Train-L2-0.0213, Test-L2-1.5970
Epoch-344, Time-0.3009, Train-MSE-0.0020, Train-L2-0.0210, Test-L2-1.5974
Epoch-345, Time-0.2441, Train-MSE-0.0020, Train-L2-0.0209, Test-L2-1.5974
Epoch-346, Time-0.2906, Train-MSE-0.0020, Train-L2-0.0209, Test-L2-1.5968
Epoch-347, Time-0.3101, Train-MSE-0.0020, Train-L2-0.0210, Test-L2-1.5969
Epoch-348, Time-0.2991, Train-MSE-0.00

Epoch-446, Time-0.2884, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-447, Time-0.2981, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-448, Time-0.2968, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5969
Epoch-449, Time-0.3119, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5971
Epoch-450, Time-0.3126, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-451, Time-0.2963, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-452, Time-0.3115, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-453, Time-0.3003, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-454, Time-0.2890, Train-MSE-0.0018, Train-L2-0.0196, Test-L2-1.5970
Epoch-455, Time-0.3080, Train-MSE-0.0018, Train-L2-0.0195, Test-L2-1.5970
Epoch-456, Time-0.3095, Train-MSE-0.0018, Train-L2-0.0195, Test-L2-1.5970
Epoch-457, Time-0.3090, Train-MSE-0.0018, Train-L2-0.0195, Test-L2-1.5970
Epoch-458, Time-0.3126, Train-MSE-0.0018, Train-L2-0.0195, Test-L2-1.5970
Epoch-459, Time-0.3041, Train-MSE-0.00

In [None]:
# Save the HF-WNO model

torch.save(model_hf, 'model/HF_WNO_poisson1D_30')

In [None]:
# Prediction:
pred_hf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_hf(x[:,:,0:1]).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_hf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_hf = torch.cat(( pred_hf ))
# print('Mean mse_mf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))


In [None]:
inp = x_test_mf
real_hf = y_test_mf + inp[:,:,1] 
output_hf = pred_hf

In [None]:
mse_pred_hf = F.mse_loss(output_hf, real_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_hf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_hf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
fig5, axs = plt.subplots(2, 2,figsize=(18,10), dpi=100)
plt.subplots_adjust(wspace=0.25)

n0 = 15
axs[0, 0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0, 0].plot(x_coords,inp_mf[n0,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0, 0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 0].margins(0)
axs[0, 0].grid(True, alpha=0.3)

n1 = 28
axs[0, 1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[0, 1].plot(x_coords,inp_mf[n1,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[0, 1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[0, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 1].margins(0)
axs[0, 1].grid(True, alpha=0.3)

n2 = 37
axs[1, 0].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[1, 0].plot(x_coords,inp_mf[n2,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 0].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[1, 0].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[1, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 0].margins(0)
axs[1, 0].grid(True, alpha=0.3)

n3 = 38
axs[1, 1].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[1, 1].plot(x_coords,inp_mf[n3,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 1].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[1, 1].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[1, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 1].margins(0)
axs[1, 1].grid(True, alpha=0.3)
