In [6]:
"""
@author: Zongyi Li
This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf),
which uses a recurrent structure to propagates in time.
"""


import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial

from timeit import default_timer


torch.manual_seed(0)
np.random.seed(0)


################################################################
# fourier layer
################################################################

class SpectralConv2d_fast(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d_fast, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x

class FNO2d(nn.Module):
    def __init__(self, input_len, modes1, modes2, pred_len, width):
        super(FNO2d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
        input shape: (batchsize, x=64, y=64, c=12)
        output: the solution of the next timestep
        output shape: (batchsize, x=64, y=64, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.pred_len = pred_len
        self.padding = 2 # pad the domain if input is non-periodic
        self.input_len = input_len
        inputs_dim = self.input_len + self.padding
        self.fc0 = nn.Linear(inputs_dim, self.width)
        # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)

        self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm2d(self.width)
        self.bn1 = torch.nn.BatchNorm2d(self.width)
        self.bn2 = torch.nn.BatchNorm2d(self.width)
        self.bn3 = torch.nn.BatchNorm2d(self.width)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, pred_len)

    def forward(self, x):
        # b, t, c, h, w = x.shape
        # x = x.reshape(b, t*c, h, w) # b h w t
        # x = x.permute(0, 2, 3, 1)
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)
        # x = F.pad(x, [0,self.padding, 0,self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        # x = x[..., :-self.padding, :-self.padding] # pad the domain if input is non-periodic
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x) # batch_size, h, w, t
        # x = x.permute(0, 3, 1, 2)
        # x = x.reshape(b, t, c, h, w)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)


In [7]:
model = FNO2d(input_len = 10, modes1 = 12, modes2 = 12, pred_len = 10, width = 20)
inputs = torch.rand(1, 64, 64,10)
output = model(inputs)
print(output.shape)

torch.Size([1, 64, 64, 10])


In [8]:
import scipy.io
import torch
import torch.utils.data

def load_navier_stokes_data(path, sub=1, T_in=10, T_out=10, batch_size=20, reshape=None):
    ntrain = 1000
    neval = 100
    ntest = 100
    total = ntrain + neval + ntest
    f = scipy.io.loadmat(path)
    data = f['u'][..., 0:total]
    data = torch.tensor(data, dtype=torch.float32)

    # Training data
    train_a = data[:ntrain, ::sub, ::sub, :T_in]
    train_u = data[:ntrain, ::sub, ::sub, T_in:T_out+T_in] # [N, H,W,T]
    # train_a = train_a.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)  # From [N, H, W, T] to [N, T, H, W, C]
    # train_u = train_u.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)
    #print(train_a.shape, train_u.shape)
    # Evaluation data
    eval_a = data[ntrain:ntrain + neval, ::sub, ::sub, :T_in]
    eval_u = data[ntrain:ntrain + neval, ::sub, ::sub, T_in:T_out+T_in]
    # eval_a = eval_a.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)  # From [N, H, W, T] to [N, T, H, W, C]
    # eval_u = eval_u.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)
    # Testing data
    test_a = data[ntrain + neval:ntrain + neval + ntest, ::sub, ::sub, :T_in]
    test_u = data[ntrain + neval:ntrain + neval + ntest, ::sub, ::sub, T_in:T_out+T_in]
    # test_a = test_a.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)  # From [N, H, W, T] to [N, T, H, W, C] to [N, T, C, H, W]
    # test_u = test_u.unsqueeze(-1).permute(0, 3, 1, 2, 4).permute(0, 1, 4, 2, 3)

    if reshape:
        train_a = train_a.permute(reshape)
        train_u = train_u.permute(reshape)
        eval_a = eval_a.permute(reshape)
        eval_u = eval_u.permute(reshape)
        test_a = test_a.permute(reshape)
        test_u = test_u.permute(reshape)
        
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
    eval_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(eval_a, eval_u), batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

    return train_loader, eval_loader, test_loader


In [9]:
train_loader, eval_loader, test_loader = load_navier_stokes_data(path="/root/autodl-tmp/dataset/NavierStokes_V1e-5_N1200_T20.mat")
for inputs, targets in iter(train_loader):
    print(inputs.shape, targets.shape)
    break

