In [27]:
import matplotlib.pyplot as plt
import numpy as np
import os, sys; sys.path.append(os.path.join('../..'))
from timeit import default_timer

from models import FNO2d, FNO1d, FNF2d, FNF1d, FNN2d, FNN1d, DNN, FND1d, FND2d
from util import Adam
from util.utilities_module import LpLoss, LppLoss, count_params, validate, dataset_with_indices
from torch.utils.data import TensorDataset, DataLoader
import torch


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device is", device)


cnx1 = 50
cnx2 = 120
cy = 50


# prefix = "../Data/"
prefix = "/media/nnelsen/SharedHDD2TB/datasets/Geo-FNO/airfoil/naca/"
XC = np.load(prefix+"NACA_Cylinder_X.npy")
YC = np.load(prefix+"NACA_Cylinder_Y.npy")

Pressure = np.load(prefix+"NACA_Cylinder_Q.npy")[:, 3, :, :]
theta = np.load(prefix+"NACA_theta.npy")                    

Device is cuda


In [2]:
# compute drag and lift
def compute_F_coeff(XC, YC, p, cnx1 = 50, cnx2 = 120, cny = 50):
    xx, yy, p = XC[cnx1:-cnx1,0], YC[cnx1:-cnx1,0], p[cnx1:-cnx1,0]
     
    drag  = np.dot(yy[0:cnx2]-yy[1:cnx2+1], (p[0:cnx2] + p[1:cnx2+1])/2.0)
    lift  = np.dot(xx[1:cnx2+1]-xx[0:cnx2], (p[0:cnx2] + p[1:cnx2+1])/2.0)
    F = np.array([drag, lift])
    
    # F_ref = 0.5 rho_oo * u_oo^2 A
    rho_oo = 1.0
    A = 1.0
    u_oo   = 0.8*np.sqrt(1.4*1.0/1.0)
    F_ref  = 0.5*rho_oo*u_oo**2 * A
    
    return F/F_ref

    
n_data = theta.shape[0]
F_coeff = np.zeros((n_data, 2))

for i in range(n_data):
    F_coeff[i, :] = compute_F_coeff(XC[i, :, :], YC[i, :, :], Pressure[i, :, :])
   

batch_size = 20
learning_rate = 0.001
epochs = 1001
step_size = 100
gamma = 0.5


n_train = 1000 #250
n_test = 400

In [3]:
print(XC.shape, YC.shape, F_coeff.shape, theta.shape)

(2490, 221, 51) (2490, 221, 51) (2490, 2) (2490, 8)


# FNM0D ($R^m \rightarrow R^n$, latent space 1D)

In [4]:
# FNF
modes1 = 16
width = 64
s_latentspace=256
batch_size = 128
n_layers = 4
nonlinear_first=False

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[0:n_train, 1:], dtype=torch.float), torch.tensor(F_coeff[0:n_train, :], dtype=torch.float)), batch_size=batch_size,
                                           shuffle=True)                                 
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[-n_test:, 1:], dtype=torch.float), torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)


model = FNN1d(d_in=7, d_out=2, n_layers = n_layers, s_latentspace=s_latentspace, modes1=modes1, width=width,nonlinear_first=nonlinear_first).to(device)
print("FNM0D (1D latent space) #params : ", count_params(model))

FNM0D (1D latent space) #params :  599618


In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

