In [3]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
import os 
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu"    


torch.set_default_dtype(torch.float64)
pi = torch.tensor(np.pi,dtype=torch.float64)
ZERO = torch.tensor([0.]).to(device)
torch.set_printoptions(precision=6)

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    def evaluate_derivative(self, x, i):
        if self.k == 1:
            u1 = self.fc2(torch.heaviside(self.fc1(x),ZERO) * self.fc1.weight.t()[i-1:i,:] )
        else:
            u1 = self.fc2(self.k*F.relu(self.fc1(x))**(self.k-1) *self.fc1.weight.t()[i-1:i,:] )  
        return u1

def adjust_neuron_position(my_model, dims = 3):

    def create_mesh_grid(dims, pts):
        mesh = torch.tensor(list(itertools.product(pts,repeat=dims)))
        vertices = mesh.reshape(len(pts) ** dims, -1) 
        return vertices
    counter = 0 
    # positions = torch.tensor([[0.,0.],[0.,1.],[1.,1.],[1.,0.]])
    pts = torch.tensor([0.,1.])
    positions = create_mesh_grid(dims,pts) 
    neuron_num = my_model.fc1.bias.size(0)
    for i in range(neuron_num): 
        w = my_model.fc1.weight.data[i:i+1,:]
        b = my_model.fc1.bias.data[i]
    #     print(w,b)
        values = torch.matmul(positions,w.T) # + b
        left_end = - torch.max(values)
        right_end = - torch.min(values)
        offset = (right_end - left_end)/50
        if b <= left_end + offset/2 : 
            b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
            my_model.fc1.bias.data[i] = b 
        if b >= right_end - offset/2 :
            if counter < (dims+1):
#                 print("here")
                counter += 1
            else: # (d + 1) or more 
                b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
                my_model.fc1.bias.data[i] = b 
    return my_model



In [7]:
def show_convergence_order2(err_l2,err_h10,exponent,dict_size, filename,write2file = False):
    
    if write2file:
        file_mode = "a" if os.path.exists(filename) else "w"
        f_write = open(filename, file_mode)
    
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    # f_write.write('M:{}, relu {} \n'.format(M,k))
    if write2file:
        f_write.write('dictionary size: {}\n'.format(dict_size))
        f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    if write2file:     
        f_write.write("\n")
        f_write.close()

