In [1]:
"""
@author: Zongyi Li
This file is the Fourier Neural Operator for 2D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf),
which uses a recurrent structure to propagates in time.
"""


import torch
import numpy as np
import gc
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from utilities3 import *

import operator
from functools import reduce
from functools import partial

from timeit import default_timer
import scipy.io

gc.collect()
torch.cuda.empty_cache()
torch.manual_seed(0)
np.random.seed(0)

#Complex multiplication
def compl_mul2d(a, b):
    op = partial(torch.einsum, "bctq,dctq->bdtq")
    return torch.stack([
        op(a[..., 0], b[..., 0]) - op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) + op(a[..., 0], b[..., 1])
    ], dim=-1)

################################################################
# fourier layer
################################################################

class SpectralConv2d_fast(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d_fast, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2))

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.rfft(x, 2, normalized=True, onesided=True)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.in_channels, x.size(-2), x.size(-1)//2 + 1, 2, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=(x.size(-2), x.size(-1)))
        return x

class SimpleBlock2d(nn.Module):
    def __init__(self, modes1, modes2, width):
        super(SimpleBlock2d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)
        input shape: (batchsize, x=64, y=64, c=12)
        output: the solution of the next timestep
        output shape: (batchsize, x=64, y=64, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.fc0 = nn.Linear(12, self.width)
        # input channel is 12: the solution of the previous 10 timesteps + 2 locations (u(t-10, x, y), ..., u(t-1, x, y),  x, y)

        self.conv0 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.conv3 = SpectralConv2d_fast(self.width, self.width, self.modes1, self.modes2)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)
        self.bn0 = torch.nn.BatchNorm2d(self.width)
        self.bn1 = torch.nn.BatchNorm2d(self.width)
        self.bn2 = torch.nn.BatchNorm2d(self.width)
        self.bn3 = torch.nn.BatchNorm2d(self.width)


        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        batchsize = x.shape[0]
        size_x, size_y = x.shape[1], x.shape[2]

        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)

        x1 = self.conv0(x)
        x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn0(x1 + x2)
        x = F.relu(x)
        x1 = self.conv1(x)
        x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn1(x1 + x2)
        x = F.relu(x)
        x1 = self.conv2(x)
        x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn2(x1 + x2)
        x = F.relu(x)
        x1 = self.conv3(x)
        x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y)
        x = self.bn3(x1 + x2)


        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class Net2d(nn.Module):
    def __init__(self, modes, width):
        super(Net2d, self).__init__()

        """
        A wrapper function
        """

        self.conv1 = SimpleBlock2d(modes, modes, width)


    def forward(self, x):
        x = self.conv1(x)
        return x


    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))

        return c


################################################################
# configs
################################################################
#TRAIN_PATH = 'data/NavierStokes_V1e-5_N1200_T20.mat'
#TEST_PATH = 'data/NavierStokes_V1e-5_N1200_T20.mat'
TRAIN_PATH = 'data/Vortex_dynamics_64_64_grid.mat'
TEST_PATH = 'data/Vortex_dynamics_64_64_grid.mat'

ntrain = 1000
ntest = 200

modes = 12
width = 20

batch_size = 1
batch_size2 = batch_size

epochs = 500
learning_rate = 0.0025
scheduler_step = 100
scheduler_gamma = 0.5

print(epochs, learning_rate, scheduler_step, scheduler_gamma)

path = 'ns_fourier_2d_rnn_V10000_T20_N'+str(ntrain)+'_ep' + str(epochs) + '_m' + str(modes) + '_w' + str(width)
path_model = 'model/'+path
path_train_err = 'results/'+path+'train.txt'
path_test_err = 'results/'+path+'test.txt'
path_image = 'image/'+path

runtime = np.zeros(2, )
t1 = default_timer()

sub = 1
S = 64
T_in = 10
T = 2
step = 1

################################################################
# load data
################################################################

reader = MatReader(TRAIN_PATH)
train_a = reader.read_field('u')[:ntrain,::sub,::sub,:T_in]
train_u = reader.read_field('u')[:ntrain,::sub,::sub,T_in:T+T_in]