loss_data = np.zeros((3, epochs))

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        _batch_size = x.shape[0]
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = myloss(out.view(_batch_size, -1), y.view(_batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            _batch_size = x.shape[0]
            x, y = x.to(device), y.to(device)
            out =  model(x)
            test_l2 += myloss(out.view(_batch_size, -1), y.view(_batch_size, -1)).item()
            
    train_l2 /= n_train
    test_l2  /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)

    loss_data[:, ep] = ep, train_l2, test_l2




0 0.6840702990011778 1.0764505157470703 0.9949059963226319
1 0.053657251002732664 0.9815802688598633 0.973803699016571
2 0.051592924995929934 0.9688782196044922 0.9667192792892456
3 0.05205939800362103 0.9673923034667968 0.9719081425666809
4 0.05386857800476719 0.9645298080444336 0.9633858156204224
5 0.06645544299681205 0.9681538162231446 0.9679714441299438
6 0.04918150599405635 0.9676561889648437 0.9630731534957886
7 0.04849678500613663 0.9648946075439453 0.9640014386177063
8 0.047956167996744625 0.9622657623291015 0.964152705669403
9 0.048616333995596506 0.9625171890258789 0.962591118812561
10 0.056559173986897804 0.9623003616333008 0.9619541239738464
11 0.04803472200001124 0.9625604705810547 0.9623198986053467
12 0.0478152629948454 0.9630263671875 0.9656809806823731
13 0.04906326199125033 0.9636554260253907 0.9628073787689209
14 0.05631389599875547 0.9612860412597656 0.9687408590316773
15 0.0473667199985357 0.9675287322998047 0.9643996405601502
16 0.04749281099066138 0.9612079849243

133 0.06039316899841651 0.16523183822631837 0.16860940873622896
134 0.058930055995006114 0.17040018463134765 0.17193758606910706
135 0.05263549699157011 0.1700702533721924 0.16967666536569595
136 0.04735414299648255 0.16424561691284179 0.17082450211048125
137 0.04740940799820237 0.1596133394241333 0.1901146340370178
138 0.04858331099967472 0.16597323417663573 0.1720765572786331
139 0.055847626994363964 0.16289443016052246 0.17917216062545777
140 0.04748599299637135 0.16246581554412842 0.16750799238681793
141 0.047675533991423436 0.16444933891296387 0.17468180060386657
142 0.054245148014160804 0.16329688453674315 0.16791570276021958
143 0.07087199100351427 0.1693087100982666 0.17759468406438828
144 0.06163628499780316 0.17248099327087402 0.16956380128860474
145 0.06451674099662341 0.16194283294677733 0.17160483121871947
146 0.06916685699252412 0.16087868118286133 0.16925136864185333
147 0.058242468992830254 0.16638480758666993 0.17461106836795806
148 0.047461092995945364 0.1609729843139

264 0.048138304002350196 0.15052704048156737 0.16099167943000794
265 0.04944365999836009 0.1485174570083618 0.16453336149454117
266 0.056404116010526195 0.14790299797058107 0.16338395595550537
267 0.047531636999337934 0.14960706520080566 0.17076061129570008
268 0.04854916699696332 0.1492221574783325 0.16607104897499084
269 0.050145456989412196 0.15027812576293945 0.16575902193784714
270 0.05746029999863822 0.14857364463806152 0.17703128665685652
271 0.04826081098872237 0.1506438789367676 0.1635724914073944
272 0.048322189002647065 0.15032095909118653 0.1602168071269989
273 0.050038764005876146 0.15093146705627442 0.17672167658805849
274 0.05806419799046125 0.15174656391143798 0.16210790812969209
275 0.049499562999699265 0.15218856143951415 0.16153755843639372
276 0.04966546400100924 0.1543472671508789 0.16514950275421142
277 0.050734313001157716 0.14977682018280028 0.1764268869161606
278 0.05730226398736704 0.15268416595458983 0.16942212700843812
279 0.047915127011947334 0.150356866836

393 0.049336942989612 0.15118790245056152 0.17476493060588838
394 0.05627580800501164 0.14872781562805176 0.16073133319616317
395 0.04749682899273466 0.14473535537719726 0.16288233041763306
396 0.04749559400079306 0.14514737510681153 0.17263093888759612
397 0.04742445101146586 0.14706201171875 0.15889493107795716
398 0.049046277010347694 0.14603295421600343 0.17502104133367538
399 0.05673214799026027 0.1464431858062744 0.15869269281625747
400 0.04746186299598776 0.14532829189300536 0.17011932224035264
401 0.04747827899700496 0.1453404655456543 0.15959871262311937
402 0.04739370199968107 0.145050931930542 0.16206579506397248
403 0.04897904999961611 0.1442418804168701 0.1638196086883545
404 0.05663845800154377 0.14513769435882568 0.16090729415416719
405 0.04746359400451183 0.14450196647644042 0.1621964943408966
406 0.04739777299982961 0.14435249519348145 0.16230326890945435
407 0.047355794988106936 0.14421555519104004 0.16436954379081725
408 0.04976211900066119 0.14473878002166748 0.1615

523 0.049074253998696804 0.14362325382232666 0.16208986818790436
524 0.0568499939981848 0.14366118240356446 0.16299823611974718
525 0.047531734002404846 0.14375927352905274 0.16376825273036957
526 0.04736206500092521 0.1434236650466919 0.16095506489276887
527 0.047458023007493466 0.14375810718536378 0.16120642185211181
528 0.04905411199433729 0.14370274162292482 0.16347744345664977
529 0.05655619199387729 0.14379435539245605 0.16082270741462706
530 0.04750322600011714 0.14359601402282715 0.16278299450874328
531 0.04742676300520543 0.14364091873168947 0.16391307473182679
532 0.047583610998117365 0.14354507446289064 0.1613112586736679
533 0.048971978001645766 0.14353080749511718 0.16315038055181502
534 0.056129858989152126 0.14376767253875733 0.16265789538621903
535 0.047454597006435506 0.14366584587097167 0.1610081547498703
536 0.047502030996838585 0.1439749479293823 0.1632133439183235
537 0.04747733699332457 0.1438436231613159 0.16177212774753572
538 0.04899506900983397 0.1434475421905

653 0.04915022799104918 0.14319426441192626 0.16282583236694337
654 0.05603240799973719 0.14326760673522948 0.16184038609266282
655 0.04767193499719724 0.14344047927856446 0.1626716697216034
656 0.04760746500687674 0.1430745334625244 0.16215938568115235
657 0.047521482993033715 0.14361643981933594 0.16082921385765075
658 0.04915560998779256 0.14310682487487794 0.16327590972185135
659 0.05672080800286494 0.14323748588562013 0.16246600568294525
660 0.04788667199318297 0.14327327156066894 0.1622420972585678
661 0.047883822000585496 0.14348433685302733 0.16141082376241683
662 0.048053081991383806 0.143961838722229 0.1647440469264984
663 0.049418013004469685 0.1434656867980957 0.160786714553833
664 0.05640895600663498 0.14349629402160644 0.1625462853908539
665 0.04805779499292839 0.1432693853378296 0.1621309131383896
666 0.047778633001144044 0.14315773010253907 0.16150676727294921
667 0.04939332000503782 0.14321214962005616 0.16228203654289244
668 0.05615020099503454 0.1431381607055664 0.16

782 0.04771214800712187 0.1430013132095337 0.1617571869492531
783 0.04871682499651797 0.143000018119812 0.16245993316173554
784 0.05635547400743235 0.14304798030853272 0.16208527892827987
785 0.047632710993639193 0.1430023193359375 0.16249681353569032
786 0.047792185010621324 0.14297355461120606 0.16203132510185242
787 0.04985572899749968 0.14302243328094483 0.1616622006893158
788 0.05629979800141882 0.14295797729492188 0.16228356540203095
789 0.047599443001672626 0.14310088729858397 0.1629789626598358
790 0.04753762899781577 0.1429315595626831 0.16190445333719253
791 0.04988521400082391 0.14299476051330567 0.1618807351589203
792 0.05642105299921241 0.14303167152404786 0.16207520842552184
793 0.04738966000149958 0.1429417953491211 0.16181144654750823
794 0.047576281998772174 0.14298357677459717 0.16233291149139403
795 0.049845809990074486 0.14298505687713622 0.16205225884914398
796 0.056372028004261665 0.14296099853515626 0.16190451979637147
797 0.04755982699862216 0.14301963138580323 

912 0.049091608001617715 0.14287854194641114 0.1622222089767456
913 0.05652804599958472 0.1428822660446167 0.16225486844778061
914 0.04760644100315403 0.1428816909790039 0.1622233971953392
915 0.04747181800485123 0.14287066841125487 0.1622602605819702
916 0.047573535004630685 0.14286908054351807 0.16216352909803392
917 0.049174111991305836 0.14288041400909424 0.16221894562244416
918 0.057075072996667586 0.14288202476501466 0.16220263183116912
919 0.04759431899583433 0.14288712310791016 0.16225357711315155
920 0.04742865500156768 0.14287746334075926 0.16214276790618898
921 0.04756732200621627 0.14288469123840333 0.1620493757724762
922 0.049210376004339196 0.1428818244934082 0.16203190594911576
923 0.0566019870020682 0.14288253116607666 0.16222165554761886
924 0.04751088400371373 0.1429118766784668 0.16244007557630538
925 0.04753213399089873 0.14288651466369628 0.16216354310512543
926 0.04759275499964133 0.14289655017852784 0.1621488732099533
927 0.04900471899600234 0.142918288230896 0.1

# FNM0D ($R^m \rightarrow R^n$, latent space 2D)

In [91]:
batch_size = 128

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[0:n_train, 1:], dtype=torch.float), torch.tensor(F_coeff[0:n_train, :], dtype=torch.float)), batch_size=batch_size,
                                           shuffle=True)                                 
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[-n_test:, 1:], dtype=torch.float), torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)

# FNF
modes1 = 12
modes2 = 12
width = 32

model = FNN2d(d_in=7, d_out=2, s_latentspace=(32,32), modes1=modes1, modes2=modes2, width=width, nonlinear_first=nonlinear_first).to(device)

print("FNM0D (2D latent space) #params : ", count_params(model))

FNM0D (2D latent space) #params :  2485474


In [92]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

loss_data = np.zeros((3, epochs))

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        _batch_size = x.shape[0]

        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(_batch_size, -1), y.view(_batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            _batch_size = x.shape[0]
            
            out =  model(x)
            test_l2 += myloss(out.view(_batch_size, -1), y.view(_batch_size, -1)).item()
            
    train_l2 /= n_train
    test_l2  /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)

    loss_data[:, ep] = ep, train_l2, test_l2
   

0 0.16436054999940097 0.9772943267822266 0.9833748602867126
1 0.13380335900001228 0.9679924011230469 0.9664535665512085
2 0.1246541339205578 0.9643018264770508 0.9633041620254517
3 0.13258995302021503 0.9624719619750977 0.968457818031311
4 0.12424277199897915 0.9628967514038086 0.9679471206665039
5 0.1323688980191946 0.9618406982421875 0.9658880233764648
6 0.12453161994926631 0.9635021438598633 0.965856773853302
7 0.13316319102887064 0.9618898315429687 0.9698789405822754
8 0.1244408639613539 0.9644494094848632 0.9663843536376953
9 0.13342329789884388 0.9624898071289063 0.9638825869560241
10 0.12480263004545122 0.9632035751342773 0.9670960521697998
11 0.13338705303613096 0.9622467346191407 0.962691068649292
12 0.12523801892530173 0.9621221694946289 0.9651953268051148
13 0.13344191201031208 0.9642826538085938 0.9677716183662415
14 0.12497337104286999 0.9649410400390624 0.9605616044998169
15 0.1332403909182176 0.9657206497192383 0.9690972781181335
16 0.12464364501647651 0.9620711364746094

136 0.12334779603406787 0.9613872528076172 0.9618182802200317
137 0.12473609496373683 0.9619893264770508 0.9633649349212646
138 0.12427901697810739 0.961352409362793 0.9626261115074157
139 0.12408055202104151 0.9618689956665039 0.9645348358154296
140 0.12431442399974912 0.9620199661254882 0.9623000860214234
141 0.12431963090784848 0.9614817047119141 0.9621399402618408
142 0.13055622996762395 0.9619256057739258 0.96587730884552
143 0.138720108079724 0.9617266311645508 0.9621605825424194
144 0.13862892996985465 0.9614283218383789 0.9640002727508545
145 0.12924051494337618 0.9625168762207031 0.9640305542945862
146 0.1494942760327831 0.9609120178222657 0.9619552254676819
147 0.1495072280522436 0.9617934875488281 0.9624018573760986
148 0.12553391105029732 0.9625567398071289 0.9666607046127319
149 0.1240007389569655 0.9620232009887695 0.9619471955299378
150 0.12537213403265923 0.9613862533569336 0.9631934213638306
151 0.13621617597527802 0.961818862915039 0.9638522696495057
152 0.14540593500

