In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from scipy import integrate
from scipy import linalg
from scipy import interpolate
from sklearn import gaussian_process as gp
from sklearn.gaussian_process import GaussianProcessRegressor
from mpl_toolkits.mplot3d import Axes3D
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import operator
from functools import reduce
from functools import partial
from timeit import default_timer
from utilities3 import *
from Adam import Adam
from fourier_2d import *
from util import *

In [None]:
class GPR:

    def __init__(self, optimize=True):
        self.is_fit = False
        self.train_X, self.train_y = None, None
        self.params = {"l": 0.2, "sigma_f": 1}
        self.optimize = optimize

    def fit(self, X, y):
        # store train data
        self.train_X = np.asarray(X)
        self.train_y = np.asarray(y)
        self.is_fit = True

    def predict(self, X, samples):
        if not self.is_fit:
            print("GPR Model not fit yet.")
            return

        X = np.asarray(X)
        Kff = self.kernel(self.train_X, self.train_X)  # (N, N)
        Kyy = self.kernel(X, X)  # (k, k)
        Kfy = self.kernel(self.train_X, X)  # (N, k)
        Kff_inv = np.linalg.inv(Kff + 1e-8 * np.eye(len(self.train_X)))  # (N, N)
        
        mu = Kfy.T.dot(Kff_inv).dot(self.train_y)
        cov = Kyy - Kfy.T.dot(Kff_inv).dot(Kfy)
        
        gp_samples = np.random.multivariate_normal(
            mean=mu.ravel(), 
            cov=cov, 
            size=samples)
        return gp_samples

    def kernel(self, x1, x2):
        dist_matrix = np.sum(x1**2, 1).reshape(-1, 1) + np.sum(x2**2, 1) - 2 * np.dot(x1, x2.T)
        return self.params["sigma_f"] ** 2 * np.exp(-0.5 / self.params["l"] ** 2 * dist_matrix)
    
dim = 51
grid = np.linspace(0,1,dim)
X, Y = np.meshgrid(grid, grid)
res = np.zeros((4*(dim-1),2))
for i in range(dim-1):
    res[i,0] = 0
    res[i,1] = grid[i]
for i in range(dim-1):
    res[dim-1+i,0] = grid[i]
    res[dim-1+i,1] = 1
for i in range(dim-1):
    res[(dim-1)*2+i,0] = 1
    res[(dim-1)*2+i,1] = grid[i+1]
for i in range(dim-1):
    res[(dim-1)*3+i,0] = grid[i+1]
    res[(dim-1)*3+i,1] = 0

z0 = np.zeros(4*(dim-1))
gpr = GPR(optimize=False)
gpr.fit(res, z0)
res = np.zeros((dim**2,2))
for i in range(dim):
    for j in range(dim):
        res[i*dim+j,0] = grid[i]
        res[i*dim+j,1] = grid[j]
z = gpr.predict(res,1000)
z = z.reshape((z.shape[0],X.shape[0],X.shape[1]))

plt.pcolormesh(Y, X, z[0], shading='auto')
plt.colorbar()
plt.show()

def solve(u0):
    N = u0.shape[0]
    dimx = u0.shape[2]
    dimy = u0.shape[1]
    hx = 1/(dimx-1)
    hy = 1/(dimy-1)
    T = 2
    ht = 1/100
    u = np.zeros((N,T,dimy,dimx))
    for i in range(N):
        u[i,0,:,:] = u0[i,:,:]
        for j in range(1,T):
            u[i,j,:,:] = u0[i,:,:]
            for k in range(1,dimy-1):
                for l in range(1,dimx-1):
                    u[i,j,k,l] = u[i,j-1,k,l]+ht*((u[i,j-1,k-1,l]+u[i,j-1,k+1,l]-2*u[i,j-1,k,l])/hy/hy+(u[i,j-1,k,l-1]+u[i,j-1,k,l+1]-2*u[i,j-1,k,l])/hx/hx)
    return u