modes = reader.read_field('Modetensorabridged')

reader = MatReader(TEST_PATH)
test_a = reader.read_field('u')[-ntest:,::sub,::sub,:T_in]
test_u = reader.read_field('u')[-ntest:,::sub,::sub,T_in:T+T_in]

print(train_u.shape)
print(test_u.shape)
assert (S == train_u.shape[-2])
assert (T == train_u.shape[-1])

train_a = train_a.reshape(ntrain,S,S,T_in)
test_a = test_a.reshape(ntest,S,S,T_in)

# pad the location (x,y)
gridx = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridx = gridx.reshape(1, S, 1, 1).repeat([1, 1, S, 1])
gridy = torch.tensor(np.linspace(0, 1, S), dtype=torch.float)
gridy = gridy.reshape(1, 1, S, 1).repeat([1, S, 1, 1])

train_a = torch.cat((gridx.repeat([ntrain,1,1,1]), gridy.repeat([ntrain,1,1,1]), train_a), dim=-1)
test_a = torch.cat((gridx.repeat([ntest,1,1,1]), gridy.repeat([ntest,1,1,1]), test_a), dim=-1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=batch_size, shuffle=False)

t2 = default_timer()

print('preprocessing finished, time used:', t2-t1)
device = torch.device('cuda')

################################################################
# training and evaluation
################################################################

model = Net2d(modes, width).cuda()
# model = torch.load('model/ns_fourier_V100_N1000_ep100_m8_w20')

print(model.count_params())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)


myloss = LpLoss(size_average=False)
gridx = gridx.to(device)
gridy = gridy.to(device)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2_step = 0
    train_l2_full = 0
    for xx, yy in train_loader:
        loss = 0
        xx = xx.to(device)
        yy = yy.to(device)

        for t in range(0, T, step):
            y = yy[..., t:t + step]
            im = model(xx)
            loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

            if t == 0:
                pred = im
            else:
                pred = torch.cat((pred, im), -1)

            xx = torch.cat((xx[..., step:-2], im,
                            gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1)

        train_l2_step += loss.item()
        l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1))
        train_l2_full += l2_full.item()

        optimizer.zero_grad()
        loss.backward()
        # l2_full.backward()
        optimizer.step()

    test_l2_step = 0
    test_l2_full = 0
    with torch.no_grad():
        for xx, yy in test_loader:
            loss = 0
            xx = xx.to(device)
            yy = yy.to(device)

            for t in range(0, T, step):
                y = yy[..., t:t + step]
                im = model(xx)
                loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1))

                if t == 0:
                    pred = im
                else:
                    pred = torch.cat((pred, im), -1)

                xx = torch.cat((xx[..., step:-2], im,
                                gridx.repeat([batch_size, 1, 1, 1]), gridy.repeat([batch_size, 1, 1, 1])), dim=-1)


            test_l2_step += loss.item()
            test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item()

    t2 = default_timer()
    scheduler.step()
    print(ep, t2 - t1, train_l2_step / ntrain / (T / step), train_l2_full / ntrain, test_l2_step / ntest / (T / step),
          test_l2_full / ntest)
    
torch.save(model, path_model)


pred = torch.zeros(test_u.shape)
index = 0
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False)
with torch.no_grad():
     for x, y in test_loader:
         test_l2 = 0;
         x, y = x.cuda(), y.cuda()

         out = model(x)
         out = y_normalizer.decode(out)
         pred[index] = out

         test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
         print(index, test_l2)
         index = index + 1

scipy.io.savemat('pred/'+path+'.mat', mdict={'pred': pred.cpu().numpy()})


500 0.0025 100 0.5
torch.Size([1000, 64, 64, 2])
torch.Size([200, 64, 64, 2])
preprocessing finished, time used: 17.848776034006733
926517


  x_ft = torch.rfft(x, 2, normalized=True, onesided=True)
  x = torch.irfft(out_ft, 2, normalized=True, onesided=True, signal_sizes=(x.size(-2), x.size(-1)))