KeyboardInterrupt: 

# Fully connected neural network: $\theta \rightarrow F_{coeff}$

In [93]:
batch_size = 256

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[0:n_train, 1:], dtype=torch.float), torch.tensor(F_coeff[0:n_train, :], dtype=torch.float)), batch_size=batch_size,
                                           shuffle=True)                                 
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(theta[-n_test:, 1:], dtype=torch.float), torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)

szs = [7,32,32,32,2]
# szs = [7,64,64,64,2]
model = DNN(szs).to(device)
print("FNN #params : ", count_params(model))

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

loss_data = np.zeros((3, epochs))

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        _batch_size = x.shape[0]
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(_batch_size, -1), y.view(_batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            _batch_size = x.shape[0]
            x, y = x.to(device), y.to(device)
            out =  model(x)
            test_l2 += myloss(out.view(_batch_size, -1), y.view(_batch_size, -1)).item()
            
    train_l2 /= n_train
    test_l2  /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)

    loss_data[:, ep] = ep, train_l2, test_l2
   

FNN #params :  2434
0 0.01253644993994385 1.0787103881835938 1.077671890258789
1 0.012475225026719272 1.0103127593994141 1.0000753021240234
2 0.009758187923580408 0.969468505859375 0.9609344863891601
3 0.012952739954926074 0.9686070098876953 0.9662876510620118
4 0.009375931927934289 0.9756031036376953 0.9653215408325195
5 0.009172615013085306 0.9692238006591797 0.9609554672241211
6 0.017273338045924902 0.9620688323974609 0.9663556289672851
7 0.010036719031631947 0.9626688537597656 0.9724649047851562
8 0.009279936086386442 0.9634005737304687 0.9708606719970703
9 0.009892607922665775 0.9619302520751953 0.965996322631836
10 0.009093534084968269 0.9599649505615234 0.9613489532470703
11 0.009076304966583848 0.9604500122070313 0.9599333953857422
12 0.00909824890550226 0.9615067596435547 0.9603832244873047
13 0.00906923902221024 0.9597523345947265 0.9638791656494141
14 0.009069186053238809 0.9606455993652344 0.9639596176147461
15 0.009056785958819091 0.958857162475586 0.9625605392456055
16 0.

145 0.00924194697290659 0.16471956634521484 0.17733579158782958
146 0.012402379070408642 0.1641395797729492 0.179791202545166
147 0.00933538400568068 0.16538015747070312 0.17595766067504884
148 0.010493081994354725 0.16643105697631835 0.18156668186187744
149 0.00933422101661563 0.16983978652954101 0.17447239875793458
150 0.009120742091909051 0.17245608520507813 0.1792990255355835
151 0.015624846098944545 0.16972193145751954 0.17770919799804688
152 0.0097806939156726 0.16530996322631836 0.17443380355834961
153 0.009083189070224762 0.1640480728149414 0.18126265525817872
154 0.009057975024916232 0.16446641159057618 0.17404633045196533
155 0.009061726974323392 0.16577687072753905 0.18496581077575683
156 0.009091040003113449 0.16450586700439454 0.17411194801330565
157 0.009057980962097645 0.1661543426513672 0.18599197387695313
158 0.009073882945813239 0.16612582397460937 0.1742246913909912
159 0.00909409299492836 0.16641123580932618 0.1858174228668213
160 0.009091682964935899 0.166953590393

279 0.01394767896272242 0.16319393157958983 0.17634706020355226
280 0.010038654902018607 0.16407617950439454 0.17190966606140137
281 0.020040994975715876 0.16182236480712892 0.18002121925354003
282 0.00978925998788327 0.1628011703491211 0.17356390476226807
283 0.014213092043064535 0.16218254852294922 0.17420601844787598
284 0.015485561918467283 0.16301799392700195 0.178355655670166
285 0.014856732916086912 0.16225139617919923 0.17282384872436524
286 0.010792025947012007 0.16196783065795897 0.17798418045043946
287 0.009266547043807805 0.16188995361328126 0.17343650817871092
288 0.015368991065770388 0.16155712509155273 0.17554003715515137
289 0.009247799054719508 0.1614332046508789 0.17473522186279297
290 0.00916742894332856 0.16153717803955078 0.1749977207183838
291 0.01546251296531409 0.16147857666015625 0.17574442386627198
292 0.009083051001653075 0.16146368789672852 0.17432160854339598
293 0.009195256978273392 0.16153249740600586 0.17415326118469238
294 0.009084405959583819 0.1621113

419 0.009774680947884917 0.1610128059387207 0.1740172290802002
420 0.016051026061177254 0.16086574554443359 0.17389330863952637
421 0.014160018065012991 0.16086578369140625 0.17377769947052002
422 0.010990515002049506 0.16092778396606444 0.1742125129699707
423 0.016371665988117456 0.16098508071899414 0.1745924949645996
424 0.009387742960825562 0.16084550857543944 0.17391396522521974
425 0.00925156706944108 0.16098796844482421 0.17335860729217528
426 0.01633867807686329 0.16115771102905274 0.17474568367004395
427 0.00922335498034954 0.16084859848022462 0.1740433692932129
428 0.009215942933224142 0.16089306640625 0.17384867191314698
429 0.009217254002578557 0.16085784912109374 0.17392078399658203
430 0.009206553106196225 0.1608271484375 0.17416593551635742
431 0.00922196893952787 0.16094973373413085 0.1748926258087158
432 0.009180239983834326 0.16087599563598634 0.17414185523986817
433 0.00921262800693512 0.16083033752441406 0.1734307098388672
434 0.009212080971337855 0.16100265502929687

564 0.009432032937183976 0.16069067764282227 0.17402731418609618
565 0.013188331038691103 0.16078277587890624 0.174135422706604
566 0.009679391980171204 0.16067676162719727 0.17338720321655274
567 0.010168926091864705 0.1607890739440918 0.17343040943145752
568 0.009332309011369944 0.16073554229736328 0.17369571685791016
569 0.009270838927477598 0.1606505126953125 0.1741177558898926
570 0.015385368023999035 0.16077019119262695 0.17459790229797365
571 0.009200887056067586 0.16073082733154298 0.1742626667022705
572 0.009555739001370966 0.16074733352661133 0.17347821712493897
573 0.009610629989765584 0.16069789123535155 0.17346715927124023
574 0.00929926207754761 0.16067823791503907 0.1736888074874878
575 0.009316792944446206 0.16068954467773439 0.17408124923706056
576 0.009319365955889225 0.16068283462524413 0.17394968032836913
577 0.009270094917155802 0.1607094955444336 0.17393712043762208
578 0.009304231032729149 0.16063665390014648 0.17355928421020508
579 0.009281218983232975 0.1607364

704 0.009354202076792717 0.1606120948791504 0.17376691341400147
705 0.01117875799536705 0.16060764694213867 0.1737986660003662
706 0.011739476001821458 0.16060319137573242 0.17373281002044677
707 0.010463230079039931 0.16062281036376952 0.17383675575256347
708 0.009639707044698298 0.1606317138671875 0.1736881446838379
709 0.009299520985223353 0.16059771728515626 0.17374961853027343
710 0.017434786073863506 0.16061528778076173 0.17368316650390625
711 0.009885410079732537 0.16067844009399415 0.173925724029541
712 0.013461881899274886 0.16061004638671875 0.17380515098571778
713 0.014052802929654717 0.16060778427124023 0.17368893146514894
714 0.01013907603919506 0.1606273422241211 0.17384231567382813
715 0.01990808197297156 0.16060988616943359 0.17379636764526368
716 0.012749732006341219 0.16060702896118165 0.17378803253173827
717 0.01114800397772342 0.16061565780639647 0.17367269992828369
718 0.019861910957843065 0.1606024398803711 0.17364176750183105
719 0.012114960933104157 0.1606597518