def show_convergence_order_latex2(err_l2,err_h10,exponent,k=1,d=1): 
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    l2_order = -1/2-(2*k + 1)/(2*d)
    h1_order =  -1/2-(2*(k-1)+ 1)/(2*d)
    print("neuron num  & \t $\|u-u_n \|_{{L^2}}$ & \t order $O(n^{{{:.2f}}})$  & \t $ | u -u_n |_{{H^1}}$ & \t order $O(n^{{{:.2f}}})$  \\\ \hline \hline ".format(l2_order,h1_order))
    for i, item in enumerate(err_list):
        if i == 0: 
            print("{} \t\t & {:.6f} &\t\t * & \t\t {:.6f} & \t\t *  \\\ \hline  \n".format(neuron_nums[i],item, err_list2[i] ) )   
        else: 
            print("{} \t\t &  {:.3e} &  \t\t {:.2f} &  \t\t {:.3e} &  \t\t {:.2f} \\\ \hline  \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )


In [8]:
def MonteCarlo_Sobol_dDim_weights_points(M ,d = 4):
    Sob_integral = torch.quasirandom.SobolEngine(dimension =d, scramble= False, seed=None) 
    integration_points = Sob_integral.draw(M).double() 
    integration_points = integration_points.to(device)
    weights = torch.ones(M,1).to(device)/M 
    return weights, integration_points 

def generate_relu_dict4D(N_list):
    
    N = np.prod(N_list) 

    grid_indices = [np.linspace(0,1,N_item,endpoint=False) for N_item in N_list]
    grid = np.meshgrid(*grid_indices,indexing='ij')
    grid_coordinates = np.column_stack([axis.ravel() for axis in grid]) 
    samples = torch.tensor(grid_coordinates) 

    T =torch.tensor([[pi,0,0,0],[0,pi,0,0],[0,0,2*pi,0],[0,0,0,2*2]]) # 2 * sqrt(d)
    shift = torch.tensor([0,0,0,-2])
    samples = samples@T + shift 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)
    f5 = torch.zeros(N,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.cos(samples[:,2])
    f4[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.sin(samples[:,2])  
    f5[:,0] = samples[:,3]

    Wb_tensor = torch.cat([f1,f2,f3,f4,f5],1) # N x 4 
    return Wb_tensor

def generate_relu_dict4plusD_QMC(dim, s,N0):
    # s*N0 is the total number of samples 
    samples = torch.rand(s*N0,dim)  

    # for i in range(s-1):
        # samples = torch.cat([samples,Sob.draw(N0).double()],0)
    # Form the transformation matrix and shift vector 
    diagonal = torch.ones(dim)*pi  
    diagonal[-1] =  2*dim**0.5
    diagonal[-2] = 2*pi 
    T = torch.diag(diagonal)

    shift = torch.zeros(dim)
    shift[-1] = -dim**0.5 
    samples = samples@T + shift 

    Wb_tensor = torch.ones(s*N0,dim+1) # each neuron parameter stored in rows  
    for i in range(dim): # 0, 1, ... dim-1 
        for j in range(i+1): # 0, 1, ... i 
            if i == 0: 
                Wb_tensor[:,i] = Wb_tensor[:,i]*torch.cos(samples[:,j])
            if i == (dim - 1):
                if j != i:
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.sin(samples[:,j]) 
            if i != 0 and i != (dim - 1): 
                if j != i: 
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.sin(samples[:,j]) 
                else: 
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.cos(samples[:,j]) 
            
    Wb_tensor[:,dim] = samples[:,-1] 

    return Wb_tensor.to(device)

def minimize_linear_layer_H1_explicit_assemble_efficient(model,alpha, target, g_N, weights, integration_points, w_bd, pts_bd, activation = 'relu',solver="direct",memory = 2**29 ):
    """ -div alpha grad u(x) + u = f 
    Parameters
    ----------
    model: 
        nn model
    alpha:
        alpha function
    target:
        rhs function f 
    pts_bd:
        integration points on the boundary, embdedded in the domain 
    """ 
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    neuron_num = b.size(0) 
    dim = integration_points.size(1) 
    M = integration_points.size(0)
    coef_alpha = alpha(integration_points) # alpha  
    
    total_size = neuron_num * M # memory, number of floating numbers 
    print('total size: {} {} = {}'.format(neuron_num,M,total_size))
    num_batch = total_size//memory + 1 # divide according to memory
    print("num batches: ",num_batch)
    batch_size = M//num_batch
    start_ind = 0
    end_ind = 0 
    jac = torch.zeros(b.size(0),b.size(0)).to(device)
    rhs = torch.zeros(b.size(0),1).to(device)

    for j in range(0,M,batch_size): # batch operation in data points 
        end_ind = j + batch_size
        basis_value_col = F.relu(integration_points[j:end_ind] @ w.t()+ b)**(model.k) 
        weighted_basis_value_col = basis_value_col * weights[j:end_ind] 
        jac += weighted_basis_value_col.t() @ basis_value_col 
        rhs += weighted_basis_value_col.t() @ (target(integration_points[j:end_ind,:])) 

    # Assemble the boundary condition term <g,v>_{\Gamma_N} 
    size_pts_bd = int(pts_bd.size(0)/(2*dim))
    # M_bc = size_pts_bd 
    # total_size = M_bc * neuron_num 
    # num_batch = total_size//memory + 1 
    # batch_size = M_bc//num_batch
    if g_N != None:
        bcs_N = g_N(dim)
        for ii, g_ii in bcs_N:
            weighted_g_N = -g_ii(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* w_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N

            weighted_g_N = g_ii(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* w_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N
            
    # Stiffness matrix term in the jacobian 
    for d in range(dim):
        end_ind = 0 
        if model.k == 1:  
            for j in range(0,M,batch_size): 
                end_ind = j + batch_size 
                basis_value_dxi_col = torch.heaviside(integration_points[j:end_ind] @ w.t()+ b, ZERO) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_ind] * coef_alpha[j:end_ind] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = torch.heaviside(integration_points @ w.t()+ b, zero) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha 
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

        else: 
            for j in range(0,M,batch_size):  
                end_ind = j + batch_size 
                basis_value_dxi_col = model.k * F.relu(integration_points[j:end_ind] @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_ind] * coef_alpha[j:end_ind] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha  
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

    print("assembling the mass matrix time taken: ", time.time()-start_time) 

    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 


def OGANeumannReLU10D(my_model,alpha, target,g_N, u_exact, u_exact_grad, N_list,num_epochs,plot_freq, M, k =1, rand_deter = 'deter', linear_solver = "direct",memory = 2**29): 
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    dim = 10 
    gw_expand, integration_points = MonteCarlo_Sobol_dDim_weights_points(M, d=dim)
    gw_expand = gw_expand.to(device)
    integration_points = integration_points.to(device)

    # define integration on the boundary 
    gw_expand_bd, integration_points_bd = MonteCarlo_Sobol_dDim_weights_points(M//10, d=dim-1)
    size_pts_bd = integration_points_bd.size(0) 
    gw_expand_bd_faces = torch.tile(gw_expand_bd,(2*dim,1))
    
    integration_points_bd_faces = torch.zeros(2*dim*integration_points_bd.size(0),dim).to(device)
    for ind in range(dim): 
        integration_points_bd_faces[2 *ind * size_pts_bd :(2 *ind +1) * size_pts_bd,ind:ind+1] = 0 
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,:ind] = integration_points_bd[:,:ind]
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind:ind+1] = 1
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,:ind] = integration_points_bd[:,:ind]        
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind+1:] = integration_points_bd[:,ind:]


    err = torch.zeros(num_epochs+1)
    err_h10 = torch.zeros(num_epochs+1).to(device) 
    if my_model == None: 
        func_values = - target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []
    else: 
        func_values = - (target(integration_points) + my_model(integration_points).detach())
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    ## l2 error 
    func_values_sqrd = func_values*func_values
    err[0]= torch.sum(func_values_sqrd*gw_expand)**0.5
    ## h1 seminorm 
    if u_exact_grad != None:
        u_grad = u_exact_grad() 
        for grad_i in u_grad: 
            err_h10[0] += torch.sum((grad_i(integration_points))**2 * gw_expand)**0.5
    
    start_time = time.time()
    solver = linear_solver

    N0 = np.prod(N_list)
    if rand_deter == 'deter':