0 40.30181629200524 0.15195404033362866 0.1534847814925015 0.18532634247094393 0.1879283130913973
1 37.768681269997614 0.09257938171736896 0.093930448429659 0.19621368784457446 0.20083853511139751
2 38.340476108001894 0.07909611418657005 0.08035548323951662 0.15545687023550273 0.15818860834464432
3 38.53105846600374 0.07394480232708156 0.07532105411402881 0.12190016409382225 0.12285329280421138
4 38.922719450012664 0.0679126758929342 0.06916237312741577 0.11893218696117401 0.12081456452608108
5 40.1269338009879 0.06773151951655745 0.0690442645996809 0.15337530113756656 0.15517105186358093
6 38.07031911899685 0.062467125833034515 0.06372603782825172 0.13082848897203803 0.13192732591181994
7 38.1772986970027 0.05980601353012025 0.06099431964661926 0.13600995594635606 0.13735650323331355
8 38.627788452999084 0.057235793277621266 0.05837741803005338 0.1002057046815753 0.10187120993621647
9 37.83480367201264 0.05776745788380504 0.05899507644213736 0.1046176533959806 0.1063590706139803
10 35

81 37.77166063900222 0.0331540653295815 0.03400484789349139 0.06983676979318261 0.07347140084952115
82 37.88113692600746 0.03320329341758042 0.034085401161573825 0.07031118498183787 0.07414906612597406
83 37.94149020900659 0.03283234870899469 0.03372477998584509 0.06871005768887699 0.07257300942204893
84 37.74588328199752 0.03298021431826055 0.033859852323774246 0.07390320308506489 0.07729255553334952
85 34.08302597900911 0.03450994278583676 0.03538823889940977 0.06957835726439952 0.07334734609350563
86 37.39043876499636 0.032474637245759365 0.03331810601521283 0.07105123563669621 0.07363605263642967
87 37.75613068300299 0.033231713806279005 0.034076931741088626 0.07521517585963011 0.0784935836866498
88 37.66727789500146 0.0325579370111227 0.03339468791987747 0.06448936807923018 0.06790241499431432
89 37.96183521700732 0.032545828469097614 0.03342990148440003 0.06653055410832166 0.0697534792125225
90 38.686560599002405 0.03186016974505037 0.03276857077330351 0.06438972802832722 0.06798

161 37.99483574599435 0.022297966711688787 0.02288541356381029 0.05541401349939406 0.058329702354967594
162 38.17882819000806 0.0218577189729549 0.022441416333429515 0.05735689737368375 0.060323863229714336
163 38.32262857499882 0.02225218509323895 0.02286179817095399 0.05770019092131406 0.06031518123112619
164 38.6477690069878 0.021882315090391784 0.022493863785173745 0.054411659128963945 0.0575623282417655
165 38.48711149599694 0.022017861200962215 0.022604506806470453 0.05632419831585139 0.05927390816621483
166 38.570795442996314 0.022126265715807676 0.022734963460825385 0.05625644093379378 0.05917652796022594
167 38.356090591987595 0.021807012697216122 0.02237436453672126 0.055644183033145965 0.05852049422450364
168 39.7668663039949 0.02214305458171293 0.022726885804906487 0.054521206384524706 0.05741831302642822
169 40.369370650005294 0.021608782428782433 0.02219988899677992 0.05599273135885596 0.05896750946063548
170 38.95971862199076 0.021841272704303263 0.02242560574784875 0.05

240 37.56333586599794 0.015203140083700419 0.015614557089284062 0.04954386357218027 0.051877580471336844
241 37.735169077001046 0.015475406141020357 0.015865154921542852 0.05052788988221437 0.052839747541584076
242 37.6194890600018 0.015607377959415317 0.016024288162589072 0.050724620022810994 0.05302715890109539
243 37.94680406799307 0.015605079185217618 0.01603695209417492 0.05024752852041274 0.0525517117837444
244 37.681819359000656 0.015512653038371354 0.01589118398563005 0.04938664698973298 0.051733399727381765
245 37.51299787800235 0.015722881925292313 0.01613391580292955 0.04968333681114018 0.05204035886563361
246 37.5734483779961 0.01572993179410696 0.016131482211174445 0.049940204499289396 0.05229064258746803
247 37.68979750700237 0.015612662099767477 0.016009932402987034 0.050574143631383774 0.05300748202949762
248 37.48537980800029 0.015976466255728156 0.016377731259446592 0.047869607810862365 0.05005187680013478
249 37.5619136759924 0.015576572511577978 0.015970747417537495