836 0.019369599991478026 0.16057901000976563 0.17374344825744628
837 0.016557747032493353 0.16060700607299805 0.17363025665283202
838 0.012460650992579758 0.1605825653076172 0.1736668062210083
839 0.022333361906930804 0.1605815124511719 0.17371126174926757
840 0.010438028955832124 0.16061252975463866 0.17379909992218018
841 0.01026643009390682 0.1605809211730957 0.17373992443084718
842 0.015597349032759666 0.16057965850830078 0.17369853019714354
843 0.009457463049329817 0.16058431243896484 0.17370279312133788
844 0.009346632985398173 0.16059598541259765 0.17363048553466798
845 0.016980889020487666 0.16060153198242189 0.17357380867004393
846 0.009314983966760337 0.16058236312866211 0.1736318826675415
847 0.009279469028115273 0.16057863235473632 0.173719539642334
848 0.009272458963096142 0.160607364654541 0.17383822441101074
849 0.009278464945964515 0.16058688354492187 0.17380928993225098
850 0.00931924197357148 0.16058356475830077 0.17376988410949706
851 0.009314659982919693 0.160581249

965 0.009959659073501825 0.16057721710205078 0.17373371124267578
966 0.010985867003910244 0.16056869506835938 0.17372252464294433
967 0.011075105983763933 0.16057276153564454 0.17367896556854248
968 0.010708182002417743 0.1605714912414551 0.173691987991333
969 0.009365261998027563 0.16056783294677734 0.1736874485015869
970 0.009305092971771955 0.16056784057617188 0.17368433475494385
971 0.01689576799981296 0.1605687370300293 0.17369858741760255
972 0.010608393931761384 0.16056990432739257 0.17370466709136964
973 0.01284518395550549 0.16057658004760741 0.17372886657714845
974 0.01514900999609381 0.16057431411743164 0.17367419242858886
975 0.014269881998188794 0.16057403182983399 0.17364141464233399
976 0.016546929953619838 0.16057176208496093 0.17365092277526856
977 0.014233967987820506 0.1605697937011719 0.17366745948791504
978 0.01691652601584792 0.16056884765625 0.1736907958984375
979 0.01350875897333026 0.16056647872924804 0.1737233257293701
980 0.014572415966540575 0.16056919479370

# FNM (1D $\rightarrow R^n$)

In [98]:
r1 = 1
r2 = 1

cnx1 = 50
cnx2 = 120
cny = 50

batch_size = 64

input_data  = torch.stack([torch.tensor(XC[:,cnx1:-cnx1,0], dtype=torch.float), torch.tensor(YC[:,cnx1:-cnx1,0], dtype=torch.float)], dim=-1)
output_data = torch.tensor(F_coeff, dtype=torch.float)

input_data = input_data.permute(0,2,1)

r1 = r2 = 1
x_train = input_data[:n_train,  :, :] 
y_train = output_data[:n_train, :] 
x_test  = input_data[-n_test:,  :, :] 
y_test  = output_data[-n_test:, :]



train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)
print(x_train.shape,y_train.shape)

# FNF
modes1 = 12
width = 32
model = FNF1d(modes1=modes1, width=width, d_in=2, d_out=2).to(device)
print("FNM1D #params : ", count_params(model))

torch.Size([1000, 2, 121]) torch.Size([1000, 2])
FNM1D #params :  129154


In [99]:

optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        _batch_size = x.shape[0]
        optimizer.zero_grad()
        out = model(x)
        loss = myloss(out.view(_batch_size, -1), y.view(_batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            _batch_size = x.shape[0]
            out = model(x)
            test_l2 += myloss(out.view(_batch_size, -1), y.view(_batch_size, -1)).item()
            
            
    train_l2 /= n_train
    test_l2 /= n_test
   
    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)




0 0.16734354302752763 0.987732307434082 0.9646367835998535
1 0.14060463092755526 0.9715389823913574 0.9738849782943726
2 0.14122604194562882 0.9707960662841797 0.9686512184143067
3 0.13637726206798106 0.9627419395446777 0.9673435616493226
4 0.16509815398603678 0.9648503799438477 0.9762897777557373
5 0.1326198980677873 0.9678789978027343 0.9656484317779541
6 0.106199000030756 0.9730558700561523 0.9615264511108399
7 0.1060594889568165 0.9639936332702637 0.9668790173530578
8 0.1009813459822908 0.9624527969360351 0.966039776802063
9 0.11112901801243424 0.9629422721862793 0.9689901328086853
10 0.09577560902107507 0.9640664329528809 0.9700339651107788
11 0.10957609699107707 0.9658569755554199 0.9611167669296264
12 0.09545094799250364 0.9630511856079101 1.002035880088806
13 0.10980944300536066 0.9719897232055664 0.9607142066955566
14 0.09510254603810608 0.9652768783569335 0.9775474214553833
15 0.10777762706857175 0.9673100128173828 0.9712821340560913
16 0.09514045994728804 0.968341911315918 0

134 0.09570322697982192 0.13599125385284425 0.1373429489135742
135 0.10525812802370638 0.12246507549285889 0.11310829550027847
136 0.09536411403678358 0.1238144669532776 0.12164003610610961
137 0.10875741601921618 0.11915045166015625 0.11456546813249588
138 0.10010670393239707 0.11564577293395996 0.11611554086208344
139 0.11336805706378073 0.11349274110794068 0.12303289741277695
140 0.10703448601998389 0.13971601152420043 0.14528441727161406
141 0.11817433102987707 0.12143093585968018 0.13016355335712432
142 0.12329027999658138 0.10850148391723632 0.11392236411571503
143 0.1137188239954412 0.11044713020324708 0.13998566150665284
144 0.11645715998020023 0.14562784075737 0.14604117214679718
145 0.11210548796225339 0.1288865170478821 0.1168709021806717
146 0.0945176100358367 0.11347709751129151 0.14483636736869812
147 0.10829501692205667 0.12094431257247924 0.15171657085418702
148 0.0943588960217312 0.12348167109489441 0.11373471081256867
149 0.10775974998250604 0.10493070125579834 0.1119

265 0.10199226601980627 0.09513275957107543 0.09982583224773407
266 0.10668522899504751 0.09476108813285827 0.10467712819576264
267 0.11158793093636632 0.09135641551017762 0.10155368626117706
268 0.09608510602265596 0.09861391234397889 0.12014741271734238
269 0.09330624900758266 0.09829413104057312 0.1050865113735199
270 0.09616769792046398 0.1109815411567688 0.18958449602127075
271 0.0990276129450649 0.11818044257164001 0.14427328288555144
272 0.10187010304071009 0.10870126914978027 0.11631370067596436
273 0.10690271505154669 0.09945558714866638 0.10474965453147889
274 0.10845389403402805 0.09255493927001954 0.09905475124716759
275 0.11272378102876246 0.09152830505371094 0.10366320133209228
276 0.11824526800774038 0.10598154258728028 0.097563988417387
277 0.11264977604150772 0.10090970087051392 0.09916060924530029
278 0.11591155896894634 0.09702985787391663 0.12096094042062759
279 0.11425818898715079 0.09512250566482544 0.09884703040122986
280 0.1215003440156579 0.10187948179244995 0.

394 0.10945204203017056 0.0735175495147705 0.09178543180227279
395 0.0944339670240879 0.07452891302108765 0.07908704370260239
396 0.09408899303525686 0.07275181221961975 0.08677739500999451
397 0.10790799208916724 0.0746357901096344 0.07922840774059296
398 0.09440499299671501 0.0700799298286438 0.07935787379741668
399 0.09434188203886151 0.0714209496974945 0.08089149922132492
400 0.10610145307146013 0.07332325172424316 0.08194489598274231
401 0.09444507502485067 0.0738268165588379 0.07820388376712799
402 0.09542149305343628 0.0696736512184143 0.07821240276098251
403 0.10807087505236268 0.06984223127365112 0.08246791929006576
404 0.09448439592961222 0.07143877506256104 0.079213205575943
405 0.09944159898441285 0.06940615200996399 0.07864136576652526
406 0.1329371629981324 0.06983187651634216 0.07881556272506714
407 0.13312670798040926 0.07036797618865967 0.07973578065633774
408 0.10551766201388091 0.07239233899116516 0.07819590747356414
409 0.10346809006296098 0.06993594002723694 0.0776

524 0.1369394469074905 0.06656652522087098 0.07522201776504517
525 0.12016416597180068 0.06598582792282104 0.07551455408334733
526 0.09855573205277324 0.06623088192939758 0.07736121356487274
527 0.12765586900059134 0.06717182183265687 0.07628883212804795
528 0.1374735439894721 0.06658027458190918 0.07873175919055939
529 0.10242389107588679 0.06716368889808655 0.07935721695423126
530 0.0975355509435758 0.06726167821884155 0.07529845237731933
531 0.1344055139925331 0.06717169976234436 0.0748468029499054
532 0.13812714908272028 0.0659248685836792 0.07445603907108307
533 0.10663973307237029 0.06632762002944946 0.07905122101306915
534 0.11067338893190026 0.06639634799957275 0.07980513215065002
535 0.14255792391486466 0.07014114880561828 0.08442993193864823
536 0.13031576899811625 0.06967639064788818 0.07803461968898773
537 0.09848672593943775 0.06741694927215576 0.07472883820533753
538 0.12927314406260848 0.06662536787986756 0.07499634444713593
539 0.13781418895814568 0.06578347873687744 0.

654 0.0988677330315113 0.06413650035858154 0.07408621370792388
655 0.1091869220836088 0.0650265986919403 0.07415493309497834
656 0.0973011499736458 0.06473237299919128 0.07485154449939728
657 0.10873544809874147 0.06554813909530639 0.07640239715576172
658 0.09676014096476138 0.06406102681159973 0.07387466192245483
659 0.10833187005482614 0.06480866026878357 0.07384783565998078
660 0.09701647004112601 0.06551933598518371 0.07543526411056518
661 0.10882919898722321 0.06504711937904357 0.074112868309021
662 0.09954467497300357 0.06474848103523255 0.07426764845848083
663 0.10896160395350307 0.06461706495285034 0.07412221133708954
664 0.10041043092496693 0.06476575136184692 0.07488530039787293
665 0.10843491600826383 0.06523103165626526 0.07584881693124772
666 0.09991130698472261 0.06448538112640381 0.07400847375392913
667 0.11051866097841412 0.06431512808799744 0.07348765850067139
668 0.10022471006959677 0.06401126599311828 0.07458082914352417
669 0.10708841902669519 0.06465715432167053 0.

784 0.12116411689203233 0.06367706394195556 0.0742495396733284
785 0.12956934608519077 0.0637117235660553 0.07340554177761077
786 0.10543772601522505 0.06330558705329895 0.07359595835208893
787 0.10978461999911815 0.06337515997886657 0.07342579692602158
788 0.09456519596278667 0.0636035499572754 0.07352654486894608
789 0.1076002970803529 0.06348512434959412 0.07359283536672592
790 0.09645713993813843 0.06347052812576294 0.07344749957323074
791 0.10897320206277072 0.06360742402076722 0.0738428807258606
792 0.09830584598239511 0.06370644354820251 0.0734507179260254
793 0.11621620308142155 0.06417337846755981 0.07354623079299927
794 0.09723183594178408 0.06344507837295532 0.07405238062143325
795 0.10961318504996598 0.06386658358573914 0.07344010323286057
796 0.09760525601450354 0.0635975182056427 0.07437959432601929
797 0.10935091599822044 0.06378691291809083 0.07324659436941147
798 0.09569640399422497 0.06391581010818481 0.07395718574523925
799 0.10884684801567346 0.06354681897163392 0.0

913 0.09817916201427579 0.06304029703140258 0.07312093019485473
914 0.1080372539581731 0.06309312796592713 0.0731552791595459
915 0.09865826193708926 0.06307971787452697 0.07308723568916321
916 0.10627703997306526 0.0630090663433075 0.07313595831394196
917 0.09418869507499039 0.06294324851036072 0.07308776795864105
918 0.10997399606276304 0.06302151954174041 0.0730792161822319
919 0.09438628901261836 0.06312320327758789 0.07307617396116256
920 0.09422553307376802 0.06297001934051513 0.07305067062377929
921 0.10797097499016672 0.06299688816070556 0.07305419653654098
922 0.09465349197853357 0.06294814443588256 0.07307933449745178
923 0.10280052805319428 0.06295479011535644 0.07302781313657761
924 0.10747813002672046 0.06301315474510193 0.0731487312912941
925 0.09754909202456474 0.06306011962890624 0.07309191465377808
926 0.10299390705768019 0.06296452164649963 0.07308751314878464
927 0.10840579902287573 0.06303803312778473 0.07311391443014145
928 0.10449472803156823 0.06302479743957519 0

# FNM (2D $\rightarrow R^n$)

In [104]:
input_data  = torch.stack([torch.tensor(XC, dtype=torch.float), torch.tensor(YC, dtype=torch.float)], dim=-1)
output_data = torch.tensor(F_coeff, dtype=torch.float)

input_data = input_data.permute(0,3,1,2)


batch_size = 32

r1 = r2 = 1
x_train = input_data[:n_train, ::r1, ::r2, :] 
y_train = output_data[:n_train, :] 
x_test  = input_data[-n_test:, ::r1, ::r2, :] 
y_test  = output_data[-n_test:, :]


train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)

print(x_train.shape)


# FNF
modes1 = 12
modes2 = 12
width = 32
################################################################
# training and evaluation
################################################################
model = FNF2d(modes1, modes2, width, d_in=2, d_out=2).to(device)
print("FNM2D #params : ", count_params(model))

torch.Size([1000, 2, 221, 51])
FNM2D #params :  2437282


In [105]:
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        _batch_size = x.shape[0]
        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(_batch_size, -1), y.view(_batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            _batch_size = x.shape[0]
            x, y = x.to(device), y.to(device)
            out = model(x)
            test_l2 += myloss(out.view(_batch_size, -1), y.view(_batch_size, -1)).item()
            
    train_l2 /= n_train
    test_l2 /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)


0 2.519204652053304 0.9857515277862549 0.9622814416885376
1 2.5242341450648382 0.9792099409103393 0.993147165775299
2 2.5772896870039403 0.9648603119850159 0.9859591245651245
3 2.4612223679432645 0.9683745450973511 0.9617404961585998
4 2.4309598829131573 0.966803159236908 0.9771071290969848
5 2.4291684520430863 0.9674914083480834 0.9847726488113403
6 2.4956229880917817 0.9654786243438721 0.9648423767089844
7 2.50905702507589 0.9751085147857667 0.9800618886947632
8 2.428600223036483 0.9711348152160645 0.9627724313735961
9 2.527794894995168 0.9730864944458008 0.9719726133346558
10 2.5862422649515793 0.9657096815109253 0.9697159194946289
11 2.4290900570340455 0.9641968183517456 0.9627891969680786
12 2.602275384007953 0.9685567216873169 0.9672888565063477
13 2.5553762710187584 0.9642606830596924 0.9781955218315125
14 2.460706577054225 0.9699424314498901 0.9776313066482544
15 2.45587472000625 0.9673001108169555 0.9675318241119385
16 2.4295533050317317 0.9706518869400025 0.9732275867462158
1

KeyboardInterrupt: 

# FNO (2D $\rightarrow$ 2D)


In [14]:
# compute drag and lift
def torch_compute_F_coeff(XYC, p, cnx1 = 50, cnx2 = 120, cny = 50):
    xx, yy, p = XYC[:,cnx1:-cnx1,0,0], XYC[:,cnx1:-cnx1,0,1], p[:,cnx1:-cnx1,0,0]

    drag  = torch.einsum('ik,ik->i', (yy[:,0:cnx2]-yy[:,1:cnx2+1]), (p[:,0:cnx2] + p[:,1:cnx2+1])/2.0)
    lift  = torch.einsum('ik,ik->i', (xx[:,1:cnx2+1]-xx[:,0:cnx2]), (p[:,0:cnx2] + p[:,1:cnx2+1])/2.0)
    F = torch.column_stack((drag, lift))
    
    # F_ref = 0.5 rho_oo * u_oo^2 A
    rho_oo = 1.0
    A = 1.0
    u_oo   = 0.8*np.sqrt(1.4*1.0/1.0)
    F_ref  = 0.5*rho_oo*u_oo**2 * A
    
    return F/F_ref