def solve0(u0):
    N = u0.shape[0]
    dimx = u0.shape[2]
    dimy = u0.shape[1]
    hx = 1/(dimx-1)
    hy = 1/(dimy-1)
    T = 2
    ht = 1/100
    u = np.zeros((N,T,dimy,dimx))
    for i in range(N):
        u[i,0,:,:] = u0[i,:,:]
        for j in range(1,T):
            u[i,j,:,:] = u0[i,:,:]
            for k in range(1,dimy-1):
                for l in range(1,dimx-1):
                    u[i,j,k,l] = u[i,j-1,k,l]+ht*((u[i,j-1,k-1,l]+u[i,j-1,k+1,l]-2*u[i,j-1,k,l])/hy/hy+(u[i,j-1,k,l-1]+u[i,j-1,k,l+1]-2*u[i,j-1,k,l])/hx/hx+1)
    return u
u = solve(z)
plt.pcolormesh(Y, X, u[0,0], shading='auto')
plt.colorbar()
plt.show()
plt.pcolormesh(Y, X, u[0,-1], shading='auto')
plt.colorbar()
plt.show()
u0 = solve0(z)
plt.pcolormesh(Y, X, u0[0,0], shading='auto')
plt.colorbar()
plt.show()
plt.pcolormesh(Y, X, u0[0,-1], shading='auto')
plt.colorbar()
plt.show()

In [40]:
x_grid = np.linspace(0, 1, u.shape[-1])
y_grid = x_grid

ntrain = 1000
ntest = 100

batch_size = 20
learning_rate = 0.001
epochs = 600
step_size = 100
gamma = 0.5
modes1 = 12
modes2 = 12
width = 16

f_train = torch.Tensor(z)
u_train = torch.Tensor(u[:,-1,:,:])
f_train = torch.reshape(f_train,(f_train.shape[0],f_train.shape[1],f_train.shape[2],1))
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(f_train, u_train), batch_size=batch_size, shuffle=True)

model = FNO2d(modes1, modes2, width).cuda()
print('Total parameters:',count_params(model))

optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

start = default_timer()

MSE = torch.zeros(epochs)
L2 = torch.zeros(epochs)

myloss = LpLoss(size_average=False)
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse = 0
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        out = model(x)
        mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
        l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2.backward()
        
        optimizer.step()
        train_mse += mse.item()
        train_l2 += l2.item()
        
    scheduler.step()
    
    train_mse /= len(train_loader)
    train_l2 /= ntrain
    t2 = default_timer()
    
    MSE[ep] = train_mse
    L2[ep] = train_l2
    print('\repoch {:d}/{:d} L2 = {:.6f}, MSE = {:.6f}, using {:.6f}s'.format(ep+1,epochs,train_l2,train_mse,t2-t1), end='', flush=True)

print('Total training time:',default_timer()-start,'s')

Total parameters: 592129
epoch 1/800 L2 = 0.896381, MSE = 0.140787, using 3.082539s
epoch 2/800 L2 = 0.512259, MSE = 0.040863, using 0.514900s
epoch 3/800 L2 = 0.386295, MSE = 0.021401, using 0.498771s
epoch 4/800 L2 = 0.248561, MSE = 0.008900, using 0.529106s
epoch 5/800 L2 = 0.189453, MSE = 0.005131, using 0.510045s
epoch 6/800 L2 = 0.155853, MSE = 0.003480, using 0.491968s
epoch 7/800 L2 = 0.136065, MSE = 0.002654, using 0.493360s
epoch 8/800 L2 = 0.122151, MSE = 0.002132, using 0.498883s
epoch 9/800 L2 = 0.112241, MSE = 0.001798, using 0.496598s
epoch 10/800 L2 = 0.105716, MSE = 0.001594, using 0.550238s
epoch 11/800 L2 = 0.098448, MSE = 0.001382, using 0.495199s
epoch 12/800 L2 = 0.093598, MSE = 0.001247, using 0.572611s
epoch 13/800 L2 = 0.089100, MSE = 0.001129, using 0.516145s
epoch 14/800 L2 = 0.084491, MSE = 0.001016, using 0.493396s
epoch 15/800 L2 = 0.083336, MSE = 0.000985, using 0.495825s
epoch 16/800 L2 = 0.079432, MSE = 0.000898, using 0.490315s
epoch 17/800 L2 = 0.0742

