In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
import os 
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu"    


torch.set_default_dtype(torch.float64)
pi = torch.tensor(np.pi)
ZERO = torch.tensor([0.]).to(device)
torch.set_printoptions(precision=6)

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    def evaluate_derivative(self, x, i):
        if self.k == 1:
            u1 = self.fc2(torch.heaviside(self.fc1(x),ZERO) * self.fc1.weight.t()[i-1:i,:] )
        else:
            u1 = self.fc2(self.k*F.relu(self.fc1(x))**(self.k-1) *self.fc1.weight.t()[i-1:i,:] )  
        return u1

def adjust_neuron_position(my_model, dims = 3):

    def create_mesh_grid(dims, pts):
        mesh = torch.tensor(list(itertools.product(pts,repeat=dims)))
        vertices = mesh.reshape(len(pts) ** dims, -1) 
        return vertices
    counter = 0 
    # positions = torch.tensor([[0.,0.],[0.,1.],[1.,1.],[1.,0.]])
    pts = torch.tensor([0.,1.])
    positions = create_mesh_grid(dims,pts) 
    neuron_num = my_model.fc1.bias.size(0)
    for i in range(neuron_num): 
        w = my_model.fc1.weight.data[i:i+1,:]
        b = my_model.fc1.bias.data[i]
    #     print(w,b)
        values = torch.matmul(positions,w.T) # + b
        left_end = - torch.max(values)
        right_end = - torch.min(values)
        offset = (right_end - left_end)/50
        if b <= left_end + offset/2 : 
            b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
            my_model.fc1.bias.data[i] = b 
        if b >= right_end - offset/2 :
            if counter < (dims+1):
#                 print("here")
                counter += 1
            else: # (d + 1) or more 
                b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
                my_model.fc1.bias.data[i] = b 
    return my_model



In [2]:
def show_convergence_order2(err_l2,err_h10,exponent,dict_size, filename,write2file = False):
    
    if write2file:
        file_mode = "a" if os.path.exists(filename) else "w"
        f_write = open(filename, file_mode)
    
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    # f_write.write('M:{}, relu {} \n'.format(M,k))
    if write2file:
        f_write.write('dictionary size: {}\n'.format(dict_size))
        f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    if write2file:     
        f_write.write("\n")
        f_write.close()

def show_convergence_order_latex2(err_l2,err_h10,exponent,k=1,d=1): 
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    l2_order = -1/2-(2*k + 1)/(2*d)
    h1_order =  -1/2-(2*(k-1)+ 1)/(2*d)
    print("neuron num  & \t $\|u-u_n \|_{{L^2}}$ & \t order $O(n^{{{:.2f}}})$  & \t $ | u -u_n |_{{H^1}}$ & \t order $O(n^{{{:.2f}}})$  \\\ \hline \hline ".format(l2_order,h1_order))
    for i, item in enumerate(err_list):
        if i == 0: 
            print("{} \t\t & {:.6f} &\t\t * & \t\t {:.6f} & \t\t *  \\\ \hline  \n".format(neuron_nums[i],item, err_list2[i] ) )   
        else: 
            print("{} \t\t &  {:.3e} &  \t\t {:.2f} &  \t\t {:.3e} &  \t\t {:.2f} \\\ \hline  \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )


In [3]:
def PiecewiseGQ2D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

#     print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,indexing='ij')).reshape(2,-1).T
    weights =  (w*w[:,None]).ravel()

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**2,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**2 /4 

    integration_points = torch.tile(gauss_pts,(Nx**2,1))
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(2,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**2)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,2)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 
#     print(integration_points.size())
    # func_values = integrand2_torch(integration_points)
    return long_weights.to(device), integration_points.to(device)


def PiecewiseGQ3D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

    """
    Parameters
    ----------
    target : 
        Target function 
    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature
    """

    # print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,x,indexing='ij')).reshape(3,-1).T
    weight_list = np.array(np.meshgrid(w,w,w,indexing='ij'))
    weights =   (weight_list[0]*weight_list[1]*weight_list[2]).ravel() 

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**3,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**3 /8 

    integration_points = torch.tile(gauss_pts,(Nx**3,1))
    # print("shape of integration_points", integration_points.size())
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(3,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**3)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,3)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 

    return long_weights.to(device), integration_points.to(device)


def generate_relu_dict3D(N_list):
    N1 = N_list[0]
    N2 = N_list[1]
    N3 = N_list[2]
    
    N = N1*N2*N3 
    theta1 = np.linspace(0, pi, N1, endpoint= True).reshape(N1,1)
    theta2 = np.linspace(0, 2*pi, N2, endpoint= False).reshape(N2,1)
    b = np.linspace(-1.732, 1.732, N3,endpoint=False).reshape(N3,1) # threshold: 3**0.5  
    coord3 = np.array(np.meshgrid(theta1,theta2,b,indexing='ij'))
    coord3 = coord3.reshape(3,-1).T # N1*N2*N3 x 3. coordinates for the grid points 
    coord3 = torch.tensor(coord3) 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)

    f1[:,0] = torch.cos(coord3[:,0]) 
    f2[:,0] = torch.sin(coord3[:,0]) * torch.cos(coord3[:,1])
    f3[:,0] = torch.sin(coord3[:,0]) * torch.sin(coord3[:,1])
    f4[:,0] = coord3[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor


def generate_relu_dict3D_QMC(s,N0):
#     Sob = torch.quasirandom.SobolEngine(dimension =3, scramble= True, seed=None) 
#     samples = Sob.draw(N0).double() 

#     for i in range(s-1):
#         samples = torch.cat([samples,Sob.draw(N0).double()],0)

    # Monte Carlo 
    samples = torch.rand(s*N0,3) 
    T =torch.tensor([[pi,0,0],[0,2*pi,0],[0,0,1.732*2]])
    shift = torch.tensor([0,0,-1.732])
    samples = samples@T + shift 

    f1 = torch.zeros(s*N0,1) 
    f2 = torch.zeros(s*N0,1)
    f3 = torch.zeros(s*N0,1)
    f4 = torch.zeros(s*N0,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1])
    f4[:,0] = samples[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor


def minimize_linear_layer_H1_explicit_assemble_efficient(model,alpha, target, g_N, weights, integration_points, w_bd, pts_bd, activation = 'relu',solver="direct" ,memory=2**29):
    """ -div alpha grad u(x) + u = f 
    Parameters
    ----------
    model: 
        nn model
    alpha:
        alpha function
    target:
        rhs function f 
    pts_bd:
        integration points on the boundary, embdedded in the domain 
    """ 
    ZERO = torch.tensor([0.]).to(device)
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    neuron_num = b.size(0) 
    dim = integration_points.size(1) 
    M = integration_points.size(0)
    coef_alpha = alpha(integration_points) # alpha  

    
    total_size = neuron_num * M # memory, number of floating numbers 
    print('total size: {} {} = {}'.format(neuron_num,M,total_size))
    num_batch = total_size//memory + 1 # divide according to memory
    print("num batches: ",num_batch)
    batch_size = M//num_batch
    jac = torch.zeros(b.size(0),b.size(0)).to(device)
    rhs = torch.zeros(b.size(0),1).to(device)
    
    for j in range(0,M,batch_size): 
        end_index = j + batch_size
        basis_value_col = F.relu(integration_points[j:end_index] @ w.t()+ b)**(model.k) 
        weighted_basis_value_col = basis_value_col * weights[j:end_index] 
        jac += weighted_basis_value_col.t() @ basis_value_col 
        rhs += weighted_basis_value_col.t() @ (target(integration_points[j:end_index,:])) 

    # Assemble the boundary condition term <g,v>_{\Gamma_N} 
    size_pts_bd = int(pts_bd.size(0)/(2*dim))
    if g_N != None: # no batch operations for the boundary part, since it is only rhs on the boundary 
        bcs_N = g_N(dim)
        for ii, g_ii in bcs_N:
            # pts_bd_ii = pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            weighted_g_N = -g_ii(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* w_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N

            weighted_g_N = g_ii(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* w_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N

    # Stiffness matrix term in the jacobian 
     
    for d in range(dim):
        if model.k == 1:  
            for j in range(0,M,batch_size):  
                end_index = j + batch_size 
                basis_value_dxi_col = torch.heaviside(integration_points[j:end_index] @ w.t()+ b, ZERO) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_index] * coef_alpha[j:end_index] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = torch.heaviside(integration_points @ w.t()+ b, zero) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha 
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

        else:
            for j in range(0,M,batch_size):  
                end_index = j + batch_size 
                basis_value_dxi_col = model.k * F.relu(integration_points[j:end_index] @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_index] * coef_alpha[j:end_index] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha  
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

    print("assembling the mass matrix time taken: ", time.time()-start_time) 


    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 


def OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact, u_exact_grad, N_list,num_epochs,plot_freq, Nx,order, k =1, rand_deter = 'deter', linear_solver = "direct",memory = 2**29): 
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    dim = 3 
    gw_expand, integration_points = PiecewiseGQ3D_weights_points(Nx, order = order) 
    gw_expand = gw_expand.to(device)
    integration_points = integration_points.to(device)

    # define integration on the boundary 
    gw_expand_bd, integration_points_bd = PiecewiseGQ2D_weights_points(50, order = 3) 
    size_pts_bd = integration_points_bd.size(0) 
    gw_expand_bd_faces = torch.tile(gw_expand_bd,(2*dim,1))

    integration_points_bd_faces = torch.zeros(2*dim*integration_points_bd.size(0),dim).to(device)
    for ind in range(dim): 
        integration_points_bd_faces[2 *ind * size_pts_bd :(2 *ind +1) * size_pts_bd,ind:ind+1] = 0 
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,:ind] = integration_points_bd[:,:ind]
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind:ind+1] = 1
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,:ind] = integration_points_bd[:,:ind]        
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

    err = torch.zeros(num_epochs+1)
    err_h10 = torch.zeros(num_epochs+1).to(device) 
    if my_model == None: 
        func_values = - target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []
    else: 
        func_values = - (target(integration_points) - my_model(integration_points).detach())
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done
    func_values_sqrd = func_values*func_values
    # print(func_values_sqrd.size())
    # print(gw_expand.size()) 
    err[0]= torch.sum(func_values_sqrd*gw_expand)**0.5
    ## h1 seminorm 
    if u_exact_grad != None:
        u_grad = u_exact_grad() 
        for grad_i in u_grad: 
            err_h10[0] += torch.sum((grad_i(integration_points))**2 * gw_expand)**0.5
    
    start_time = time.time()
    solver = linear_solver

    N0 = np.prod(N_list)
    if rand_deter == 'deter':
        relu_dict_parameters = generate_relu_dict3D(N_list).to(device)
    print("using linear solver: ",solver)
    M2 = integration_points.size(0) # add
    for i in range(num_epochs): 
        start_time = time.time()
        print("epoch: ",i+1, end = '\t')
        if rand_deter == 'rand':
            relu_dict_parameters = generate_relu_dict3D_QMC(1,N0).to(device) 
        if num_neuron == 0: 
            func_values = - target(integration_points)
        else: 
            func_values = - target(integration_points) + my_model(integration_points).detach()

        weight_func_values = func_values*gw_expand  

        ### ======================= 
        total_size = M2 * N0 
        num_batch = total_size//memory + 1 
        batch_size = N0//num_batch

        output = torch.zeros(N0,1).to(device)
        print("argmax batch num, ", num_batch) 
        
        for j in range(0,N0,batch_size):  
            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k).T # uses broadcasting    
            output[j:end_index,:]  = torch.matmul(basis_values_batch,weight_func_values)[:,:]
        ### ======================= 
        
