In [8]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
import os 
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu" 


torch.set_default_dtype(torch.float64)
pi = torch.tensor(np.pi)
ZERO = torch.tensor([0.]).to(device)

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    def evaluate_derivative(self, x, i):
        if self.k == 1:
            u1 = self.fc2(torch.heaviside(self.fc1(x),ZERO) * self.fc1.weight.t()[i-1:i,:] )
        else:
            u1 = self.fc2(self.k*F.relu(self.fc1(x))**(self.k-1) *self.fc1.weight.t()[i-1:i,:] )  
        return u1


def plot_subdomains(my_model):
    x_coord =torch.linspace(0,1,200)
    wi = my_model.fc1.weight.data
    bi = my_model.fc1.bias.data 
    for i, bias in enumerate(bi):  
        if wi[i,1] !=0: 
            plt.plot(x_coord, - wi[i,0]/wi[i,1]*x_coord - bias/wi[i,1])
        else: 
            plt.plot(x_coord,  - bias/wi[i,0]*torch.ones(x_coord.size()))

    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.legend()
    plt.show()
    return 0   

def adjust_neuron_position(my_model, dims = 3):

    def create_mesh_grid(dims, pts):
        mesh = torch.tensor(list(itertools.product(pts,repeat=dims)))
        vertices = mesh.reshape(len(pts) ** dims, -1) 
        return vertices
    counter = 0 
    # positions = torch.tensor([[0.,0.],[0.,1.],[1.,1.],[1.,0.]])
    pts = torch.tensor([0.,1.])
    positions = create_mesh_grid(dims,pts) 
    neuron_num = my_model.fc1.bias.size(0)
    for i in range(neuron_num): 
        w = my_model.fc1.weight.data[i:i+1,:]
        b = my_model.fc1.bias.data[i]
    #     print(w,b)
        values = torch.matmul(positions,w.T) # + b
        left_end = - torch.max(values)
        right_end = - torch.min(values)
        offset = (right_end - left_end)/50
        if b <= left_end + offset/2 : 
            b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
            my_model.fc1.bias.data[i] = b 
        if b >= right_end - offset/2 :
            if counter < (dims+1):
#                 print("here")
                counter += 1
            else: # (d + 1) or more 
                b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
                my_model.fc1.bias.data[i] = b 
    return my_model



In [9]:
def show_convergence_order2(err_l2,err_h10,exponent,dict_size, filename,write2file = False):
    
    if write2file:
        file_mode = "a" if os.path.exists(filename) else "w"
        f_write = open(filename, file_mode)
    
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    # f_write.write('M:{}, relu {} \n'.format(M,k))
    if write2file:
        f_write.write('dictionary size: {}\n'.format(dict_size))
        f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    if write2file:     
        f_write.write("\n")
        f_write.close()