#         relu_dict_parameters = generate_relu_dict4D(N_list).to(device)
        assert rand_deter == "rand", "no deterministic dictionary, dimension"+str(dim)+"too large"
    
    print("using linear solver: ",solver)
    M2 = integration_points.size(0) # add 
    for i in range(num_epochs): 
        start_time = time.time()
        print("epoch: ",i+1, end = '\t')
        if rand_deter == 'rand':
            relu_dict_parameters = generate_relu_dict4plusD_QMC(dim, 1,N0).to(device) 
        if num_neuron == 0: 
            func_values = - target(integration_points)
        else: 
            func_values = - target(integration_points) + my_model(integration_points).detach()

        weight_func_values = func_values*gw_expand 

        ### ======================= 
        total_size = M2 * N0 
        num_batch = total_size//memory + 1 
        batch_size = N0//num_batch
        output = torch.zeros(N0,1).to(device)
        print("argmax batch num, ", num_batch) 
        
        for j in range(0,N0,batch_size):  
            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k).T # uses broadcasting    
            output[j:end_index,:]  = torch.matmul(basis_values_batch,weight_func_values)[:,:] 
        
        # grad u part
        alpha_coef = alpha(integration_points) # alpha 
        if my_model!= None:
            if k == 1:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = torch.heaviside(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim], ZERO) # dimension 4 
                    derivative_part *= alpha_coef # alpha 
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) *gw_expand   
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 


            else:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = k * F.relu(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim])**(k-1) # dimension 4 
                    derivative_part *= alpha_coef # alpha
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) * gw_expand    
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

        
        #Boundary condition term -<g,v>_{\Gamma_N}  
        M_bc = size_pts_bd 
        total_size = M_bc * N0 
        num_batch = total_size//memory + 1 
        batch_size = N0//num_batch
        if g_N != None:
            bcs_N = g_N(dim) 
            for ii, g_ii in bcs_N: 
                
                weighted_g_N = -g_ii(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* gw_expand_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                ## todo 
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:],relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k).T
                    output[j:end_index,:] -= torch.matmul(basis_values_bd_faces,weighted_g_N)
                # basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                # output -= torch.matmul(basis_values_bd_faces,weighted_g_N)
                
                weighted_g_N = g_ii(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* gw_expand_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                ## todo  
                for j in range(0,N0,batch_size): 
                    end_index = j + batch_size 
                    basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:],relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k).T
                    output[j:end_index,:] -= torch.matmul(basis_values_bd_faces,weighted_g_N)

                # basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                # output -= torch.matmul(basis_values_bd_faces,weighted_g_N)
        
        # output = torch.abs(torch.matmul(basis_values,weight_func_values)) # 
        output = torch.abs(output)
        neuron_index = torch.argmax(output.flatten())
        print("argmax time taken, ", time.time() - start_time)
        # print(neuron_index)
        list_w.append(relu_dict_parameters[neuron_index,0:dim]) # dimension 4 
        list_b.append(-relu_dict_parameters[neuron_index,dim])
        num_neuron += 1
        my_model = model(dim,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo Done 
#         alpha = None 
        sol = minimize_linear_layer_H1_explicit_assemble_efficient(my_model,alpha, target, g_N, gw_expand, integration_points,gw_expand_bd_faces, integration_points_bd_faces,activation = 'relu',solver = solver)

        my_model.fc2.weight.data[0,:] = sol[:]
        with torch.no_grad():
            model_values = my_model(integration_points)
        # L2 error ||u - u_n||
        diff_values_sqrd = (u_exact(integration_points) - model_values)**2 
        err[i+1]= torch.sum(diff_values_sqrd*gw_expand)**0.5

        # H10 error || grad(u) - grad(u_n) ||
        if u_exact_grad != None:
            for ind, grad_i in enumerate(u_grad):  
                with torch.no_grad():
                    my_model_dxi = my_model.evaluate_derivative(integration_points,ind+1).detach() 
                err_h10[i+1] += torch.sum((grad_i(integration_points) - my_model_dxi)**2 * gw_expand)**0.5
        print("l2 error {:.6f}, h1 error {:.6f}".format(err[i+1],err_h10[i+1]))
        print()
    print("time taken: ",time.time() - start_time)
    return err, err_h10.cpu(), my_model



In the following tests, we compare using deterministic dictionaries with using random dictionary for the following three target functions. 

- $\sin(\pi x_1) \sin(\pi x_2)$ 
- $\sin(4\pi x_1) \sin(8\pi x_2)$ 
- Gabor function 

In [10]:
def u_exact(x):
    d = 10  
    cn =   7.03/d 
    return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True))  