#         basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T # uses broadcasting, # dimension 4 
#         output = torch.matmul(basis_values,weight_func_values) #

        ##========
        # grad u part
        alpha_coef = alpha(integration_points) # alpha 
        if my_model!= None:

            if k == 1:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = torch.heaviside(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim], ZERO) # dimension 4 
                    derivative_part *= alpha_coef # alpha 
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) *gw_expand   
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

            else:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = k * F.relu(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim])**(k-1) # dimension 4 
                    derivative_part *= alpha_coef # alpha
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) * gw_expand    
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

        ##========
    
        #Boundary condition term -<g,v>_{\Gamma_N}  
        if g_N != None:
            bcs_N = g_N(dim) 
            for ii, g_ii in bcs_N: 
                
                weighted_g_N = -g_ii(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* gw_expand_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)
                
                weighted_g_N = g_ii(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* gw_expand_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)

        
        # output = torch.abs(torch.matmul(basis_values,weight_func_values)) # 
        output = torch.abs(output)
        neuron_index = torch.argmax(output.flatten())
        print("argmax time taken, ", time.time() - start_time)
        
        
        # print(neuron_index)
        list_w.append(relu_dict_parameters[neuron_index,0:dim]) # dimension 4 
        list_b.append(-relu_dict_parameters[neuron_index,dim])
        num_neuron += 1
        my_model = model(dim,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo Done 
#         alpha = None 
        sol = minimize_linear_layer_H1_explicit_assemble_efficient(my_model,alpha, target, g_N, gw_expand, integration_points,gw_expand_bd_faces, integration_points_bd_faces,activation = 'relu',solver = solver)

        my_model.fc2.weight.data[0,:] = sol[:]

        model_values = my_model(integration_points).detach()
        # L2 error ||u - u_n||
        diff_values_sqrd = (u_exact(integration_points) - model_values)**2 
        err[i+1]= torch.sum(diff_values_sqrd*gw_expand)**0.5

        # H10 error || grad(u) - grad(u_n) ||
        if u_exact_grad != None:
            for ind, grad_i in enumerate(u_grad):  
                my_model_dxi = my_model.evaluate_derivative(integration_points,ind+1).detach() 
                err_h10[i+1] += torch.sum((grad_i(integration_points) - my_model_dxi)**2 * gw_expand)**0.5
        print("l2 error {:.6f}, h1 error {:.6f}".format(err[i+1],err_h10[i+1]))
    print("time taken: ",time.time() - start_time)
    return err, err_h10.cpu(), my_model




In [4]:

def u_exact(x):
    return torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
def alpha(x): 
    return torch.ones(x.size(0),1).to(device)

def u_exact_grad():
    d = 3 

    def grad_1(x):
        return - pi* torch.sin(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])   
    def grad_2(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.sin( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
    def grad_3(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.sin(pi*x[:,2:3])   

    u_grad=[grad_1, grad_2,grad_3] 

    return u_grad

def target(x):
    z = (  3 * (pi)**2 + 1)*torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) 
    return z 
g_N = None 

dim = 3 
function_name = "cospix" 
filename_write = "data/3DOGA-{}-order.txt".format(function_name)
Nx = 50   
order = 3   
f_write = open(filename_write, "a")
f_write.write("Integration points: Nx {}, order {} \n".format(Nx,order))
f_write.close() 
save = False 
write2file = True
rand_deter = 'rand'

for N_list in [[2**3,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 
    exponent = 9
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    relu_k = 3 
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, Nx = Nx, order = order, k = relu_k, rand_deter= rand_deter, linear_solver = "direct")
    
    if save: 
        folder = 'data/'
        filename = folder + 'err_NeumannOGA_3D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_NeumannOGA_3D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)
        
    show_convergence_order2(err_QMC2,err_h10,exponent,N,filename_write,write2file = write2file)
    show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)


using linear solver:  direct
epoch:  1	argmax batch num,  4
argmax time taken,  0.38147830963134766
total size: 1 3375000 = 3375000
num batches:  1
assembling the mass matrix time taken:  0.003135204315185547
solving Ax = b time taken:  0.07679963111877441
l2 error 0.354029, h1 error 3.330833
epoch:  2	argmax batch num,  4
argmax time taken,  0.07454872131347656
total size: 2 3375000 = 6750000
num batches:  1
assembling the mass matrix time taken:  0.004509449005126953
solving Ax = b time taken:  0.006899833679199219
l2 error 0.353664, h1 error 3.329471
epoch:  3	argmax batch num,  4
argmax time taken,  0.07352828979492188
total size: 3 3375000 = 10125000
num batches:  1
assembling the mass matrix time taken:  0.0035092830657958984
solving Ax = b time taken:  0.007887125015258789
l2 error 0.353577, h1 error 3.326542
epoch:  4	argmax batch num,  4
argmax time taken,  0.07374310493469238
total size: 4 3375000 = 13500000
num batches:  1
assembling the mass matrix time taken:  0.0033743381

total size: 32 3375000 = 108000000
num batches:  1
assembling the mass matrix time taken:  0.004259586334228516
solving Ax = b time taken:  0.03401041030883789
l2 error 0.027010, h1 error 0.563074
epoch:  33	argmax batch num,  4
argmax time taken,  0.07702279090881348
total size: 33 3375000 = 111375000
num batches:  1
assembling the mass matrix time taken:  0.0038568973541259766
solving Ax = b time taken:  0.03678083419799805
l2 error 0.025602, h1 error 0.524875
epoch:  34	argmax batch num,  4
argmax time taken,  0.07747530937194824
total size: 34 3375000 = 114750000
num batches:  1
assembling the mass matrix time taken:  0.004060506820678711
solving Ax = b time taken:  0.037648916244506836
l2 error 0.020836, h1 error 0.448546
epoch:  35	argmax batch num,  4
argmax time taken,  0.07709288597106934
total size: 35 3375000 = 118125000
num batches:  1
assembling the mass matrix time taken:  0.010573863983154297
solving Ax = b time taken:  0.038519859313964844
l2 error 0.019391, h1 error 0.

total size: 63 3375000 = 212625000
num batches:  1
assembling the mass matrix time taken:  0.010473251342773438
solving Ax = b time taken:  0.06408333778381348
l2 error 0.005174, h1 error 0.144304
epoch:  64	argmax batch num,  4
argmax time taken,  0.08098912239074707
total size: 64 3375000 = 216000000
num batches:  1
assembling the mass matrix time taken:  0.0035371780395507812
solving Ax = b time taken:  0.064666748046875
l2 error 0.004953, h1 error 0.140089
epoch:  65	argmax batch num,  4
argmax time taken,  0.08073067665100098
total size: 65 3375000 = 219375000
num batches:  1
assembling the mass matrix time taken:  0.004433631896972656
solving Ax = b time taken:  0.0818319320678711
l2 error 0.004838, h1 error 0.132460
epoch:  66	argmax batch num,  4
argmax time taken,  0.08137774467468262
total size: 66 3375000 = 222750000
num batches:  1
assembling the mass matrix time taken:  0.004454135894775391
solving Ax = b time taken:  0.07827234268188477
l2 error 0.004556, h1 error 0.12614

epoch:  94	argmax batch num,  4
argmax time taken,  0.08564567565917969
total size: 94 3375000 = 317250000
num batches:  1
assembling the mass matrix time taken:  0.004421710968017578
solving Ax = b time taken:  0.11345028877258301
l2 error 0.001617, h1 error 0.056136
epoch:  95	argmax batch num,  4
argmax time taken,  0.08590388298034668
total size: 95 3375000 = 320625000
num batches:  1
assembling the mass matrix time taken:  0.004485607147216797
solving Ax = b time taken:  0.1136772632598877
l2 error 0.001586, h1 error 0.055451
epoch:  96	argmax batch num,  4
argmax time taken,  0.08561968803405762
total size: 96 3375000 = 324000000
num batches:  1
assembling the mass matrix time taken:  0.00494837760925293
solving Ax = b time taken:  0.10762405395507812
l2 error 0.001538, h1 error 0.054152
epoch:  97	argmax batch num,  4
argmax time taken,  0.08570528030395508
total size: 97 3375000 = 327375000
num batches:  1
assembling the mass matrix time taken:  0.4025535583496094
solving Ax = 

l2 error 0.000931, h1 error 0.036754
epoch:  125	argmax batch num,  4
argmax time taken,  0.09037399291992188
total size: 125 3375000 = 421875000
num batches:  1
assembling the mass matrix time taken:  0.006649017333984375
solving Ax = b time taken:  0.13268136978149414
l2 error 0.000902, h1 error 0.035846
epoch:  126	argmax batch num,  4
argmax time taken,  0.09021759033203125
total size: 126 3375000 = 425250000
num batches:  1
assembling the mass matrix time taken:  0.004466533660888672
solving Ax = b time taken:  0.1353132724761963
l2 error 0.000896, h1 error 0.035504
epoch:  127	argmax batch num,  4
argmax time taken,  0.09039568901062012
total size: 127 3375000 = 428625000
num batches:  1
assembling the mass matrix time taken:  0.004610300064086914
solving Ax = b time taken:  0.135847806930542
l2 error 0.000887, h1 error 0.035292
epoch:  128	argmax batch num,  4
argmax time taken,  0.09053206443786621
total size: 128 3375000 = 432000000
num batches:  1
assembling the mass matrix t

total size: 155 3375000 = 523125000
num batches:  1
assembling the mass matrix time taken:  0.4397435188293457
solving Ax = b time taken:  0.10106611251831055
l2 error 0.000604, h1 error 0.026407
epoch:  156	argmax batch num,  4
argmax time taken,  0.10968446731567383
total size: 156 3375000 = 526500000
num batches:  1
assembling the mass matrix time taken:  0.024174213409423828
solving Ax = b time taken:  0.19137096405029297
l2 error 0.000589, h1 error 0.025895
epoch:  157	argmax batch num,  4
argmax time taken,  0.09343981742858887
total size: 157 3375000 = 529875000
num batches:  1
assembling the mass matrix time taken:  0.5952253341674805
solving Ax = b time taken:  0.10180997848510742
l2 error 0.000583, h1 error 0.025705
epoch:  158	argmax batch num,  4
argmax time taken,  0.10939240455627441
total size: 158 3375000 = 533250000
num batches:  1
assembling the mass matrix time taken:  0.013896465301513672
solving Ax = b time taken:  0.19051790237426758
l2 error 0.000579, h1 error 0.

solving Ax = b time taken:  0.21547746658325195
l2 error 0.000420, h1 error 0.019640
epoch:  186	argmax batch num,  4
argmax time taken,  0.09740185737609863
total size: 186 3375000 = 627750000
num batches:  2
assembling the mass matrix time taken:  0.003840208053588867
solving Ax = b time taken:  0.21821093559265137
l2 error 0.000419, h1 error 0.019583
epoch:  187	argmax batch num,  4
argmax time taken,  0.5458340644836426
total size: 187 3375000 = 631125000
num batches:  2
assembling the mass matrix time taken:  0.017077207565307617
solving Ax = b time taken:  0.21307015419006348
l2 error 0.000416, h1 error 0.019456
epoch:  188	argmax batch num,  4
argmax time taken,  0.09793710708618164
total size: 188 3375000 = 634500000
num batches:  2
assembling the mass matrix time taken:  0.004194021224975586
solving Ax = b time taken:  0.21943926811218262
l2 error 0.000412, h1 error 0.019338
epoch:  189	argmax batch num,  4
argmax time taken,  0.09817767143249512
total size: 189 3375000 = 6378

argmax time taken,  0.10197043418884277
total size: 216 3375000 = 729000000
num batches:  2
assembling the mass matrix time taken:  0.00424647331237793
solving Ax = b time taken:  0.3065946102142334
l2 error 0.000303, h1 error 0.015231
epoch:  217	argmax batch num,  4
argmax time taken,  0.11925029754638672
total size: 217 3375000 = 732375000
num batches:  2
assembling the mass matrix time taken:  0.008087396621704102
solving Ax = b time taken:  0.2943699359893799
l2 error 0.000302, h1 error 0.015183
epoch:  218	argmax batch num,  4
argmax time taken,  0.10184168815612793
total size: 218 3375000 = 735750000
num batches:  2
assembling the mass matrix time taken:  0.00420832633972168
solving Ax = b time taken:  0.34064435958862305
l2 error 0.000294, h1 error 0.014724
epoch:  219	argmax batch num,  4
argmax time taken,  0.10237503051757812
total size: 219 3375000 = 739125000
num batches:  2
assembling the mass matrix time taken:  0.004799365997314453
solving Ax = b time taken:  0.29868745

solving Ax = b time taken:  0.32300281524658203
l2 error 0.000224, h1 error 0.011890
epoch:  247	argmax batch num,  4
argmax time taken,  0.3360164165496826
total size: 247 3375000 = 833625000
num batches:  2
assembling the mass matrix time taken:  0.007292509078979492
solving Ax = b time taken:  0.319016695022583
l2 error 0.000222, h1 error 0.011808
epoch:  248	argmax batch num,  4
argmax time taken,  0.10596537590026855
total size: 248 3375000 = 837000000
num batches:  2
assembling the mass matrix time taken:  0.004318952560424805
solving Ax = b time taken:  0.323805570602417
l2 error 0.000222, h1 error 0.011775
epoch:  249	argmax batch num,  4
argmax time taken,  0.2877635955810547
total size: 249 3375000 = 840375000
num batches:  2
assembling the mass matrix time taken:  0.005568742752075195
solving Ax = b time taken:  0.3224911689758301
l2 error 0.000221, h1 error 0.011734
epoch:  250	argmax batch num,  4
argmax time taken,  0.1064302921295166
total size: 250 3375000 = 843750000
n

total size: 277 3375000 = 934875000
num batches:  2
assembling the mass matrix time taken:  0.007010936737060547
solving Ax = b time taken:  0.39112424850463867
l2 error 0.000185, h1 error 0.010092
epoch:  278	argmax batch num,  4
argmax time taken,  0.11007547378540039
total size: 278 3375000 = 938250000
num batches:  2
assembling the mass matrix time taken:  0.004267692565917969
solving Ax = b time taken:  0.4762117862701416
l2 error 0.000183, h1 error 0.010012
epoch:  279	argmax batch num,  4
argmax time taken,  0.25432300567626953
total size: 279 3375000 = 941625000
num batches:  2
assembling the mass matrix time taken:  0.005148887634277344
solving Ax = b time taken:  0.39428138732910156
l2 error 0.000183, h1 error 0.009999
epoch:  280	argmax batch num,  4
argmax time taken,  0.1103661060333252
total size: 280 3375000 = 945000000
num batches:  2
assembling the mass matrix time taken:  0.004380464553833008
solving Ax = b time taken:  0.38947391510009766
l2 error 0.000182, h1 error 

solving Ax = b time taken:  0.41967105865478516
l2 error 0.000147, h1 error 0.008315
epoch:  308	argmax batch num,  4
argmax time taken,  0.11382746696472168
total size: 308 3375000 = 1039500000
num batches:  2
assembling the mass matrix time taken:  0.0044209957122802734
solving Ax = b time taken:  0.4215099811553955
l2 error 0.000147, h1 error 0.008302
epoch:  309	argmax batch num,  4
argmax time taken,  0.11759090423583984
total size: 309 3375000 = 1042875000
num batches:  2
assembling the mass matrix time taken:  0.005808830261230469
solving Ax = b time taken:  0.42120909690856934
l2 error 0.000146, h1 error 0.008287
epoch:  310	argmax batch num,  4
argmax time taken,  0.11424756050109863
total size: 310 3375000 = 1046250000
num batches:  2
assembling the mass matrix time taken:  0.004689693450927734
solving Ax = b time taken:  0.4240756034851074
l2 error 0.000145, h1 error 0.008251
epoch:  311	argmax batch num,  4
argmax time taken,  0.13251090049743652
total size: 311 3375000 = 1

total size: 338 3375000 = 1140750000
num batches:  3
assembling the mass matrix time taken:  0.005368471145629883
solving Ax = b time taken:  0.6591813564300537
l2 error 0.000122, h1 error 0.007218
epoch:  339	argmax batch num,  4
argmax time taken,  0.13686203956604004
total size: 339 3375000 = 1144125000
num batches:  3
assembling the mass matrix time taken:  0.005014896392822266
solving Ax = b time taken:  0.5509727001190186
l2 error 0.000121, h1 error 0.007182
epoch:  340	argmax batch num,  4
argmax time taken,  0.11853694915771484
total size: 340 3375000 = 1147500000
num batches:  3
assembling the mass matrix time taken:  0.005226850509643555
solving Ax = b time taken:  0.5012857913970947
l2 error 0.000121, h1 error 0.007171
epoch:  341	argmax batch num,  4
argmax time taken,  0.12137937545776367
total size: 341 3375000 = 1150875000
num batches:  3
assembling the mass matrix time taken:  0.0050373077392578125
solving Ax = b time taken:  0.5529687404632568
l2 error 0.000119, h1 err

solving Ax = b time taken:  0.5714950561523438
l2 error 0.000103, h1 error 0.006270
epoch:  369	argmax batch num,  4
argmax time taken,  0.21888184547424316
total size: 369 3375000 = 1245375000
num batches:  3
assembling the mass matrix time taken:  0.005078315734863281
solving Ax = b time taken:  0.5771896839141846
l2 error 0.000102, h1 error 0.006241
epoch:  370	argmax batch num,  4
argmax time taken,  0.21867609024047852
total size: 370 3375000 = 1248750000
num batches:  3
assembling the mass matrix time taken:  0.005015850067138672
solving Ax = b time taken:  0.5776469707489014
l2 error 0.000101, h1 error 0.006229
epoch:  371	argmax batch num,  4
argmax time taken,  0.22054147720336914
total size: 371 3375000 = 1252125000
num batches:  3
assembling the mass matrix time taken:  0.005779743194580078
solving Ax = b time taken:  0.578622579574585
l2 error 0.000100, h1 error 0.006156
epoch:  372	argmax batch num,  4
argmax time taken,  0.22106003761291504
total size: 372 3375000 = 12555

total size: 399 3375000 = 1346625000
num batches:  3
assembling the mass matrix time taken:  0.005170106887817383
solving Ax = b time taken:  0.6378481388092041
l2 error 0.000087, h1 error 0.005480
epoch:  400	argmax batch num,  4
argmax time taken,  0.1899111270904541
total size: 400 3375000 = 1350000000
num batches:  3
assembling the mass matrix time taken:  0.005717039108276367
solving Ax = b time taken:  0.6650171279907227
l2 error 0.000087, h1 error 0.005468
epoch:  401	argmax batch num,  4
argmax time taken,  0.24487090110778809
total size: 401 3375000 = 1353375000
num batches:  3
assembling the mass matrix time taken:  0.004967212677001953
solving Ax = b time taken:  0.639331579208374
l2 error 0.000086, h1 error 0.005451
epoch:  402	argmax batch num,  4
argmax time taken,  0.13384413719177246
total size: 402 3375000 = 1356750000
num batches:  3
assembling the mass matrix time taken:  0.0053288936614990234
solving Ax = b time taken:  0.6745636463165283
l2 error 0.000086, h1 error

solving Ax = b time taken:  0.6720566749572754
l2 error 0.000077, h1 error 0.004902
epoch:  430	argmax batch num,  4
argmax time taken,  0.13730192184448242
total size: 430 3375000 = 1451250000
num batches:  3
assembling the mass matrix time taken:  0.005206584930419922
solving Ax = b time taken:  0.7238998413085938
l2 error 0.000076, h1 error 0.004895
epoch:  431	argmax batch num,  4
argmax time taken,  0.13576388359069824
total size: 431 3375000 = 1454625000
num batches:  3
assembling the mass matrix time taken:  0.005331277847290039
solving Ax = b time taken:  0.6724405288696289
l2 error 0.000076, h1 error 0.004888
epoch:  432	argmax batch num,  4
argmax time taken,  0.13604307174682617
total size: 432 3375000 = 1458000000
num batches:  3
assembling the mass matrix time taken:  0.005321979522705078
solving Ax = b time taken:  0.6930818557739258
l2 error 0.000076, h1 error 0.004880
epoch:  433	argmax batch num,  4
argmax time taken,  0.1361079216003418
total size: 433 3375000 = 14613

total size: 460 3375000 = 1552500000
num batches:  3
assembling the mass matrix time taken:  0.005427837371826172
solving Ax = b time taken:  0.7725052833557129
l2 error 0.000068, h1 error 0.004448
epoch:  461	argmax batch num,  4
argmax time taken,  0.14148664474487305
total size: 461 3375000 = 1555875000
num batches:  3
assembling the mass matrix time taken:  0.005268573760986328
solving Ax = b time taken:  0.7836649417877197
l2 error 0.000068, h1 error 0.004438
epoch:  462	argmax batch num,  4
argmax time taken,  0.1399991512298584
total size: 462 3375000 = 1559250000
num batches:  3
assembling the mass matrix time taken:  0.005447864532470703
solving Ax = b time taken:  0.8010311126708984
l2 error 0.000067, h1 error 0.004426
epoch:  463	argmax batch num,  4
argmax time taken,  0.14040017127990723
total size: 463 3375000 = 1562625000
num batches:  3
assembling the mass matrix time taken:  0.005193948745727539
solving Ax = b time taken:  0.7843663692474365
l2 error 0.000067, h1 error

solving Ax = b time taken:  0.804964542388916
l2 error 0.000059, h1 error 0.003989
epoch:  491	argmax batch num,  4
argmax time taken,  0.14428067207336426
total size: 491 3375000 = 1657125000
num batches:  4
assembling the mass matrix time taken:  0.006038188934326172
solving Ax = b time taken:  0.8070082664489746
l2 error 0.000059, h1 error 0.003981
epoch:  492	argmax batch num,  4
argmax time taken,  0.14438462257385254
total size: 492 3375000 = 1660500000
num batches:  4
assembling the mass matrix time taken:  0.005875587463378906
solving Ax = b time taken:  0.8086912631988525
l2 error 0.000059, h1 error 0.003975
epoch:  493	argmax batch num,  4
argmax time taken,  0.14393067359924316
total size: 493 3375000 = 1663875000
num batches:  4
assembling the mass matrix time taken:  0.005926370620727539
solving Ax = b time taken:  0.8088419437408447
l2 error 0.000058, h1 error 0.003929
epoch:  494	argmax batch num,  4
argmax time taken,  0.14265966415405273
total size: 494 3375000 = 16672

NameError: name 'os' is not defined

In [6]:

show_convergence_order2(err_QMC2,err_h10,exponent,2*N,filename_write,write2file = write2file)
show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)


neuron num 		 error 		 order
4 		 0.358174 		 * 		 3.318033 		 * 

8 		 0.373086 		 -0.058849 		 3.054734 		 0.119282 

16 		 0.094217 		 1.985454 		 1.523022 		 1.004110 

32 		 0.027010 		 1.802477 		 0.563074 		 1.435541 

64 		 0.004953 		 2.447055 		 0.140089 		 2.006984 

128 		 0.000860 		 2.525422 		 0.034445 		 2.023984 

256 		 0.000205 		 2.066843 		 0.010990 		 1.648127 

512 		 0.000054 		 1.936588 		 0.003723 		 1.561493 

neuron num  & 	 $\|u-u_n \|_{L^2}$ & 	 order $O(n^{-1.67})$  & 	 $ | u -u_n |_{H^1}$ & 	 order $O(n^{-1.33})$  \\ \hline \hline 
4 		 & 0.358174 &		 * & 		 3.318033 & 		 *  \\ \hline  

8 		 &  3.731e-01 &  		 -0.06 &  		 3.055e+00 &  		 0.12 \\ \hline  

16 		 &  9.422e-02 &  		 1.99 &  		 1.523e+00 &  		 1.00 \\ \hline  

32 		 &  2.701e-02 &  		 1.80 &  		 5.631e-01 &  		 1.44 \\ \hline  

64 		 &  4.953e-03 &  		 2.45 &  		 1.401e-01 &  		 2.01 \\ \hline  

128 		 &  8.603e-04 &  		 2.53 &  		 3.444e-02 &  		 2.02 \\ \hline  

256 		 &  2.053e-04 & 

## oscillatory coefficient 

In [6]:

def u_exact(x):
    return torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3]) 