epoch 137/800 L2 = 0.025157, MSE = 0.000090, using 0.490926s
epoch 138/800 L2 = 0.024502, MSE = 0.000086, using 0.492633s
epoch 139/800 L2 = 0.026083, MSE = 0.000096, using 0.496109s
epoch 140/800 L2 = 0.025548, MSE = 0.000092, using 0.494271s
epoch 141/800 L2 = 0.024968, MSE = 0.000089, using 0.495342s
epoch 142/800 L2 = 0.024446, MSE = 0.000085, using 0.491420s
epoch 143/800 L2 = 0.023750, MSE = 0.000080, using 0.495661s
epoch 144/800 L2 = 0.023839, MSE = 0.000081, using 0.491438s
epoch 145/800 L2 = 0.023746, MSE = 0.000080, using 0.496606s
epoch 146/800 L2 = 0.024296, MSE = 0.000084, using 0.489234s
epoch 147/800 L2 = 0.025082, MSE = 0.000089, using 0.497558s
epoch 148/800 L2 = 0.023105, MSE = 0.000076, using 0.493872s
epoch 149/800 L2 = 0.023825, MSE = 0.000081, using 0.497162s
epoch 150/800 L2 = 0.024547, MSE = 0.000086, using 0.494723s
epoch 151/800 L2 = 0.024598, MSE = 0.000085, using 0.492305s
epoch 152/800 L2 = 0.023574, MSE = 0.000079, using 0.500711s
epoch 153/800 L2 = 0.024

epoch 272/800 L2 = 0.017470, MSE = 0.000043, using 0.498940s
epoch 273/800 L2 = 0.017299, MSE = 0.000042, using 0.491750s
epoch 274/800 L2 = 0.017434, MSE = 0.000043, using 0.494058s
epoch 275/800 L2 = 0.017371, MSE = 0.000043, using 0.491718s
epoch 276/800 L2 = 0.017902, MSE = 0.000046, using 0.495757s
epoch 277/800 L2 = 0.017980, MSE = 0.000046, using 0.493572s
epoch 278/800 L2 = 0.017554, MSE = 0.000044, using 0.496390s
epoch 279/800 L2 = 0.017840, MSE = 0.000045, using 0.493544s
epoch 280/800 L2 = 0.018248, MSE = 0.000047, using 0.496748s
epoch 281/800 L2 = 0.018122, MSE = 0.000046, using 0.492032s
epoch 282/800 L2 = 0.017091, MSE = 0.000041, using 0.494663s
epoch 283/800 L2 = 0.017473, MSE = 0.000043, using 0.503292s
epoch 284/800 L2 = 0.017555, MSE = 0.000044, using 0.500794s
epoch 285/800 L2 = 0.017364, MSE = 0.000043, using 0.502726s
epoch 286/800 L2 = 0.017546, MSE = 0.000044, using 0.492499s
epoch 287/800 L2 = 0.019183, MSE = 0.001529, using 0.498901s
epoch 288/800 L2 = 0.016

epoch 407/800 L2 = 0.014645, MSE = 0.000030, using 0.497123s
epoch 408/800 L2 = 0.014603, MSE = 0.000030, using 0.492905s
epoch 409/800 L2 = 0.014683, MSE = 0.000031, using 0.495250s
epoch 410/800 L2 = 0.014741, MSE = 0.000031, using 0.491790s
epoch 411/800 L2 = 0.014537, MSE = 0.000030, using 0.495410s
epoch 412/800 L2 = 0.014654, MSE = 0.000030, using 0.492533s
epoch 413/800 L2 = 0.014696, MSE = 0.000031, using 0.499029s
epoch 414/800 L2 = 0.014646, MSE = 0.000030, using 0.490144s
epoch 415/800 L2 = 0.014940, MSE = 0.000032, using 0.504497s
epoch 416/800 L2 = 0.014633, MSE = 0.000030, using 0.497582s
epoch 417/800 L2 = 0.014585, MSE = 0.000030, using 0.506377s
epoch 418/800 L2 = 0.014526, MSE = 0.000030, using 0.494202s
epoch 419/800 L2 = 0.014743, MSE = 0.000031, using 0.499098s
epoch 420/800 L2 = 0.014606, MSE = 0.000030, using 0.494423s
epoch 421/800 L2 = 0.014775, MSE = 0.000031, using 0.498835s
epoch 422/800 L2 = 0.014555, MSE = 0.000030, using 0.492358s
epoch 423/800 L2 = 0.014