################################################################
# load data and data normalization
################################################################
modes = 12
width = 32

r1 = 1
r2 = 1

input_data  = torch.stack([torch.tensor(XC, dtype=torch.float), torch.tensor(YC, dtype=torch.float)], dim=-1)
output_data = torch.tensor(Pressure, dtype=torch.float)
output_data_preprocess = torch.tensor(F_coeff, dtype=torch.float)


r1 = r2 = 1
x_train = input_data[:n_train, ::r1, ::r2, :] 
y_train = output_data[:n_train, ::r1, ::r2] 
x_test  = input_data[-n_test:, ::r1, ::r2, :] 
y_test  = output_data[-n_test:, ::r1, ::r2]



train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)
force_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)

# FNO
modes1 = 12
modes2 = 12
width = 32
s_outputspace = None
################################################################
# training and evaluation
################################################################
model = FNO2d(modes1, modes2, width, s_outputspace, d_in=2, d_out=1).to(device)
print("FNO2D #parmas: ", count_params(model))

FNO2D #parmas:  2367969


In [15]:
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)

        loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    force_test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            
        for x, y in force_test_loader:
            x, y = x.to(device), y.to(device)
            out =  model(x)
            force = torch_compute_F_coeff(x, out)
            force_test_l2 += myloss(force.view(batch_size, -1), y.view(batch_size, -1)).item()

            
    train_l2 /= n_train
    test_l2 /= n_test
    force_test_l2 /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2, force_test_l2)




RuntimeError: shape '[64, -1]' is invalid for input of size 653718

# FNO (1D $\rightarrow$ 1D)

In [19]:
# compute drag and lift
def torch_compute_F_coeff(xyc, p, cnx1 = 50, cnx2 = 120, cny = 50):
    xx, yy, p = xyc[:,0,:], xyc[:,1,:], p[:,0,:]
    
    drag  = torch.einsum('ik,ik->i', (yy[:,0:cnx2]-yy[:,1:cnx2+1]), (p[:,0:cnx2] + p[:,1:cnx2+1])/2.0)
    lift  = torch.einsum('ik,ik->i', (xx[:,1:cnx2+1]-xx[:,0:cnx2]), (p[:,0:cnx2] + p[:,1:cnx2+1])/2.0)
    F = torch.column_stack((drag, lift))
    
    # F_ref = 0.5 rho_oo * u_oo^2 A
    rho_oo = 1.0
    A = 1.0
    u_oo   = 0.8*np.sqrt(1.4*1.0/1.0)
    F_ref  = 0.5*rho_oo*u_oo**2 * A
    
    return F/F_ref

################################################################
# load data and data normalization
################################################################
batch_size = 32

r1 = 1
r2 = 1

cnx1 = 50
cnx2 = 120
cny = 50

input_data  = torch.stack([torch.tensor(XC[:,cnx1:-cnx1,0], dtype=torch.float), torch.tensor(YC[:,cnx1:-cnx1,0], dtype=torch.float)], dim=-1)
output_data = torch.tensor(Pressure[:,cnx1:-cnx1,0], dtype=torch.float)
output_data_preprocess = torch.tensor(F_coeff, dtype=torch.float)

input_data = input_data.permute(0,2,1)

r1 = r2 = 1
x_train = input_data[:n_train,  :, :] 
y_train = output_data[:n_train, :] 
x_test  = input_data[-n_test:,  :, :] 
y_test  = output_data[-n_test:, :]



train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)
force_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)

print(x_train.shape,y_train.shape, output_data_preprocess.shape, batch_size)


# FNO
modes1 = 12
width = 32
s_outputspace = None
################################################################
# training and evaluation
################################################################
model = FNO1d(modes1, width, s_outputspace, d_in=2, d_out=1).to(device)
print("FNO1D #parmas: ", count_params(model))

torch.Size([1000, 2, 121]) torch.Size([1000, 121]) torch.Size([2490, 2]) 32
FNO1D #parmas:  107009


In [26]:
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = myloss(out, y)
#         loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    force_test_l2 = 0.0
    with torch.no_grad():
#         for x, y in test_loader:
#             x, y = x.to(device), y.to(device)
#             out = model(x)
#             test_l2 += myloss(out, y).item()
# #             test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            
#         for x, y in force_test_loader:
#             x, y = x.to(device), y.to(device)
#             out =  model(x)
#             force = torch_compute_F_coeff(x, out)
#             force_test_l2 += myloss(force, y).item()
# #             force_test_l2 += myloss(force.view(batch_size, -1), y.view(batch_size, -1)).item()

        for (xy, fxfy) in zip(test_loader, force_test_loader):
            x, y = xy
            _, f = fxfy
            x, y, f = x.to(device), y.to(device), f.to(device)
            out = model(x)
            test_l2 += myloss(out, y).item()
            force = torch_compute_F_coeff(x, out)
            force_test_l2 += myloss(force, f).item()

            
    train_l2 /= n_train
    test_l2 /= n_test
    force_test_l2 /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2, force_test_l2)




0 0.2541494140023133 0.053761429846286776 0.0321498641371727 0.10786643385887146
1 0.1435694190004142 0.024485072195529936 0.02309550382196903 0.06975386187434196
2 0.1571024760050932 0.02292741133272648 0.03503973960876465 0.2647298049926758
3 0.1436250449914951 0.02248505875468254 0.03744513645768165 0.20545204639434814
4 0.15459505100443494 0.02164475552737713 0.02509940929710865 0.19933870792388916
5 0.14354261600237805 0.021711419209837914 0.02188177056610584 0.08282350182533264
6 0.170968467995408 0.016464761838316917 0.02398195631802082 0.13951799929141998
7 0.14708824100671336 0.016883624903857707 0.02583765283226967 0.09814885377883911
8 0.16644436201022472 0.017828315049409866 0.02253212884068489 0.08772081643342972
9 0.1493269599886844 0.02441179096698761 0.025112350061535834 0.10263323128223419
10 0.17471530799230095 0.017781514286994933 0.020874824225902557 0.07169460624456406
11 0.17683354999462608 0.018002871558070184 0.022478236258029936 0.13971574366092682
12 0.1957945

99 0.2061421729886206 0.014817210584878922 0.01853541724383831 0.11076139986515045
100 0.2245507159968838 0.011533446043729783 0.015453255772590636 0.0639218170940876
101 0.1803236289997585 0.011424385763704777 0.015935209318995475 0.07544055998325348
102 0.1650690440001199 0.011448189213871956 0.014077361598610878 0.06355500847101211
103 0.1704839829908451 0.01022281700372696 0.013675259873270989 0.049500246345996854
104 0.18200892000459135 0.01083687374740839 0.01395881611853838 0.045870780050754546
105 0.20803797199914698 0.010617672599852085 0.013780624568462373 0.049141202867031095
106 0.22054804100480396 0.009905524007976054 0.017622163519263268 0.10532263994216919
107 0.16663911700015888 0.0108448371514678 0.01710538886487484 0.09622622132301331
108 0.1684878850064706 0.011901195272803306 0.016124160066246986 0.06611480489373207
109 0.1686168779997388 0.01090238930284977 0.015248648039996624 0.054586066007614134
110 0.15207402600208297 0.011083685040473938 0.01745644748210907 0.

197 0.16390714100270998 0.01221104408055544 0.013295087106525898 0.04490462005138397
198 0.17831696900248062 0.01119177969545126 0.015062116757035256 0.07774864509701729
199 0.1979851950018201 0.01115748219192028 0.015090944468975067 0.05292420014739037
200 0.22947093300172128 0.00902847684174776 0.01272857528179884 0.0449276439845562
201 0.20903205499053001 0.00820735737681389 0.012487045973539352 0.04257190808653832
202 0.21619742699840572 0.008368585474789143 0.013961264416575431 0.0649372974038124
203 0.14670833600393962 0.009034357979893685 0.012545819729566574 0.04781666904687881
204 0.14868763599952217 0.008865170247852803 0.012625975608825684 0.043092683255672455
205 0.14638876399840228 0.008110161885619164 0.012593756690621375 0.04223308339715004
206 0.15132735000224784 0.008113115191459656 0.012707768827676774 0.04891161233186722
207 0.1453037289902568 0.008312413141131401 0.012403251454234123 0.042063880413770675
208 0.14733685800456442 0.008219641722738744 0.012476080209016