def alpha(x): 
#     return torch.ones(x.size(0),1).to(device)
    return 0.5 * torch.sin(6 * pi*x[:,0:1]) + 1. 

def u_exact_grad():
    d = 3 

    def grad_1(x):
        return - pi* torch.sin(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])   
    def grad_2(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.sin( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
    def grad_3(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.sin(pi*x[:,2:3])   

    u_grad=[grad_1, grad_2,grad_3] 

    return u_grad

def target(x):

    z_c = torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) 
    z1 = 3 * pi**2 * torch.sin(pi * x[:,0:1]) * torch.cos( 6*pi*x[:,0:1] ) * torch.cos( pi*x[:,1:2] )* torch.cos(pi*x[:,2:3]) 
    z2 = 0.5 * pi**2 * torch.sin(6*pi * x[:,0:1])* z_c 
    z = z1 + z2 + 2/2*pi**2 * torch.sin(6 * pi * x[:,0:1]) * z_c 
    z += ( 3 * (pi)**2 + 1)*z_c 
    return z 

g_N = None 

dim = 3 
function_name = "cospix-osci-coef" 
filename_write = "data/3DOGA-{}-order.txt".format(function_name)
Nx = 50   
order = 3   
f_write = open(filename_write, "a")
f_write.write("Integration points: Nx {}, order {} \n".format(Nx,order))
f_write.close() 
save = True 
write2file = True
rand_deter = 'rand'

for N_list in [[2**3,2**3,2**3]]: # 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 

    exponent = 9 
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    relu_k = 3 
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, Nx,order, k = relu_k, rand_deter= 'rand', linear_solver = "direct")
    
    if save: 
        folder = 'data/'
        filename = folder + 'err_NeumannOGA_OsciCoeff_3D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        filename = folder + 'model_NeumannOGA_OsciCoeff_3D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)

    show_convergence_order2(err_QMC2,err_h10,exponent,N,filename_write,write2file = write2file)
    show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)



