In [1]:
#Importing necessary packages
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pickle
import pandas as pd

In [2]:
!pip install neptune-client
!pip install neptune-notebooks

Collecting neptune-client
  Obtaining dependency information for neptune-client from https://files.pythonhosted.org/packages/c2/7f/1b7d7c0faffc6ed69d4982662f415a70b4bac9ea4506296c27cf42142f27/neptune_client-1.8.6-py3-none-any.whl.metadata
  Downloading neptune_client-1.8.6-py3-none-any.whl.metadata (17 kB)
Collecting bravado<12.0.0,>=11.0.0 (from neptune-client)
  Downloading bravado-11.0.3-py2.py3-none-any.whl (38 kB)
Collecting swagger-spec-validator>=2.7.4 (from neptune-client)
  Downloading swagger_spec_validator-3.0.3-py2.py3-none-any.whl (27 kB)
Collecting botocore<1.30.0,>=1.29.100 (from boto3>=1.16.0->neptune-client)
  Obtaining dependency information for botocore<1.30.0,>=1.29.100 from https://files.pythonhosted.org/packages/46/20/e7a9a8e6746872afcc4e3ad5ab503702c38813b3a532df27cce95c98b8cb/botocore-1.29.165-py3-none-any.whl.metadata
  Downloading botocore-1.29.165-py3-none-any.whl.metadata (5.9 kB)