def u_exact_grad():

    def make_grad_i(i):
        def grad_i(x):
            d = 10  
            cn = 7.03/d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return grad_i 
    
    u_grad=[] 
    for i in range(10):
        u_grad.append(make_grad_i(i))
    return u_grad
                                                                
                                                                    
def alpha(x): 
    return torch.ones(x.size(0),1).to(device)

def target(x):
    d = 10 
    cn =   7.03/d 
    z = torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    return z* ( -torch.sum(  (2 *cn**2 * (x - 0.5))**2 - 2*cn**2 ,dim = 1, keepdim = True) +1)

def g_N(dim):
    def make_g(i):
        def g_i(x):
            d = 10  
            cn = 7.03 / d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return g_i

    bcs_N = []
    for i in range(dim):
        bcs_N.append((i, make_g(i)))
    
    return bcs_N

dim = 10 
function_name = "gaussian" 
filename_write = "data/10DOGA-{}-order.txt".format(function_name)
M = int(2**19) 
f_write = open(filename_write, "a")
f_write.write("Numerical Integration MC points: {} \n".format(M))
f_write.close() 
save = True 
write2file = True
rand_deter = 'rand'

for N_list in [[2**10]]: # ,[2**6,2**6],[2**7,2**7] 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 
    exponent = 9 
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    relu_k = 4 
    err_QMC2, err_h10, my_model = OGANeumannReLU10D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, M, k = relu_k, rand_deter= rand_deter, linear_solver = "direct")
    
    if save: 
        folder = 'data/'
        filename = folder + 'err_OGA_10D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_OGA_10D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)
        
    show_convergence_order2(err_QMC2,err_h10,exponent,N,filename_write,write2file = write2file)
    show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)


using linear solver:  direct
epoch:  1	argmax batch num,  2
argmax time taken,  0.03701353073120117
total size: 1 524288 = 524288
num batches:  1
assembling the mass matrix time taken:  0.0060312747955322266
solving Ax = b time taken:  0.00018405914306640625
l2 error 0.386720, h1 error 2.107200

epoch:  2	argmax batch num,  2
argmax time taken,  0.0413057804107666
total size: 2 524288 = 1048576
num batches:  1
assembling the mass matrix time taken:  0.0075109004974365234
solving Ax = b time taken:  0.0015666484832763672
l2 error 0.257125, h1 error 2.122811

epoch:  3	argmax batch num,  2
argmax time taken,  0.03828740119934082
total size: 3 524288 = 1572864
num batches:  1
assembling the mass matrix time taken:  0.006812572479248047
solving Ax = b time taken:  0.0018534660339355469
l2 error 0.196146, h1 error 2.100717