using linear solver:  direct
epoch:  1	argmax batch num,  4
argmax time taken,  0.0039675235748291016
total size: 1 3375000 = 3375000
num batches:  1
assembling the mass matrix time taken:  0.001203775405883789
solving Ax = b time taken:  0.0029745101928710938
l2 error 0.354324, h1 error 3.331350
epoch:  2	argmax batch num,  4
argmax time taken,  0.006653785705566406
total size: 2 3375000 = 6750000
num batches:  1
assembling the mass matrix time taken:  0.001329183578491211
solving Ax = b time taken:  0.008353948593139648
l2 error 0.353771, h1 error 3.331324
epoch:  3	argmax batch num,  4
argmax time taken,  0.007819890975952148
total size: 3 3375000 = 10125000
num batches:  1
assembling the mass matrix time taken:  0.0014786720275878906
solving Ax = b time taken:  0.009195804595947266
l2 error 0.355522, h1 error 3.330669
epoch:  4	argmax batch num,  4
argmax time taken,  0.008081674575805664
total size: 4 3375000 = 13500000
num batches:  1
assembling the mass matrix time taken:  0.001

total size: 32 3375000 = 108000000
num batches:  1
assembling the mass matrix time taken:  0.0013408660888671875
solving Ax = b time taken:  0.03499341011047363
l2 error 0.033666, h1 error 0.661404
epoch:  33	argmax batch num,  4
argmax time taken,  0.011051654815673828
total size: 33 3375000 = 111375000
num batches:  1
assembling the mass matrix time taken:  0.0013408660888671875
solving Ax = b time taken:  0.037834882736206055
l2 error 0.029557, h1 error 0.592801
epoch:  34	argmax batch num,  4
argmax time taken,  0.011394739151000977
total size: 34 3375000 = 114750000
num batches:  1
assembling the mass matrix time taken:  0.0013391971588134766
solving Ax = b time taken:  0.03859853744506836
l2 error 0.026958, h1 error 0.552359
epoch:  35	argmax batch num,  4
argmax time taken,  0.011512279510498047
total size: 35 3375000 = 118125000
num batches:  1
assembling the mass matrix time taken:  0.0013737678527832031
solving Ax = b time taken:  0.03949618339538574
l2 error 0.023896, h1 err

