In [1]:
## 1D General elliptic PDE of the following form: 
## -div( a(x) grad u(x)) + b(x) grad u(x) + c(x) u(x) = f(x) in [0,1] 
## a(x), b(x), c(x) are set to be constant functions 
## du_dn = g on the boundary 
## this version also contains using the tanh-activated shallow neural network to solve the PDE 
"""
log
Nov 17th 2024 Modified by Xiaofeng: 
added three functions   
1. select_discrete_dictionary
2. compute_l2_error 
3. compute_gradient_error

Nov 20th 2024 Modified by Xiaofeng 
1. use an efficient way to assemble the matrix that reuses previous matrices 
    - minimize_linear_layer_efficient

Todo: 
1. remove some redundant variable to save memory 
2. test a huge dictionary and huge quadrature points  
"""
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
import os 
from scipy.sparse import linalg
from pathlib import Path
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu" 

torch.set_default_dtype(torch.float64)
pi = torch.tensor(np.pi)
ZERO = torch.tensor([0.]).to(device)

###===============model parameters below================================
LAMBDA = -4 # c(x) = LAMBDA, if negative Helmholtz equation parameters
BETA = 5 ## convection term parameters 
DIMENSION = 3  ## dimension of the problem 
###===============model parameters above================================

## Define the neural network model
## already general in any dimension
class model_tanh(nn.Module):
    """ cosine shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    """
    def __init__(self, input_size, hidden_size1, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
    def forward(self, x):
        u1 = self.fc2( torch.tanh(self.fc1(x)) )
        return u1
    
    def tanh_activation_dx(self,x): 
        return 1/torch.cosh(x)**2  
      
    def evaluate_derivative(self, x, i):
        u1 = self.fc2( self.tanh_activation_dx(self.fc1(x)) *self.fc1.weight.t()[i-1:i,:] )  
        return u1

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    def evaluate_derivative(self, x, i):
        if self.k == 1:
            ## ZERO = torch.tensor([0.]).to(device)
            u1 = self.fc2(torch.heaviside(self.fc1(x),ZERO) * self.fc1.weight.t()[i-1:i,:] )
        else:
            u1 = self.fc2(self.k*F.relu(self.fc1(x))**(self.k-1) *self.fc1.weight.t()[i-1:i,:] )  
        return u1


In [2]:
def plot_solution_modified(r1,r2,model,x_test,u_true,name=None): 
    # Plot function: test results 
    u_model_cpu = model(x_test).cpu().detach()
    
    w = model.fc1.weight.data.squeeze()
    b = model.fc1.bias.data.squeeze()
    x_model_pt = (-b/w).view(-1,1)
    x_model_pt = x_model_pt[x_model_pt>=r1].reshape(-1,1)
    u_model_pt = model(x_model_pt).cpu().detach()
    plt.figure(dpi = 100)
    plt.plot(x_test.cpu(),u_model_cpu,'-.',label = "nn function")
    plt.plot(x_test.cpu(),u_true.cpu(),label = "true")
    # plt.plot(x_model_pt.cpu(),u_model_pt.cpu(),'.r')
    if name!=None: 
        plt.title(name)
    plt.legend()
    plt.show()

In [3]:
def PiecewiseGQ1D_weights_points(x_l,x_r,Nx, order):
    """ Output the coeffients and weights for piecewise Gauss Quadrature 
    Parameters
    ----------
    x_l : float 
    left endpoint of an interval 
    x_r: float
    right endpoint of an interval 
    Nx: int 
    number of subintervals for integration
    order: int
    order of Gauss Quadrature 
    Returns
    -------
    vectorized quadrature weights and integration points
    """
    x,w = np.polynomial.legendre.leggauss(order)
    gx = torch.tensor(x).to(device)
    gx = gx.view(1,-1) # row vector 
    gw = torch.tensor(w).to(device)    
    gw = gw.view(-1,1) # Column vector 
    nodes = torch.linspace(x_l,x_r,Nx+1).view(-1,1).to(device) 
    coef1 = ((nodes[1:,:] - nodes[:-1,:])/2) # n by 1  
    coef2 = ((nodes[1:,:] + nodes[:-1,:])/2) # n by 1  
    coef2_expand = coef2.expand(-1,gx.size(1)) # Expand to n by p shape, -1: keep the first dimension n , expand the 2nd dim (columns)
    integration_points = coef1@gx + coef2_expand
    integration_points = integration_points.flatten().view(-1,1) # Make it a column vector
    gw_expand = torch.tile(gw,(Nx,1)) # rows: n copies of current tensor, columns: 1 copy, no change
    # Modify coef1 to be compatible with func_values
    coef1_expand = coef1.expand(coef1.size(0),gx.size(1))    
    coef1_expand = coef1_expand.flatten().view(-1,1)
    return coef1_expand.to(device) * gw_expand.to(device), integration_points.to(device)

def PiecewiseGQ2D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------
    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature
    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """
#     print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,indexing='ij')).reshape(2,-1).T
    weights =  (w*w[:,None]).ravel()

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**2,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**2 /4 

    integration_points = torch.tile(gauss_pts,(Nx**2,1))
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(2,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**2)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,2)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 
#     print(integration_points.size())
    # func_values = integrand2_torch(integration_points)
    return long_weights.to(device), integration_points.to(device)

def PiecewiseGQ3D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

    """
    Parameters
    ----------
    target : 
        Target function 
    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature
    """

    # print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,x,indexing='ij')).reshape(3,-1).T
    weight_list = np.array(np.meshgrid(w,w,w,indexing='ij'))
    weights =   (weight_list[0]*weight_list[1]*weight_list[2]).ravel() 

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**3,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**3 /8 

    integration_points = torch.tile(gauss_pts,(Nx**3,1))
    # print("shape of integration_points", integration_points.size())
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(3,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**3)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,3)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 

    return long_weights.to(device), integration_points.to(device)

def MonteCarlo_Sobol_dDim_weights_points(M ,d = 4):
    Sob_integral = torch.quasirandom.SobolEngine(dimension =d, scramble= False, seed=None) 
    integration_points = Sob_integral.draw(M).double() 
    integration_points = integration_points.to(device)
    weights = torch.ones(M,1).to(device)/M 
    return weights, integration_points 

def Neumann_boundary_quadrature_points_weights(M,d):
    def generate_quadpts_on_boundary(gw_expand_bd, integration_points_bd,d):
        size_pts_bd = integration_points_bd.size(0) 
        gw_expand_bd_faces = torch.tile(gw_expand_bd,(2*d,1)) # 2d boundaries, 拉成长条

        integration_points_bd_faces = torch.zeros(2*d*integration_points_bd.size(0),d).to(device)
        for ind in range(d): 
            integration_points_bd_faces[2 *ind * size_pts_bd :(2 *ind +1) * size_pts_bd,ind:ind+1] = 0 
            integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,:ind] = integration_points_bd[:,:ind]
            integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

            integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind:ind+1] = 1
            integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,:ind] = integration_points_bd[:,:ind]        
            integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind+1:] = integration_points_bd[:,ind:]
        return gw_expand_bd_faces, integration_points_bd_faces
    
    if d == 1: 
        print('dim',d)
        gw_expand_bd_faces = torch.tensor([1.,1.]).view(-1,1).to(device)
        integration_points_bd_faces = torch.tensor([0.,1.]).view(-1,1).to(device) 
    elif d == 2: 
        print('dim',d)
        gw_expand_bd, integration_points_bd = PiecewiseGQ1D_weights_points(0,1,8192, order = 3) 
    elif d == 3: 
        gw_expand_bd, integration_points_bd = PiecewiseGQ2D_weights_points(100, order = 3) 
    elif d == 4: 
        gw_expand_bd, integration_points_bd = PiecewiseGQ3D_weights_points(25, order = 3) 
        print('dim',d)
    else: 
        gw_expand_bd, integration_points_bd = MonteCarlo_Sobol_dDim_weights_points(M ,d = d)
        print('dim >=5 ')
    gw_expand_bd_faces, integration_points_bd_faces = generate_quadpts_on_boundary(gw_expand_bd, integration_points_bd,d)
    return gw_expand_bd_faces.to(device), integration_points_bd_faces.to(device) 

def generate_relu_dict3D(N_list):
    N1 = N_list[0]
    N2 = N_list[1]
    N3 = N_list[2]
    
    N = N1*N2*N3 
    theta1 = np.linspace(0, pi, N1, endpoint= True).reshape(N1,1)
    theta2 = np.linspace(0, 2*pi, N2, endpoint= False).reshape(N2,1)
    b = np.linspace(-3**0.5, 3**0.5, N3,endpoint=False).reshape(N3,1) # threshold: 3**0.5  
    coord3 = np.array(np.meshgrid(theta1,theta2,b,indexing='ij'))
    coord3 = coord3.reshape(3,-1).T # N1*N2*N3 x 3. coordinates for the grid points 
    coord3 = torch.tensor(coord3) 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)

    f1[:,0] = torch.cos(coord3[:,0]) 
    f2[:,0] = torch.sin(coord3[:,0]) * torch.cos(coord3[:,1])
    f3[:,0] = torch.sin(coord3[:,0]) * torch.sin(coord3[:,1])
    f4[:,0] = coord3[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor

def generate_relu_dict3D_QMC(s,N0):
    # Monte Carlo 
    samples = torch.rand(s*N0,3) 
    T =torch.tensor([[pi,0,0],[0,2*pi,0],[0,0,3**0.5 *2]])
    shift = torch.tensor([0,0,-3**0.5])
    samples = samples@T + shift 

    f1 = torch.zeros(s*N0,1) 
    f2 = torch.zeros(s*N0,1)
    f3 = torch.zeros(s*N0,1)
    f4 = torch.zeros(s*N0,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1])
    f4[:,0] = samples[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor

def generate_tanh_dict3D_QMC(s,N0,Rm):
    # Monte Carlo 
    samples = torch.randn(s*N0,4)

    T =torch.tensor([[2*Rm,0.,0.,0.],[0.,2*Rm,0.,0.],[0.,0.,2*Rm,0.],[0.,0.,0.,2*Rm]])
    shift = torch.tensor([-Rm,-Rm,-Rm,-Rm],dtype = torch.float64) 
    samples = samples@T  + shift 

    return samples 

def generate_tanh_dict3D_QMC_normal(s,N0,var):
    # Monte Carlo 
    samples = torch.normal(0,var,(4,s*N0))

    return samples 

In [4]:
def show_convergence_order(err_l2,err_h10,exponent,dict_size, filename,write2file = False):
    
    if write2file:
        file_mode = "a" if os.path.exists(filename) else "w"
        f_write = open(filename, file_mode)
    
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    # f_write.write('M:{}, relu {} \n'.format(M,k))
    if write2file:
        f_write.write('dictionary size: {}\n'.format(dict_size))
        f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    if write2file:     
        f_write.write("\n")
        f_write.close()

def show_convergence_order_latex(err_l2,err_h10,exponent,k=1,d=1): 
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    # f_write.write('M:{}, relu {} \n'.format(M,k))
    # f_write.write('randomized dictionary size: {}\n'.format(N))
    # f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    l2_order = -1/2-(2*k + 1)/(2*d)
    h10_order = -1/2-(2*(k-1) + 1)/(2*d)
#     print("neuron num  & \t $\|u-u_n \|_{L^2}$ & \t order $O(n^{{{}})$ & \t $ | u -u_n |_{H^1}$ & \t order $O(n^{{{}})$ \\\ \hline \hline ".format(l2_order,h10_order))
    print("neuron num  & \t $\\|u-u_n \\|_{{L^2}}$ & \t order $O(n^{{{:.2f}}})$ & \t $ | u -u_n |_{{H^1}}$ & \t order $O(n^{{{:.2f}}})$ \\\\ \\hline \\hline ".format(l2_order, h10_order))
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")

            # print("*")
            print("{} \t\t & {:.6f} &\t\t * & \t\t {:.6f} & \t\t *  \\\ \hline  \n".format(neuron_nums[i],item, err_list2[i] ) )   
            # f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t &  {:.3e} &  \t\t {:.2f} &  \t\t {:.3e} &  \t\t {:.2f} \\\ \hline  \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            # f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    # f_write.write("\n")
    # f_write.close()

In [5]:
# def relu_dict(x_l,x_r,N):
#     """generate relu dictionary parameters 
    
#     Parameters
#     ----------
#     x_l: float 
#     x_r: float
#     N: int 
#         number of dictionary elements 
        
#     Returns
#     torch tensor
#         containing relu dictionary parameters, corresponds to nodal points
        
#     """
#     # w = 1 
#     relu_dict_parameters = torch.zeros((2*N,2)).to(device)
#     relu_dict_parameters[:N,0] = torch.ones(N)[:]
#     relu_dict_parameters[:N,1] = torch.linspace(x_l,x_r,N+1)[:-1] # relu(x-bi)  
#     relu_dict_parameters[N:2*N,0] = -torch.ones(N)[:]
#     relu_dict_parameters[N:2*N,1] = -torch.linspace(x_l,x_r,N+1)[1:] + 1/(2*N) # relu(-x - -bi) 
    
#     return relu_dict_parameters

# # relu dictionary
# def relu_dict_MC(x_l,x_r,N):
#     """generate relu dictionary parameters 
    
#     Parameters
#     ----------
#     x_l: float 
#     x_r: float
#     N: int 
#        number of dictionary elements  
        
#     Returns
#     torch tensor
#         containing relu dictionary parameters, corresponds to nodal points
#     """
#     # w = 1 
#     random_value = torch.randint(0, 2, (N,)) * 2 - 1 # +1 or -1  
#     relu_dict_parameters = torch.zeros((N,2)).to(device)
#     relu_dict_parameters[:N,0] = random_value[:]
#     relu_dict_parameters[:N,1] = (torch.rand(N)*(x_r-x_l) + x_l)*random_value # relu(x-bi) 

#     return relu_dict_parameters

# # relu dictionary
# def tanh_dict_MC(x_l,x_r,N):
#     """generate relu dictionary parameters 
    
#     Parameters
#     ----------
#     x_l: float 
#     x_r: float
#     N: int 
#        number of dictionary elements  
        
#     Returns
#     torch tensor
#         containing relu dictionary parameters, corresponds to nodal points
#     """
#     # w = 1 
#     # random_value = torch.randint(0, 2, (N,)) * 2 - 1 # +1 or -1  

#     tanh_dict_parameters = torch.zeros((N,2)).to(device)
#     tanh_dict_parameters[:N,0] = (torch.rand(N)*(x_r-x_l) + x_l)
#     tanh_dict_parameters[:N,1] = (torch.rand(N)*(x_r-x_l) + x_l) # relu(x-bi) 

#     return tanh_dict_parameters

def select_discrete_dictionary(activation,rand_deter,N0,R):
    if isinstance(N0, list):
        N0_num = np.prod(N0)
    else: 
        N0_num = N0 

    if rand_deter == 'deter': 
        if activation == 'relu':
            dict_parameters = generate_relu_dict3D(N0).to(device) 
        elif activation == 'tanh':
            print("for tanh, automatically use randomized dictionary")
            dict_parameters = generate_tanh_dict3D_QMC(1,N0_num,R).to(device)
            
    if rand_deter == 'rand': 
        if activation == 'relu':
            dict_parameters = generate_relu_dict3D_QMC(1,N0_num).to(device)   
        elif activation == 'tanh':
            dict_parameters = generate_tanh_dict3D_QMC(1,N0_num,R).to(device)
    return dict_parameters 

In [6]:
## helper functions 
def compute_l2_error(u_exact,my_model,M,batch_size_2,weights,integration_points): 
    err = 0 
    if my_model == None: 
        for jj in range(0,M,batch_size_2): 
            end_index = jj + batch_size_2 
            func_values = u_exact(integration_points[jj:end_index,:])
            err += torch.sum(func_values**2 * weights[jj:end_index,:])**0.5
    else: 
        for jj in range(0,M,batch_size_2): 
            end_index = jj + batch_size_2 
            func_values = u_exact(integration_points[jj:end_index,:]) - my_model(integration_points[jj:end_index,:]).detach()
            err += torch.sum(func_values**2 * weights[jj:end_index,:])**0.5	
    return err 

def compute_gradient_error(u_exact_grad,my_model,M,batch_size_2,weights,integration_points):
    """
    Parameters
    ----------
    u_exact_grad: list or None
        a list that contains ways of evaluating partial derivatives that gives the gradient  
    """
    err_h10 = 0 
     # initial gradient error 
    if u_exact_grad != None and my_model!=None:
        u_grad = u_exact_grad() 
        for ii, grad_i in enumerate(u_grad): 
            for jj in range(0,M,batch_size_2): 
                end_index = jj + batch_size_2 
                my_model_dxi = my_model.evaluate_derivative(integration_points[jj:end_index,:],ii+1).detach() 
                err_h10 += torch.sum((grad_i(integration_points[jj:end_index,:]) - my_model_dxi)**2 * weights[jj:end_index,:])**0.5
    elif u_exact_grad != None and my_model==None:
        u_grad = u_exact_grad() 
        for grad_i in u_grad: 
            for jj in range(0,M,batch_size_2): 
                end_index = jj + batch_size_2 
                err_h10 += torch.sum((grad_i(integration_points[jj:end_index,:]))**2 * weights[jj:end_index,:])**0.5
    return err_h10


In [7]:
def minimize_linear_layer_efficient(Mat, rhs_vec, model,target,weights, integration_points,weights_bd, integration_points_bd, g_N, activation = 'relu', solver = 'direct',memory=2**29):  
    """
    calls the following functions (dependency): 
    1. GQ_piecewise_2D
    input: the nn model containing parameter 
    1. define the loss function  
    2. take derivative to extract the linear system A
    3. call the cg solver in scipy to solve the linear system 
    output: sol. solution of Ax = b
    """
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    neuron_num = b.size(0) 
    M = integration_points.size(0)

    total_size = neuron_num * M # memory, number of floating numbers 
    print('total size: {} {} = {}'.format(neuron_num,M,total_size))
    num_batch = total_size//memory + 1 # divide according to memory
    print("num batches: ",num_batch)
    batch_size = M//num_batch

    coef_func = LAMBDA # 3 * model(integration_points).detach()**2 #changing after each newton iteration 
    # jac = torch.zeros(neuron_num,neuron_num).to(device)
    # rhs = torch.zeros(neuron_num,1).to(device)
    col_k_sym = torch.zeros(neuron_num,1).to(device)
    col_k_nonsym = torch.zeros(neuron_num,1).to(device) # nonsymmetric part from convection term 
    row_k_nonsym = torch.zeros(neuron_num,1).to(device) # nonsymmetric part from convection term 
    rhs_k = torch.zeros(1,1).to(device) 
    for j in range(0,M,batch_size): 
        end_index = j + batch_size
        if activation == 'relu':
            basis_value_col = F.relu(integration_points[j:end_index] @ (w.t())+ b)**(model.k) 
        if activation == 'tanh':
            basis_value_col = torch.tanh(integration_points[j:end_index] @ w.t()+ b)
        weighted_basis_value_col = basis_value_col * weights[j:end_index] 
        
        if activation == 'relu' and model.k == 1:  
            derivative_comm_part = torch.heaviside(integration_points[j:end_index] @ w.t()+ b, ZERO) 
        elif activation == 'relu' and model.k > 1: 
            derivative_comm_part = model.k * F.relu(integration_points[j:end_index] @ w.t()+ b)**(model.k-1)
        elif activation == 'tanh':
            derivative_comm_part = torch.cosh(integration_points[j:end_index] @ w.t()+ b)**(-2)  
            
        # jac += weighted_basis_value_col.t() @ (coef_func * basis_value_col) # mass matrix 
        # rhs += weighted_basis_value_col.t() @ (target(integration_points[j:end_index]) ) #rhs 
        col_k_sym += weighted_basis_value_col.t() @ (coef_func * basis_value_col[:,neuron_num-1:neuron_num])
        rhs_k += (weighted_basis_value_col[:,neuron_num-1:neuron_num]).t() @ (target(integration_points[j:end_index])) #rhs 

        for d in range(DIMENSION): 
            basis_value_dxi_col = derivative_comm_part * w.t()[d:d+1,:]
            weighted_basis_value_dxi_col = basis_value_dxi_col * weights[j:end_index] 
            
            col_k_sym += weighted_basis_value_dxi_col.t() @ basis_value_dxi_col[:,neuron_num-1:neuron_num] # stifness matrix 
            col_k_nonsym += BETA* weighted_basis_value_col.t()@basis_value_dxi_col[:,neuron_num-1:neuron_num] # convection term (grad u, v)
            row_k_nonsym += BETA * weighted_basis_value_dxi_col.t() @ basis_value_col[:,neuron_num-1:neuron_num]

    # Neumman boundary condition
    if DIMENSION == 1: 
        if activation == 'relu':
            basis_value_col_bd = F.relu(integration_points_bd @ w.t()+ b)**(model.k) 
        elif activation == 'tanh':
            basis_value_col_bd = torch.tanh(integration_points_bd @ w.t()+ b) 
        weighted_basis_value_col_bd = basis_value_col_bd *weights_bd 
        dudn = g_N(integration_points_bd)* (torch.tensor([-1,1]).view(-1,1)).to(device) 
        rhs_gN =  (weighted_basis_value_col_bd[:,neuron_num-1:neuron_num]).t() @ dudn
    # neumann boundary condition 
    if DIMENSION > 1 and g_N != None:
        size_pts_bd = int(integration_points_bd.size(0)/(2*DIMENSION))
        bcs_N = g_N(DIMENSION)
        for ii, g_ii in bcs_N:
            #Another for loop needed if we need to divide the integration points into batches 
            weighted_g_N = -g_ii(integration_points_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* weights_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            if activation == 'relu':
                basis_value_bd_col = F.relu(integration_points_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            elif activation == 'tanh':
                basis_value_bd_col = torch.tanh(integration_points_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b) 
            rhs_gN += (basis_value_bd_col[:,neuron_num-1:neuron_num]).t() @ weighted_g_N

            weighted_g_N = g_ii(integration_points_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* weights_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
            if activation == 'relu':
                basis_value_bd_col = F.relu(integration_points_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            elif activation == 'tanh':
                basis_value_bd_col = torch.tanh(integration_points_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)
            rhs_gN += (basis_value_bd_col[:,neuron_num-1:neuron_num]) @ weighted_g_N
        rhs_k += rhs_gN 

    ## form the linear system by adding the last column and row 
#     print("size col_k",col_k)
    Mat[:neuron_num,neuron_num-1] = col_k_sym.view(-1)  + col_k_nonsym.view(-1)#col 
    Mat[neuron_num-1,:neuron_num-1] =  col_k_sym[:-1,:].view(-1) + row_k_nonsym[:-1,:].view(-1) # row 
#     Mat[neuron_num-1:neuron_num,:neuron_num-1:neuron_num] += row_k_nonsym[-1,-1]
    rhs_vec[neuron_num-1] = rhs_k.view(-1) 

    jac = Mat[:neuron_num,:neuron_num]
    rhs = rhs_vec[:neuron_num,:] 
    print("assembling the matrix time taken: ", time.time()-start_time) 
    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    ## update the solution 
    return sol 

In [8]:
def select_greedy_neuron_ind(relu_dict_parameters,my_model,target,weights, integration_points,g_N,weights_bd, integration_points_bd,k,activation = 'relu',memory = 2**29): 
    dim = integration_points.size(1) 
    M = integration_points.size(0)
    N0 = relu_dict_parameters.size(0)   
    neuron_num = my_model.fc2.weight.size(1) if my_model != None else 0

    output = torch.zeros(N0,1).to(device) 
    s_time = time.time()
    total_size2 = M*(neuron_num+1)
    num_batch2 = total_size2//memory + 1 
    batch_size_2 = M//num_batch2 # integration points 
    residual_values = torch.zeros(M,1).to(device) 

    if my_model!= None:
        for jj in range(0,M,batch_size_2): 
            end_index = jj + batch_size_2
            residual_values[jj:end_index] += - target(integration_points[jj:end_index]) 
            residual_values[jj:end_index] += LAMBDA * my_model(integration_points[jj:end_index,:]).detach()
    else:  
        for jj in range(0,M,batch_size_2): 
            end_index = jj + batch_size_2
            residual_values[jj:end_index] += - target(integration_points[jj:end_index])
    weight_func_values = residual_values*weights


    total_size = M * N0 
    num_batch = total_size//memory + 1 
    batch_size_1 = N0//num_batch # dictionary elements
    print("======argmax subproblem:f and N(u) terms, num batches: ",num_batch)
    for j in range(0,N0,batch_size_1):
        end_index = j + batch_size_1 
        if activation == 'relu':
            basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k) # uses broadcasting
        elif activation == 'tanh':
            basis_values = (torch.tanh( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim]))
        output[j:end_index] += basis_values.t()@weight_func_values #
    print('======TIME=======f and N(u) terms time :',time.time()-s_time)

    s_time =time.time() 
    if my_model!= None:
        #compute the derivative of the model 
        model_derivative_values = torch.zeros(M,dim).to(device) 
        for d in range(DIMENSION): ## there is a more efficient way 
            for jj in range(0,M,batch_size_2):
                end_index = jj + batch_size_2 
                model_derivative_values[jj:end_index,d:d+1] = my_model.evaluate_derivative(integration_points[jj:end_index,:],d+1).detach()
            #compute the derivative of the dictionary elements 
        for j in range(0,N0,batch_size_1): 
            end_index = j + batch_size_1 
            if activation == 'relu' and my_model.k == 1: 
                weighted_derivative_part = weights * torch.heaviside(integration_points@ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim], ZERO)
                weighted_basis_value_col = weights *  (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k) # uses broadcasting
            elif activation == 'relu' and my_model.k > 1:
                weighted_derivative_part = weights * my_model.k * F.relu(integration_points@ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim])**(my_model.k-1)
                weighted_basis_value_col = weights *  (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k) # uses broadcasting
            elif activation == 'tanh':
                weighted_derivative_part = weights * (1/torch.cosh(integration_points@ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim])**2)
                weighted_basis_value_col = weights *  (torch.tanh( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])) # uses broadcasting
            for d in range(DIMENSION):
                weighted_basis_value_dx_col = weighted_derivative_part * relu_dict_parameters.t()[d:d+1,j:end_index] 
                output[j:end_index] += weighted_basis_value_dx_col.t() @ model_derivative_values[:,d:d+1]  # diffusion term
                output[j:end_index] += BETA * weighted_basis_value_col.t() @ model_derivative_values[:,d:d+1] # convection term 
    print("======argmax subproblem:< grad u_n, grad g> terms, num batches: ",num_batch)
    print('======TIME=======< grad u_n, grad g> terms time :',time.time()-s_time)
    
    
    # Neumann boundary condition
    s_time =time.time()  
    if g_N != None:  
        if DIMENSION == 1:
            if activation == 'relu':
                basis_values_bd_col = (F.relu(relu_dict_parameters[:,0] *integration_points_bd - relu_dict_parameters[:,1])**k) 
            elif activation == 'tanh':
                basis_values_bd_col = (torch.tanh(relu_dict_parameters[:,0] *integration_points_bd - relu_dict_parameters[:,1])) 
            weighted_basis_value_col_bd = basis_values_bd_col * weights_bd
            dudn = g_N(integration_points_bd)* (torch.tensor([-1,1]).view(-1,1)).to(device)
            output -=  weighted_basis_value_col_bd.t() @ dudn
        else: 
            size_pts_bd = int(integration_points_bd.size(0)/(2*DIMENSION)) # pre-defined rules for integration points on bdries
            bcs_N = g_N(dim)
            for ii, g_ii in bcs_N:
                # pts_bd_ii = pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                weighted_g_N = -g_ii(integration_points_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* weights_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                basis_value_bd_col = F.relu(integration_points_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ (relu_dict_parameters[:,0:dim].T) - relu_dict_parameters[:,dim] )**(k)
                output -= basis_value_bd_col.t() @ weighted_g_N

                weighted_g_N = g_ii(integration_points_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* weights_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                basis_value_bd_col = F.relu(integration_points_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ (relu_dict_parameters[:,0:dim].T) - relu_dict_parameters[:,dim])**(k)
                output -= basis_value_bd_col.t() @ weighted_g_N
    print('======TIME=======Neumann boundary condition time :',time.time()-s_time)
    output = torch.abs(output) 
    neuron_index = torch.argmax(output.flatten())
    
    return neuron_index 

def OGAGeneralEllipticReLUNDim(my_model,target,u_exact,u_exact_grad,g_N, N_list,num_epochs,plot_freq = 10,Nx = 1024,order =5, activation = 'relu',k = 1,rand_deter = 'deter', solver = 'direct',memory = 2**29): 
    """ Orthogonal greedy algorithm to solve a general elliptic PDE over [0,1]^d
    two choices of activation: tanh, relu_k 
    """

    if DIMENSION == 1:
        weights, integration_points = PiecewiseGQ1D_weights_points(x_l= 0,x_r=1, Nx = Nx,order =order)
    elif DIMENSION == 2:
        weights, integration_points = PiecewiseGQ2D_weights_points(Nx = Nx, order = order)
    elif DIMENSION == 3:
        weights, integration_points = PiecewiseGQ3D_weights_points(Nx = Nx, order = order) 
    else:
        weights, integration_points = MonteCarlo_Sobol_dDim_weights_points(M = 2**14 ,d = 4)
    weights_bd, integration_points_bd = Neumann_boundary_quadrature_points_weights(M = 2**14,d = DIMENSION)
    M = integration_points.size(0) 

    # Compute initial L2 error and the gradient error 
    err = torch.zeros(num_epochs+1).to(device)
    err_h10 = torch.zeros(num_epochs+1).to(device)
    
    num_neuron = 0 if my_model == None else int(my_model.fc1.bias.detach().data.size(0))
    total_size2 = M*(num_neuron+1)
    num_batch2 = total_size2//memory + 1 
    batch_size_2 = M//num_batch2 # in
      
    if my_model == None: 
        list_b,list_w = [],[]
    else:
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        list_b,list_w = list(bias), list(weights)
        
    err[0] = compute_l2_error(u_exact,my_model,M,batch_size_2,weights,integration_points)
    err_h10[0] = compute_gradient_error(u_exact_grad,my_model,M,batch_size_2,weights,integration_points)

    start_time = time.time()
    solver = "direct"
    print("using linear solver: ",solver)
    
    N0 = np.prod(N_list) 
    dict_parameters = None 

    Mat = torch.zeros(num_epochs,num_epochs).to(device)  # size of the final matrix 
    rhs_vec = torch.zeros(num_epochs,1).to(device) # size of the final vector 

    for i in range(num_epochs): 
        print('epoch: ',i+1)

        if (rand_deter == 'deter' and i == 0) or (rand_deter == 'rand'): 
            dict_parameters = select_discrete_dictionary(activation,rand_deter,N_list,R = 0.4)
    
        neuron_index = select_greedy_neuron_ind(dict_parameters,my_model,target,weights, integration_points,g_N,weights_bd, integration_points_bd, k,activation = activation, memory=memory) 
        
        list_w.append(dict_parameters[neuron_index,0:DIMENSION])
        list_b.append(-dict_parameters[neuron_index,DIMENSION]) # different sign convention 
        num_neuron += 1
        if activation == 'relu':
            my_model = model(DIMENSION,num_neuron,1,k).to(device)
        elif activation == 'tanh':
            my_model = model_tanh(DIMENSION,num_neuron,1).to(device) 

        w_tensor = torch.stack(list_w, 0) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

#         sol = minimize_linear_layer_general_elliptic(my_model,target,weights, integration_points,weights_bd, integration_points_bd,g_N,activation =activation, solver = solver,memory = memory)
        sol = minimize_linear_layer_efficient(Mat, rhs_vec, my_model,target,weights, integration_points,weights_bd, integration_points_bd,g_N,activation =activation, solver = solver,memory = memory)
        sol = sol.flatten() 
        my_model.fc2.weight.data[0,:] = sol[:]

        #plot the solution 
        if DIMENSION == 1 and (i+1)%plot_freq == 0:  
            x_test = torch.linspace(0,1,200).view(-1,1).to(device)
            u_true = u_exact(x_test)
            plot_solution_modified(0,1,my_model,x_test,u_true)

        # Get L2 error and gradient error 
        total_size2 = M*(num_neuron+1)
        num_batch2 = total_size2//memory + 1 
        batch_size_2 = M//num_batch2 # integration points 
        err[i+1] = compute_l2_error(u_exact,my_model,M,batch_size_2,weights,integration_points)
        err_h10[i+1] = compute_gradient_error(u_exact_grad,my_model,M,batch_size_2,weights,integration_points)

    print("time taken: ",time.time() - start_time)
    return err.cpu(), err_h10.cpu(), my_model


##  example 

In [None]:
#
m1 = 2
m2 = 2
m3 = 2 
def u_exact(x):
    return torch.cos(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])  