epoch:  4	argmax batch num,  2
argmax time taken,  0.038239479064941406
total size: 4 524288 = 2097152
num batches:  1
assembling the mass matrix time taken:  0.00715708

total size: 32 524288 = 16777216
num batches:  1
assembling the mass matrix time taken:  0.01963639259338379
solving Ax = b time taken:  0.015250921249389648
l2 error 0.057717, h1 error 1.073740

epoch:  33	argmax batch num,  2
argmax time taken,  0.06463003158569336
total size: 33 524288 = 17301504
num batches:  1
assembling the mass matrix time taken:  0.019628286361694336
solving Ax = b time taken:  0.01607203483581543
l2 error 0.056793, h1 error 1.052732

epoch:  34	argmax batch num,  2
argmax time taken,  0.05412626266479492
total size: 34 524288 = 17825792
num batches:  1
assembling the mass matrix time taken:  0.006832599639892578
solving Ax = b time taken:  0.0156705379486084
l2 error 0.056269, h1 error 1.038782

epoch:  35	argmax batch num,  2
argmax time taken,  0.04066658020019531
total size: 35 524288 = 18350080
num batches:  1
assembling the mass matrix time taken:  0.006727933883666992
solving Ax = b time taken:  0.017001628875732422
l2 error 0.054811, h1 error 1.022528



total size: 63 524288 = 33030144
num batches:  1
assembling the mass matrix time taken:  0.007050037384033203
solving Ax = b time taken:  0.02982163429260254
l2 error 0.021455, h1 error 0.464228

epoch:  64	argmax batch num,  2
argmax time taken,  0.04116559028625488
total size: 64 524288 = 33554432
num batches:  1
assembling the mass matrix time taken:  0.006968498229980469
solving Ax = b time taken:  0.030119895935058594
l2 error 0.020773, h1 error 0.454340

epoch:  65	argmax batch num,  2
argmax time taken,  0.041120052337646484
total size: 65 524288 = 34078720
num batches:  1
assembling the mass matrix time taken:  0.00667119026184082
solving Ax = b time taken:  0.03687596321105957
l2 error 0.020173, h1 error 0.442320

epoch:  66	argmax batch num,  2
argmax time taken,  0.0411074161529541
total size: 66 524288 = 34603008
num batches:  1
assembling the mass matrix time taken:  0.007039546966552734
solving Ax = b time taken:  0.03758645057678223
l2 error 0.019469, h1 error 0.433239



total size: 94 524288 = 49283072
num batches:  1
assembling the mass matrix time taken:  0.0070934295654296875
solving Ax = b time taken:  0.053807735443115234
l2 error 0.011679, h1 error 0.300405

epoch:  95	argmax batch num,  2
argmax time taken,  0.043485403060913086
total size: 95 524288 = 49807360
num batches:  1
assembling the mass matrix time taken:  0.0067138671875
solving Ax = b time taken:  0.05333733558654785
l2 error 0.011447, h1 error 0.296838

epoch:  96	argmax batch num,  2
argmax time taken,  0.04183387756347656
total size: 96 524288 = 50331648
num batches:  1
assembling the mass matrix time taken:  0.007111549377441406
solving Ax = b time taken:  0.04956793785095215
l2 error 0.011198, h1 error 0.292853

epoch:  97	argmax batch num,  2
argmax time taken,  0.041672706604003906
total size: 97 524288 = 50855936
num batches:  1
assembling the mass matrix time taken:  0.006546497344970703
solving Ax = b time taken:  0.05324864387512207
l2 error 0.011088, h1 error 0.291725

e

total size: 125 524288 = 65536000
num batches:  1
assembling the mass matrix time taken:  0.006627559661865234
solving Ax = b time taken:  0.06675529479980469
l2 error 0.008792, h1 error 0.253500

epoch:  126	argmax batch num,  2
argmax time taken,  0.04093289375305176
total size: 126 524288 = 66060288
num batches:  1
assembling the mass matrix time taken:  0.007048130035400391
solving Ax = b time taken:  0.06939387321472168
l2 error 0.008775, h1 error 0.253103

epoch:  127	argmax batch num,  2
argmax time taken,  0.04092741012573242
total size: 127 524288 = 66584576
num batches:  1
assembling the mass matrix time taken:  0.0066564083099365234
solving Ax = b time taken:  0.06754350662231445
l2 error 0.008719, h1 error 0.252224

epoch:  128	argmax batch num,  2
argmax time taken,  0.04097270965576172
total size: 128 524288 = 67108864
num batches:  1
assembling the mass matrix time taken:  0.01989603042602539
solving Ax = b time taken:  0.0671541690826416
l2 error 0.008683, h1 error 0.25