total size: 63 3375000 = 212625000
num batches:  1
assembling the mass matrix time taken:  0.001291036605834961
solving Ax = b time taken:  0.06549692153930664
l2 error 0.004946, h1 error 0.139728
epoch:  64	argmax batch num,  4
argmax time taken,  0.015197038650512695
total size: 64 3375000 = 216000000
num batches:  1
assembling the mass matrix time taken:  0.0013914108276367188
solving Ax = b time taken:  0.06596779823303223
l2 error 0.004313, h1 error 0.123774
epoch:  65	argmax batch num,  4
argmax time taken,  0.014883756637573242
total size: 65 3375000 = 219375000
num batches:  1
assembling the mass matrix time taken:  0.002452373504638672
solving Ax = b time taken:  0.08246493339538574
l2 error 0.004204, h1 error 0.121599
epoch:  66	argmax batch num,  4
argmax time taken,  0.015760421752929688
total size: 66 3375000 = 222750000
num batches:  1
assembling the mass matrix time taken:  0.0022134780883789062
solving Ax = b time taken:  0.07950854301452637
l2 error 0.004164, h1 error 

epoch:  94	argmax batch num,  4
argmax time taken,  0.020126819610595703
total size: 94 3375000 = 317250000
num batches:  1
assembling the mass matrix time taken:  0.005449771881103516
solving Ax = b time taken:  0.11117339134216309
l2 error 0.001853, h1 error 0.061319
epoch:  95	argmax batch num,  4
argmax time taken,  0.0198209285736084
total size: 95 3375000 = 320625000
num batches:  1
assembling the mass matrix time taken:  0.0023162364959716797
solving Ax = b time taken:  0.1149909496307373
l2 error 0.001790, h1 error 0.060473
epoch:  96	argmax batch num,  4
argmax time taken,  0.019987106323242188
total size: 96 3375000 = 324000000
num batches:  1
assembling the mass matrix time taken:  0.0023145675659179688
solving Ax = b time taken:  0.10947060585021973
l2 error 0.001730, h1 error 0.058914
epoch:  97	argmax batch num,  4
argmax time taken,  0.019956588745117188
total size: 97 3375000 = 327375000
num batches:  1
assembling the mass matrix time taken:  0.0022804737091064453
solvi