def show_convergence_order_latex2(err_l2,err_h10,exponent,k=1,d=1): 
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    l2_order = -1/2-(2*k + 1)/(2*d)
    h1_order =  -1/2-(2*(k-1)+ 1)/(2*d)
    print("neuron num  & \t $\|u-u_n \|_{{L^2}}$ & \t order $O(n^{{{:.2f}}})$  & \t $ | u -u_n |_{{H^1}}$ & \t order $O(n^{{{:.2f}}})$  \\\ \hline \hline ".format(l2_order,h1_order))
    for i, item in enumerate(err_list):
        if i == 0: 
            print("{} \t\t & {:.6f} &\t\t * & \t\t {:.6f} & \t\t *  \\\ \hline  \n".format(neuron_nums[i],item, err_list2[i] ) )   
        else: 
            print("{} \t\t &  {:.3e} &  \t\t {:.2f} &  \t\t {:.3e} &  \t\t {:.2f} \\\ \hline  \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )


In [10]:
def MonteCarlo_Sobol_dDim_weights_points(M ,d = 4):
    Sob_integral = torch.quasirandom.SobolEngine(dimension =d, scramble= False, seed=None) 
    integration_points = Sob_integral.draw(M).double() 
    integration_points = integration_points.to(device)
    weights = torch.ones(M,1).to(device)/M 
    return weights, integration_points 

def generate_relu_dict4D(N_list):
    
    N = np.prod(N_list) 

    grid_indices = [np.linspace(0,1,N_item,endpoint=False) for N_item in N_list]
    grid = np.meshgrid(*grid_indices,indexing='ij')
    grid_coordinates = np.column_stack([axis.ravel() for axis in grid]) 
    samples = torch.tensor(grid_coordinates) 

    T =torch.tensor([[pi,0,0,0],[0,pi,0,0],[0,0,2*pi,0],[0,0,0,2*2]]) # 2 * sqrt(d)
    shift = torch.tensor([0,0,0,-2])
    samples = samples@T + shift 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)
    f5 = torch.zeros(N,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.cos(samples[:,2])
    f4[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.sin(samples[:,2])  
    f5[:,0] = samples[:,3]

    Wb_tensor = torch.cat([f1,f2,f3,f4,f5],1) # N x 4 
    return Wb_tensor

def PiecewiseGQ3D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

    """
    Parameters
    ----------
    target : 
        Target function 
    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature
    """

    # print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,x,indexing='ij')).reshape(3,-1).T
    weight_list = np.array(np.meshgrid(w,w,w,indexing='ij'))
    weights =   (weight_list[0]*weight_list[1]*weight_list[2]).ravel() 

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**3,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**3 /8 

    integration_points = torch.tile(gauss_pts,(Nx**3,1))
    # print("shape of integration_points", integration_points.size())
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(3,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**3)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,3)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 

    return long_weights.to(device), integration_points.to(device)


def generate_relu_dict4D_QMC(s,N0):
    # Sob = torch.quasirandom.SobolEngine(dimension =4, scramble= True, seed=None) 
    # samples = Sob.draw(N0).double() 

    # for i in range(s-1):
    #     samples = torch.cat([samples,Sob.draw(N0).double()],0)

    # Monte Carlo 
    samples = torch.rand(s*N0,4) 

    T =torch.tensor([[pi,0,0,0],[0,pi,0,0],[0,0,2*pi,0],[0,0,0,2*2]])
    shift = torch.tensor([0,0,0,-2])
    samples = samples@T + shift 

    f1 = torch.zeros(s*N0,1) 
    f2 = torch.zeros(s*N0,1)
    f3 = torch.zeros(s*N0,1)
    f4 = torch.zeros(s*N0,1)
    f5 = torch.zeros(s*N0,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.cos(samples[:,2])
    f4[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.sin(samples[:,2])  
    f5[:,0] = samples[:,3]

    Wb_tensor = torch.cat([f1,f2,f3,f4,f5],1) # N x 4 
    return Wb_tensor

def minimize_linear_layer_H1_explicit_assemble_efficient(model,alpha, target, g_N, weights, integration_points, w_bd, pts_bd, activation = 'relu',solver="direct" ):
    """ -div alpha grad u(x) + u = f 
    Parameters
    ----------
    model: 
        nn model
    alpha:
        alpha function
    target:
        rhs function f 
    pts_bd:
        integration points on the boundary, embdedded in the domain 
    """ 

    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    neuron_num = b.size(0) 
    dim = integration_points.size(1) 
    
    coef_alpha = alpha(integration_points) # alpha  
    if activation == 'relu':
        basis_value_col = F.relu(integration_points @ w.t()+ b)**(model.k)
        weighted_basis_value_col = basis_value_col * weights 
        jac = weighted_basis_value_col.t() @ basis_value_col  # mass matrix 
        rhs = weighted_basis_value_col.t() @ (target(integration_points)) 

        # Todo1: assemble the boundary condition term <g,v>_{\Gamma_N} 
        size_pts_bd = int(pts_bd.size(0)/(2*dim))
        if g_N != None:
            bcs_N = g_N(dim)
            for ii, g_ii in bcs_N:
                # pts_bd_ii = pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                weighted_g_N = -g_ii(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* w_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                basis_value_bd_col = F.relu(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
                rhs += basis_value_bd_col.t() @ weighted_g_N

                weighted_g_N = g_ii(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* w_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                basis_value_bd_col = F.relu(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
                rhs += basis_value_bd_col.t() @ weighted_g_N
        
        if model.k == 1:  
            for d in range(dim):
                basis_value_dxi_col = torch.heaviside(integration_points @ w.t()+ b, ZERO) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
            
        else: 
            for d in range(dim):
                basis_value_dxi_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha  
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dx_all_col = torch.stack([model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:] for d in range(dim)]) 
            # basis_value_dx_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[0:1,:]
            # basis_value_dy_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[1:2,:] 

    print("assembling the mass matrix time taken: ", time.time()-start_time) 


    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 


def OGANeumannReLU4D(my_model,alpha, target,g_N, u_exact, u_exact_grad, N_list,num_epochs,plot_freq, M, k =1, rand_deter = 'deter', linear_solver = "direct"): 
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    dim = 4 
    gw_expand, integration_points = MonteCarlo_Sobol_dDim_weights_points(M, d=dim)
    gw_expand = gw_expand.to(device)
    integration_points = integration_points.to(device)

    # define integration on the boundary 
    gw_expand_bd, integration_points_bd = PiecewiseGQ3D_weights_points(25, order = 2) 
    size_pts_bd = integration_points_bd.size(0) 
    gw_expand_bd_faces = torch.tile(gw_expand_bd,(2*dim,1))

    integration_points_bd_faces = torch.zeros(2*dim*integration_points_bd.size(0),dim).to(device)
    integration_points_bd_faces[0:size_pts_bd,0:1] = 0 
    integration_points_bd_faces[0:size_pts_bd,1:] = integration_points_bd[:]

    integration_points_bd_faces[size_pts_bd:size_pts_bd*2,0:1] = 1
    integration_points_bd_faces[size_pts_bd:size_pts_bd*2,1:] = integration_points_bd[:,:]

    integration_points_bd_faces[size_pts_bd*2:size_pts_bd*3,1:2] = 0 
    integration_points_bd_faces[size_pts_bd*2:size_pts_bd*3, 0:1] = integration_points_bd[:,0:1]
    integration_points_bd_faces[size_pts_bd*2:size_pts_bd*3,2:] =  integration_points_bd[:,1:]

    integration_points_bd_faces[size_pts_bd*3:size_pts_bd*4,1:2] = 1
    integration_points_bd_faces[size_pts_bd*3:size_pts_bd*4, 0:1] = integration_points_bd[:,0:1]
    integration_points_bd_faces[size_pts_bd*3:size_pts_bd*4,2:] =  integration_points_bd[:,1:]

    integration_points_bd_faces[size_pts_bd*4:size_pts_bd*5,2:3] = 0
    integration_points_bd_faces[size_pts_bd*4:size_pts_bd*5,0:2] = integration_points_bd[:,0:2]
    integration_points_bd_faces[size_pts_bd*4:size_pts_bd*5,3:] = integration_points_bd[:,2:]

    integration_points_bd_faces[size_pts_bd*5:size_pts_bd*6,2:3] = 1
    integration_points_bd_faces[size_pts_bd*5:size_pts_bd*6,0:2] = integration_points_bd[:,0:2]
    integration_points_bd_faces[size_pts_bd*5:size_pts_bd*6,3:] = integration_points_bd[:,2:]

    integration_points_bd_faces[size_pts_bd*6:size_pts_bd*7,3:4] = 0
    integration_points_bd_faces[size_pts_bd*6:size_pts_bd*7,0:3] = integration_points_bd[:,:]

    integration_points_bd_faces[size_pts_bd*7:size_pts_bd*8,3:4] = 1 
    integration_points_bd_faces[size_pts_bd*7:size_pts_bd*8,0:3] = integration_points_bd[:,:]

    err = torch.zeros(num_epochs+1)
    err_h10 = torch.zeros(num_epochs+1).to(device) 
    if my_model == None: 
        func_values = - target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []
    else: 
        func_values = - (target(integration_points) - my_model(integration_points).detach())
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done
    func_values_sqrd = func_values*func_values
    # print(func_values_sqrd.size())
    # print(gw_expand.size()) 
    err[0]= torch.sum(func_values_sqrd*gw_expand)**0.5
    ## h1 seminorm 
    if u_exact_grad != None:
        u_grad = u_exact_grad() 
        for grad_i in u_grad: 
            err_h10[0] += torch.sum((grad_i(integration_points))**2 * gw_expand)**0.5
    
    start_time = time.time()
    solver = linear_solver

    N0 = np.prod(N_list)
    if rand_deter == 'deter':
        relu_dict_parameters = generate_relu_dict4D(N_list).to(device)
    print("using linear solver: ",solver)

    for i in range(num_epochs): 
        print("epoch: ",i+1, end = '\t')
        if rand_deter == 'rand':
            relu_dict_parameters = generate_relu_dict4D_QMC(1,N0).to(device) 
        if num_neuron == 0: 
            func_values = - target(integration_points)
        else: 
            func_values = - target(integration_points) + my_model(integration_points).detach()

        weight_func_values = func_values*gw_expand  
        basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[:,0:4].T ) - relu_dict_parameters[:,4])**k).T # uses broadcasting, # dimension 4 
        output = torch.matmul(basis_values,weight_func_values) #
        
        # grad u part
        alpha_coef = alpha(integration_points) # alpha 
        if my_model!= None:
            if k == 1:  
                derivative_part = torch.heaviside(integration_points @ (relu_dict_parameters[:,0:4].T) - relu_dict_parameters[:,4], ZERO) # dimension 4 
                derivative_part *= alpha_coef # alpha 
                for dx_i in range(dim): 
                    weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,:]) *gw_expand   
                    dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                    output += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

            else:  
                derivative_part = k * F.relu(integration_points @ (relu_dict_parameters[:,0:4].T) - relu_dict_parameters[:,4])**(k-1) # dimension 4 
                derivative_part *= alpha_coef # alpha
                for dx_i in range(dim): 
                    weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,:]) * gw_expand    
                    dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                    output += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 
        
        #Todo2 boundary condition term -<g,v>_{\Gamma_N}  
        if g_N != None:
            bcs_N = g_N(dim) 
            for ii, g_ii in bcs_N: 
                
                weighted_g_N = -g_ii(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* gw_expand_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:],relu_dict_parameters[:,0:4].T ) - relu_dict_parameters[:,4])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)
                
                weighted_g_N = g_ii(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* gw_expand_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:],relu_dict_parameters[:,0:4].T ) - relu_dict_parameters[:,4])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)

        
        # output = torch.abs(torch.matmul(basis_values,weight_func_values)) # 
        output = torch.abs(output)
        neuron_index = torch.argmax(output.flatten())
        
        # print(neuron_index)
        list_w.append(relu_dict_parameters[neuron_index,0:4]) # dimension 4 
        list_b.append(-relu_dict_parameters[neuron_index,4])
        num_neuron += 1
        my_model = model(4,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo Done 
#         alpha = None 
        sol = minimize_linear_layer_H1_explicit_assemble_efficient(my_model,alpha, target, g_N, gw_expand, integration_points,gw_expand_bd_faces, integration_points_bd_faces,activation = 'relu',solver = solver)

        my_model.fc2.weight.data[0,:] = sol[:]

        model_values = my_model(integration_points).detach()
        # L2 error ||u - u_n||
        diff_values_sqrd = (u_exact(integration_points) - model_values)**2 
        err[i+1]= torch.sum(diff_values_sqrd*gw_expand)**0.5

        # H10 error || grad(u) - grad(u_n) ||
        if u_exact_grad != None:
            for ii, grad_i in enumerate(u_grad):  
                my_model_dxi = my_model.evaluate_derivative(integration_points,ii+1).detach()  
                err_h10[i+1] += torch.sum((grad_i(integration_points) - my_model_dxi)**2 * gw_expand)**0.5

    print("time taken: ",time.time() - start_time)
    return err, err_h10.cpu(), my_model




In the following tests, we compare using deterministic dictionaries with using random dictionary for the following three target functions. 

- $\sin(\pi x_1) \sin(\pi x_2)$ 
- $\sin(4\pi x_1) \sin(8\pi x_2)$ 
- Gabor function 

In [15]:
def u_exact(x):
    d = 4 
    cn =   7.03/d 
    return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True))  

