# It trains MF-WNO for 2D time dependent Allen Cahn equation
### HF data size = 20 samples, with 50 time steps 1000 points

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
torch.cuda.empty_cache()
import matplotlib.pyplot as plt
from utils import *

from timeit import default_timer
from pytorch_wavelets import DWT, IDWT # (or import DWT, IDWT)
from pytorch_wavelets import DTCWTForward, DTCWTInverse

In [None]:
torch.manual_seed(0)
np.random.seed(0)

# WNO

In [None]:
class WaveConv2dCwt(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet1, wavelet2):
        super(WaveConv2dCwt, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        !! It is computationally expensive than the discrete "WaveConv2d" !!
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet_level1 = wavelet1
        self.wavelet_level2 = wavelet2        
        dummy_data = torch.randn( 1,1,*size ) 
        dwt_ = DTCWTForward(J=self.level, biort=self.wavelet_level1,
                            qshift=self.wavelet_level2)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        self.modes21 = mode_coef[-1].shape[-3]
        self.modes22 = mode_coef[-1].shape[-2]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights0 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights15r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights15c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights45c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights75c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights105c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights135c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165r = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))
        self.weights165c = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes21, self.modes22))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """        
        # Compute dual tree continuous Wavelet coefficients 
        cwt = DTCWTForward(J=self.level, biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x_ft, x_coeff = cwt(x)
        
        out_ft = torch.zeros_like(x_ft, device= x.device)
        out_coeff = [torch.zeros_like(coeffs, device= x.device) for coeffs in x_coeff]
        
        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights0)
        # Multiply the final detailed wavelet coefficients        
        out_coeff[-1][:,:,0,:,:,0] = self.mul2d(x_coeff[-1][:,:,0,:,:,0].clone(), self.weights15r)
        out_coeff[-1][:,:,0,:,:,1] = self.mul2d(x_coeff[-1][:,:,0,:,:,1].clone(), self.weights15c)
        out_coeff[-1][:,:,1,:,:,0] = self.mul2d(x_coeff[-1][:,:,1,:,:,0].clone(), self.weights45r)
        out_coeff[-1][:,:,1,:,:,1] = self.mul2d(x_coeff[-1][:,:,1,:,:,1].clone(), self.weights45c)
        out_coeff[-1][:,:,2,:,:,0] = self.mul2d(x_coeff[-1][:,:,2,:,:,0].clone(), self.weights75r)
        out_coeff[-1][:,:,2,:,:,1] = self.mul2d(x_coeff[-1][:,:,2,:,:,1].clone(), self.weights75c)
        out_coeff[-1][:,:,3,:,:,0] = self.mul2d(x_coeff[-1][:,:,3,:,:,0].clone(), self.weights105r)
        out_coeff[-1][:,:,3,:,:,1] = self.mul2d(x_coeff[-1][:,:,3,:,:,1].clone(), self.weights105c)
        out_coeff[-1][:,:,4,:,:,0] = self.mul2d(x_coeff[-1][:,:,4,:,:,0].clone(), self.weights135r)
        out_coeff[-1][:,:,4,:,:,1] = self.mul2d(x_coeff[-1][:,:,4,:,:,1].clone(), self.weights135c)
        out_coeff[-1][:,:,5,:,:,0] = self.mul2d(x_coeff[-1][:,:,5,:,:,0].clone(), self.weights165r)
        out_coeff[-1][:,:,5,:,:,1] = self.mul2d(x_coeff[-1][:,:,5,:,:,1].clone(), self.weights165c)
        
        # Return to physical space        
        icwt = DTCWTInverse(biort=self.wavelet_level1, qshift=self.wavelet_level2).to(x.device)
        x = icwt((out_ft, out_coeff))
        return x


In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        """
        The WNO network. It contains 4 layers of the Wavelet integral layer.
        1. Lift the input using v(x) = self.fc0 .
        2. 4 layers of the integral operators v(+1) = g(K(.) + W)(v).
            W is defined by self.w_; K is defined by self.conv_.
        3. Project the output of last layer using self.fc1 and self.fc2.
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """

        self.level = level
        self.width = width
        self.size = size
        self.wavelet1 = wavelet[0]
        self.wavelet2 = wavelet[1]
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.conv1 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.conv2 = WaveConv2dCwt(self.width, self.width, self.level, self.size,
                                            self.wavelet1, self.wavelet2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding, 0,self.padding]) 
        
        # pdb.set_trace()
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding, :-self.padding]
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

# Training and Data

In [None]:
ntrain = 1000
ntest = 1000
ntotal = ntrain + ntest
epochs = 250
lst = 3000
batch_size = 100
side = 65

n_total = ntrain + ntest
learning_rate = 0.001

step_size = 40
gamma = 0.5

wavelet = ['near_sym_a', 'qshift_a']  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 32       # uplifting dimension
s = side
grid_range = [1, 1]
in_channel = 4


In [None]:
path = 'data/ac2dlowhighres_1.mat'
reader = MatReader(path)
u_low = np.array(reader.read_field('ulr_nextstep')[:(ntotal//50)])
u_high = np.array(reader.read_field('uhr')[:(ntotal//50)])

In [None]:
print(u_low.shape, u_high.shape)

In [None]:
x_or_h = u_high[:,:-1,:,:].reshape(-1,s,s,1)
y_or_h = u_high[:,1:,:,:].reshape(-1,s,s)
y_or_l = u_low.reshape(-1,s,s,1)

print(x_or_h.shape, y_or_h.shape, y_or_l.shape)

In [None]:
# Create the input and output (residual) dataset

x_mf = np.concatenate((x_or_h,y_or_l),axis=-1)
y_mf = y_or_h - y_or_l.reshape((n_total,s,s))

x_mf = torch.tensor( x_mf, dtype=torch.float ) 
y_mf = torch.tensor( y_mf, dtype=torch.float ) 

In [None]:
generator = torch.Generator().manual_seed(453)
dataset = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_mf, y_mf),
                                    [ntrain, ntest], generator=generator)
train_data, test_data = dataset[0], dataset[1]


In [None]:
# Split the training and testing datasets

x_train_mf, y_train_mf = train_data[:][0], train_data[:][1]
x_test_mf, y_test_mf = test_data[:][0], test_data[:][1]

In [None]:
y_train_mf.shape

In [None]:
# Define the dataloaders

train_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_mf, y_train_mf),
                                           batch_size=batch_size, shuffle=True)
test_loader_mf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_mf, y_test_mf),
                                              batch_size=batch_size, shuffle=False)

In [None]:
# %%
""" The MD-WNO model definition """
model_mf = WNO2d(width=width, level=level, size=[s,s], wavelet=wavelet,
              in_channel=in_channel, grid_range=grid_range).to(device)
print(count_params(model_mf))

optimizer = torch.optim.Adam(model_mf.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# MF Model

In [None]:
# Train the MF-WNO model on MF-dataset

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model_mf.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_mf:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model_mf(x).reshape(x.shape[0], s, s)
        
        mse = F.mse_loss(out.view(x.shape[0], -1), y.view(x.shape[0], -1), reduction='mean')
        loss = myloss(out.view(x.shape[0],-1), y.view(x.shape[0],-1))
        loss.backward()
        optimizer.step()
        
        train_mse += mse.item()
        train_l2 += loss.item()
    
    scheduler.step()
    model_mf.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_mf:
            x, y = x.to(device), y.to(device)

            out = model_mf(x).reshape(x.shape[0], s, s)

            test_l2 += myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()

    train_mse /= len(train_loader_mf)
    train_l2/= ntrain
    test_l2 /= ntest
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))

Epoch-111, Time-10.3409, Train-MSE-0.0022, Train-L2-0.0778, Test-L2-0.0845
Epoch-112, Time-10.3255, Train-MSE-0.0022, Train-L2-0.0777, Test-L2-0.0844
Epoch-113, Time-10.3060, Train-MSE-0.0022, Train-L2-0.0776, Test-L2-0.0840
Epoch-114, Time-10.3505, Train-MSE-0.0022, Train-L2-0.0775, Test-L2-0.0841
Epoch-115, Time-10.3581, Train-MSE-0.0022, Train-L2-0.0772, Test-L2-0.0836
Epoch-116, Time-10.3614, Train-MSE-0.0022, Train-L2-0.0770, Test-L2-0.0848
Epoch-117, Time-10.3406, Train-MSE-0.0022, Train-L2-0.0770, Test-L2-0.0834
Epoch-118, Time-10.3378, Train-MSE-0.0022, Train-L2-0.0767, Test-L2-0.0839
Epoch-119, Time-10.3365, Train-MSE-0.0022, Train-L2-0.0766, Test-L2-0.0830
Epoch-120, Time-10.2836, Train-MSE-0.0022, Train-L2-0.0757, Test-L2-0.0824
Epoch-121, Time-10.3182, Train-MSE-0.0022, Train-L2-0.0754, Test-L2-0.0823
Epoch-122, Time-10.3123, Train-MSE-0.0022, Train-L2-0.0753, Test-L2-0.0820
Epoch-123, Time-10.2999, Train-MSE-0.0022, Train-L2-0.0751, Test-L2-0.0819
Epoch-124, Time-10.3298, 

Epoch-221, Time-10.3758, Train-MSE-0.0022, Train-L2-0.0691, Test-L2-0.0763
Epoch-222, Time-10.3756, Train-MSE-0.0022, Train-L2-0.0691, Test-L2-0.0762
Epoch-223, Time-10.4011, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0762
Epoch-224, Time-10.4006, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0762
Epoch-225, Time-10.3939, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0762
Epoch-226, Time-10.4035, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0762
Epoch-227, Time-10.3591, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0762
Epoch-228, Time-10.3892, Train-MSE-0.0022, Train-L2-0.0690, Test-L2-0.0761
Epoch-229, Time-10.3848, Train-MSE-0.0022, Train-L2-0.0689, Test-L2-0.0761
Epoch-230, Time-10.3771, Train-MSE-0.0022, Train-L2-0.0689, Test-L2-0.0760
Epoch-231, Time-10.4102, Train-MSE-0.0022, Train-L2-0.0688, Test-L2-0.0760
Epoch-232, Time-10.3891, Train-MSE-0.0022, Train-L2-0.0688, Test-L2-0.0760
Epoch-233, Time-10.4259, Train-MSE-0.0022, Train-L2-0.0688, Test-L2-0.0760
Epoch-234, Time-10.4129, 

In [None]:
# Save the MF-WNO model

torch.save(model_mf, 'model/MF_WNO_AC2D_1000samples')

In [None]:
# Prediction:
pred_mf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_mf:
        x, y = x.to(device), y.to(device)

        out = model_mf(x).reshape(x.shape[0], s, s)
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        test_l2 /= x.shape[0]
        print('Batch-{}, Test-L2-{:0.4f}'.format(index, test_l2))
        
        pred_mf.append(out.cpu())
        index += 1

pred_mf = torch.cat(( pred_mf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_test_mf, pred_mf).item()))

In [None]:
# Add the residual operator to LF-dataset 

real_mf = y_test_mf + x_test_mf[..., 1]
output_mf = pred_mf + x_test_mf[..., 1]

real_mf_time = real_mf.reshape(20, 50, s, s)
output_mf_time = output_mf.reshape(20, 50, s, s)


In [None]:
print(real_mf.shape, output_mf.shape)

In [None]:
mse_pred = F.mse_loss(output_mf, real_mf).item()
mse_LF = F.mse_loss(real_mf, x_test_mf[..., 1])
mse_residual = F.mse_loss(y_test_mf, pred_mf)

print('MSE-Predicted solution-{:0.4f}, MSE-LF Data-{:0.4f}, MSE-Residual-{:0.4f}'
      .format(mse_pred, mse_LF, mse_residual))


In [None]:
fig4, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig4.subplots_adjust(hspace=0.35, wspace=0.2)

fig4.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
sample = 0
index = 0 
for i in range(50):
    if i % 10 == 0:
        im = axs[0, index].imshow(real_mf_time[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(output_mf_time[sample, i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(real_mf_time[sample, i, :, :] - output_mf_time[sample, i, :, :]),
                                    cmap='jet')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

# High Fidelity

In [None]:
class WaveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, level, size, wavelet):
        super(WaveConv2d, self).__init__()

        """
        2D Wavelet layer. It does DWT, linear transform, and Inverse dWT. 
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.level = level
        self.wavelet = wavelet       
        dummy_data = torch.randn( 1,1,*size )        
        dwt_ = DWT(J=self.level, mode='symmetric', wave=self.wavelet)
        mode_data, mode_coef = dwt_(dummy_data)
        self.modes1 = mode_data.shape[-2]
        self.modes2 = mode_data.shape[-1]
        
        # Parameter initilization
        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))
        self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2))

    # Convolution
    def mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        """
        Input parameters: 
        -----------------
        x : tensor, shape-[Batch * Channel * x * y]
        Output parameters: 
        ------------------
        x : tensor, shape-[Batch * Channel * x * y]
        """
        # Compute single tree Discrete Wavelet coefficients using some wavelet
        dwt = DWT(J=self.level, mode='symmetric', wave=self.wavelet).to(x.device)
        x_ft, x_coeff = dwt(x)

        # Multiply the final approximate Wavelet modes
        out_ft = self.mul2d(x_ft, self.weights1)
        # Multiply the final detailed wavelet coefficients
        x_coeff[-1][:,:,0,:,:] = self.mul2d(x_coeff[-1][:,:,0,:,:].clone(), self.weights2)
        x_coeff[-1][:,:,1,:,:] = self.mul2d(x_coeff[-1][:,:,1,:,:].clone(), self.weights3)
        x_coeff[-1][:,:,2,:,:] = self.mul2d(x_coeff[-1][:,:,2,:,:].clone(), self.weights4)
        
        # Return to physical space        
        idwt = IDWT(mode='symmetric', wave=self.wavelet).to(x.device)
        x = idwt((out_ft, x_coeff))
        return x