def u_exact_grad():

    def grad_1(x):
        return - m1*pi* torch.sin(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])   
    def grad_2(x):
        return - m2*pi* torch.cos(m1*pi*x[:,0:1])*torch.sin( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])  
    def grad_3(x):
        return - m3*pi* torch.cos(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.sin(m3*pi*x[:,2:3])   
    
    u_grad=[grad_1, grad_2,grad_3] 

    return u_grad

def laplace_u_exact(x):
    return -((m1*pi)**2 + (m2*pi)**2 +(m3*pi)**2 )  * torch.cos(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])

def convection_term(x):
    grad1 = - m1*pi* torch.sin(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])   
    grad2 = - m2*pi* torch.cos(m1*pi*x[:,0:1])*torch.sin( m2*pi*x[:,1:2]) * torch.cos(m3*pi*x[:,2:3])  
    grad3 = - m3*pi* torch.cos(m1*pi*x[:,0:1])*torch.cos( m2*pi*x[:,1:2]) * torch.sin(m3*pi*x[:,2:3])    
    return BETA * grad1 + BETA * grad2 + BETA * grad3  

def rhs(x):
    return  -laplace_u_exact(x) + convection_term(x) + LAMBDA * u_exact(x)  

g_N = None 



function_name = "cosine"
Nx = 50 
order = 3
exponent = 9
num_epochs = 2**exponent  
plot_freq = num_epochs 
rand_deter = 'rand'
memory = 2**27
activation = 'tanh' 
relu_k = 3 # not used if activation != relu 