total size: 124 3375000 = 418500000
num batches:  1
assembling the mass matrix time taken:  0.5269227027893066
solving Ax = b time taken:  0.07139921188354492
l2 error 0.001019, h1 error 0.038503
epoch:  125	argmax batch num,  4
argmax time taken,  0.024055004119873047
total size: 125 3375000 = 421875000
num batches:  1
assembling the mass matrix time taken:  0.007004261016845703
solving Ax = b time taken:  0.1347641944885254
l2 error 0.001007, h1 error 0.038283
epoch:  126	argmax batch num,  4
argmax time taken,  0.0242462158203125
total size: 126 3375000 = 425250000
num batches:  1
assembling the mass matrix time taken:  0.003780364990234375
solving Ax = b time taken:  0.13522100448608398
l2 error 0.000988, h1 error 0.037752
epoch:  127	argmax batch num,  4
argmax time taken,  0.0244600772857666
total size: 127 3375000 = 428625000
num batches:  1
assembling the mass matrix time taken:  0.0024259090423583984
solving Ax = b time taken:  0.13711786270141602
l2 error 0.000979, h1 error 0

solving Ax = b time taken:  0.09865570068359375
l2 error 0.000625, h1 error 0.026430
epoch:  155	argmax batch num,  4
argmax time taken,  0.02741408348083496
total size: 155 3375000 = 523125000
num batches:  1
assembling the mass matrix time taken:  0.4181020259857178
solving Ax = b time taken:  0.10487699508666992
l2 error 0.000619, h1 error 0.026286
epoch:  156	argmax batch num,  4
argmax time taken,  0.028360366821289062
total size: 156 3375000 = 526500000
num batches:  1
assembling the mass matrix time taken:  0.010354042053222656
solving Ax = b time taken:  0.1893911361694336
l2 error 0.000610, h1 error 0.026054
epoch:  157	argmax batch num,  4
argmax time taken,  0.02776646614074707
total size: 157 3375000 = 529875000
num batches:  1
assembling the mass matrix time taken:  0.6172435283660889
solving Ax = b time taken:  0.10598635673522949
l2 error 0.000598, h1 error 0.025755
epoch:  158	argmax batch num,  4
argmax time taken,  0.02868485450744629
total size: 158 3375000 = 5332500

total size: 185 3375000 = 624375000
num batches:  2
assembling the mass matrix time taken:  0.008816003799438477
solving Ax = b time taken:  0.21657490730285645
l2 error 0.000457, h1 error 0.020790
epoch:  186	argmax batch num,  4
argmax time taken,  0.03154778480529785
total size: 186 3375000 = 627750000
num batches:  2
assembling the mass matrix time taken:  0.0024085044860839844
solving Ax = b time taken:  0.21924066543579102
l2 error 0.000446, h1 error 0.020230
epoch:  187	argmax batch num,  4
argmax time taken,  0.545050859451294
total size: 187 3375000 = 631125000
num batches:  2
assembling the mass matrix time taken:  0.018192529678344727
solving Ax = b time taken:  0.21554970741271973
l2 error 0.000445, h1 error 0.020133
epoch:  188	argmax batch num,  4
argmax time taken,  0.03182101249694824
total size: 188 3375000 = 634500000
num batches:  2
assembling the mass matrix time taken:  0.0024466514587402344
solving Ax = b time taken:  0.22082781791687012
l2 error 0.000442, h1 erro

solving Ax = b time taken:  0.29657554626464844
l2 error 0.000342, h1 error 0.016020
epoch:  216	argmax batch num,  4
argmax time taken,  0.03623700141906738
total size: 216 3375000 = 729000000
num batches:  2
assembling the mass matrix time taken:  0.0026493072509765625
solving Ax = b time taken:  0.3067936897277832
l2 error 0.000336, h1 error 0.015905
epoch:  217	argmax batch num,  4
argmax time taken,  0.03829312324523926
total size: 217 3375000 = 732375000
num batches:  2
assembling the mass matrix time taken:  0.0049555301666259766
solving Ax = b time taken:  0.29732179641723633
l2 error 0.000336, h1 error 0.015784
epoch:  218	argmax batch num,  4
argmax time taken,  0.035915374755859375
total size: 218 3375000 = 735750000
num batches:  2
assembling the mass matrix time taken:  0.009011030197143555
solving Ax = b time taken:  0.3389778137207031
l2 error 0.000332, h1 error 0.015630
epoch:  219	argmax batch num,  4
argmax time taken,  0.036824941635131836
total size: 219 3375000 = 7

epoch:  246	argmax batch num,  4
argmax time taken,  0.03958559036254883
total size: 246 3375000 = 830250000
num batches:  2
assembling the mass matrix time taken:  0.0024607181549072266
solving Ax = b time taken:  0.324704647064209
l2 error 0.000237, h1 error 0.012105
epoch:  247	argmax batch num,  4
argmax time taken,  0.9466047286987305
total size: 247 3375000 = 833625000
num batches:  2
assembling the mass matrix time taken:  0.002823352813720703
solving Ax = b time taken:  0.32369279861450195
l2 error 0.000236, h1 error 0.012077
epoch:  248	argmax batch num,  4
argmax time taken,  0.03986048698425293
total size: 248 3375000 = 837000000
num batches:  2
assembling the mass matrix time taken:  0.002499818801879883
solving Ax = b time taken:  0.3252890110015869
l2 error 0.000232, h1 error 0.011958
epoch:  249	argmax batch num,  4
argmax time taken,  0.3102116584777832
total size: 249 3375000 = 840375000
num batches:  2
assembling the mass matrix time taken:  0.0028221607208251953
solv

solving Ax = b time taken:  0.4059290885925293
l2 error 0.000186, h1 error 0.010160
epoch:  277	argmax batch num,  4
argmax time taken,  0.16365313529968262
total size: 277 3375000 = 934875000
num batches:  2
assembling the mass matrix time taken:  0.007270336151123047
solving Ax = b time taken:  0.39104580879211426
l2 error 0.000185, h1 error 0.010140
epoch:  278	argmax batch num,  4
argmax time taken,  0.04396772384643555
total size: 278 3375000 = 938250000
num batches:  2
assembling the mass matrix time taken:  0.009063482284545898
solving Ax = b time taken:  0.478161096572876
l2 error 0.000184, h1 error 0.010117
epoch:  279	argmax batch num,  4
argmax time taken,  0.4161972999572754
total size: 279 3375000 = 941625000
num batches:  2
assembling the mass matrix time taken:  0.003484010696411133
solving Ax = b time taken:  0.3965418338775635
l2 error 0.000184, h1 error 0.010102
epoch:  280	argmax batch num,  4
argmax time taken,  0.044303178787231445
total size: 280 3375000 = 9450000