Collecting bravado-core>=5.16.1 (from bravado<12.0.0,>=11.0.0->neptu

In [3]:
import neptune.new as neptune
project = neptune.init_project(project="geometricintegrationntnu/Transport-Equation", api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2YzM5YjI2My1kYTI2LTRhNmMtOWI5Ni1lYzlmYzBiZWZiNzIifQ==")

  from neptune.version import version as neptune_client_version
  import neptune.new as neptune


https://app.neptune.ai/geometricintegrationntnu/Transport-Equation/


In [4]:
torch.manual_seed(7)
np.random.seed(7)

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
import os

def download_data(pde_name):

    #Downloading and saving the dataset into the data directory
    if not os.path.exists('data'):
        
        try:
            import zenodo_get
        except:
            input('To download the data the package ''zenodo_get'' needs to be imported.\n Press enter to agree on dowloading it.') #if you press enter you go on
            os.system('pip install zenodo_get')
        
        working_directory = os.getcwd()
        os.mkdir('data')
        os.chdir('data')
        os.system(f'zenodo_get 7665159')
        os.chdir(working_directory)
    else:
        working_directory = os.getcwd()
        os.chdir('data')
        if not os.path.exists(f'data_{pde_name}.pickle') or not os.path.exists(f'data_{pde_name}_verification.pickle'):
            try:
                import zenodo_get
            except:
                input('To download the data the package ''zenodo_get'' needs to be imported.\n Press enter to agree on dowloading it.') #if you press enter you go on
                os.system('pip install zenodo_get')
            
            os.system(f'zenodo_get 7665159')
            
        os.chdir(working_directory)

In [7]:
class dataset(Dataset):
  def __init__(self,x,y,device):
    self.x = torch.from_numpy(x.astype(np.float64)).to(device)
    self.y = torch.from_numpy(y.astype(np.float64)).to(device)
    self.length = self.x.shape[0]
 
  def __getitem__(self,idx):
    return self.x[idx],self.y[idx]
  def __len__(self):
    return self.length

#Method splitting the dataset into train and test sets
def get_train_test_split(pde_name='linadv',timesteps=5,device=device):
    
    #Load the data points
    if pde_name=="linadv":
        #with open(f'data/data_{pde_name}_dt001.pickle','rb') as file:
        with open(f'data/data_{pde_name}.pickle','rb') as file:
          data_train = pickle.load(file)
        #with open(f'data/data_{pde_name}_dt001_validation.pickle','rb') as file:
        with open(f'data/data_{pde_name}_verification.pickle','rb') as file:
          data_test = pickle.load(file)
    else:
        with open(f'data/data_{pde_name}.pickle','rb') as file:
              data_train = pickle.load(file)
        with open(f'data/data_{pde_name}_verification.pickle','rb') as file:
              data_test = pickle.load(file)
    
    if pde_name=="fisher":
      new_data_train = []
      new_data_test = []
      for i in range(len(data_train)):
        if np.linalg.norm(data_train[i][0].reshape(-1),ord=2).item()>10:
          new_data_train.append(data_train[i])
      
      for i in range(len(data_test)):
        if np.linalg.norm(data_test[i][0].reshape(-1),ord=2).item()>10:
          new_data_test.append(data_test[i])
      
      data_train = new_data_train
      data_test = new_data_test
    
    
    #Split the loaded data into training and testing sets
    dim1,dim2 = data_train[1][0].shape
    timesteps_test = len(data_test[1])-1 #we subtract one because we also have the IC
    
    n_train = len(data_train)
    n_test = len(data_test)
    input_train = np.zeros((n_train,1,dim1,dim2))
    label_train = np.zeros((n_train,timesteps,dim1,dim2))
    input_test = np.zeros((n_test,1,dim1,dim2))
    label_test = np.zeros((n_test,timesteps_test,dim1,dim2))
    
    #Store the initial condition into the input variables
    #and the remaining updates into the label variables.
    for i in range(n_train):
        input_train[i,0] = data_train[i][0]
        for j in range(timesteps):
            label_train[i,j] = data_train[i][j+1]
    for i in range(n_test):
        input_test[i,0] = data_test[i][0]
        for j in range(timesteps_test):
            label_test[i,j] = data_test[i][j+1]
    
    #Create the dataloaders given the obtained splitting
    trainset = dataset(input_train,label_train,device)
    testset = dataset(input_test,label_test,device)
    
    return trainset, testset

In [8]:
def getLambda(y,norm_0,sum_0):
    a,b,c = y.shape
    vec_y = y.reshape(a,b*c)
    ones = torch.ones_like(vec_y,dtype=torch.float64)
    dg = torch.cat((ones,vec_y),dim=1)
    norm_now = torch.linalg.norm(vec_y,ord=2,dim=1,keepdim=True)
    sum_now = vec_y.sum(dim=1,keepdim=True)
    num_elements = b*c * torch.ones_like(norm_now,dtype=torch.float64)
    
    scaling = 1/(4*num_elements*norm_now**2-4*sum_now**2)
    row1 = torch.cat((4*norm_now**2,-2*sum_now),dim=1).unsqueeze(1)
    row2 = torch.cat((-2*sum_now,num_elements),dim=1).unsqueeze(1)
    mat = torch.cat((row1,row2),dim=1)
    
    g = torch.zeros((len(sum_now),2),dtype=torch.float64,device=device)
    g[:,0:1] = sum_now-sum_0
    g[:,1:2] = norm_now**2-norm_0**2
    return -scaling*torch.einsum('ijk,ik->ij',mat,g)

In [9]:
class relu_squared(nn.Module):
    def __init__(self,):
        super().__init__()
    def forward(self,x):
        return torch.relu(x)**2
    
class network(nn.Module):
    def __init__(self,n_layers=3,kernel_size=3,bias=False,preserve_norm=True,is_linear=False):
        super().__init__()

        self.dt = 0.01 #Temporal step
        self.dx = 1/99.
        
        self.s = n_layers
        self.preserve_norm = preserve_norm
        
        pad = 1 if kernel_size==3 else 2
        
        self.lift = nn.Conv2d(1,2,kernel_size,padding=pad,padding_mode='circular',bias=bias,dtype=torch.float64)
        self.proj = nn.Conv2d(2,1,1,padding=0,bias=False,dtype=torch.float64)
        
        self.linear = nn.Conv2d(1,1,kernel_size,padding=pad,padding_mode='circular',bias=bias,dtype=torch.float64)
        
        self.is_linear = is_linear
        
        #self.cc = nn.Conv2d(1, 1, 5,  stride=1, padding=0, bias=False,dtype=torch.float64)
        #self.ccT = nn.ConvTranspose2d(1, 1, 5, stride=1, padding=4, bias=False,dtype=torch.float64)
    
    def F(self,U):
        
        if self.is_linear:
            return self.linear(U)
        else:
            U = torch.relu(self.lift(U))
            return self.proj(U)
        
        '''self.cc.weight.data -=  torch.mean(self.cc.weight.data.view(-1))
        self.ccT.weight.data = self.cc.weight.data
        p1d = (2,2,2,2)
        pd = lambda x : F.pad(x, p1d, "circular")
        return self.cc(pd(U))#-self.ccT(pd(U))'''
    
    def timestep(self,dt,yhat,norm_0,sum_0,k_max=2):
        yhat = yhat + dt*self.F(yhat)
        ones = torch.ones_like(yhat)
        for k in range(k_max):
            multiplier = getLambda(yhat[:,0],norm_0,sum_0)
            yhat = yhat + multiplier[:,0].view(-1,1,1,1) * ones + 2 * multiplier[:,1].view(-1,1,1,1) * yhat
        return yhat
        
        
    def forward(self,U,norm=None):
        
        if self.preserve_norm:
            if norm==None:
                no = torch.linalg.norm(U.view(len(U),-1),dim=1,ord=2).reshape(-1,1,1,1)
            else:
                no = norm.reshape(-1,1,1,1) #This occurs when we are in the training loop and we inject noise
        
        
        
        #a,_,b,c = U.shape
        #norm_0 = torch.linalg.norm(U.reshape(-1,b*c),ord=2,dim=1,keepdim=True)
        #sum_0 = U.reshape(-1,b*c).sum(dim=1,keepdim=True)
        
        for i in range(self.s):

            #U = self.timestep(self.dt/self.s,U,norm_0,sum_0)
            
            U = U + self.dt/self.s * self.F((U))
            if self.preserve_norm:
                no_current = torch.linalg.norm(U.view(len(U),-1),ord=2,dim=1).reshape(-1,1,1,1)
                U = U / no_current * no

        return U

In [10]:
def train(run,model,lr,weight_decay,epochs,trainloader,timesteps=3,gamma=1e-4,is_cyclic=True,is_noise=True):
    
    criterion = nn.MSELoss()
    
    
    for max_t in np.arange(2,timesteps):
        
        print(f"Training with {max_t} timesteps")

        optimizer = torch.optim.Adam(model.parameters(),lr=lr)
        if is_cyclic:
            scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=2000, mode='exp_range',cycle_momentum=False)
        else:
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=int(0.45*epochs))
        epoch = 1
    
        while epoch < epochs:
            
            for i, inp in enumerate(trainloader):
                inputs, labels = inp
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                loss = 0.
                
                epsilon = 0.01
                if is_noise:
                    noise = (torch.rand_like(inputs)*2*epsilon-epsilon)
                else:
                    noise = 0.
                res = inputs.clone()
                
                no_initial = torch.linalg.norm(res.reshape(len(res),-1),dim=1,ord=2)
                
                for tt in range(max_t):
                    if tt==0:
                        res = model(res + noise,no_initial)
                    else:
                        res = model(res)
                    loss += criterion(res,labels[:,tt:tt+1]) / max_t
                    
                    if gamma>0:
                        no_current = torch.linalg.norm(res.reshape(len(res),-1),dim=1,ord=2)
                        loss += gamma * criterion(no_current,no_initial) / max_t
                loss.backward()
                
                optimizer.step()
                if is_cyclic:
                    scheduler.step()
            
            if is_cyclic==False:
                scheduler.step()
            
            run[f"train/loss_t_max={max_t}"].log(loss.item())
            epoch += 1
            
            if epoch%99==0:
                print(f'Loss [{epoch}](epoch): ', loss.item())
        
        lr = lr/2
        
    print('Training Done')