def u_exact_grad():
    d = 4 
    cn = 7.03/d
    def make_grad_i(i):
        def grad_i(x):
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return grad_i 
    
    u_grad=[] 
    for i in range(d):
        u_grad.append(make_grad_i(i))
    return u_grad


def alpha(x): 
    return torch.ones(x.size(0),1).to(device)
    # return 0.5 * torch.sin(6 * pi*x[:,0:1]) + 1. 

def target(x):
    d = 4 
    cn =   7.03/d 
    z = torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    return z* ( -torch.sum(  (2 *cn**2 * (x - 0.5))**2 - 2*cn**2 ,dim = 1, keepdim = True) +1)

def g_N(dim):
    def make_g(i):
        def g_i(x):
            d = 4 
            cn = 7.03 / d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return g_i

    bcs_N = []
    for i in range(dim):
        bcs_N.append((i, make_g(i)))
    
    return bcs_N

dim = 4 
function_name = "gaussian" 
filename_write = "data/4DOGA-{}-order.txt".format(function_name)
M = 500000  
f_write = open(filename_write, "a")
f_write.write("Numerical Integration MC points: {} \n".format(M))
f_write.close() 
save = True 
write2file = True
rand_deter = 'rand'


for N_list in [[2**2,2**2,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 

    f_write = open(filename_write, "a")
    my_model = None   
    exponent = 9  
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    relu_k = 4 
    err_QMC2, err_h10, my_model = OGANeumannReLU4D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, M, k = relu_k, rand_deter= rand_deter, linear_solver = "direct")
    
    if save: 
        folder = 'data/'
        filename = folder + 'err_OGA_4D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data/'
        filename = folder + 'model_OGA_4D_{}_neuron_{}_N_{}_randomized.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)
    show_convergence_order2(err_QMC2,err_h10,exponent,N,filename_write,write2file = write2file)
    show_convergence_order_latex2(err_QMC2,err_h10,exponent,k=relu_k,d = dim)


using linear solver:  direct
epoch:  1	assembling the mass matrix time taken:  0.002882719039916992
solving Ax = b time taken:  0.000209808349609375
epoch:  2	assembling the mass matrix time taken:  0.003682374954223633
solving Ax = b time taken:  0.02903151512145996
epoch:  3	assembling the mass matrix time taken:  0.002955913543701172
solving Ax = b time taken:  0.0007290840148925781
epoch:  4	assembling the mass matrix time taken:  0.0030450820922851562
solving Ax = b time taken:  0.0008268356323242188
epoch:  5	assembling the mass matrix time taken:  0.0029900074005126953
solving Ax = b time taken:  0.0010879039764404297
epoch:  6	assembling the mass matrix time taken:  0.003939628601074219
solving Ax = b time taken:  0.0012705326080322266
epoch:  7	assembling the mass matrix time taken:  0.002982616424560547
solving Ax = b time taken:  0.0014963150024414062
epoch:  8	assembling the mass matrix time taken:  0.0035283565521240234
solving Ax = b time taken:  0.001605987548828125
epoc

solving Ax = b time taken:  0.017873764038085938
epoch:  69	assembling the mass matrix time taken:  0.002904176712036133
solving Ax = b time taken:  0.0184478759765625
epoch:  70	assembling the mass matrix time taken:  0.0031137466430664062
solving Ax = b time taken:  0.01894831657409668
epoch:  71	assembling the mass matrix time taken:  0.0029048919677734375
solving Ax = b time taken:  0.018827199935913086
epoch:  72	assembling the mass matrix time taken:  0.0038046836853027344
solving Ax = b time taken:  0.01790595054626465
epoch:  73	assembling the mass matrix time taken:  0.002897500991821289
solving Ax = b time taken:  0.019748210906982422
epoch:  74	assembling the mass matrix time taken:  0.0030260086059570312
solving Ax = b time taken:  0.020589113235473633
epoch:  75	assembling the mass matrix time taken:  0.0029213428497314453
solving Ax = b time taken:  0.020639419555664062
epoch:  76	assembling the mass matrix time taken:  0.0030210018157958984
solving Ax = b time taken:  0.

epoch:  204	assembling the mass matrix time taken:  0.0029823780059814453
solving Ax = b time taken:  0.06055784225463867
epoch:  205	assembling the mass matrix time taken:  0.002912282943725586
solving Ax = b time taken:  0.07292699813842773
epoch:  206	assembling the mass matrix time taken:  0.002944469451904297
solving Ax = b time taken:  0.061168670654296875
epoch:  207	assembling the mass matrix time taken:  0.0028481483459472656
solving Ax = b time taken:  0.0732576847076416
epoch:  208	assembling the mass matrix time taken:  0.0029191970825195312
solving Ax = b time taken:  0.061415672302246094
epoch:  209	assembling the mass matrix time taken:  0.002913236618041992
solving Ax = b time taken:  0.07345843315124512
epoch:  210	assembling the mass matrix time taken:  0.003117084503173828
solving Ax = b time taken:  0.06318831443786621
epoch:  211	assembling the mass matrix time taken:  0.0028696060180664062
solving Ax = b time taken:  0.07346749305725098
epoch:  212	assembling the 

epoch:  272	assembling the mass matrix time taken:  0.0030748844146728516
solving Ax = b time taken:  0.08507180213928223
epoch:  273	assembling the mass matrix time taken:  0.0029408931732177734
solving Ax = b time taken:  0.10663080215454102
epoch:  274	assembling the mass matrix time taken:  0.0030422210693359375
solving Ax = b time taken:  0.08751893043518066
epoch:  275	assembling the mass matrix time taken:  0.0029611587524414062
solving Ax = b time taken:  0.1071772575378418
epoch:  276	assembling the mass matrix time taken:  0.0031290054321289062
solving Ax = b time taken:  0.08605432510375977
epoch:  277	assembling the mass matrix time taken:  0.002936124801635742
solving Ax = b time taken:  0.10754084587097168
epoch:  278	assembling the mass matrix time taken:  0.0031194686889648438
solving Ax = b time taken:  0.08812808990478516
epoch:  279	assembling the mass matrix time taken:  0.003152132034301758
solving Ax = b time taken:  0.10803055763244629
epoch:  280	assembling the 

epoch:  340	assembling the mass matrix time taken:  0.003019571304321289
solving Ax = b time taken:  0.1117863655090332
epoch:  341	assembling the mass matrix time taken:  0.002967357635498047
solving Ax = b time taken:  0.12160134315490723
epoch:  342	assembling the mass matrix time taken:  0.003018617630004883
solving Ax = b time taken:  0.11345171928405762
epoch:  343	assembling the mass matrix time taken:  0.003153562545776367
solving Ax = b time taken:  0.12209558486938477
epoch:  344	assembling the mass matrix time taken:  0.0029997825622558594
solving Ax = b time taken:  0.11236691474914551
epoch:  345	assembling the mass matrix time taken:  0.0029718875885009766
solving Ax = b time taken:  0.1227574348449707
epoch:  346	assembling the mass matrix time taken:  0.0030202865600585938
solving Ax = b time taken:  0.11456656455993652
epoch:  347	assembling the mass matrix time taken:  0.0029850006103515625
solving Ax = b time taken:  0.12281537055969238
epoch:  348	assembling the mas

epoch:  408	assembling the mass matrix time taken:  0.0030035972595214844
solving Ax = b time taken:  0.14913225173950195
epoch:  409	assembling the mass matrix time taken:  0.002951383590698242
solving Ax = b time taken:  0.14328455924987793
epoch:  410	assembling the mass matrix time taken:  0.003084897994995117
solving Ax = b time taken:  0.14583420753479004
epoch:  411	assembling the mass matrix time taken:  0.0028562545776367188
solving Ax = b time taken:  0.1437702178955078
epoch:  412	assembling the mass matrix time taken:  0.0029866695404052734
solving Ax = b time taken:  0.14441514015197754
epoch:  413	assembling the mass matrix time taken:  0.0029456615447998047
solving Ax = b time taken:  0.14415240287780762
epoch:  414	assembling the mass matrix time taken:  0.0030393600463867188
solving Ax = b time taken:  0.1467418670654297
epoch:  415	assembling the mass matrix time taken:  0.002885580062866211
solving Ax = b time taken:  0.14458465576171875
epoch:  416	assembling the ma

epoch:  476	assembling the mass matrix time taken:  0.0030977725982666016
solving Ax = b time taken:  0.17630958557128906
epoch:  477	assembling the mass matrix time taken:  0.0029299259185791016
solving Ax = b time taken:  0.1761155128479004
epoch:  478	assembling the mass matrix time taken:  0.0029137134552001953
solving Ax = b time taken:  0.17699408531188965
epoch:  479	assembling the mass matrix time taken:  0.002920389175415039
solving Ax = b time taken:  0.17639398574829102
epoch:  480	assembling the mass matrix time taken:  0.003180980682373047
solving Ax = b time taken:  0.17546534538269043
epoch:  481	assembling the mass matrix time taken:  0.0028803348541259766
solving Ax = b time taken:  0.17678117752075195
epoch:  482	assembling the mass matrix time taken:  0.0028467178344726562
solving Ax = b time taken:  0.1771996021270752
epoch:  483	assembling the mass matrix time taken:  0.0029115676879882812
solving Ax = b time taken:  0.17748451232910156
epoch:  484	assembling the m

## $\cos(\pi x_1) \cos( \pi x_2) \cos( \pi x_3) \cos( \pi x_4)$  