In [None]:
class WNO2d(nn.Module):
    def __init__(self, width, level, size, wavelet, in_channel, grid_range):
        super(WNO2d, self).__init__()

        self.level = level
        self.width = width
        self.size = size
        self.wavelet = wavelet
        self.in_channel = in_channel
        self.grid_range = grid_range 
        self.padding = 1
        
        self.fc0 = nn.Linear(self.in_channel, self.width) # input channel is 3: (a(x, y), x, y)

        self.conv0 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv1 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.conv2 = WaveConv2d(self.width, self.width, self.level, self.size, self.wavelet)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        if self.padding != 0:
            x = F.pad(x, [0,self.padding, 0,self.padding]) 
        
        # pdb.set_trace()
        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2

        if self.padding != 0:
            x = x[..., :-self.padding, :-self.padding]
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
    def get_grid(self, shape, device):
        # The grid of the solution
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, self.grid_range[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, self.grid_range[1], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)
    

In [None]:
wavelet = 'db4'  # wavelet basis function
level = 2        # lavel of wavelet decomposition
width = 32       # uplifting dimension
s = side
grid_range = [1, 1]
in_channel = 3
epochs = 105
step_size = 20


In [None]:
# Create the input and output (residual) dataset
x_hf = torch.tensor( x_or_h, dtype=torch.float ) 
y_hf = torch.tensor( y_or_h, dtype=torch.float ) 
    