filename_write = "3DOGA-{}-{}-general-elliptic-a{}-b{}-c{}-order.txt".format(activation,function_name,1,BETA,LAMBDA)
f_write = open(filename_write, "a")
f_write.write("Integration points: Nx {}, order {} \n".format(Nx,order))
f_write.close() 

save = True 
for N_list in [[2**3,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 
    f_write = open(filename_write, "a")
    my_model = None 
    N = np.prod(N_list)
    err_QMC2, err_h10, my_model = OGAGeneralEllipticReLUNDim(my_model,rhs, u_exact, u_exact_grad,g_N, N_list,num_epochs,plot_freq, Nx, order, activation= activation, k = relu_k, rand_deter = rand_deter, solver = "direct",memory = memory)
    if save: 
        folder = 'data/'
        filename = folder + 'err_OGA_3D_{}_neuron_{}_N_{}_rand.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_OGA_3D_{}_neuron_{}_N_{}_rand.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)
    
    show_convergence_order(err_QMC2,err_h10,exponent,N,filename_write,True)
    show_convergence_order_latex(err_QMC2,err_h10,exponent,k =relu_k,d = DIMENSION)

using linear solver:  direct
epoch:  1
total size: 1 3375000 = 3375000
num batches:  1
assembling the matrix time taken:  0.003572225570678711
solving Ax = b time taken:  0.13354730606079102
epoch:  2
total size: 2 3375000 = 6750000
num batches:  1
assembling the matrix time taken:  0.0035834312438964844
solving Ax = b time taken:  0.05768251419067383
epoch:  3
total size: 3 3375000 = 10125000
num batches:  1
assembling the matrix time taken:  0.0033750534057617188
solving Ax = b time taken:  0.06227993965148926
epoch:  4
total size: 4 3375000 = 13500000
num batches:  1
assembling the matrix time taken:  0.0033218860626220703
solving Ax = b time taken:  0.06561636924743652
epoch:  5
total size: 5 3375000 = 16875000
num batches:  1
assembling the matrix time taken:  0.0034694671630859375
solving Ax = b time taken:  0.07160663604736328
epoch:  6
total size: 6 3375000 = 20250000
num batches:  1
assembling the matrix time taken:  0.003237009048461914
solving Ax = b time taken:  0.076147317

total size: 17 3375000 = 57375000
num batches:  1
assembling the matrix time taken:  0.003171682357788086
solving Ax = b time taken:  0.1407918930053711
epoch:  18
total size: 18 3375000 = 60750000
num batches:  1
assembling the matrix time taken:  0.0032672882080078125
solving Ax = b time taken:  0.1466679573059082
epoch:  19
total size: 19 3375000 = 64125000
num batches:  1
assembling the matrix time taken:  0.003213644027709961
solving Ax = b time taken:  0.15276193618774414
epoch:  20
total size: 20 3375000 = 67500000
num batches:  1
assembling the matrix time taken:  0.003373384475708008
solving Ax = b time taken:  0.15808749198913574
epoch:  21
total size: 21 3375000 = 70875000
num batches:  1
assembling the matrix time taken:  0.0032198429107666016
solving Ax = b time taken:  0.16562700271606445
epoch:  22
total size: 22 3375000 = 74250000
num batches:  1
assembling the matrix time taken:  0.003259420394897461
solving Ax = b time taken:  0.17389702796936035
epoch:  23
total size

total size: 34 3375000 = 114750000
num batches:  1
assembling the matrix time taken:  0.003050088882446289
solving Ax = b time taken:  0.24721717834472656
epoch:  35
total size: 35 3375000 = 118125000
num batches:  1
assembling the matrix time taken:  0.0030570030212402344
solving Ax = b time taken:  0.2996811866760254
epoch:  36
total size: 36 3375000 = 121500000
num batches:  1
assembling the matrix time taken:  0.003307819366455078
solving Ax = b time taken:  0.2974233627319336
epoch:  37
total size: 37 3375000 = 124875000
num batches:  1
assembling the matrix time taken:  0.003004312515258789
solving Ax = b time taken:  0.25941967964172363
epoch:  38
total size: 38 3375000 = 128250000
num batches:  1
assembling the matrix time taken:  0.003001689910888672
solving Ax = b time taken:  0.30832338333129883
epoch:  39
total size: 39 3375000 = 131625000
num batches:  1
assembling the matrix time taken:  0.003246784210205078
solving Ax = b time taken:  0.313647985458374
epoch:  40
total s

solving Ax = b time taken:  0.3385603427886963
epoch:  51
total size: 51 3375000 = 172125000
num batches:  2
assembling the matrix time taken:  0.0055844783782958984
solving Ax = b time taken:  0.3480954170227051
epoch:  52
total size: 52 3375000 = 175500000
num batches:  2
assembling the matrix time taken:  0.005388736724853516
solving Ax = b time taken:  0.3491191864013672
epoch:  53
total size: 53 3375000 = 178875000
num batches:  2
assembling the matrix time taken:  0.005605936050415039
solving Ax = b time taken:  0.3643462657928467
epoch:  54
total size: 54 3375000 = 182250000
num batches:  2
assembling the matrix time taken:  0.005517721176147461
solving Ax = b time taken:  0.36507153511047363
epoch:  55
total size: 55 3375000 = 185625000
num batches:  2
assembling the matrix time taken:  0.00548553466796875
solving Ax = b time taken:  0.37130165100097656
epoch:  56
total size: 56 3375000 = 189000000
num batches:  2
assembling the matrix time taken:  0.005349397659301758
solving 

total size: 67 3375000 = 226125000
num batches:  2
assembling the matrix time taken:  0.005630016326904297
solving Ax = b time taken:  0.4464857578277588
epoch:  68
total size: 68 3375000 = 229500000
num batches:  2
assembling the matrix time taken:  0.0055696964263916016
solving Ax = b time taken:  0.44429636001586914
epoch:  69
total size: 69 3375000 = 232875000
num batches:  2
assembling the matrix time taken:  0.0055370330810546875
solving Ax = b time taken:  0.45058107376098633
epoch:  70
total size: 70 3375000 = 236250000
num batches:  2
assembling the matrix time taken:  0.005434274673461914
solving Ax = b time taken:  0.4590725898742676
epoch:  71
total size: 71 3375000 = 239625000
num batches:  2
assembling the matrix time taken:  0.005758523941040039
solving Ax = b time taken:  0.45715975761413574
epoch:  72
total size: 72 3375000 = 243000000
num batches:  2
assembling the matrix time taken:  0.005455493927001953
solving Ax = b time taken:  0.44545507431030273
epoch:  73
tota

total size: 84 3375000 = 283500000
num batches:  3
assembling the matrix time taken:  0.008020877838134766
solving Ax = b time taken:  0.5451152324676514
epoch:  85
total size: 85 3375000 = 286875000
num batches:  3
assembling the matrix time taken:  0.007678508758544922
solving Ax = b time taken:  0.5353231430053711
epoch:  86
total size: 86 3375000 = 290250000
num batches:  3
assembling the matrix time taken:  0.007690906524658203
solving Ax = b time taken:  0.5407495498657227
epoch:  87
total size: 87 3375000 = 293625000
num batches:  3
assembling the matrix time taken:  0.00790548324584961
solving Ax = b time taken:  0.5469341278076172
epoch:  88
total size: 88 3375000 = 297000000
num batches:  3
assembling the matrix time taken:  0.008014202117919922
solving Ax = b time taken:  0.5146908760070801
epoch:  89
total size: 89 3375000 = 300375000
num batches:  3
assembling the matrix time taken:  0.007932424545288086
solving Ax = b time taken:  0.554020881652832
epoch:  90
total size: 

solving Ax = b time taken:  0.6280653476715088
epoch:  101
total size: 101 3375000 = 340875000
num batches:  3
assembling the matrix time taken:  0.00781393051147461
solving Ax = b time taken:  0.7120165824890137
epoch:  102
total size: 102 3375000 = 344250000
num batches:  3
assembling the matrix time taken:  0.007995367050170898
solving Ax = b time taken:  0.6521446704864502
epoch:  103
total size: 103 3375000 = 347625000
num batches:  3
assembling the matrix time taken:  0.007874488830566406
solving Ax = b time taken:  0.6266615390777588
epoch:  104
total size: 104 3375000 = 351000000
num batches:  3
assembling the matrix time taken:  0.008048295974731445
solving Ax = b time taken:  0.7426064014434814
epoch:  105
total size: 105 3375000 = 354375000
num batches:  3
assembling the matrix time taken:  0.008085966110229492
solving Ax = b time taken:  0.6699981689453125
epoch:  106
total size: 106 3375000 = 357750000
num batches:  3
assembling the matrix time taken:  0.008187055587768555

total size: 117 3375000 = 394875000
num batches:  3
assembling the matrix time taken:  0.007730007171630859
solving Ax = b time taken:  0.7715890407562256
epoch:  118
total size: 118 3375000 = 398250000
num batches:  3
assembling the matrix time taken:  0.00811314582824707
solving Ax = b time taken:  0.7740874290466309
epoch:  119
total size: 119 3375000 = 401625000
num batches:  3
assembling the matrix time taken:  0.23799896240234375
solving Ax = b time taken:  0.6437318325042725
epoch:  120
total size: 120 3375000 = 405000000
num batches:  4
assembling the matrix time taken:  0.010775566101074219
solving Ax = b time taken:  0.7405948638916016
epoch:  121
total size: 121 3375000 = 408375000
num batches:  4
assembling the matrix time taken:  0.010544538497924805
solving Ax = b time taken:  0.7784318923950195
epoch:  122
total size: 122 3375000 = 411750000
num batches:  4
assembling the matrix time taken:  0.010511398315429688
solving Ax = b time taken:  0.770697832107544
epoch:  123
t

total size: 134 3375000 = 452250000
num batches:  4
assembling the matrix time taken:  0.010240554809570312
solving Ax = b time taken:  0.8430321216583252
epoch:  135
total size: 135 3375000 = 455625000
num batches:  4
assembling the matrix time taken:  0.010749340057373047
solving Ax = b time taken:  0.8460338115692139
epoch:  136
total size: 136 3375000 = 459000000
num batches:  4
assembling the matrix time taken:  0.010540485382080078
solving Ax = b time taken:  0.8051691055297852
epoch:  137
total size: 137 3375000 = 462375000
num batches:  4
assembling the matrix time taken:  0.010526418685913086
solving Ax = b time taken:  0.8287115097045898
epoch:  138
total size: 138 3375000 = 465750000
num batches:  4
assembling the matrix time taken:  0.011127948760986328
solving Ax = b time taken:  0.826549768447876
epoch:  139
total size: 139 3375000 = 469125000
num batches:  4
assembling the matrix time taken:  0.010802745819091797
solving Ax = b time taken:  0.8339967727661133
epoch:  140

solving Ax = b time taken:  0.9045925140380859
epoch:  151
total size: 151 3375000 = 509625000
num batches:  4
assembling the matrix time taken:  0.01164865493774414
solving Ax = b time taken:  0.9115991592407227
epoch:  152
total size: 152 3375000 = 513000000
num batches:  4
assembling the matrix time taken:  0.010254144668579102
solving Ax = b time taken:  0.8848788738250732
epoch:  153
total size: 153 3375000 = 516375000
num batches:  4
assembling the matrix time taken:  0.010431528091430664
solving Ax = b time taken:  0.9162294864654541
epoch:  154
total size: 154 3375000 = 519750000
num batches:  4
assembling the matrix time taken:  0.009942293167114258
solving Ax = b time taken:  0.912346601486206
epoch:  155
total size: 155 3375000 = 523125000
num batches:  4
assembling the matrix time taken:  0.010895967483520508
solving Ax = b time taken:  0.919522762298584
epoch:  156
total size: 156 3375000 = 526500000
num batches:  4
assembling the matrix time taken:  0.010469436645507812
s

total size: 167 3375000 = 563625000
num batches:  5
assembling the matrix time taken:  0.012295007705688477
solving Ax = b time taken:  0.9321398735046387
epoch:  168
total size: 168 3375000 = 567000000
num batches:  5
assembling the matrix time taken:  0.013109445571899414
solving Ax = b time taken:  0.9171051979064941
epoch:  169
total size: 169 3375000 = 570375000
num batches:  5
assembling the matrix time taken:  0.013286828994750977
solving Ax = b time taken:  0.937999963760376
epoch:  170
total size: 170 3375000 = 573750000
num batches:  5
assembling the matrix time taken:  0.013319730758666992
solving Ax = b time taken:  0.9377231597900391
epoch:  171
total size: 171 3375000 = 577125000
num batches:  5
assembling the matrix time taken:  0.013387918472290039
solving Ax = b time taken:  0.9365627765655518
epoch:  172
total size: 172 3375000 = 580500000
num batches:  5
assembling the matrix time taken:  0.013069868087768555
solving Ax = b time taken:  0.9408972263336182
epoch:  173

total size: 184 3375000 = 621000000
num batches:  5
assembling the matrix time taken:  0.012892484664916992
solving Ax = b time taken:  0.9845032691955566
epoch:  185
total size: 185 3375000 = 624375000
num batches:  5
assembling the matrix time taken:  0.012794971466064453
solving Ax = b time taken:  1.0206835269927979
epoch:  186
total size: 186 3375000 = 627750000
num batches:  5
assembling the matrix time taken:  0.013285160064697266
solving Ax = b time taken:  1.103344440460205
epoch:  187
total size: 187 3375000 = 631125000
num batches:  5
assembling the matrix time taken:  0.0127410888671875
solving Ax = b time taken:  1.0210282802581787
epoch:  188
total size: 188 3375000 = 634500000
num batches:  5
assembling the matrix time taken:  0.013023138046264648
solving Ax = b time taken:  1.0020380020141602
epoch:  189
total size: 189 3375000 = 637875000
num batches:  5
assembling the matrix time taken:  0.013012170791625977
solving Ax = b time taken:  1.1121981143951416
epoch:  190
t

solving Ax = b time taken:  1.0518066883087158
epoch:  201
total size: 201 3375000 = 678375000
num batches:  6
assembling the matrix time taken:  0.015815258026123047
solving Ax = b time taken:  1.0779087543487549
epoch:  202
total size: 202 3375000 = 681750000
num batches:  6
assembling the matrix time taken:  0.015456914901733398
solving Ax = b time taken:  1.0705089569091797
epoch:  203
total size: 203 3375000 = 685125000
num batches:  6
assembling the matrix time taken:  0.015195369720458984
solving Ax = b time taken:  1.0785365104675293
epoch:  204
total size: 204 3375000 = 688500000
num batches:  6
assembling the matrix time taken:  0.015323162078857422
solving Ax = b time taken:  1.0724115371704102
epoch:  205
total size: 205 3375000 = 691875000
num batches:  6
assembling the matrix time taken:  0.015318632125854492
solving Ax = b time taken:  1.0789015293121338
epoch:  206
total size: 206 3375000 = 695250000
num batches:  6
assembling the matrix time taken:  0.01525664329528808

total size: 217 3375000 = 732375000
num batches:  6
assembling the matrix time taken:  0.015337467193603516
solving Ax = b time taken:  1.134000301361084
epoch:  218
total size: 218 3375000 = 735750000
num batches:  6
assembling the matrix time taken:  0.01517033576965332
solving Ax = b time taken:  1.1336426734924316
epoch:  219
total size: 219 3375000 = 739125000
num batches:  6
assembling the matrix time taken:  0.015435218811035156
solving Ax = b time taken:  1.1404788494110107
epoch:  220
total size: 220 3375000 = 742500000
num batches:  6
assembling the matrix time taken:  0.01543736457824707
solving Ax = b time taken:  1.1328768730163574
epoch:  221
total size: 221 3375000 = 745875000
num batches:  6
assembling the matrix time taken:  0.01605367660522461
solving Ax = b time taken:  1.1363334655761719
epoch:  222
total size: 222 3375000 = 749250000
num batches:  6
assembling the matrix time taken:  0.015380620956420898
solving Ax = b time taken:  1.137610912322998
epoch:  223
tot

total size: 234 3375000 = 789750000
num batches:  6
assembling the matrix time taken:  0.015450239181518555
solving Ax = b time taken:  1.189544439315796
epoch:  235
total size: 235 3375000 = 793125000
num batches:  6
assembling the matrix time taken:  0.01571941375732422
solving Ax = b time taken:  1.2398748397827148
epoch:  236
total size: 236 3375000 = 796500000
num batches:  6
assembling the matrix time taken:  0.015827417373657227
solving Ax = b time taken:  1.2782647609710693
epoch:  237
total size: 237 3375000 = 799875000
num batches:  6
assembling the matrix time taken:  0.015651702880859375
solving Ax = b time taken:  1.1953699588775635
epoch:  238
total size: 238 3375000 = 803250000
num batches:  6
assembling the matrix time taken:  0.015483856201171875
solving Ax = b time taken:  1.2574443817138672
epoch:  239
total size: 239 3375000 = 806625000
num batches:  7
assembling the matrix time taken:  0.020823240280151367
solving Ax = b time taken:  1.285778284072876
epoch:  240
t

In [None]:


function_name = "cosine"
Nx = 50 
order = 3
exponent = 9
num_epochs = 2**exponent  
plot_freq = num_epochs 
rand_deter = 'rand'
memory = 2**27
activation = 'relu' 
relu_k = 3 # not used 

filename_write = "3DOGA-{}-{}-general-elliptic-a{}-b{}-c{}-order.txt".format(activation,function_name,1,BETA,LAMBDA)
f_write = open(filename_write, "a")
f_write.write("\n")
f_write.close() 

save = True 
for N_list in [[2**3,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 
    f_write = open(filename_write, "a")
    my_model = None 
    N = np.prod(N_list)
    err_QMC2, err_h10, my_model = OGAGeneralEllipticReLUNDim(my_model,rhs, u_exact, u_exact_grad,g_N, N_list,num_epochs,plot_freq, Nx, order, activation= activation, k = relu_k, rand_deter = rand_deter, solver = "direct",memory = memory)
    if save: 
        folder = 'data/'
        filename = folder + 'err_OGA_3D_{}_neuron_{}_N_{}_rand.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_OGA_3D_{}_neuron_{}_N_{}_rand.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)
    
    show_convergence_order(err_QMC2,err_h10,exponent,N,filename_write,True)
    show_convergence_order_latex(err_QMC2,err_h10,exponent,k =relu_k,d = DIMENSION)

In [None]:


4 		 & 2.214259 &		 * & 		 7.282636 & 		 *  \\ \hline  

8 		 &  1.830e+00 &  		 0.27 &  		 6.876e+00 &  		 0.08 \\ \hline  

16 		 &  3.747e+00 &  		 -1.03 &  		 7.204e+00 &  		 -0.07 \\ \hline  

32 		 &  9.894e-01 &  		 1.92 &  		 6.380e+00 &  		 0.18 \\ \hline  

64 		 &  1.015e+00 &  		 -0.04 &  		 4.274e+00 &  		 0.58 \\ \hline  

128 		 &  4.000e-01 &  		 1.34 &  		 1.816e+00 &  		 1.23 \\ \hline  

256 		 &  9.382e-03 &  		 5.41 &  		 3.855e-01 &  		 2.24 \\ \hline  

512 		 &  1.435e-03 &  		 2.71 &  		 7.405e-02 &  		 2.38 \\ \hline