In [11]:
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
import torch
import numpy as np
import pandas as pd
import warnings
import os

def generate_gif_predicted(run,pde_name,model,X,timesteps_test):
    
    if not os.path.exists('saved_plots'):
        os.mkdir('saved_plots')
    
    plt.rcParams["figure.autolayout"] = True
    fig = plt.figure(figsize=(10,10))

    ax = fig.add_subplot(111)
    div = make_axes_locatable(ax)
    cax = div.append_axes('right', '5%', '5%')

    def animate(i):
        cax.cla()
        res = X.clone()
        res = (res.unsqueeze(0))
        for j in range(i):
            res = model(res)
        res = res[0].detach().cpu().numpy()
        im = ax.imshow(res[-1], cmap = 'hot')
        fig.colorbar(im, cax=cax)
        ax.set_title('Prediction, Frame {0}'.format(i))

    ani = animation.FuncAnimation(fig, animate, frames=timesteps_test)
    ani.save(f"saved_plots/predicted_{pde_name}.gif", writer='pillow')
      
    run["predicted_dynamics"].upload(f"saved_plots/predicted_{pde_name}.gif")

#Show the true time evolution of the difference between
#the true and predicted time evolutions for timesteps_test steps
#of size dt, startinf from the initial condition X

def generate_gif_true(run,pde_name,X,Y,timesteps_test):
    
    if not os.path.exists('saved_plots'):
        os.mkdir('saved_plots')
    
    plt.rcParams["figure.autolayout"] = True

    Traj = torch.cat((X,Y),dim=0)

    fig = plt.figure(figsize=(10,10))

    ax = fig.add_subplot(111)
    div = make_axes_locatable(ax)
    cax = div.append_axes('right', '5%', '5%')

    def animate(i):
        cax.cla()
        im = ax.imshow((Traj[i]).detach().cpu().numpy(), cmap = 'hot')
        fig.colorbar(im,cax=cax)
        ax.set_title('True dynamics, Frame {0}'.format(i))
    ani = animation.FuncAnimation(fig, animate, frames=timesteps_test)
    ani.save(f"saved_plots/true_{pde_name}.gif", writer='pillow')
    
    run["true_dynamics"].upload(f"saved_plots/true_{pde_name}.gif")