torch.Size([20, 64, 64, 10]) torch.Size([20, 64, 64, 10])


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import math
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def FDM_NS_vorticity(w, v=1/40, t_interval=1.0):
    batchsize = w.size(0)  # w.shape = b h w t
    nx = w.size(1)
    ny = w.size(2)
    nt = w.size(3)
    device = w.device
    w = w.reshape(batchsize, nx, ny, nt)

    w_h = torch.fft.fft2(w, dim=[1, 2])
    k_max = nx // 2
    N = nx
    k_x = torch.cat((torch.arange(0, k_max, device=device), torch.arange(-k_max, 0, device=device)), 0).reshape(N, 1).repeat(1, N).reshape(1, N, N, 1)
    k_y = torch.cat((torch.arange(0, k_max, device=device), torch.arange(-k_max, 0, device=device)), 0).reshape(1, N).repeat(N, 1).reshape(1, N, N, 1)

    lap = (k_x ** 2 + k_y ** 2)
    lap[0, 0, 0, 0] = 1.0
    f_h = w_h / lap

    ux_h = 1j * k_y * f_h
    uy_h = -1j * k_x * f_h
    wx_h = 1j * k_x * w_h
    wy_h = 1j * k_y * w_h
    wlap_h = -lap * w_h

    ux = torch.fft.irfft2(ux_h[:, :, :k_max + 1], dim=[1, 2])
    uy = torch.fft.irfft2(uy_h[:, :, :k_max + 1], dim=[1, 2])
    wx = torch.fft.irfft2(wx_h[:, :, :k_max + 1], dim=[1, 2])
    wy = torch.fft.irfft2(wy_h[:, :, :k_max + 1], dim=[1, 2])
    wlap = torch.fft.irfft2(wlap_h[:, :, :k_max + 1], dim=[1, 2])

    dt = t_interval / (nt - 1)
    wt = (w[:, :, :, 2:] - w[:, :, :, :-2]) / (2 * dt)

    Du1 = wt + (ux * wx + uy * wy - v * wlap)[..., 1:-1]
    return Du1

# LpLoss 
class LpLoss:
    def __init__(self, size_average=True):
        self.size_average = size_average

    def __call__(self, pred, target):
        pred = pred.to(device)
        target = target.to(device)
        loss = torch.abs(pred - target)
        if self.size_average:
            return loss.mean()
        return loss.sum()

#  PINO_loss_2d 
def PINO_loss_2d(u, u0, forcing, v=1/40, t_interval=1.0):
    batchsize = u.size(0)
    nx = u.size(1)
    ny = u.size(2)
    nt = u.size(3)

    u = u.reshape(batchsize, nx, ny, nt)
    lploss = LpLoss(size_average=True)

    u_in = u[:, :, :, 0].to(device)
    u0 = u0.to(device)
    loss_ic = lploss(u_in, u0)

    Du = FDM_NS_vorticity(u, v, t_interval)
    f = forcing.repeat(batchsize, 1, 1, 1)
    loss_f = lploss(Du, f)

    return loss_ic, loss_f


s = 64
t = torch.linspace(0, 1, s+1, device=device)
t = t[0:-1]

X, Y = torch.meshgrid(t, t)
f = 0.1 * (torch.sin(2 * math.pi * (X + Y)) + torch.cos(2 * math.pi * (X + Y)))

nt = 10
forcing = f.unsqueeze(-1).repeat(1, 1, nt).to(device)


model = FNO2d(input_len=10, modes1=12, modes2=12, pred_len=10, width=20).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def train_model(model, train_loader, eval_loader, criterion, optimizer, num_epochs=100):
    best_loss = np.inf
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            u0 = inputs[:, :, :, 0]
            time_steps = inputs.shape[-1]
            u = torch.zeros(inputs.shape[0], inputs.shape[1], inputs.shape[2], time_steps + 2, device=device)
            u[:, :, :, :time_steps] = inputs
            loss_ic, loss_f = PINO_loss_2d(u, u0, forcing, v=1/40, t_interval=1.0)

            loss = criterion(outputs, targets) + loss_ic + loss_f
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)

        train_loss /= len(train_loader.dataset)
        model.eval()
        eval_loss = 0.0
        with torch.no_grad():
            for inputs, targets in eval_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                eval_loss += loss.item() * inputs.size(0)

        eval_loss /= len(eval_loader.dataset)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}')

        if eval_loss < best_loss:
            best_loss = eval_loss
            torch.save(model.state_dict(), 'fnoPINO_best_model.pth')