319 37.441661805991316 0.011575568135362119 0.011842807365581394 0.045790949366055426 0.047463819300755855
320 37.45280774199637 0.011681953271385282 0.011954784823581576 0.04588805203326046 0.04765141657087952
321 37.35219816099561 0.011811963323503732 0.01208187802741304 0.0472404720261693 0.04897529166657478
322 37.54948165800306 0.01175030776602216 0.01202919317339547 0.046343565792776646 0.047972823209129274
323 36.655121431002044 0.011758755278075114 0.012030958950519562 0.04529581900220364 0.04699986042454839
324 37.17967567899905 0.011641712277662008 0.011919186796527356 0.045950199477374554 0.04759293383918703
325 37.577082326999516 0.011706781913060694 0.011970596643164754 0.04652690504211932 0.048195860972628
326 37.50193164499069 0.011675486621214077 0.011960330336820334 0.045114258299581704 0.04681364716496319
327 37.405144017000566 0.011672837302321569 0.011938574668485672 0.04445018329191953 0.04611197818536311
328 37.48708146699937 0.011528158620465547 0.011797690200386

398 37.4997773340001 0.011491305708419532 0.011776332583278417 0.042243880801834166 0.04362450065556914
399 37.63056781400519 0.011413158765295521 0.011693040956277401 0.04234240132384002 0.04374241141602397
400 37.5410479830025 0.00942485629604198 0.009615845763357356 0.042596731879748405 0.04393608387093991
401 37.5860365289991 0.009358518529916182 0.009542685653083026 0.043068398493342104 0.04440975723322481
402 37.340352370010805 0.009326255577383563 0.00951489898446016 0.04280638671480119 0.0441350009245798
403 37.55376783700194 0.009384810558985919 0.009578738931100816 0.04198799685575068 0.043347430774010716
404 37.317179638994276 0.0092888173032552 0.009476377652958036 0.04328293187078089 0.04460384063888341
405 38.063680585997645 0.009280548786045983 0.00947114520217292 0.04337089715059847 0.044717081123963
406 38.135504789999686 0.009346276646712795 0.009534614796983079 0.04231380494311452 0.04364928556140512
407 37.93574082799023 0.009142969794105739 0.009323594520101324 0.0

477 37.79613288800465 0.008857273531612008 0.00905586438975297 0.04267853506375104 0.043928272402845324
478 37.62502300000051 0.008767531829653308 0.008954925556201488 0.04287892064545304 0.044108998496085405
479 37.26595963000727 0.008950260508339853 0.009149293887894601 0.043067218181677164 0.04431315880268812
480 37.654457276003086 0.008866188909392805 0.009063081626431085 0.04327726666815579 0.044503894853405654
481 37.50488031699206 0.008863316526869312 0.009050029962789268 0.04318904687184841 0.044406088166870174
482 37.664560727993376 0.008840830168221146 0.009033685553236864 0.04196514209266752 0.04317535526119173
483 37.62037392199272 0.008763918036827818 0.0089472558114212 0.04301288319285959 0.044211823889054355
484 37.429969189994154 0.008767901436425746 0.008956682387040928 0.04252258236519992 0.04376926614437252
485 37.80207029799931 0.008871357984142379 0.00906225029216148 0.04280390663072467 0.044043415170162914
486 37.996363782003755 0.008806250913534314 0.009000250886

FileNotFoundError: [Errno 2] No such file or directory: 'model/ns_fourier_2d_rnn_V10000_T20_N1000_ep500_m12_w20'