#Show the time evolution of the initial condition X
#for timesteps_test steps of size dt
def generate_gif_error(run,pde_name,model,X,Y,timesteps_test):
    
    if not os.path.exists('saved_plots'):
        os.mkdir('saved_plots')
    
    dim = X.shape[0]
    
    plt.rcParams["figure.figsize"] = [10,10]
    plt.rcParams["figure.autolayout"] = True
    
    fig = plt.figure()
    Traj = torch.cat((X,Y),dim=0)
    ax = fig.add_subplot(111)
    div = make_axes_locatable(ax)
    cax = div.append_axes('right', '5%', '5%')

    def animate(i):
        cax.cla()
        res = X.clone()
        res = (res.unsqueeze(0))
        for j in range(i):
            res = model(res)
        res = res[0,0]
        im = ax.imshow((res-Traj[i]).detach().cpu().numpy(), cmap='hot')
        fig.colorbar(im, cax=cax)
        ax.set_title('Difference of matrices, Frame {0}'.format(i))

    ani = animation.FuncAnimation(fig, animate, frames=timesteps_test)
    ani.save(f"saved_plots/error_{pde_name}.gif", writer='pillow')
    
    run["error_dynamics"].upload(f"saved_plots/error_{pde_name}.gif")

#Generate .csv files where the three test metrics are stored
#We measure the mean squared error, the  relative error and
#the maximum absolute error for 30 test initial conditions and
#store in these files the mean of such values
def save_test_results(run,pde_name,model,testloader,preserve_norm=None):
    
    if not os.path.exists('saved_test_results'):
        os.mkdir('saved_test_results')
    
    if pde_name!='linadv':
        if preserve_norm!=None:
            warnings.warn('Projected Euler has been implemented only for linadv problem')
            preserve_norm = None
    
    model.to('cpu');
    X,Y = next(iter(testloader))
    X,Y = (X.to('cpu')), (Y.to('cpu'))
    Traj = torch.cat((X,Y),dim=1)
    res = X
    
    #Initialize the lists where we store the values
    mseList = []
    max_errorList = []
    relative_error_list = []
    
    mseList.append(torch.mean((res-X)**2).item())
    max_errorList.append(torch.mean(torch.max((torch.abs(res-X)).reshape(len(X),-1),dim=1)[0], dim=0).item())
    relative_error_list.append(torch.mean((torch.linalg.norm((res-X).view(len(X),-1),dim=1,ord=2)  / torch.linalg.norm((X).view(len(X),-1),dim=1,ord=2))).item())
    
    res = model(res)
    
    #Compute the quantities for the successive iterations
    for j in np.arange(1,40):
        mseList.append(torch.mean((res-Traj[:,j:j+1])**2).item())
        max_errorList.append(torch.mean(torch.max((torch.abs(res-Traj[:,j:j+1])).reshape(len(X),-1),dim=1)[0], dim=0).item())
        relative_error_list.append(torch.mean((torch.linalg.norm((res-Traj[:,j:j+1]).view(len(X),-1),dim=1,ord=2)  / torch.linalg.norm((Traj[:,j:j+1]).view(len(X),-1),dim=1,ord=2))).item())
        res = model(res)
    
    #Save the results
    if pde_name == 'linadv':
        if preserve_norm==True:
            np.savetxt(f"saved_test_results/{pde_name}_AverageMSE_Conserved_Test30.csv", mseList, delimiter=",")
            np.savetxt(f"saved_test_results/{pde_name}_MaxError_Conserved_Test30.csv", max_errorList, delimiter=",")
            np.savetxt(f"saved_test_results/{pde_name}_RelativeError_Conserved_Test30.csv", relative_error_list, delimiter=",")
            
            run["AverageMSE"].upload(f"saved_test_results/{pde_name}_AverageMSE_Conserved_Test30.csv")
            run["MaxError"].upload(f"saved_test_results/{pde_name}_MaxError_Conserved_Test30.csv")
            run["RelativeError"].upload(f"saved_test_results/{pde_name}_RelativeError_Conserved_Test30.csv")
            
        else:
            np.savetxt(f"saved_test_results/{pde_name}_AverageMSE_NonConserved_Test30.csv", mseList, delimiter=",")
            np.savetxt(f"saved_test_results/{pde_name}_MaxError_NonConserved_Test30.csv", max_errorList, delimiter=",")
            np.savetxt(f"saved_test_results/{pde_name}_RelativeError_NonConserved_Test30.csv", relative_error_list, delimiter=",")
            
            run["AverageMSE"].upload(f"saved_test_results/{pde_name}_AverageMSE_NonConserved_Test30.csv")
            run["MaxError"].upload(f"saved_test_results/{pde_name}_MaxError_NonConserved_Test30.csv")
            run["RelativeError"].upload(f"saved_test_results/{pde_name}_RelativeError_NonConserved_Test30.csv")
    else:
        np.savetxt(f"saved_test_results/{pde_name}_AverageMSE_Test30.csv", mseList, delimiter=",")
        np.savetxt(f"saved_test_results/{pde_name}_MaxError_Test30.csv", max_errorList, delimiter=",")
        np.savetxt(f"saved_test_results/{pde_name}_RelativeError_Test30.csv", relative_error_list, delimiter=",")
        
        run["AverageMSE"].upload(f"saved_test_results/{pde_name}_AverageMSE_Test30.csv")
        run["MaxError"].upload(f"saved_test_results/{pde_name}_MaxError_Test30.csv")
        run["RelativeError"].upload(f"saved_test_results/{pde_name}_RelativeError_Test30.csv")
        
    print("Results saved or updated in the directory 'saved_test_results'")
    
    
    
