In [1]:
import pickle 
import torch
import torch.nn.functional as F
import torch.nn as nn

In [2]:
from dataloader import DataLoader2D
from plot_utils import animate_dyn,plot_eig,plot_perturbation,plot_wavepacket,comp_dyn

In [3]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, modes2):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.modes2 = modes2

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))
        self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels,  x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)))
        return x

class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels):
        super(MLP, self).__init__()
        self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1)
        self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1)

    def forward(self, x):
        x = self.mlp1(x)
        x = F.gelu(x)
        x = self.mlp2(x)
        return x


In [4]:
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)


In [5]:
class FNO2d(nn.Module):
    def __init__(self, modes1, modes2,  width):
        super(FNO2d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """

        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.padding = 9 # pad the domain if input is non-periodic

        self.p = nn.Linear(4, self.width) # input channel is 3: (a(x, y), x, y)
        self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2)
        self.mlp0 = MLP(self.width, self.width, self.width)
        self.mlp1 = MLP(self.width, self.width, self.width)
        self.mlp2 = MLP(self.width, self.width, self.width)
        self.mlp3 = MLP(self.width, self.width, self.width)
        self.w0 = nn.Conv2d(self.width, self.width, 1)
        self.w1 = nn.Conv2d(self.width, self.width, 1)
        self.w2 = nn.Conv2d(self.width, self.width, 1)
        self.w3 = nn.Conv2d(self.width, self.width, 1)
        self.q = MLP(self.width, 2, self.width * 4) # output channel is 1: u(x, y)

    def forward(self, x):
        #grid = self.get_grid(x.shape, x.device)
        #x = torch.cat((x,V, grid), dim=-1)
       # print(x.shape)
        x = self.p(x)
       # print(x.shape)
        x = x.permute(0, 3, 1, 2)
       # x = F.pad(x, [0,self.padding, 0,self.padding])

        x1 = self.conv0(x)
        x1 = self.mlp0(x1)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x1 = self.mlp1(x1)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x1 = self.mlp2(x1)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x1 = self.mlp3(x1)
        x2 = self.w3(x)
        x = x1 + x2

       # x = x[..., :-self.padding, :-self.padding]
        x = self.q(x)
        x = x.permute(0, 2, 3, 1)
        #print(x.shape)
        return x

In [6]:
n_samp=1000
ntrain=900
ntest=100
timestep=990
batchsize=20
epochs=100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
with open('data.pickle', 'rb') as handle:
    data_dict = pickle.load(handle)
#data_dict = torch.load('data.pt')

In [8]:
V=data_dict.get('V')
u0=data_dict.get('u0')
#D=data_dict.get('D')
#U=data_dict.get('U')
ux=data_dict.get('ux')
U_real=data_dict.get('U_real')
gridx=data_dict.get('gridx')
gridt=data_dict.get('gridt')

In [9]:
print(V.shape,u0.shape,ux.shape,U_real.shape,gridx.shape,gridt.shape)
#print(V.shape,u0.shape,D.shape,U.shape,ux.shape,U_real.shape,gridx.shape,gridt.shape)
dataset=DataLoader2D(u0,U_real,V,timestep)

torch.Size([1000, 300, 300]) torch.Size([1000, 300]) torch.Size([1000, 300]) torch.Size([1000, 300, 300, 2]) (300,) (300,)


In [10]:
train_loader = dataset.make_loader(gridx,gridt,ntrain, batchsize, start=0, train=True)
test_loader = dataset.make_loader(gridx,gridt,ntest, batchsize, start=ntrain, train=False)

torch.Size([900, 300, 300]) torch.Size([900, 300, 300])
torch.Size([900, 300, 300]) torch.Size([900, 300, 300]) torch.Size([900, 300, 300])
torch.Size([900, 300, 300, 4]) torch.Size([900, 300, 300, 2]) torch.Size([900, 300, 300])
torch.Size([100, 300, 300]) torch.Size([100, 300, 300])
torch.Size([100, 300, 300]) torch.Size([100, 300, 300]) torch.Size([100, 300, 300])
torch.Size([100, 300, 300, 4]) torch.Size([100, 300, 300, 2]) torch.Size([100, 300, 300])


In [11]:
model = FNO2d(modes1=20,
                  modes2=20,width=32)
model=model.to(device)

In [12]:
def train_model(model,train_loader,epochs,device,batch_size):
    
    
    
    myloss = LpLoss(size_average=True)
    learning_rate = 0.01
    scheduler_step = 30
    scheduler_gamma = 0.9
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
    train_loss=[]

    
    for e in range(epochs):
        train_mse = 0
        train_l2 = 0
        model.train()        
        for x,v,y in train_loader:
            #print(x.shape)

            x = x.to(device).float()
            v=v.to(device)
            y=y.to(device)
            
            #x.requires_grad=True
            optimizer.zero_grad()
            
            out = model(x).view(y.shape[0],300,300,2)
            #data_loss = LpLoss(out, y)
       
            #out = model(x).view(batch_size, S, S, T)

            mse = F.mse_loss(out, y, reduction='mean')
        
            l2 = myloss(out, y)
            
            l2.backward()

            optimizer.step()
            scheduler.step()
       
            train_mse += mse.item()
            train_l2 += l2.item()
        
        train_mse/= ntrain
        train_l2 /= ntrain    
        
        
        if (e%10==0):
            print(e,train_l2,train_mse)
        
        train_loss.append(train_l2)    
          
   
    return train_loss
  

In [None]:
l1=train_model(model,train_loader,epochs,device,batchsize)

In [None]:
import numpy as np
model.eval()
test_l2 = 0.0
test_mse=0.0
V_out=[]
pred_out=[]
true_out=[]
myloss = LpLoss(size_average=True)

with torch.no_grad():
    for x,v, y in test_loader:
            x = x.to(device).float()
            
            y=y.to(device)
            
            out = model(x).view(y.shape[0],300,300,2)
            #data_loss = LpLoss(out, y)
       
            #out = model(x).view(batch_size, S, S, T)

            mse = F.mse_loss(out, y, reduction='mean')
        
            l2 = myloss(out, y)
            
            #l2.backward()
            
            pred_out.append(out)
            V_out.append(v)
            true_out.append(y)
            

       
            test_mse += mse.item()
            test_l2 += l2.item()
    
    print(test_l2,test_mse)

In [None]:
print(len(pred_out),len(V_out))

In [None]:
#print(pred_out[0][0])


In [None]:
#print(pred_out[0][0].shape,V_out[0][0].shape,density_out[0].shape)
#p1=pred_out[0][2]
x=torch.view_as_complex(pred_out[1][1].contiguous())
x_true=torch.view_as_complex(true_out[1][1].contiguous())
dens=(x.abs()) ** 2
dens_true=(x_true.abs()) ** 2
#print(type(x),x.shape,dens.shape)
v1=V_out[0][0]

In [None]:
x_min, x_max = -90*np.pi/180, 90*np.pi/180
x_size, y_size = 300, 300
x_grid = np.linspace(x_min, x_max, x_size)



In [None]:
%matplotlib widget
#%matplotlib notebook
t=comp_dyn(x_min,x_max,0,100,x_grid,dens.cpu(),dens_true.cpu(),v1)
#t.event_source.stop()


In [21]:
ax=dens.cpu()
print(ax[3])

tensor([4.9262e-04, 1.7045e-04, 9.5592e-05, 1.6971e-04, 2.9066e-04, 3.9186e-04,
        4.4398e-04, 4.5064e-04, 4.1558e-04, 3.9547e-04, 3.6431e-04, 3.3881e-04,
        3.3323e-04, 3.2329e-04, 3.2746e-04, 3.2245e-04, 3.3011e-04, 3.2590e-04,
        3.4214e-04, 3.5962e-04, 3.9443e-04, 4.2680e-04, 5.0329e-04, 6.1999e-04,
        7.9587e-04, 1.0441e-03, 1.3796e-03, 1.7601e-03, 2.1084e-03, 2.3696e-03,
        2.4555e-03, 2.3029e-03, 1.9174e-03, 1.4616e-03, 9.8224e-04, 5.8332e-04,
        3.5149e-04, 2.6451e-04, 2.8929e-04, 3.7436e-04, 5.0656e-04, 6.8434e-04,
        9.0760e-04, 1.2114e-03, 1.5449e-03, 1.8229e-03, 1.9682e-03, 1.9159e-03,
        1.6844e-03, 1.3566e-03, 9.9369e-04, 6.9544e-04, 5.0036e-04, 3.8112e-04,
        3.0310e-04, 2.1815e-04, 1.3367e-04, 5.4436e-05, 1.6126e-05, 1.4316e-05,
        2.5201e-05, 6.0932e-05, 2.0214e-04, 6.5902e-04, 1.7866e-03, 4.1033e-03,
        8.3917e-03, 1.5608e-02, 2.7214e-02, 4.5159e-02, 7.2209e-02, 1.1400e-01,
        1.8141e-01, 3.0013e-01, 5.1494e-