In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import scipy

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, activation=nn.ReLU()):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, mid_channels, kernel_size=1),
            activation,
            nn.Conv1d(mid_channels, out_channels, kernel_size=1),
        )

    def forward(self, x):
        return self.conv(x)

class FFTBlock(nn.Module):
    def __init__(self, in_channels, out_channels, modes, activation=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes
        weights = torch.rand(in_channels, out_channels, modes).cfloat() 
        scale = 1 / in_channels / out_channels
        self.weights = nn.Parameter(weights * scale)
        self.conv = ConvBlock(out_channels, out_channels, out_channels, activation=nn.GELU())
        self.skip = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.activation = activation

    def forward(self, x):
        x_ft = torch.fft.rfft(x)
        out_ft = torch.zeros(x.shape[0], self.out_channels, x.shape[-1]//2 + 1).cfloat().to(x.device)
        out_ft[:, :, :self.modes] = torch.einsum("bix,iox->box", x_ft[:, :, :self.modes], self.weights)
        x1 = torch.fft.irfft(out_ft, n=x.shape[-1])
        x1 = self.conv(x1)
        x2 = self.skip(x)
        return (x1 + x2) if self.activation is None else self.activation(x1 + x2)

class FNO1d(nn.Module):
    def __init__(self, modes, width, activation=nn.ReLU()):
        super().__init__()
        self.padding = 8 # pad the domain if input is non-periodic

        self.p_layer = nn.Linear(2, width)
        self.fft_block0 = FFTBlock(width, width, modes, activation=activation)
        self.fft_block1 = FFTBlock(width, width, modes, activation=activation)
        self.fft_block2 = FFTBlock(width, width, modes, activation=activation)
        self.fft_block3 = FFTBlock(width, width, modes, activation=None)
        self.q_layer = ConvBlock(width, 1, width*2, activation=nn.ReLU())

    def forward(self, x):
        grid = torch.linspace(0, 1, x.shape[1]).repeat([x.shape[0], 1]).unsqueeze(-1).to(x.device)
        x = torch.cat((x, grid), dim=-1)        # (batch_size, n_points, 2)
        x = self.p_layer(x).permute(0, 2, 1)    # (batch_sdize, width, n_points)
        # x = F.pad(x, [0, self.padding])       # pad the domain if input is non-periodic

        x = self.fft_block0(x)
        x = self.fft_block1(x)
        x = self.fft_block2(x)
        x = self.fft_block3(x)

        # x = x[..., :-self.padding]            # pad the domain if input is non-periodic
        x = self.q_layer(x).permute(0, 2, 1)
        return x

In [3]:
from torch.utils.data import DataLoader, TensorDataset

## Data
n_train = 1000
n_test = 100

sub_sampling = 2**3          # subsampling rate
batch_size = 64

data = scipy.io.loadmat("data/burgers_data_R10.mat")
x_data_np = data["a"]   # a(x): initial condition
y_data_np = data["u"]   # u(x): PDE solution

x_data = torch.tensor(x_data_np).float()[:, ::sub_sampling]
y_data = torch.tensor(y_data_np).float()[:, ::sub_sampling]

x_train = x_data[:n_train].unsqueeze(-1)
y_train = y_data[:n_train].unsqueeze(-1)
x_test = x_data[-n_test:].unsqueeze(-1)
y_test = y_data[-n_test:].unsqueeze(-1)
print(x_train.shape, x_test.shape)

train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

torch.Size([1000, 1024, 1]) torch.Size([100, 1024, 1])