def generate_error_plots(run,pde_name,model,testloader,preserve_norm=None):
    
    #Generate the results
    save_test_results(run,pde_name,model,testloader,preserve_norm)
    
    #List of plots we show
    namePlots = ["MaxError", "AverageMSE", "RelativeError"]
    #labels = [r"$\texttt{maxE}(j)$", r"$\texttt{mse}(j)$", r"$\texttt{rE}(j)$"]
    labels = [r"maxE(j)", r"mse(j)", r"rE(j)"]
    
    if not os.path.exists('saved_plots'):
        os.mkdir('saved_plots')
    
    #Generation of the plots for the 3 different PDEs
    if pde_name=="linadv":
        for it, name in enumerate(namePlots):
            is_data_constr = False
            is_data_unconstr = False
            
            try:
                dfConstr = pd.read_csv(f'saved_test_results/{pde_name}_{name}_Conserved_Test30.csv',header=None)
                is_data_constr = True
            except:
                print(f"Missing data for the {name} of the conserved case")
            try:
                dfUnconstr = pd.read_csv(f'saved_test_results/{pde_name}_{name}_NonConserved_Test30.csv',header=None)
                is_data_unconstr = True
            except:
                print(f"Missing data for the {name} of the non conserved case")
                
            xx = np.arange(0,40)
            if is_data_unconstr or is_data_constr:
                fig = plt.figure(figsize=[10,10],dpi=300)
            if is_data_constr:
                plt.plot(xx,dfConstr.iloc[:len(xx),-1],'r-o',label="Constrained")
            if is_data_unconstr:
                plt.plot(xx,dfUnconstr.iloc[:len(xx),-1],'b-o',label="Unconstrained")
            
            if is_data_constr==False and is_data_unconstr==False:
                print(f"There is nothing saved to generate the plots of {name}")
            else:
                plt.yticks(fontsize=45)
                plt.xticks(fontsize=45)
                plt.legend(fontsize=45)
                plt.xlabel("Number of iterations",fontsize=45)
                plt.ylabel(f"{labels[it]}",fontsize=45)
                if is_data_unconstr:
                    ymax = np.max(dfUnconstr.iloc[:,-1])
                else:
                    ymax = np.max(dfConstr.iloc[:,-1])
                ymin = 0
                ylist = np.linspace(ymin,ymax,4)
                plt.yticks(ylist)
                plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
                plt.savefig(f"saved_plots/{pde_name}_{name}_BatchOf30.pdf", format="pdf",bbox_inches='tight')
                
                
                run[f"{name}"].upload(f"saved_plots/{pde_name}_{name}_BatchOf30.pdf")
                
    else:
        for it, name in enumerate(namePlots):
            is_data = False
            
            try:
                df = pd.read_csv(f'saved_test_results/{pde_name}_{name}_Test30.csv',header=None)
                is_data = True
            except:
                print(f"Missing data for the {name} of {pde_name}")
                
            xx = np.arange(0,40)
            if is_data:
                fig = plt.figure(figsize=[10,10],dpi=300)
                plt.plot(xx,df.iloc[:len(xx),-1],'r-o')
                plt.yticks(fontsize=45)
                plt.xticks(fontsize=45)
                plt.xlabel("Number of iterations",fontsize=45)
                plt.ylabel(f"{labels[it]}",fontsize=45)
                ymax = np.max(df.iloc[:,-1])
                ymin = 0
                ylist = np.linspace(ymin,ymax,4)
                plt.yticks(ylist)
                plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
                plt.savefig(f"saved_plots/{pde_name}_{name}_BatchOf30.pdf", format="pdf",bbox_inches='tight')
            else:
                print(f"There is nothing saved to generate the plots of {name} in {pde_name}")
    print("Plots saved or updated in the directory 'saved plots'")