generator_hf = torch.Generator().manual_seed(453)
dataset_hf = torch.utils.data.random_split(torch.utils.data.TensorDataset(x_hf, y_hf),
                                    [ntrain, ntest], generator=generator)
train_data_hf, test_data_hf = dataset_hf[0], dataset_hf[1]

# Split the training and testing datasets
x_train_hf, y_train_hf = train_data_hf[:][0], train_data_hf[:][1]
x_test_hf, y_test_hf = test_data_hf[:][0], test_data_hf[:][1]

# Define the dataloaders
train_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_hf, y_train_hf),
                                             batch_size=batch_size, shuffle=True)
test_loader_hf = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_hf, y_test_hf),
                                            batch_size=batch_size, shuffle=False)


In [None]:
y_train_hf.shape

In [None]:
model = WNO2d(width=width, level=level, size=[s,s], wavelet=wavelet,
              in_channel=in_channel, grid_range=grid_range).to(device)
print(count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:
# Train the HF-WNO model

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader_hf:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x).reshape(batch_size, s, s)
        
        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1))
        loss.backward()
        optimizer.step()
        
        train_mse += mse.item()
        train_l2 += loss.item()
    
    scheduler.step()
    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader_hf:
            x, y = x.to(device), y.to(device)

            out = model(x).reshape(batch_size, s, s)

            test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item()

    train_mse /= len(train_loader_hf)
    train_l2/= ntrain
    test_l2 /= ntest
    t2 = default_timer()
    print('Epoch-{}, Time-{:0.4f}, Train-MSE-{:0.4f}, Train-L2-{:0.4f}, Test-L2-{:0.4f}'
          .format(ep, t2-t1, train_mse, train_l2, test_l2))