294 0.16880124699673615 0.008040307287126779 0.013194323368370533 0.0507808655500412
295 0.17448422600864433 0.007935645520687103 0.012284024730324744 0.04315495803952217
296 0.18564479300403036 0.008107494037598372 0.012422912642359734 0.05157548666000366
297 0.18311151099624112 0.008772484429180622 0.012442441210150718 0.04108775705099106
298 0.16074116399977356 0.007887514166533946 0.012320363596081733 0.04397599086165428
299 0.14496449200669304 0.00785576181486249 0.01193316537886858 0.04260484859347344
300 0.1504094050033018 0.007193402260541916 0.011618967838585377 0.03939818352460861
301 0.14736776798963547 0.007144858412444591 0.011800434850156307 0.045337064117193224
302 0.15289041800133418 0.007362317759543657 0.012302692830562591 0.04832807749509811
303 0.16113681100250687 0.007081707514822483 0.011732704266905784 0.03946731761097908
304 0.18509807800001 0.007299295570701361 0.011993713565170766 0.04394831985235214
305 0.20417882499168627 0.007194657638669014 0.0119505329802

KeyboardInterrupt: 

# FNM ( $\mathbb{R}^n \rightarrow$ 1D)

In [23]:
print("test", n_train, n_test)

test 1000 400


In [35]:
################################################################
# load data and data normalization
################################################################
batch_size = 32

r1 = 1
r2 = 1

cnx1 = 50
cnx2 = 120
cny = 50


input_data  = torch.tensor(theta[..., 1:], dtype=torch.float)
output_data = torch.tensor(Pressure[:,cnx1:-cnx1,0], dtype=torch.float)
output_data_preprocess = torch.tensor(F_coeff, dtype=torch.float)


r1 = r2 = 1
x_train = input_data[:n_train,  :] 
y_train = output_data[:n_train, :] 
x_test  = input_data[-n_test:,  :] 
y_test  = output_data[-n_test:, :]

x_test_force = torch.stack([torch.tensor(XC[:,cnx1:-cnx1,0], dtype=torch.float), torch.tensor(YC[:,cnx1:-cnx1,0], dtype=torch.float)], dim=-1).permute(0,2,1)[-n_test:,:,:]


train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)
force_test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_force, torch.tensor(F_coeff[-n_test:, :], dtype=torch.float)), batch_size=batch_size,
                                          shuffle=False)

print(x_train.shape,y_train.shape, x_test_force.shape)


# FNO
modes1 = 12
width = 32
s_outputspace = None
################################################################
# training and evaluation
################################################################
model = FND1d(x_train.shape[-1], y_train.shape[-1], modes1, width).to(device)
print("FND1D #parms: ", count_params(model))

torch.Size([1000, 7]) torch.Size([1000, 121]) torch.Size([400, 2, 121])
FND1D #parms:  113025


In [37]:
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = myloss(out, y)
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    force_test_l2 = 0.0
    with torch.no_grad():
        for (xy, fxfy) in zip(test_loader, force_test_loader):
            x, y = xy
            xf, f = fxfy
            x, y, xf, f = x.to(device), y.to(device), xf.to(device), f.to(device)
            out = model(x)
            test_l2 += myloss(out, y).item()
            force = torch_compute_F_coeff(xf, out)
            force_test_l2 += myloss(force, f).item()

            
    train_l2 /= n_train
    test_l2 /= n_test
    force_test_l2 /= n_test

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2, force_test_l2)

0 0.285212381000747 0.21120629000663757 0.1934337204694748 0.9808246493339539
1 0.1337898430065252 0.19559205102920532 0.1937147396802902 1.0639258337020874
2 0.14826083999651019 0.19433857214450836 0.19159241318702697 1.017140624523163
3 0.13357333300518803 0.19365377628803254 0.19025884687900543 0.9550221228599548
4 0.14782183400529902 0.19211223292350768 0.18814390480518342 0.9812829732894898
5 0.13414225399901625 0.18954140830039978 0.18813357830047608 1.0248295879364013
6 0.15681708000192884 0.18879542219638826 0.1862078106403351 0.9902427744865417
7 0.1369492070080014 0.18837529230117797 0.18699317276477814 0.9881100988388062
8 0.14823067300312687 0.18791635751724242 0.18672664165496827 0.9635186338424683
9 0.14222098200116307 0.1873065311908722 0.18475012838840485 1.000235652923584
10 0.17165130900684744 0.18685393834114075 0.18413114428520203 0.9506078004837036
11 0.1807921889994759 0.18612444305419923 0.18378519177436828 1.0358399057388306
12 0.2202010749897454 0.1854563333988

100 0.1637400549952872 0.041361848056316375 0.04200110882520676 0.11232524871826172
101 0.13472599700617138 0.04042341607809067 0.04139821290969849 0.10163583487272262
102 0.1558732669946039 0.04022328448295593 0.04222980231046677 0.11331417262554169
103 0.14081329900363926 0.040127609193325045 0.042347884476184844 0.14191751539707184
104 0.16847208500257693 0.03995635604858398 0.04092529386281967 0.10463450878858566
105 0.17736728899762966 0.03936978393793106 0.04008064836263656 0.10289043426513672
106 0.20302427900605835 0.03885286128520966 0.03965743601322174 0.10235872268676757
107 0.15696565799589735 0.038752481341362 0.04070866048336029 0.13813451439142227
108 0.15569508500630036 0.03865527758002281 0.03920517086982727 0.10655901551246644
109 0.14393446000758559 0.03811205071210861 0.03905785322189331 0.10904264211654663
110 0.1478982729895506 0.03766370669007301 0.039055967032909394 0.10926628798246384
111 0.1422164889954729 0.037367963433265684 0.03899911522865295 0.10040148049

198 0.1914609879895579 0.026624391108751297 0.028663170412182807 0.10755163490772247
199 0.21943746799661312 0.026840561628341676 0.028488469794392585 0.09835070133209228
200 0.16972130499198101 0.02553358696401119 0.02668317496776581 0.09551004260778427
201 0.1543451159959659 0.025347057670354844 0.02700918883085251 0.09341023355722428
202 0.14575563299877103 0.02532624328136444 0.027129443287849428 0.08877576500177384
203 0.145033560009324 0.025307377010583876 0.02671598583459854 0.09033508211374283
204 0.1578773670044029 0.02530533653497696 0.027237383574247362 0.09272601187229157
205 0.17760055301187094 0.0257781041264534 0.02739767372608185 0.09977829366922379
206 0.17216629200265743 0.025261697471141816 0.02651710659265518 0.08955466777086257
207 0.16493137000361457 0.025343031242489815 0.027513737082481383 0.09328834235668182
208 0.14326394400268327 0.025522722631692887 0.026817223727703093 0.09433021694421768
209 0.13805073300318327 0.02517577764391899 0.02714217446744442 0.088

295 0.1559833010105649 0.023139019429683685 0.025081625431776045 0.09206177920103073
296 0.1444749289948959 0.0234414199590683 0.0259340725839138 0.09590116560459137
297 0.14165325499197934 0.023324150204658507 0.0252185420691967 0.0844648939371109
298 0.1351492570101982 0.023186672657728197 0.024694772213697435 0.08597306668758392
299 0.1395560799865052 0.022800326511263846 0.02506051704287529 0.08798267751932144
300 0.14070252599776722 0.02282713857293129 0.02448842130601406 0.0872376149892807
301 0.15308870599255897 0.022554246366024017 0.02536960244178772 0.09779767960309982
302 0.15869044499413576 0.02268389278650284 0.024636113196611405 0.08745556414127349
303 0.15787984199414495 0.022489822208881377 0.024453263208270074 0.08889654815196991
304 0.15055971599940676 0.022420092791318895 0.024414157271385194 0.08627009689807892
305 0.15184691300964914 0.022511165797710418 0.024486532732844353 0.08456520348787308
306 0.14933201800158713 0.022377839028835295 0.02434107430279255 0.0849