total size: 307 3375000 = 1036125000
num batches:  2
assembling the mass matrix time taken:  0.010316848754882812
solving Ax = b time taken:  0.42138123512268066
l2 error 0.000159, h1 error 0.009011
epoch:  308	argmax batch num,  4
argmax time taken,  0.04783797264099121
total size: 308 3375000 = 1039500000
num batches:  2
assembling the mass matrix time taken:  0.009149789810180664
solving Ax = b time taken:  0.42466139793395996
l2 error 0.000159, h1 error 0.008997
epoch:  309	argmax batch num,  4
argmax time taken,  0.39443492889404297
total size: 309 3375000 = 1042875000
num batches:  2
assembling the mass matrix time taken:  0.0037889480590820312
solving Ax = b time taken:  0.4232609272003174
l2 error 0.000157, h1 error 0.008870
epoch:  310	argmax batch num,  4
argmax time taken,  0.04824972152709961
total size: 310 3375000 = 1046250000
num batches:  2
assembling the mass matrix time taken:  0.009220361709594727
solving Ax = b time taken:  0.4265477657318115
l2 error 0.000157, h1 e

solving Ax = b time taken:  0.5507969856262207
l2 error 0.000135, h1 error 0.007883
epoch:  338	argmax batch num,  4
argmax time taken,  0.05232739448547363
total size: 338 3375000 = 1140750000
num batches:  3
assembling the mass matrix time taken:  0.004706144332885742
solving Ax = b time taken:  0.6525266170501709
l2 error 0.000135, h1 error 0.007836
epoch:  339	argmax batch num,  4
argmax time taken,  0.2255549430847168
total size: 339 3375000 = 1144125000
num batches:  3
assembling the mass matrix time taken:  0.0034868717193603516
solving Ax = b time taken:  0.5524611473083496
l2 error 0.000135, h1 error 0.007825
epoch:  340	argmax batch num,  4
argmax time taken,  0.052580833435058594
total size: 340 3375000 = 1147500000
num batches:  3
assembling the mass matrix time taken:  0.0036973953247070312
solving Ax = b time taken:  0.5085654258728027
l2 error 0.000134, h1 error 0.007802
epoch:  341	argmax batch num,  4
argmax time taken,  0.053635597229003906
total size: 341 3375000 = 1

total size: 368 3375000 = 1242000000
num batches:  3
assembling the mass matrix time taken:  0.010232210159301758
solving Ax = b time taken:  0.5739295482635498
l2 error 0.000121, h1 error 0.007178
epoch:  369	argmax batch num,  4
argmax time taken,  0.37646031379699707
total size: 369 3375000 = 1245375000
num batches:  3
assembling the mass matrix time taken:  0.003633260726928711
solving Ax = b time taken:  0.5784952640533447
l2 error 0.000120, h1 error 0.007153
epoch:  370	argmax batch num,  4
argmax time taken,  0.05654406547546387
total size: 370 3375000 = 1248750000
num batches:  3
assembling the mass matrix time taken:  0.010188579559326172
solving Ax = b time taken:  0.5791749954223633
l2 error 0.000120, h1 error 0.007148
epoch:  371	argmax batch num,  4
argmax time taken,  0.06007695198059082
total size: 371 3375000 = 1252125000
num batches:  3
assembling the mass matrix time taken:  0.010071039199829102
solving Ax = b time taken:  0.5809073448181152
l2 error 0.000120, h1 erro

solving Ax = b time taken:  0.6719799041748047
l2 error 0.000103, h1 error 0.006296
epoch:  399	argmax batch num,  4
argmax time taken,  0.19944381713867188
total size: 399 3375000 = 1346625000
num batches:  3
assembling the mass matrix time taken:  0.01022481918334961
solving Ax = b time taken:  0.6392042636871338
l2 error 0.000102, h1 error 0.006260
epoch:  400	argmax batch num,  4
argmax time taken,  0.2446765899658203
total size: 400 3375000 = 1350000000
num batches:  3
assembling the mass matrix time taken:  0.0034949779510498047
solving Ax = b time taken:  0.6678256988525391
l2 error 0.000101, h1 error 0.006241
epoch:  401	argmax batch num,  4
argmax time taken,  0.27631640434265137
total size: 401 3375000 = 1353375000
num batches:  3
assembling the mass matrix time taken:  0.0035529136657714844
solving Ax = b time taken:  0.6421823501586914
l2 error 0.000101, h1 error 0.006235
epoch:  402	argmax batch num,  4
argmax time taken,  0.06701159477233887
total size: 402 3375000 = 1356

argmax time taken,  0.260988712310791
total size: 429 3375000 = 1447875000
num batches:  3
assembling the mass matrix time taken:  0.003536701202392578
solving Ax = b time taken:  0.6749138832092285
l2 error 0.000088, h1 error 0.005578
epoch:  430	argmax batch num,  4
argmax time taken,  0.2708752155303955
total size: 430 3375000 = 1451250000
num batches:  3
assembling the mass matrix time taken:  0.003550291061401367
solving Ax = b time taken:  0.7288651466369629
l2 error 0.000088, h1 error 0.005567
epoch:  431	argmax batch num,  4
argmax time taken,  0.0718233585357666
total size: 431 3375000 = 1454625000
num batches:  3
assembling the mass matrix time taken:  0.0035729408264160156
solving Ax = b time taken:  0.6753559112548828
l2 error 0.000087, h1 error 0.005543
epoch:  432	argmax batch num,  4
argmax time taken,  0.369596004486084
total size: 432 3375000 = 1458000000
num batches:  3
assembling the mass matrix time taken:  0.0035104751586914062
solving Ax = b time taken:  0.6954288

solving Ax = b time taken:  0.785332441329956
l2 error 0.000078, h1 error 0.005088
epoch:  460	argmax batch num,  4
argmax time taken,  0.07382822036743164
total size: 460 3375000 = 1552500000
num batches:  3
assembling the mass matrix time taken:  0.010565519332885742
solving Ax = b time taken:  0.7722632884979248
l2 error 0.000078, h1 error 0.005084
epoch:  461	argmax batch num,  4
argmax time taken,  0.07303404808044434
total size: 461 3375000 = 1555875000
num batches:  3
assembling the mass matrix time taken:  0.010335683822631836
solving Ax = b time taken:  0.7862858772277832
l2 error 0.000078, h1 error 0.005079
epoch:  462	argmax batch num,  4
argmax time taken,  0.07445478439331055
total size: 462 3375000 = 1559250000
num batches:  3
assembling the mass matrix time taken:  0.010364770889282227
solving Ax = b time taken:  0.7980732917785645
l2 error 0.000078, h1 error 0.005072
epoch:  463	argmax batch num,  4
argmax time taken,  0.0747532844543457
total size: 463 3375000 = 156262

total size: 490 3375000 = 1653750000
num batches:  4
assembling the mass matrix time taken:  0.011005163192749023
solving Ax = b time taken:  0.8076202869415283
l2 error 0.000070, h1 error 0.004642
epoch:  491	argmax batch num,  4
argmax time taken,  0.08140707015991211
total size: 491 3375000 = 1657125000
num batches:  4
assembling the mass matrix time taken:  0.011159658432006836
solving Ax = b time taken:  0.8101785182952881
l2 error 0.000069, h1 error 0.004626
epoch:  492	argmax batch num,  4
argmax time taken,  0.07673907279968262
total size: 492 3375000 = 1660500000
num batches:  4
assembling the mass matrix time taken:  0.011013507843017578
solving Ax = b time taken:  0.8099699020385742
l2 error 0.000069, h1 error 0.004615
epoch:  493	argmax batch num,  4
argmax time taken,  0.0794837474822998
total size: 493 3375000 = 1663875000
num batches:  4
assembling the mass matrix time taken:  0.010980606079101562
solving Ax = b time taken:  0.8123819828033447
l2 error 0.000069, h1 error

