# It trains MFWNO on the MF 1D Poisson's data (time-independent problem).
### HF data size = 10

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils import *
import matplotlib.pyplot as plt

from timeit import default_timer
from pytorch_wavelets import DWT1D, IDWT1D

In [None]:
torch.manual_seed(10)
np.random.seed(10)

# WNO

In [None]:
class WaveConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv1d, self).__init__()

        """
        1D Wavelet layer. It does Wavelet Transform, linear transform, and
        Inverse Wavelet Transform.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet 
        self.dwt_ = DWT1D(wave=self.wavelet, J=self.level, mode='zero')
        dummy_data = torch.randn( 1,1,size ) 
        mode_data, _ = self.dwt_(dummy_data)
        self.modes1 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1))

    # Convolution
    def mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet     
        dwt = DWT1D(wave=self.wavelet, J=self.level, mode='zero').to(x.device)
        x_ft, x_coeff = dwt(x)
        
        # Multiply the final low pass wavelet coefficients
        out_ft = self.mul1d(x_ft, self.weights1)
        # Multiply the final high pass wavelet coefficients
        x_coeff[-1] = self.mul1d(x_coeff[-1].clone(), self.weights2)
        
        # Reconstruct the signal
        idwt = IDWT1D(wave=self.wavelet, mode='zero').to(x.device)
        x = idwt((out_ft, x_coeff)) 
        return x

In [None]:
class WNO1d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv3 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)
        
        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class WNO1d_linear(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO1d_linear, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 2
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 2: (a(x), x)

        self.conv0 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv1d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        x = F.pad(x, [0,self.padding]) # do padding, if required

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2

        x = x[..., :-self.padding] # remove padding, when required
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)
        return x

    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, self.grid_range, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)
    

In [None]:
class MFWNO(nn.Module):
  def __init__(self, width, level, size, wavelet, in_channel, grid_range):
    super(MFWNO, self).__init__()
    
    self.width = width
    self.level = level 
    self.size = size
    self.wavelet = wavelet 
    self.in_channel = in_channel 
    self.grid_range = grid_range
    
    self.conv1 = WNO1d_linear(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.conv2 = WNO1d(self.width, self.level, self.size, self.wavelet, self.in_channel, self.grid_range)
    self.fc0 = nn.Linear(1,12)
    self.fc1 = nn.Linear(12,1)

  def forward(self, x):
    x = self.conv1(x) + self.conv2(x)
    x = self.fc0(x)
    x = F.gelu(x)
    x = self.fc1(x)
    return x


# Multifidelity

In [None]:
ntrain = 10
ntest = 40
n_total = ntrain + ntest
last_m = 400
s = 100

batch_size = 5
learning_rate = 0.001

w_decay = 1e-4
epochs = 500
step_size = 50   # weight-decay step size
gamma = 0.5      # weight-decay rate

wavelet = 'db6'  # wavelet basis function
level = 3        # lavel of wavelet decomposition
width = 64       # uplifting dimension
layers = 4       # no of wavelet layers

h = 100           # total grid size divided by the subsampling rate
grid_range = 1
in_channel = 3   # (a(x), x) for this case


In [None]:
PATH = 'data/possion_10pt_100pt__lscale_01.npz'
data = np.load(PATH)

x_data_h = data['f_stoch'][last_m-n_total:last_m]
y_data_l = data['y_low_100'][last_m-n_total:last_m]
y_data_h = data['yhi'][last_m-n_total:last_m]
x_coords = data['xhi'].reshape((s,))

In [None]:
 data['f_stoch'].shape

In [None]:
# read data

x_mf = np.stack((x_data_h, y_data_l), axis=-1)
y_mf = y_data_h - y_data_l

x_train_mf, y_train_mf = x_mf[:ntrain, ...], y_mf[:ntrain, ...]
x_test_mf, y_test_mf = x_mf[-ntest:, ...], y_mf[-ntest:, ...]

x_train_mf = torch.tensor( x_train_mf, dtype=torch.float )
y_train_mf = torch.tensor( y_train_mf, dtype=torch.float ) 
x_test_mf = torch.tensor( x_test_mf, dtype=torch.float ) 
y_test_mf = torch.tensor( y_test_mf, dtype=torch.float ) 

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                              batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_mf.shape, y_mf.shape, x_train_mf.shape, y_train_mf.shape, x_test_mf.shape, y_test_mf.shape)

In [None]:
# model
model_mf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=3, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_mf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    test_mse = 0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.cuda(), y.cuda()

            out = model_mf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            tmse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            test_mse += tmse.item()

    train_mse /= len(train_loader_mf)
    train_l2 /= ntrain
    test_l2 /= ntest
    test_mse /= len(test_loader_mf)
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}, Test-MSE-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2, test_mse))


Epoch-92, Time-0.1462, Train-MSE-0.0000, Train-L2-0.0822, Test-L2-1.1496, Test-MSE-0.0021
Epoch-93, Time-0.1448, Train-MSE-0.0000, Train-L2-0.0713, Test-L2-1.1618, Test-MSE-0.0021
Epoch-94, Time-0.1457, Train-MSE-0.0000, Train-L2-0.0697, Test-L2-1.1747, Test-MSE-0.0022
Epoch-95, Time-0.1439, Train-MSE-0.0000, Train-L2-0.0675, Test-L2-1.1606, Test-MSE-0.0021
Epoch-96, Time-0.1459, Train-MSE-0.0000, Train-L2-0.0702, Test-L2-1.1673, Test-MSE-0.0021
Epoch-97, Time-0.1443, Train-MSE-0.0000, Train-L2-0.0749, Test-L2-1.1515, Test-MSE-0.0021
Epoch-98, Time-0.1448, Train-MSE-0.0000, Train-L2-0.0668, Test-L2-1.1704, Test-MSE-0.0022
Epoch-99, Time-0.1451, Train-MSE-0.0000, Train-L2-0.0587, Test-L2-1.1606, Test-MSE-0.0021
Epoch-100, Time-0.1447, Train-MSE-0.0000, Train-L2-0.0671, Test-L2-1.1577, Test-MSE-0.0021
Epoch-101, Time-0.1449, Train-MSE-0.0000, Train-L2-0.0536, Test-L2-1.1542, Test-MSE-0.0021
Epoch-102, Time-0.1461, Train-MSE-0.0000, Train-L2-0.0452, Test-L2-1.1385, Test-MSE-0.0021
Epoch-1

Epoch-184, Time-0.1661, Train-MSE-0.0000, Train-L2-0.0166, Test-L2-1.1528, Test-MSE-0.0021
Epoch-185, Time-0.1583, Train-MSE-0.0000, Train-L2-0.0146, Test-L2-1.1549, Test-MSE-0.0021
Epoch-186, Time-0.1362, Train-MSE-0.0000, Train-L2-0.0146, Test-L2-1.1568, Test-MSE-0.0021
Epoch-187, Time-0.1440, Train-MSE-0.0000, Train-L2-0.0139, Test-L2-1.1495, Test-MSE-0.0021
Epoch-188, Time-0.1292, Train-MSE-0.0000, Train-L2-0.0207, Test-L2-1.1580, Test-MSE-0.0021
Epoch-189, Time-0.1281, Train-MSE-0.0000, Train-L2-0.0162, Test-L2-1.1609, Test-MSE-0.0021
Epoch-190, Time-0.1465, Train-MSE-0.0000, Train-L2-0.0161, Test-L2-1.1512, Test-MSE-0.0021
Epoch-191, Time-0.1400, Train-MSE-0.0000, Train-L2-0.0175, Test-L2-1.1628, Test-MSE-0.0021
Epoch-192, Time-0.1415, Train-MSE-0.0000, Train-L2-0.0185, Test-L2-1.1497, Test-MSE-0.0021
Epoch-193, Time-0.1383, Train-MSE-0.0000, Train-L2-0.0231, Test-L2-1.1530, Test-MSE-0.0021
Epoch-194, Time-0.1357, Train-MSE-0.0000, Train-L2-0.0155, Test-L2-1.1558, Test-MSE-0.0021

Epoch-276, Time-0.1514, Train-MSE-0.0000, Train-L2-0.0078, Test-L2-1.1570, Test-MSE-0.0021
Epoch-277, Time-0.1505, Train-MSE-0.0000, Train-L2-0.0077, Test-L2-1.1559, Test-MSE-0.0021
Epoch-278, Time-0.1609, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-1.1552, Test-MSE-0.0021
Epoch-279, Time-0.1508, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-1.1565, Test-MSE-0.0021
Epoch-280, Time-0.1684, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-1.1552, Test-MSE-0.0021
Epoch-281, Time-0.1495, Train-MSE-0.0000, Train-L2-0.0077, Test-L2-1.1563, Test-MSE-0.0021
Epoch-282, Time-0.1621, Train-MSE-0.0000, Train-L2-0.0075, Test-L2-1.1560, Test-MSE-0.0021
Epoch-283, Time-0.1606, Train-MSE-0.0000, Train-L2-0.0073, Test-L2-1.1552, Test-MSE-0.0021
Epoch-284, Time-0.1556, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-1.1574, Test-MSE-0.0021
Epoch-285, Time-0.1490, Train-MSE-0.0000, Train-L2-0.0076, Test-L2-1.1559, Test-MSE-0.0021
Epoch-286, Time-0.1509, Train-MSE-0.0000, Train-L2-0.0074, Test-L2-1.1557, Test-MSE-0.0021

Epoch-368, Time-0.1639, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-369, Time-0.1552, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-370, Time-0.1563, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-371, Time-0.1664, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-372, Time-0.1652, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-373, Time-0.1634, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-374, Time-0.1574, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-375, Time-0.1587, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-376, Time-0.1560, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-377, Time-0.1597, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021
Epoch-378, Time-0.1539, Train-MSE-0.0000, Train-L2-0.0065, Test-L2-1.1561, Test-MSE-0.0021

Epoch-460, Time-0.1607, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-461, Time-0.1542, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-462, Time-0.1649, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1560, Test-MSE-0.0021
Epoch-463, Time-0.1594, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-464, Time-0.1640, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-465, Time-0.1548, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-466, Time-0.1620, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1560, Test-MSE-0.0021
Epoch-467, Time-0.1496, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-468, Time-0.1643, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-469, Time-0.1557, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1561, Test-MSE-0.0021
Epoch-470, Time-0.1615, Train-MSE-0.0000, Train-L2-0.0063, Test-L2-1.1560, Test-MSE-0.0021

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_poisson1D_10')

In [None]:
# Prediction:
pred_mf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_mf(x).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_mf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_mf = torch.cat(( pred_mf ))
print('Mean mse_mf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))
    

In [None]:
pred_mf.shape

In [None]:
inp_mf  = x_test_mf 
real_mf = y_test_mf + inp_mf[:,:,1]
output_mf  =  pred_mf + inp_mf[:,:,1]

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()

print('MSE-Predicted solution-{:0.6f}'.format(mse_pred))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_mf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_mf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
colormap = plt.cm.jet  
colors2 = [colormap(i) for i in np.linspace(0, 1, 5)]

fig1 = plt.figure(figsize = (10, 4), dpi=300)
fig1.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, inp_mf[i, :, 0], color=colors2[index], label='Forcing-{}'.format(i))
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


# High Fidelity

In [None]:
# read data
x_hf = torch.tensor( x_data_h, dtype=torch.float ).unsqueeze(-1)
y_hf = torch.tensor( y_data_h, dtype=torch.float ) 

x_train_hf, y_train_hf = x_hf[:ntrain, ...], y_hf[:ntrain, ...]
x_test_hf, y_test_hf = x_hf[-ntest:, ...], y_hf[-ntest:, ...]

train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                              batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, x_test_hf),
                                             batch_size=batch_size, shuffle=False)


In [None]:
print(x_train_hf.shape, y_train_hf.shape, x_test_hf.shape, y_test_hf.shape)

In [None]:
# model
model_hf = WNO1d(width=width, level=4, size=h, wavelet=wavelet,
              in_channel=2, grid_range=grid_range).to(device)
print(count_params(model_hf))

optimizer = torch.optim.Adam(model_hf.parameters(), lr=learning_rate, weight_decay=w_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_hf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model_hf(x)

        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward() # use the l2 relative loss

        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()

    scheduler.step()
    model_hf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.cuda(), y.cuda()

            out = model_hf(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_mse /= len(train_loader_hf)
    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))


Epoch-114, Time-0.1440, Train-MSE-0.0218, Train-L2-0.0998, Test-L2-1.4040
Epoch-115, Time-0.1462, Train-MSE-0.0224, Train-L2-0.0965, Test-L2-1.4042
Epoch-116, Time-0.1437, Train-MSE-0.0217, Train-L2-0.0969, Test-L2-1.4151
Epoch-117, Time-0.1431, Train-MSE-0.0209, Train-L2-0.0953, Test-L2-1.4136
Epoch-118, Time-0.1580, Train-MSE-0.0209, Train-L2-0.0947, Test-L2-1.4084
Epoch-119, Time-0.1553, Train-MSE-0.0204, Train-L2-0.0947, Test-L2-1.4126
Epoch-120, Time-0.1410, Train-MSE-0.0194, Train-L2-0.0905, Test-L2-1.4157
Epoch-121, Time-0.1486, Train-MSE-0.0198, Train-L2-0.0998, Test-L2-1.4113
Epoch-122, Time-0.1546, Train-MSE-0.0193, Train-L2-0.0927, Test-L2-1.4113
Epoch-123, Time-0.1505, Train-MSE-0.0184, Train-L2-0.0969, Test-L2-1.4150
Epoch-124, Time-0.1542, Train-MSE-0.0184, Train-L2-0.0904, Test-L2-1.4174
Epoch-125, Time-0.1474, Train-MSE-0.0179, Train-L2-0.0867, Test-L2-1.4179
Epoch-126, Time-0.1569, Train-MSE-0.0173, Train-L2-0.0916, Test-L2-1.4138
Epoch-127, Time-0.1574, Train-MSE-0.01

Epoch-226, Time-0.1266, Train-MSE-0.0060, Train-L2-0.0499, Test-L2-1.4404
Epoch-227, Time-0.1329, Train-MSE-0.0060, Train-L2-0.0477, Test-L2-1.4410
Epoch-228, Time-0.1343, Train-MSE-0.0060, Train-L2-0.0459, Test-L2-1.4410
Epoch-229, Time-0.1319, Train-MSE-0.0060, Train-L2-0.0487, Test-L2-1.4411
Epoch-230, Time-0.1266, Train-MSE-0.0059, Train-L2-0.0461, Test-L2-1.4391
Epoch-231, Time-0.1318, Train-MSE-0.0059, Train-L2-0.0447, Test-L2-1.4404
Epoch-232, Time-0.1270, Train-MSE-0.0059, Train-L2-0.0443, Test-L2-1.4429
Epoch-233, Time-0.1379, Train-MSE-0.0058, Train-L2-0.0448, Test-L2-1.4410
Epoch-234, Time-0.1312, Train-MSE-0.0059, Train-L2-0.0447, Test-L2-1.4387
Epoch-235, Time-0.1227, Train-MSE-0.0058, Train-L2-0.0438, Test-L2-1.4418
Epoch-236, Time-0.1258, Train-MSE-0.0058, Train-L2-0.0435, Test-L2-1.4428
Epoch-237, Time-0.1293, Train-MSE-0.0057, Train-L2-0.0435, Test-L2-1.4403
Epoch-238, Time-0.1354, Train-MSE-0.0057, Train-L2-0.0439, Test-L2-1.4406
Epoch-239, Time-0.1250, Train-MSE-0.00

Epoch-338, Time-0.1618, Train-MSE-0.0046, Train-L2-0.0371, Test-L2-1.4442
Epoch-339, Time-0.1607, Train-MSE-0.0046, Train-L2-0.0371, Test-L2-1.4441
Epoch-340, Time-0.1642, Train-MSE-0.0046, Train-L2-0.0370, Test-L2-1.4441
Epoch-341, Time-0.1623, Train-MSE-0.0046, Train-L2-0.0370, Test-L2-1.4442
Epoch-342, Time-0.1846, Train-MSE-0.0046, Train-L2-0.0370, Test-L2-1.4443
Epoch-343, Time-0.1667, Train-MSE-0.0046, Train-L2-0.0369, Test-L2-1.4443
Epoch-344, Time-0.1550, Train-MSE-0.0046, Train-L2-0.0369, Test-L2-1.4442
Epoch-345, Time-0.1652, Train-MSE-0.0046, Train-L2-0.0369, Test-L2-1.4442
Epoch-346, Time-0.1700, Train-MSE-0.0046, Train-L2-0.0369, Test-L2-1.4443
Epoch-347, Time-0.1647, Train-MSE-0.0045, Train-L2-0.0368, Test-L2-1.4444
Epoch-348, Time-0.1595, Train-MSE-0.0045, Train-L2-0.0368, Test-L2-1.4444
Epoch-349, Time-0.1545, Train-MSE-0.0045, Train-L2-0.0368, Test-L2-1.4444
Epoch-350, Time-0.1804, Train-MSE-0.0045, Train-L2-0.0367, Test-L2-1.4443
Epoch-351, Time-0.1591, Train-MSE-0.00

Epoch-450, Time-0.1292, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-451, Time-0.1354, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-452, Time-0.1337, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-453, Time-0.1318, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-454, Time-0.1408, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-455, Time-0.1318, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4449
Epoch-456, Time-0.1288, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-457, Time-0.1379, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-458, Time-0.1289, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-459, Time-0.1238, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-460, Time-0.1283, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-461, Time-0.1255, Train-MSE-0.0043, Train-L2-0.0356, Test-L2-1.4450
Epoch-462, Time-0.1265, Train-MSE-0.0043, Train-L2-0.0355, Test-L2-1.4450
Epoch-463, Time-0.1286, Train-MSE-0.00

In [None]:
# Save the HF-WNO model

torch.save(model_hf, 'model/HF_WNO_poisson1D_10')

In [None]:
# Prediction:
pred_hf = []
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        test_l2 = 0 
        x, y = x.to(device), y.to(device)

        out = model_hf(x[:,:,0:1]).squeeze(-1)
        test_l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
        pred_hf.append( out.cpu() )
        print("Batch-{}, Test-loss-{:0.6f}".format( index, test_l2 ))
        index += 1

pred_hf = torch.cat(( pred_hf ))
# print('Mean mse_mf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))


In [None]:
inp = x_test_mf
real_hf = y_test_mf + inp[:,:,1] 
output_hf = pred_hf

In [None]:
mse_pred_hf = F.mse_loss(output_hf, real_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
plt.rcParams['font.family'] = 'Times New Roman' 
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'dejavuserif'

colormap = plt.cm.jet  
colors = [colormap(i) for i in np.linspace(0, 1, 5)]

fig2 = plt.figure(figsize = (10, 4), dpi=300)
fig2.suptitle('Stochastic Heat - FNO - High fidelity')

index = 0
for i in range(ntest):
    if i % 10 == 1:
        plt.plot(x_coords, real_hf[i, :], color=colors[index], label='Actual')
        plt.plot(x_coords, output_hf[i,:], '--', color=colors[index], label='Prediction')
        index += 1
plt.legend(ncol=4, loc=4, labelspacing=0.25, columnspacing=0.25, handletextpad=0.5, handlelength=1)
plt.grid(True)
plt.margins(0)


In [None]:
fig5, axs = plt.subplots(2, 2,figsize=(18,10), dpi=100)
plt.subplots_adjust(wspace=0.25)

n0 = 15
axs[0, 0].plot(x_coords,real_mf[n0], linestyle='-', color='tab:green', lw=2)
axs[0, 0].plot(x_coords,inp_mf[n0,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 0].plot(x_coords,output_hf[n0], linestyle='-.', color='tab:orange', lw=2)
axs[0, 0].plot(x_coords,output_mf[n0], linestyle='--', color='tab:red', lw=3)
axs[0, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 0].margins(0)
axs[0, 0].grid(True, alpha=0.3)

n1 = 28
axs[0, 1].plot(x_coords,real_mf[n1], linestyle='-', color='tab:green', lw=2)
axs[0, 1].plot(x_coords,inp_mf[n1,:,0], linestyle=':', color='tab:blue', lw=2)
axs[0, 1].plot(x_coords,output_hf[n1], linestyle='-.', color='tab:orange', lw=2)
axs[0, 1].plot(x_coords,output_mf[n1], linestyle='--', color='tab:red', lw=3)
axs[0, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[0, 1].margins(0)
axs[0, 1].grid(True, alpha=0.3)

n2 = 37
axs[1, 0].plot(x_coords,real_mf[n2], linestyle='-', color='tab:green', lw=2)
axs[1, 0].plot(x_coords,inp_mf[n2,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 0].plot(x_coords,output_hf[n2], linestyle='-.', color='tab:orange', lw=2)
axs[1, 0].plot(x_coords,output_mf[n2], linestyle='--', color='tab:red', lw=3)
axs[1, 0].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 0].margins(0)
axs[1, 0].grid(True, alpha=0.3)

n3 = 38
axs[1, 1].plot(x_coords,real_mf[n3], linestyle='-', color='tab:green', lw=2)
axs[1, 1].plot(x_coords,inp_mf[n3,:,0], linestyle=':', color='tab:blue', lw=2)
axs[1, 1].plot(x_coords,output_hf[n3], linestyle='-.', color='tab:orange', lw=2)
axs[1, 1].plot(x_coords,output_mf[n3], linestyle='--', color='tab:red', lw=3)
axs[1, 1].legend(['HF-Truth','LF-WNO','HF-WNO','MF-WNO'])
axs[1, 1].margins(0)
axs[1, 1].grid(True, alpha=0.3)