argmax time taken,  0.04441332817077637
total size: 156 524288 = 81788928
num batches:  1
assembling the mass matrix time taken:  0.006920337677001953
solving Ax = b time taken:  0.0944364070892334
l2 error 0.007501, h1 error 0.232390

epoch:  157	argmax batch num,  2
argmax time taken,  0.044724464416503906
total size: 157 524288 = 82313216
num batches:  1
assembling the mass matrix time taken:  0.006536245346069336
solving Ax = b time taken:  0.09430384635925293
l2 error 0.007479, h1 error 0.232020

epoch:  158	argmax batch num,  2
argmax time taken,  0.04460024833679199
total size: 158 524288 = 82837504
num batches:  1
assembling the mass matrix time taken:  0.006813764572143555
solving Ax = b time taken:  0.098388671875
l2 error 0.007461, h1 error 0.231772

epoch:  159	argmax batch num,  2
argmax time taken,  0.04462552070617676
total size: 159 524288 = 83361792
num batches:  1
assembling the mass matrix time taken:  0.0065746307373046875
solving Ax = b time taken:  0.0951316356658

l2 error 0.006923, h1 error 0.224426

epoch:  187	argmax batch num,  2
argmax time taken,  0.04526257514953613
total size: 187 524288 = 98041856
num batches:  1
assembling the mass matrix time taken:  0.006667137145996094
solving Ax = b time taken:  0.1067805290222168
l2 error 0.006915, h1 error 0.224325

epoch:  188	argmax batch num,  2
argmax time taken,  0.04524540901184082
total size: 188 524288 = 98566144
num batches:  1
assembling the mass matrix time taken:  0.007062435150146484
solving Ax = b time taken:  0.10886812210083008
l2 error 0.006903, h1 error 0.224116

epoch:  189	argmax batch num,  2
argmax time taken,  0.04536032676696777
total size: 189 524288 = 99090432
num batches:  1
assembling the mass matrix time taken:  0.006708621978759766
solving Ax = b time taken:  0.10763764381408691
l2 error 0.006875, h1 error 0.223802

epoch:  190	argmax batch num,  2
argmax time taken,  0.04533743858337402
total size: 190 524288 = 99614720
num batches:  1
assembling the mass matrix tim

total size: 217 524288 = 113770496
num batches:  1
assembling the mass matrix time taken:  0.0066564083099365234
solving Ax = b time taken:  0.16094112396240234
l2 error 0.006642, h1 error 0.219492

epoch:  218	argmax batch num,  2
argmax time taken,  0.04286503791809082
total size: 218 524288 = 114294784
num batches:  1
assembling the mass matrix time taken:  0.007087230682373047
solving Ax = b time taken:  0.13832426071166992
l2 error 0.006637, h1 error 0.219259

epoch:  219	argmax batch num,  2
argmax time taken,  0.0429234504699707
total size: 219 524288 = 114819072
num batches:  1
assembling the mass matrix time taken:  0.0071027278900146484
solving Ax = b time taken:  0.16078472137451172
l2 error 0.006634, h1 error 0.219217

epoch:  220	argmax batch num,  2
argmax time taken,  0.042969465255737305
total size: 220 524288 = 115343360
num batches:  1
assembling the mass matrix time taken:  0.007140636444091797
solving Ax = b time taken:  0.13627982139587402
l2 error 0.006629, h1 err

l2 error 0.006506, h1 error 0.213333

epoch:  248	argmax batch num,  2
argmax time taken,  0.04348158836364746
total size: 248 524288 = 130023424
num batches:  1
assembling the mass matrix time taken:  0.006671905517578125
solving Ax = b time taken:  0.15656089782714844
l2 error 0.006500, h1 error 0.212611

epoch:  249	argmax batch num,  2
argmax time taken,  0.04346442222595215
total size: 249 524288 = 130547712
num batches:  1
assembling the mass matrix time taken:  0.006676912307739258
solving Ax = b time taken:  0.1571958065032959
l2 error 0.006490, h1 error 0.212360

epoch:  250	argmax batch num,  2
argmax time taken,  0.043488264083862305
total size: 250 524288 = 131072000
num batches:  1
assembling the mass matrix time taken:  0.006770133972167969
solving Ax = b time taken:  0.15730690956115723
l2 error 0.006484, h1 error 0.212244

epoch:  251	argmax batch num,  2
argmax time taken,  0.043504953384399414
total size: 251 524288 = 131596288
num batches:  1
assembling the mass matr