## gaussian example 

In [3]:
def u_exact(x):
    d = 3 
    cn =   7.03/d 
    return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True))  

def u_exact_grad():
    d = 3 
    def make_grad_i(i):
        def grad_i(x):
            d = 3  
            cn = 7.03/d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return grad_i 
    
    u_grad=[] 
    for i in range(d):
        u_grad.append(make_grad_i(i))
    return u_grad
                                                                
                                                                    
def alpha(x): 
    return torch.ones(x.size(0),1).to(device)
    # return 0.5 * torch.sin(6 * pi*x[:,0:1]) + 1. 

# def target(x):
# #     z = (  4 * (pi)**2 + 1)*torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
# #     return z 
#     z_c = torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
#     z1 = 3 * pi**2 * torch.sin(pi * x[:,0:1]) * torch.cos( 6*pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
#     z2 = 0.5 * pi**2 * torch.sin(6*pi * x[:,0:1])* z_c 
#     z = z1 + z2 + 3/2*pi**2 * torch.sin(6 * pi * x[:,0:1]) * z_c 
#     z += ( 4 * (pi)**2 + 1)*z_c 
#     return z 

def target(x):
    d = 3 
    cn =   7.03/d 
    z = torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    return z* ( -torch.sum(  (2 *cn**2 * (x - 0.5))**2 - 2*cn**2 ,dim = 1, keepdim = True) +1)

def g_N(dim):
    def make_g(i):
        def g_i(x):
            d = 3 
            cn = 7.03 / d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return g_i

    bcs_N = []
    for i in range(dim):
        bcs_N.append((i, make_g(i)))
    
    return bcs_N


function_name = "gaussian" 
filename_write = "data/3DOGA-{}-order.txt".format(function_name)
f_write = open(filename_write, "a")
f_write.write("\n")
f_write.close() 
save = False 
for N_list in [[2**2,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 

    f_write = open(filename_write, "a")
    my_model = None 
    Nx = 50   
    order = 3   

    exponent = 8 
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    relu_k = 2 
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq,Nx = Nx, order = order, k = relu_k, rand_deter= 'rand', linear_solver = "direct")
    
    if save: 
        folder = 'data/'
        filename = folder + 'err_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)

    show_convergence_order2(err_QMC2,err_h10,exponent,N,filename_write,write2file = write2file)
    show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)



using linear solver:  direct
epoch:  1	assembling the mass matrix time taken:  0.004244804382324219
solving Ax = b time taken:  0.1321420669555664
l2 error 0.274375, h1 error 2.522434
epoch:  2	assembling the mass matrix time taken:  0.005568742752075195
solving Ax = b time taken:  0.004712104797363281
l2 error 0.213882, h1 error 2.411771
epoch:  3	assembling the mass matrix time taken:  0.0041046142578125
solving Ax = b time taken:  0.005236625671386719
l2 error 0.204598, h1 error 2.114279
epoch:  4	assembling the mass matrix time taken:  0.0047397613525390625
solving Ax = b time taken:  0.005539417266845703
l2 error 0.187524, h1 error 2.010464
epoch:  5	assembling the mass matrix time taken:  0.004512786865234375
solving Ax = b time taken:  0.005780696868896484
l2 error 0.201255, h1 error 1.951563
epoch:  6	assembling the mass matrix time taken:  0.0047681331634521484
solving Ax = b time taken:  0.00668644905090332
l2 error 0.167274, h1 error 1.863414
epoch:  7	assembling the mass ma

epoch:  53	assembling the mass matrix time taken:  0.017168521881103516
solving Ax = b time taken:  0.04207324981689453
l2 error 0.009000, h1 error 0.234303
epoch:  54	assembling the mass matrix time taken:  0.004387378692626953
solving Ax = b time taken:  0.042840003967285156
l2 error 0.008760, h1 error 0.229233
epoch:  55	assembling the mass matrix time taken:  0.017067670822143555
solving Ax = b time taken:  0.04342222213745117
l2 error 0.008520, h1 error 0.223530
epoch:  56	assembling the mass matrix time taken:  0.0171816349029541
solving Ax = b time taken:  0.043412208557128906
l2 error 0.008266, h1 error 0.218048
epoch:  57	assembling the mass matrix time taken:  0.016373395919799805
solving Ax = b time taken:  0.04475760459899902
l2 error 0.007885, h1 error 0.212589
epoch:  58	assembling the mass matrix time taken:  0.01629638671875
solving Ax = b time taken:  0.04501008987426758
l2 error 0.007483, h1 error 0.207104
epoch:  59	assembling the mass matrix time taken:  0.016647100

epoch:  106	assembling the mass matrix time taken:  0.004843235015869141
solving Ax = b time taken:  0.08366155624389648
l2 error 0.002967, h1 error 0.108168
epoch:  107	assembling the mass matrix time taken:  0.0048754215240478516
solving Ax = b time taken:  0.08427834510803223
l2 error 0.002941, h1 error 0.107546
epoch:  108	assembling the mass matrix time taken:  0.004924774169921875
solving Ax = b time taken:  0.08354330062866211
l2 error 0.002922, h1 error 0.106939
epoch:  109	assembling the mass matrix time taken:  0.004589080810546875
solving Ax = b time taken:  0.08572793006896973
l2 error 0.002891, h1 error 0.105978
epoch:  110	assembling the mass matrix time taken:  0.00465846061706543
solving Ax = b time taken:  0.08625912666320801
l2 error 0.002825, h1 error 0.104590
epoch:  111	assembling the mass matrix time taken:  0.0048999786376953125
solving Ax = b time taken:  0.08682680130004883
l2 error 0.002750, h1 error 0.102649
epoch:  112	assembling the mass matrix time taken: 

epoch:  158	assembling the mass matrix time taken:  0.18477892875671387
solving Ax = b time taken:  0.09980916976928711
l2 error 0.001650, h1 error 0.069962
epoch:  159	assembling the mass matrix time taken:  0.010071039199829102
solving Ax = b time taken:  0.13744473457336426
l2 error 0.001616, h1 error 0.069251
epoch:  160	assembling the mass matrix time taken:  0.005541324615478516
solving Ax = b time taken:  0.1387178897857666
l2 error 0.001572, h1 error 0.068309
epoch:  161	assembling the mass matrix time taken:  0.00565791130065918
solving Ax = b time taken:  0.14049720764160156
l2 error 0.001556, h1 error 0.067910
epoch:  162	assembling the mass matrix time taken:  0.007859468460083008
solving Ax = b time taken:  0.1388721466064453
l2 error 0.001539, h1 error 0.067381
epoch:  163	assembling the mass matrix time taken:  0.22953104972839355
solving Ax = b time taken:  0.07926440238952637
l2 error 0.001520, h1 error 0.067023
epoch:  164	assembling the mass matrix time taken:  0.009

l2 error 0.001030, h1 error 0.050361
epoch:  211	assembling the mass matrix time taken:  0.008592605590820312
solving Ax = b time taken:  0.1985642910003662
l2 error 0.001024, h1 error 0.050189
epoch:  212	assembling the mass matrix time taken:  0.3036653995513916
solving Ax = b time taken:  0.1352989673614502
l2 error 0.001019, h1 error 0.049999
epoch:  213	assembling the mass matrix time taken:  0.00924372673034668
solving Ax = b time taken:  0.19916582107543945
l2 error 0.001013, h1 error 0.049817
epoch:  214	assembling the mass matrix time taken:  0.006047725677490234
solving Ax = b time taken:  0.20202136039733887
l2 error 0.001014, h1 error 0.049608
epoch:  215	assembling the mass matrix time taken:  0.23576831817626953
solving Ax = b time taken:  0.11326241493225098
l2 error 0.001010, h1 error 0.049514
epoch:  216	assembling the mass matrix time taken:  0.00902414321899414
solving Ax = b time taken:  0.19700837135314941
l2 error 0.001007, h1 error 0.049396
epoch:  217	assembling