epoch 542/800 L2 = 0.013823, MSE = 0.000027, using 0.500139s
epoch 543/800 L2 = 0.013868, MSE = 0.000027, using 0.493335s
epoch 544/800 L2 = 0.013803, MSE = 0.000027, using 0.497062s
epoch 545/800 L2 = 0.013837, MSE = 0.000027, using 0.500562s
epoch 546/800 L2 = 0.013857, MSE = 0.000027, using 0.500638s
epoch 547/800 L2 = 0.013769, MSE = 0.000027, using 0.500161s
epoch 548/800 L2 = 0.013901, MSE = 0.000027, using 0.495999s
epoch 549/800 L2 = 0.013784, MSE = 0.000027, using 0.494676s
epoch 550/800 L2 = 0.013831, MSE = 0.000027, using 0.496455s
epoch 551/800 L2 = 0.013936, MSE = 0.000027, using 0.494755s
epoch 552/800 L2 = 0.013796, MSE = 0.000027, using 0.495422s
epoch 553/800 L2 = 0.013799, MSE = 0.000027, using 0.494735s
epoch 554/800 L2 = 0.013897, MSE = 0.000027, using 0.495969s
epoch 555/800 L2 = 0.013784, MSE = 0.000027, using 0.494275s
epoch 556/800 L2 = 0.013852, MSE = 0.000027, using 0.499218s
epoch 557/800 L2 = 0.015382, MSE = 0.000316, using 0.503387s
epoch 558/800 L2 = 0.013

epoch 677/800 L2 = 0.013422, MSE = 0.000025, using 0.495412s
epoch 678/800 L2 = 0.013469, MSE = 0.000026, using 0.492763s
epoch 679/800 L2 = 0.013437, MSE = 0.000026, using 0.490853s
epoch 680/800 L2 = 0.013461, MSE = 0.000026, using 0.501487s
epoch 681/800 L2 = 0.013468, MSE = 0.000026, using 0.497204s
epoch 682/800 L2 = 0.013461, MSE = 0.000026, using 0.493944s
epoch 683/800 L2 = 0.013436, MSE = 0.000025, using 0.497044s
epoch 684/800 L2 = 0.013440, MSE = 0.000025, using 0.496431s
epoch 685/800 L2 = 0.013385, MSE = 0.000025, using 0.492953s
epoch 686/800 L2 = 0.013411, MSE = 0.000025, using 0.496230s
epoch 687/800 L2 = 0.013431, MSE = 0.000025, using 0.496175s
epoch 688/800 L2 = 0.013406, MSE = 0.000025, using 0.497044s
epoch 689/800 L2 = 0.013441, MSE = 0.000026, using 0.497081s
epoch 690/800 L2 = 0.013407, MSE = 0.000025, using 0.495585s
epoch 691/800 L2 = 0.013418, MSE = 0.000025, using 0.494515s
epoch 692/800 L2 = 0.013383, MSE = 0.000025, using 0.498299s
epoch 693/800 L2 = 0.013

In [1]:
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False)
index = 0
test_l2 = 0
test_mse = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.cuda(), y.cuda()

        out = model(x).view(-1)
        pred[index] = out.reshape(y_test.shape[1], y_test.shape[2])
        mse = F.mse_loss(out.view(1, -1), y.view(1, -1), reduction='mean')
        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        test_mse += mse.item()
        index += 1
        
    test_mse /= len(test_loader)
    test_l2 /= ntest
    print('test error: L2 =', test_l2,', MSE =',test_mse)

NameError: name 'torch' is not defined