In [None]:
# Save the HF-WNO model

torch.save(model, 'model/HF_WNO_AC2D_1000samples')

In [None]:
# Predict on HF data using HF-WNO
pred_hf = [] 
with torch.no_grad():
    index = 0
    for x, y in test_loader_hf:
        x, y = x.to(device), y.to(device)

        out = model(x).reshape(x.shape[0], s, s)
        test_l2 = myloss(out.view(x.shape[0], -1), y.view(x.shape[0], -1)).item()
        test_l2 /= x.shape[0]
        print('Batch-{}, Test-L2-{:0.4f}'.format(index, test_l2))
        
        pred_hf.append(out.cpu())
        index += 1

pred_hf = torch.cat(( pred_hf ), dim=0 )

print('Mean mse_hf-{}'.format(F.mse_loss(y_test_hf, pred_hf).item()))

In [None]:
pred_hf.shape

In [None]:
output_hf = pred_hf.reshape(20, 50, s, s)
output_hf.shape

In [None]:
mse_pred_hf = F.mse_loss(pred_hf, y_test_hf).item()

print('MSE-Predicted solution-{:0.4f}'.format(mse_pred_hf))


In [None]:
fig5, axs = plt.subplots(nrows=3, ncols=5, figsize=(16, 6), facecolor='w', edgecolor='k')
fig5.subplots_adjust(hspace=0.35, wspace=0.2)