total size: 278 524288 = 145752064
num batches:  1
assembling the mass matrix time taken:  0.009807586669921875
solving Ax = b time taken:  0.18529725074768066
l2 error 0.006320, h1 error 0.202842

epoch:  279	argmax batch num,  2
argmax time taken,  0.04676079750061035
total size: 279 524288 = 146276352
num batches:  1
assembling the mass matrix time taken:  0.007429838180541992
solving Ax = b time taken:  0.23420929908752441
l2 error 0.006322, h1 error 0.202703

epoch:  280	argmax batch num,  2
argmax time taken,  0.04795575141906738
total size: 280 524288 = 146800640
num batches:  1
assembling the mass matrix time taken:  0.007909059524536133
solving Ax = b time taken:  0.18524575233459473
l2 error 0.006309, h1 error 0.201358

epoch:  281	argmax batch num,  2
argmax time taken,  0.04675102233886719
total size: 281 524288 = 147324928
num batches:  1
assembling the mass matrix time taken:  0.008116006851196289
solving Ax = b time taken:  0.23461055755615234
l2 error 0.006297, h1 error

solving Ax = b time taken:  0.2018589973449707
l2 error 0.005955, h1 error 0.186771

epoch:  309	argmax batch num,  2
argmax time taken,  0.04729580879211426
total size: 309 524288 = 162004992
num batches:  1
assembling the mass matrix time taken:  0.007451295852661133
solving Ax = b time taken:  0.20435047149658203
l2 error 0.005942, h1 error 0.186466

epoch:  310	argmax batch num,  2
argmax time taken,  0.04731035232543945
total size: 310 524288 = 162529280
num batches:  1
assembling the mass matrix time taken:  0.007796049118041992
solving Ax = b time taken:  0.20386695861816406
l2 error 0.005940, h1 error 0.186274

epoch:  311	argmax batch num,  2
argmax time taken,  0.050925493240356445
total size: 311 524288 = 163053568
num batches:  1
assembling the mass matrix time taken:  0.19172048568725586
solving Ax = b time taken:  0.1602485179901123
l2 error 0.005931, h1 error 0.185569

epoch:  312	argmax batch num,  2
argmax time taken,  0.048760175704956055
total size: 312 524288 = 1635

total size: 339 524288 = 177733632
num batches:  1
assembling the mass matrix time taken:  0.20041871070861816
solving Ax = b time taken:  0.21712279319763184
l2 error 0.005563, h1 error 0.168611

epoch:  340	argmax batch num,  2
argmax time taken,  0.04534339904785156
total size: 340 524288 = 178257920
num batches:  1
assembling the mass matrix time taken:  0.009178876876831055
solving Ax = b time taken:  0.2373661994934082
l2 error 0.005545, h1 error 0.167564

epoch:  341	argmax batch num,  2
argmax time taken,  0.04522132873535156
total size: 341 524288 = 178782208
num batches:  1
assembling the mass matrix time taken:  0.007292032241821289
solving Ax = b time taken:  0.2621488571166992
l2 error 0.005538, h1 error 0.167280

epoch:  342	argmax batch num,  2
argmax time taken,  0.045456886291503906
total size: 342 524288 = 179306496
num batches:  1
assembling the mass matrix time taken:  0.008649110794067383
solving Ax = b time taken:  0.2438364028930664
l2 error 0.005530, h1 error 0.

solving Ax = b time taken:  0.27397847175598145
l2 error 0.005170, h1 error 0.153872

epoch:  370	argmax batch num,  2
argmax time taken,  0.04595637321472168
total size: 370 524288 = 193986560
num batches:  1
assembling the mass matrix time taken:  0.008638858795166016
solving Ax = b time taken:  0.2736032009124756
l2 error 0.005136, h1 error 0.152905

epoch:  371	argmax batch num,  2
argmax time taken,  0.04595136642456055
total size: 371 524288 = 194510848
num batches:  1
assembling the mass matrix time taken:  0.4872257709503174
solving Ax = b time taken:  0.2279949188232422
l2 error 0.005115, h1 error 0.151779

epoch:  372	argmax batch num,  2
argmax time taken,  0.046065568923950195
total size: 372 524288 = 195035136
num batches:  1
assembling the mass matrix time taken:  0.008751869201660156
solving Ax = b time taken:  0.2741525173187256
l2 error 0.005104, h1 error 0.151263

epoch:  373	argmax batch num,  2
argmax time taken,  0.04593634605407715
total size: 373 524288 = 1955594