In [11]:
train_model(model, train_loader, eval_loader, criterion, optimizer, num_epochs=100)

Epoch 1, Train Loss: 3.2179, Eval Loss: 0.8256
Epoch 2, Train Loss: 2.5167, Eval Loss: 0.5893
Epoch 3, Train Loss: 2.2442, Eval Loss: 0.3789
Epoch 4, Train Loss: 2.1553, Eval Loss: 0.3386
Epoch 5, Train Loss: 2.1195, Eval Loss: 0.3054
Epoch 6, Train Loss: 2.0909, Eval Loss: 0.2809
Epoch 7, Train Loss: 2.0696, Eval Loss: 0.2625
Epoch 8, Train Loss: 2.0528, Eval Loss: 0.2483
Epoch 9, Train Loss: 2.0398, Eval Loss: 0.2376
Epoch 10, Train Loss: 2.0301, Eval Loss: 0.2298
Epoch 11, Train Loss: 2.0222, Eval Loss: 0.2215
Epoch 12, Train Loss: 2.0149, Eval Loss: 0.2165
Epoch 13, Train Loss: 2.0090, Eval Loss: 0.2126
Epoch 14, Train Loss: 2.0039, Eval Loss: 0.2073
Epoch 15, Train Loss: 1.9988, Eval Loss: 0.2025
Epoch 16, Train Loss: 1.9940, Eval Loss: 0.1978
Epoch 17, Train Loss: 1.9899, Eval Loss: 0.1944
Epoch 18, Train Loss: 1.9864, Eval Loss: 0.1915
Epoch 19, Train Loss: 1.9831, Eval Loss: 0.1884
Epoch 20, Train Loss: 1.9792, Eval Loss: 0.1866
Epoch 21, Train Loss: 1.9764, Eval Loss: 0.1847
E

In [12]:
import torch
import torch.nn as nn
import numpy as np


def test_model(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    model.load_state_dict(torch.load('fnoPINO_best_model.pth'))
    
    model.eval()
    test_loss = 0.0
    all_inputs = []
    all_targets = []
    all_preds = []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item() * inputs.size(0)
            
            # Move data to CPU for saving
            all_inputs.append(inputs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
            all_preds.append(outputs.cpu().numpy())
    
    test_loss /= len(test_loader.dataset)
    print(f'Test Loss: {test_loss:.4f}')
    
    all_inputs = np.concatenate(all_inputs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    
    np.save('inputs.npy', all_inputs)
    np.save('targets.npy', all_targets)
    np.save('preds.npy', all_preds)
    
    return all_inputs, all_targets, all_preds

all_inputs, all_targets, all_preds = test_model(model, test_loader, criterion)

Test Loss: 0.1249


In [13]:
all_inputs, all_targets, all_preds = test_model(model, test_loader, criterion)

Test Loss: 0.1249


In [1]:
# import numpy as np
# import matplotlib.pyplot as plt
# print(all_preds.shape)
# sample_index = np.random.randint(0, 100)

# sample_preds = all_preds[sample_index]
# sample_inputs = all_inputs[sample_index]
# sample_targets = all_targets[sample_index]

# fig, axes = plt.subplots(3, 10, figsize=(20, 6))

# for i, ax in enumerate(axes[0]):
#     ax.imshow(sample_inputs[i, 0, :, :], cmap='jet')
#     ax.axis('off')
#     ax.set_title(f'Input {i+1}')


# for i, ax in enumerate(axes[1]):
#     ax.imshow(sample_targets[i, 0, :, :], cmap='jet')
#     ax.axis('off')
#     ax.set_title(f'Target {i+1}')

# for i, ax in enumerate(axes[2]):
#     ax.imshow(sample_preds[i, 0, :, :], cmap='jet')
#     ax.axis('off')
#     ax.set_title(f'Pred {i+1}')

# plt.tight_layout()
# plt.show()