fig5.suptitle(f'Predictions MFWNO AC2d Size', fontsize=16)
index = 0 
for i in range(50):
    if i % 10 == 0:
        im = axs[0, index].imshow(y_test_hf[i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[0, index])
        im = axs[1, index].imshow(pred_hf[i, :, :], cmap='jet', vmin=-1, vmax=1)
        plt.colorbar(im, ax=axs[1, index])
        im = axs[2, index].imshow(torch.abs(y_test_hf[i, :, :] - pred_hf[i, :, :]),
                                    cmap='jet')
        plt.colorbar(im, ax=axs[2, index])
        index += 1
        

In [None]:
# Define rollout function for time-dependent prediction

def rollout(model, vel_in, steps, device='cpu'):
    with torch.no_grad():
        vel = vel_in.to(device)
        velocities = [vel.cpu().numpy()]
        for _ in range(steps):
            vel = model(vel) 
            velocities.append(vel.cpu().numpy())
            
    return np.concatenate(velocities,axis=-1)

In [None]:
# Load initial conditions and rollout the HF-time predictions

u_pred = reader.read_field('uhr')[:100]
u_init = torch.unsqueeze(u_pred[:,0,:,:],axis=-1)

x = reader.read_field('x')
y = reader.read_field('y')
x_low = x.reshape(-1,)[::2]
y_low = y.reshape(-1,)[::2]

epsilon = reader.read_field('epsilon')
time = reader.read_field('time')
dt = float(reader.read_field('dtlarge'))

nx = x_low.shape[0]
ny = x_low.shape[0]
dx= float(x_low[1]-x_low[0])
dy = float(y_low[1]-y_low[0])

trajectory_hf = rollout(model, u_init, 50, device='cuda')

In [None]:
trajectory_hf.shape

In [None]:
def laplacian(x,y,dx,dy,epsilon,nx,ny):
  kx = 2*torch.pi*torch.fft.fftfreq(x.shape[0],d=dx).cuda()
  ky = 2*torch.pi*torch.fft.rfftfreq(y.shape[0],d=dy).cuda()
  kxx,kyy = torch.meshgrid(kx,ky, indexing='ij')
  kxx = kxx.reshape(1,nx,-1)
  kyy = kyy.reshape(1,ny,-1)
  lapl = -epsilon*(kxx**2+kyy**2)
  return lapl

def ac2d_solver(u,laplace,dt):
  uhat = torch.fft.rfft2(u)
  laplacian = laplace*uhat
  u = u + dt*(torch.fft.irfft2(laplacian,s=(u.size(-2), u.size(-1))) + u - u**3)
  return u

def rollout_mf(model,solver,lapl,dt,vel_in,steps,device='cuda'):
  with torch.no_grad():
        vel = vel_in.to(device)
        velocities = [vel.cpu().numpy()]
        for _ in range(steps):
          vel_low = torch.squeeze(vel[:,::2,::2,:])
          vel_lout = ac2d_solver(vel_low,lapl,dt)
          vel_loutup = F.interpolate(torch.unsqueeze(vel_lout,dim=1), 
                                     size=(vel.shape[1],vel.shape[2]), 
                                     mode='bicubic',align_corners=True).permute(0,2,3,1).to(device)
          del vel_low
          del vel_lout
          vel_min = torch.concat((vel,vel_loutup),dim=-1)
          del vel_loutup
          vel = model(vel_min) 
          vel = vel + vel_min[:,:,:,1:]
          del vel_min
          velocities.append(vel.cpu().numpy())
        
  return np.concatenate(velocities,axis=-1)

In [None]:
# Rollout the MF-time predictions

lapl = laplacian(x_low,y_low,dx,dy,float(epsilon),nx,ny)
trajectory_mf = rollout_mf(model_mf,ac2d_solver,lapl,dt,u_init,50,device='cuda')

In [None]:
plt.rcParams["font.family"] = "Serif"
plt.rcParams['font.size'] = 10

fig6, ax = plt.subplots(nrows=5, ncols=6, figsize=(12, 10), dpi=300)
# plt.subplots_adjust(hspace=0.25, wspace=0.3)

sample = 0
index = 0
for i in range(50):
    if i % 10 == 0:
        im = ax[0,index].imshow(u_pred[sample,i,:,:], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='seismic')
        plt.title('Ground Truth Time 5s');
        plt.colorbar(im, ax=ax[0,index], orientation="horizontal", fraction=0.04, pad=0.2)
        im.set_clim(-1,1)
        
        im = ax[1,index].imshow(trajectory_hf[sample,:,:,i], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='seismic')
        plt.title('HFSM-Time 5s');
        plt.colorbar(im, ax=ax[1,index], orientation="horizontal", fraction=0.04, pad=0.2)
        
        im = ax[2,index].imshow(np.abs(u_pred[sample,i,:,:] - trajectory_hf[sample,:,:,i]), extent=[0,1,0,1],
                                interpolation='Gaussian', cmap='seismic')
        plt.title('HFSM-Time 5s');
        plt.colorbar(im, ax=ax[2,index], orientation="horizontal", fraction=0.04, pad=0.2)
 
        im = ax[3,index].imshow(trajectory_mf[sample,:,:,i], extent=[0,1,0,1], interpolation='Gaussian',
                                vmin=-1, vmax=1, cmap='seismic')
        plt.title('MFSM - Time 5s');
        plt.colorbar(im, ax=ax[3,index], orientation="horizontal", fraction=0.04, pad=0.2)
        
        im = ax[4,index].imshow(np.abs(u_pred[sample,i,:,:] - trajectory_mf[sample,:,:,i]), extent=[0,1,0,1],
                                interpolation='Gaussian', cmap='seismic')
        plt.title('MFSM - Time 5s');
        plt.colorbar(im, ax=ax[4,index], orientation="horizontal", fraction=0.04, pad=0.2)
        index += 1
        
# figure1.savefig(f'predictions_allencahn_{ntrain_m}.png', format='png', dpi=300, bbox_inches='tight')  