In [12]:
pde_name = 'linadv'
timesteps = 5

download_data(pde_name)

trainset, testset = get_train_test_split(pde_name,timesteps=timesteps,device=device)

#Create the model

n_layers = 3
bias = False
dim = 100
lr = 1e-3
epochs = 300
batch_size = 32


kernel_size_list = [3,5]
preserve_norm_list = [True,False]
weight_decay_list = [0] #Maybe also try 1e-5
is_linear_list = [True,False]
is_cyclic_list = [True,False]
is_noise_list = [True,False]

for kernel_size in kernel_size_list:
    for weight_decay in weight_decay_list:
        for is_linear in is_linear_list:
            for is_cyclic in is_cyclic_list:
                for preserve_norm in preserve_norm_list:
                    for is_noise in is_noise_list:

                        gamma_reg = 0

                        config = {
                          "dim":dim, 
                          "timesteps": timesteps,
                          "learning_rate":lr, 
                          "preserve_norm":preserve_norm, 
                          "epochs":epochs, 
                          "batch_size":batch_size, 
                          "weight_decay":weight_decay,
                          "optimizer":"adam",
                          "n_layers":n_layers,
                          "bias":bias,
                          "kernel_size":kernel_size,
                          "gamma_reg":gamma_reg,
                          "is_linear":is_linear,
                          "is_cyclic_scheduler":is_cyclic,
                          "is_added_noise":is_noise}

                        print("Current test with : ",pd.DataFrame.from_dict(config,orient='index',columns=["Value"]))

                        run = neptune.init_run(
                            project="geometricintegrationntnu/Transport-Equation",
                            api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2YzM5YjI2My1kYTI2LTRhNmMtOWI5Ni1lYzlmYzBiZWZiNzIifQ=="
                        )

                        run["parameters"] = config


                        model = network(n_layers,kernel_size,bias,preserve_norm,is_linear)
                        model.to(device);

                        #Create the dataloaders splitting the dataset into batches
                        trainloader = torch.utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=0)
                        testloader = torch.utils.data.DataLoader(testset,batch_size=30,shuffle=True,num_workers=0)

                        timeMax = timesteps
                        loss = train(run,model,lr,weight_decay,epochs,trainloader,timeMax,gamma_reg,is_cyclic,is_noise)

                        model.to('cpu')
                        model.eval();

                        X,Y = next(iter(testloader))

                        X = X[0].to('cpu')
                        Y = Y[0].to('cpu')

                        timesteps_test = len(Y)

                        #Generation of the plots reported also in the paper
                        generate_gif_predicted(run,pde_name,model,X,timesteps_test)
                        generate_gif_true(run,pde_name,X,Y,timesteps_test)
                        generate_gif_error(run,pde_name,model,X,Y,timesteps_test)
                        if pde_name=='linadv':
                            generate_error_plots(run,pde_name,model,testloader,preserve_norm)
                        else:
                            generate_error_plots(run,pde_name,model,testloader)

                        print("Lifting weight : ",model.lift.weight.data)
                        print("Projection weight : ",model.proj.weight.data)

                        torch.save(model.state_dict(), "trained_model.pt")

                        run["trained_model"].upload("trained_model.pt")
                        run.stop()

                        try:
                            os.remove("saved_test_results/linadv_RelativeError_Conserved_Test30.csv")
                            os.remove("saved_test_results/linadv_MaxError_Conserved_Test30.csv")
                            os.remove("saved_test_results/linadv_AverageMSE_Conserved_Test30.csv")
                            os.remove("saved_test_results/linadv_RelativeError_NonConserved_Test30.csv")
                            os.remove("saved_test_results/linadv_MaxError_NonConserved_Test30.csv")
                            os.remove("saved_test_results/linadv_AverageMSE_NonConserved_Test30.csv")
                        except:
                            pass

StdinNotImplementedError: raw_input was called, but this frontend does not support input requests.