total size: 400 524288 = 209715200
num batches:  1
assembling the mass matrix time taken:  0.009222745895385742
solving Ax = b time taken:  0.315885066986084
l2 error 0.004725, h1 error 0.137203

epoch:  401	argmax batch num,  2
argmax time taken,  0.04686880111694336
total size: 401 524288 = 210239488
num batches:  1
assembling the mass matrix time taken:  0.008269548416137695
solving Ax = b time taken:  0.3058357238769531
l2 error 0.004702, h1 error 0.136601

epoch:  402	argmax batch num,  2
argmax time taken,  0.04687833786010742
total size: 402 524288 = 210763776
num batches:  1
assembling the mass matrix time taken:  0.2068629264831543
solving Ax = b time taken:  0.25124335289001465
l2 error 0.004693, h1 error 0.136316

epoch:  403	argmax batch num,  2
argmax time taken,  0.04687666893005371
total size: 403 524288 = 211288064
num batches:  1
assembling the mass matrix time taken:  0.05293083190917969
solving Ax = b time taken:  0.2823820114135742
l2 error 0.004671, h1 error 0.1358

solving Ax = b time taken:  0.33437299728393555
l2 error 0.004348, h1 error 0.126248

epoch:  431	argmax batch num,  2
argmax time taken,  0.050022125244140625
total size: 431 524288 = 225968128
num batches:  1
assembling the mass matrix time taken:  0.06447339057922363
solving Ax = b time taken:  0.27531981468200684
l2 error 0.004339, h1 error 0.125914

epoch:  432	argmax batch num,  2
argmax time taken,  0.05012154579162598
total size: 432 524288 = 226492416
num batches:  1
assembling the mass matrix time taken:  0.25324058532714844
solving Ax = b time taken:  0.27544069290161133
l2 error 0.004337, h1 error 0.125778

epoch:  433	argmax batch num,  2
argmax time taken,  0.052690982818603516
total size: 433 524288 = 227016704
num batches:  1
assembling the mass matrix time taken:  0.0077266693115234375
solving Ax = b time taken:  0.32352352142333984
l2 error 0.004309, h1 error 0.125231

epoch:  434	argmax batch num,  2
argmax time taken,  0.05012249946594238
total size: 434 524288 = 22

total size: 461 524288 = 241696768
num batches:  1
assembling the mass matrix time taken:  0.009204626083374023
solving Ax = b time taken:  0.3710513114929199
l2 error 0.003928, h1 error 0.113093

epoch:  462	argmax batch num,  2
argmax time taken,  0.047733306884765625
total size: 462 524288 = 242221056
num batches:  1
assembling the mass matrix time taken:  0.27526330947875977
solving Ax = b time taken:  0.31056833267211914
l2 error 0.003918, h1 error 0.112919

epoch:  463	argmax batch num,  2
argmax time taken,  0.04781508445739746
total size: 463 524288 = 242745344
num batches:  1
assembling the mass matrix time taken:  0.009284734725952148
solving Ax = b time taken:  0.371793270111084
l2 error 0.003900, h1 error 0.112524

epoch:  464	argmax batch num,  2
argmax time taken,  0.04795265197753906
total size: 464 524288 = 243269632
num batches:  1
assembling the mass matrix time taken:  0.009221315383911133
solving Ax = b time taken:  0.3663482666015625
l2 error 0.003898, h1 error 0.1

solving Ax = b time taken:  0.3850104808807373
l2 error 0.003645, h1 error 0.103081

epoch:  492	argmax batch num,  2
argmax time taken,  0.04840278625488281
total size: 492 524288 = 257949696
num batches:  1
assembling the mass matrix time taken:  0.24079251289367676
solving Ax = b time taken:  0.30164551734924316
l2 error 0.003643, h1 error 0.103003

epoch:  493	argmax batch num,  2
argmax time taken,  0.04868125915527344
total size: 493 524288 = 258473984
num batches:  1
assembling the mass matrix time taken:  0.008938789367675781
solving Ax = b time taken:  0.3861818313598633
l2 error 0.003643, h1 error 0.102947

epoch:  494	argmax batch num,  2
argmax time taken,  0.048447370529174805
total size: 494 524288 = 258998272
num batches:  1
assembling the mass matrix time taken:  0.2565915584564209
solving Ax = b time taken:  0.30173587799072266
l2 error 0.003641, h1 error 0.102814

epoch:  495	argmax batch num,  2
argmax time taken,  0.04975485801696777
total size: 495 524288 = 2595225