392 0.16237143200123683 0.021621006295084955 0.023709773868322372 0.08062734752893448
393 0.16783863700402435 0.021506889313459395 0.023629285395145416 0.08210808783769608
394 0.14930972199363168 0.021607812583446503 0.02404606096446514 0.08788613766431809
395 0.16019064099236857 0.021558520086109638 0.024024604111909865 0.08265852987766266
396 0.14930966698739212 0.02175434610247612 0.02386498063802719 0.08200856149196625
397 0.1353408229915658 0.021647732734680177 0.02360066719353199 0.08228611230850219
398 0.14987182499316987 0.021586060643196105 0.023738040328025817 0.07866148263216019
399 0.14360730799671728 0.021869403839111328 0.024211894497275354 0.08143712788820266
400 0.1493275859975256 0.02143440407514572 0.02345206379890442 0.0801150444149971
401 0.16525077100959606 0.021280276104807854 0.02343967393040657 0.07994054079055786
402 0.18713575800938997 0.021195071935653686 0.02342372477054596 0.08082470655441285
403 0.1328622470027767 0.02126605460047722 0.023443420678377153 0

489 0.21144876201287843 0.0207453134059906 0.023011276200413704 0.07842201113700867
490 0.18578704800165724 0.02078836804628372 0.0230996523052454 0.07909952998161315
491 0.18091231799917296 0.02079218429327011 0.02301281288266182 0.078430215716362
492 0.20668484699854162 0.020758651196956635 0.022967724502086638 0.07879174202680587
493 0.23288322100415826 0.020784450948238373 0.02305672138929367 0.07719303280115128
494 0.2654512790031731 0.020795231223106385 0.022932993471622466 0.07661387890577316
495 0.1455224039964378 0.020856549754738808 0.023038063049316406 0.07783291012048721
496 0.13864024898794014 0.020745143502950668 0.02302366465330124 0.07994665265083313
497 0.14122014099848457 0.020879130184650423 0.023109099119901656 0.07891511082649232
498 0.14129220499307849 0.020821971774101257 0.02303364545106888 0.07601275354623795
499 0.13863827000022866 0.02079410594701767 0.023108764141798018 0.07830144435167313
500 0.1452850589994341 0.020656853206455706 0.022885275036096574 0.07

586 0.13732014800189063 0.020428976684808732 0.022797378450632094 0.07693365871906281
587 0.13520105500356294 0.020415125876665115 0.022736806347966196 0.07552491575479507
588 0.13718736000009812 0.020446453332901002 0.02269511103630066 0.07588735044002533
589 0.14430631200957578 0.020445968076586725 0.022753665670752526 0.07725430011749268
590 0.14366629300639033 0.02041027541458607 0.022721638306975364 0.07546713024377823
591 0.13559076699311845 0.020389782294631006 0.022619586363434793 0.07623890966176987
592 0.13729869200324174 0.02040524271130562 0.022817732691764833 0.07744575262069703
593 0.1352910479909042 0.020428408026695252 0.022760076373815538 0.07532872080802917
594 0.13675635500112548 0.020393781334161757 0.02261297047138214 0.07639591604471206
595 0.1351139019971015 0.0204094041287899 0.022732499241828918 0.07576611518859863
596 0.13637185499828774 0.020385212644934654 0.022671079859137536 0.07540064603090287
597 0.13521434100402985 0.02038355666399002 0.0227810685336589

684 0.13537224399624392 0.02016808070242405 0.02254867844283581 0.07546028912067414
685 0.1332377970102243 0.02017120972275734 0.022568385899066925 0.07487248480319977
686 0.13531286599754822 0.020181933104991914 0.022563586458563805 0.07539875954389572
687 0.1331879250064958 0.02019579264521599 0.02252268575131893 0.0749311849474907
688 0.1374262499884935 0.020186274379491807 0.022547787502408028 0.07517036348581314
689 0.15585270900919568 0.020186183467507363 0.022509504109621048 0.0752699014544487
690 0.15429249599401373 0.020173560246825217 0.022527685612440108 0.0754172995686531
691 0.15685014599876013 0.020186655759811403 0.02261232480406761 0.07562089651823044
692 0.1576937630015891 0.02018862546980381 0.022530355378985404 0.07524571150541305
693 0.1508926169917686 0.020182866603136063 0.022534557804465293 0.0752885264158249
694 0.1632809989969246 0.020191090375185013 0.02251900054514408 0.07622346222400665
695 0.1849182519945316 0.02019790029525757 0.022514559626579285 0.075354

781 0.13715510300244205 0.02007721309363842 0.022501915767788885 0.07520038336515426
782 0.13752258200838696 0.020065623551607133 0.022478202134370805 0.07518347203731537
783 0.13985357400088105 0.02008266603946686 0.022471060529351234 0.07499074399471282
784 0.1381093779928051 0.020064317002892493 0.022482106164097786 0.07525596499443055
785 0.14773225600947626 0.02007528781890869 0.022454034239053726 0.07484476685523987
786 0.16840529600449372 0.020070373862981796 0.02247364491224289 0.07504562437534332
787 0.17617559600330424 0.020065369755029677 0.022487289309501647 0.07514080822467804
788 0.21204181799839716 0.02007003979384899 0.02250359982252121 0.07471180886030197
789 0.1960357380012283 0.020062065809965134 0.022505536079406738 0.07577604919672012
790 0.14870535499358084 0.02007058946788311 0.02249988354742527 0.07492238610982895
791 0.1825671889964724 0.020060689955949784 0.022491456791758537 0.07511581271886826
792 0.15707537598791532 0.020063382416963576 0.022478257715702058

878 0.15977820300031453 0.02001817286014557 0.02244497075676918 0.07495136350393296
879 0.13596515198878478 0.020011847525835038 0.022448131740093233 0.07494827568531036
880 0.13891999601037242 0.02001964071393013 0.022454024702310563 0.07486263632774354
881 0.18732081100461073 0.020014953047037126 0.022434334754943847 0.07481310844421386
882 0.18350749899400398 0.020005801022052766 0.022440498471260072 0.07488380074501037
883 0.13783065999450628 0.020008627325296402 0.022457935288548468 0.0751477313041687
884 0.13872356999490876 0.020011230379343034 0.02244805611670017 0.07499263703823089
885 0.14933533099247143 0.02001784121990204 0.02243284672498703 0.07499257564544677
886 0.14041782400454395 0.020006002247333527 0.02243438094854355 0.07476327478885651
887 0.13987729800282978 0.02000916349887848 0.022456386238336564 0.07503292977809906
888 0.1434762909921119 0.020008312314748766 0.022427773252129555 0.07493233352899552
889 0.14458874400588684 0.020008257001638413 0.02243872292339802

976 0.1340640540001914 0.019979197800159453 0.02242550626397133 0.07488423824310303
977 0.1341104040038772 0.019978947877883912 0.022426208555698393 0.0749564129114151
978 0.13700426698778756 0.01997981996834278 0.02242514632642269 0.07498975425958633
979 0.13291186500282492 0.019981910407543182 0.022421549260616302 0.07492153286933899
980 0.13971706500160508 0.019977909594774246 0.02242998093366623 0.07497051686048507
981 0.13526379900577012 0.019976912170648575 0.022428828403353692 0.07491267323493958
982 0.14208638000127394 0.01997976478934288 0.022419779002666472 0.07478352665901183
983 0.13400217299931683 0.01998854298889637 0.022408171370625497 0.07482787311077117
984 0.14295618300093338 0.019974798381328582 0.02243887759745121 0.07484065502882004
985 0.1575882350007305 0.019978568598628045 0.02242933914065361 0.07488178849220276
986 0.14167862800240982 0.019976584121584894 0.022423175871372224 0.07488572031259537
987 0.1392514180042781 0.019975629955530167 0.022427247166633607 0