In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu"    


torch.set_default_dtype(torch.float64)
pi = torch.tensor(np.pi,dtype=torch.float64)
zero = torch.tensor([0.]).to(device)
torch.set_printoptions(precision=6)

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    def evaluate_derivative(self, x, i):
        if self.k == 1:
            u1 = self.fc2(torch.heaviside(self.fc1(x),zero) * self.fc1.weight.t()[i-1:i,:] )
        else:
            u1 = self.fc2(self.k*F.relu(self.fc1(x))**(self.k-1) *self.fc1.weight.t()[i-1:i,:] )  
        return u1

def plot_2D(f): 
    
    Nx = 400
    Ny = 400 
    xs = np.linspace(0, 1, Nx)
    ys = np.linspace(0, 1, Ny)
    x, y = np.meshgrid(xs, ys, indexing='xy')
    xy_comb = np.stack((x.flatten(),y.flatten())).T
    xy_comb = torch.tensor(xy_comb)
    z = f(xy_comb).reshape(Nx,Ny)
    z = z.detach().numpy()
    plt.figure(dpi=200)
    ax = plt.axes(projection='3d')
    ax.plot_surface(x , y , z )

    plt.show()

def plot_subdomains(my_model):
    x_coord =torch.linspace(0,1,200)
    wi = my_model.fc1.weight.data
    bi = my_model.fc1.bias.data 
    for i, bias in enumerate(bi):  
        if wi[i,1] !=0: 
            plt.plot(x_coord, - wi[i,0]/wi[i,1]*x_coord - bias/wi[i,1])
        else: 
            plt.plot(x_coord,  - bias/wi[i,0]*torch.ones(x_coord.size()))

    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.legend()
    plt.show()
    return 0   

def adjust_neuron_position(my_model, dims = 3):

    def create_mesh_grid(dims, pts):
        mesh = torch.tensor(list(itertools.product(pts,repeat=dims)))
        vertices = mesh.reshape(len(pts) ** dims, -1) 
        return vertices
    counter = 0 
    # positions = torch.tensor([[0.,0.],[0.,1.],[1.,1.],[1.,0.]])
    pts = torch.tensor([0.,1.])
    positions = create_mesh_grid(dims,pts) 
    neuron_num = my_model.fc1.bias.size(0)
    for i in range(neuron_num): 
        w = my_model.fc1.weight.data[i:i+1,:]
        b = my_model.fc1.bias.data[i]
    #     print(w,b)
        values = torch.matmul(positions,w.T) # + b
        left_end = - torch.max(values)
        right_end = - torch.min(values)
        offset = (right_end - left_end)/50
        if b <= left_end + offset/2 : 
            b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
            my_model.fc1.bias.data[i] = b 
        if b >= right_end - offset/2 :
            if counter < (dims+1):
#                 print("here")
                counter += 1
            else: # (d + 1) or more 
                b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
                my_model.fc1.bias.data[i] = b 
    return my_model



In [2]:
def PiecewiseGQ2D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

#     print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,indexing='ij')).reshape(2,-1).T
    weights =  (w*w[:,None]).ravel()

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**2,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**2 /4 

    integration_points = torch.tile(gauss_pts,(Nx**2,1))
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(2,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**2)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,2)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 
#     print(integration_points.size())
    # func_values = integrand2_torch(integration_points)
    return long_weights.to(device), integration_points.to(device)


def PiecewiseGQ3D_weights_points(Nx, order): 
    """ A slight modification of PiecewiseGQ2D function that only needs the weights and integration points.
    Parameters
    ----------

    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature

    Returns
    -------
    long_weights: torch.tensor
    integration_points: torch.tensor
    """

    """
    Parameters
    ----------
    target : 
        Target function 
    Nx: int 
        number of intervals along the dimension. No Ny, assume Nx = Ny
    order: int 
        order of the Gauss Quadrature
    """

    # print("order: ",order )
    x, w = np.polynomial.legendre.leggauss(order)
    gauss_pts = np.array(np.meshgrid(x,x,x,indexing='ij')).reshape(3,-1).T
    weight_list = np.array(np.meshgrid(w,w,w,indexing='ij'))
    weights =   (weight_list[0]*weight_list[1]*weight_list[2]).ravel() 

    gauss_pts =torch.tensor(gauss_pts)
    weights = torch.tensor(weights)

    h = 1/Nx # 100 intervals 
    long_weights =  torch.tile(weights,(Nx**3,1))
    long_weights = long_weights.reshape(-1,1)
    long_weights = long_weights * h**3 /8 

    integration_points = torch.tile(gauss_pts,(Nx**3,1))
    # print("shape of integration_points", integration_points.size())
    scale_factor = h/2 
    integration_points = scale_factor * integration_points

    index = np.arange(1,Nx+1)-0.5
    ordered_pairs = np.array(np.meshgrid(index,index,index,indexing='ij'))
    ordered_pairs = ordered_pairs.reshape(3,-1).T

    # print(ordered_pairs)
    # print()
    ordered_pairs = torch.tensor(ordered_pairs)
    # print(ordered_pairs.size())
    ordered_pairs = torch.tile(ordered_pairs, (1,order**3)) # number of GQ points
    # print(ordered_pairs)

    ordered_pairs =  ordered_pairs.reshape(-1,3)
    # print(ordered_pairs)
    translation = ordered_pairs*h 
    # print(translation)

    integration_points = integration_points + translation 

    return long_weights.to(device), integration_points.to(device)


def generate_relu_dict3D(N_list):
    N1 = N_list[0]
    N2 = N_list[1]
    N3 = N_list[2]
    
    N = N1*N2*N3 
    theta1 = np.linspace(0, pi, N1, endpoint= True).reshape(N1,1)
    theta2 = np.linspace(0, 2*pi, N2, endpoint= False).reshape(N2,1)
    b = np.linspace(-1.732, 1.732, N3,endpoint=False).reshape(N3,1) # threshold: 3**0.5  
    coord3 = np.array(np.meshgrid(theta1,theta2,b,indexing='ij'))
    coord3 = coord3.reshape(3,-1).T # N1*N2*N3 x 3. coordinates for the grid points 
    coord3 = torch.tensor(coord3) 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)

    f1[:,0] = torch.cos(coord3[:,0]) 
    f2[:,0] = torch.sin(coord3[:,0]) * torch.cos(coord3[:,1])
    f3[:,0] = torch.sin(coord3[:,0]) * torch.sin(coord3[:,1])
    f4[:,0] = coord3[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor


def generate_relu_dict3D_QMC(s,N0):
#     Sob = torch.quasirandom.SobolEngine(dimension =3, scramble= True, seed=None) 
#     samples = Sob.draw(N0).double() 

#     for i in range(s-1):
#         samples = torch.cat([samples,Sob.draw(N0).double()],0)

    # Monte Carlo 
    samples = torch.rand(s*N0,3) 
    T =torch.tensor([[pi,0,0],[0,2*pi,0],[0,0,1.732*2]])
    shift = torch.tensor([0,0,-1.732])
    samples = samples@T + shift 

    f1 = torch.zeros(s*N0,1) 
    f2 = torch.zeros(s*N0,1)
    f3 = torch.zeros(s*N0,1)
    f4 = torch.zeros(s*N0,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1])
    f4[:,0] = samples[:,2] 

    Wb_tensor = torch.cat([f1,f2,f3,f4],1) # N x 4 
    return Wb_tensor


def minimize_linear_layer_H1_explicit_assemble_efficient(model,alpha, target, g_N, weights, integration_points, w_bd, pts_bd, activation = 'relu',solver="direct" ,memory=2**29):
    """ -div alpha grad u(x) + u = f 
    Parameters
    ----------
    model: 
        nn model
    alpha:
        alpha function
    target:
        rhs function f 
    pts_bd:
        integration points on the boundary, embdedded in the domain 
    """ 
    zero = torch.tensor([0.]).to(device)
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    neuron_num = b.size(0) 
    dim = integration_points.size(1) 
    M = integration_points.size(0)
    coef_alpha = alpha(integration_points) # alpha  

    
    total_size = neuron_num * M # memory, number of floating numbers 
    print('total size: {} {} = {}'.format(neuron_num,M,total_size))
    num_batch = total_size//memory + 1 # divide according to memory
    print("num batches: ",num_batch)
    batch_size = M//num_batch
    jac = torch.zeros(b.size(0),b.size(0)).to(device)
    rhs = torch.zeros(b.size(0),1).to(device)
    
    for j in range(0,M,batch_size): 
        end_index = j + batch_size
        basis_value_col = F.relu(integration_points[j:end_index] @ w.t()+ b)**(model.k) 
        weighted_basis_value_col = basis_value_col * weights[j:end_index] 
        jac += weighted_basis_value_col.t() @ basis_value_col 
        rhs += weighted_basis_value_col.t() @ (target(integration_points[j:end_index,:])) 
            
#     basis_value_col = F.relu(integration_points @ w.t()+ b)**(model.k)
#     weighted_basis_value_col = basis_value_col * weights 

#     jac = weighted_basis_value_col.t() @ basis_value_col  # mass matrix 
#     rhs = weighted_basis_value_col.t() @ (target(integration_points)) 

    # Assemble the boundary condition term <g,v>_{\Gamma_N} 
    size_pts_bd = int(pts_bd.size(0)/(2*dim))
    if g_N != None: # no batch operations for the boundary part, since it is only rhs on the boundary 
        bcs_N = g_N(dim)
        for ii, g_ii in bcs_N:
            # pts_bd_ii = pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            weighted_g_N = -g_ii(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* w_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N

            weighted_g_N = g_ii(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* w_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
            basis_value_bd_col = F.relu(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
            rhs += basis_value_bd_col.t() @ weighted_g_N

    # Stiffness matrix term in the jacobian 
     
    for d in range(dim):
        if model.k == 1:  
            for j in range(0,M,batch_size):  
                end_index = j + batch_size 
                basis_value_dxi_col = torch.heaviside(integration_points[j:end_index] @ w.t()+ b, zero) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_index] * coef_alpha[j:end_index] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = torch.heaviside(integration_points @ w.t()+ b, zero) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha 
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

        else:
            for j in range(0,M,batch_size):  
                end_index = j + batch_size 
                basis_value_dxi_col = model.k * F.relu(integration_points[j:end_index] @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
                weighted_basis_value_dx_col = basis_value_dxi_col * weights[j:end_index] * coef_alpha[j:end_index] 
                jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
#             basis_value_dxi_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
#             weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha  
#             jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 

    print("assembling the mass matrix time taken: ", time.time()-start_time) 


    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 


# def minimize_linear_layer_H1_explicit_assemble_efficient(model,alpha, target, g_N, weights, integration_points, w_bd, pts_bd, activation = 'relu',solver="direct" ):
#     """ -div alpha grad u(x) + u = f 
#     Parameters
#     ----------
#     model: 
#         nn model
#     alpha:
#         alpha function
#     target:
#         rhs function f 
#     pts_bd:
#         integration points on the boundary, embdedded in the domain 
#     """ 
#     zero = torch.tensor([0.]).to(device)
#     start_time = time.time() 
#     w = model.fc1.weight.data 
#     b = model.fc1.bias.data 
#     neuron_num = b.size(0) 
#     dim = integration_points.size(1) 
#     M = integration_points.size(0)
#     coef_alpha = alpha(integration_points) # alpha  
    
#     if activation == 'relu':
#         basis_value_col = F.relu(integration_points @ w.t()+ b)**(model.k)
#         weighted_basis_value_col = basis_value_col * weights 
#         jac = weighted_basis_value_col.t() @ basis_value_col  # mass matrix 
#         rhs = weighted_basis_value_col.t() @ (target(integration_points)) 

#         # Todo1: assemble the boundary condition term <g,v>_{\Gamma_N} 
#         size_pts_bd = int(pts_bd.size(0)/(2*dim))
#         if g_N != None:
#             bcs_N = g_N(dim)
#             for ii, g_ii in bcs_N:
#                 # pts_bd_ii = pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
#                 weighted_g_N = -g_ii(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* w_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
#                 basis_value_bd_col = F.relu(pts_bd[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:] @ w.t()+ b)**(model.k)
#                 rhs += basis_value_bd_col.t() @ weighted_g_N

#                 weighted_g_N = g_ii(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* w_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
#                 basis_value_bd_col = F.relu(pts_bd[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:] @ w.t()+ b)**(model.k)
#                 rhs += basis_value_bd_col.t() @ weighted_g_N
        
#         if model.k == 1:  
#             for d in range(dim):
#                 basis_value_dxi_col = torch.heaviside(integration_points @ w.t()+ b, zero) * w.t()[d:d+1,:]
#                 weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha 
#                 jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
# #             basis_value_dx_all_col = torch.stack([torch.heaviside(integration_points @ w.t()+ b, zero) * w.t()[d:d+1,:] for d in range(dim)])
            
#         else: 
#             for d in range(dim):
#                 basis_value_dxi_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:]
#                 weighted_basis_value_dx_col = basis_value_dxi_col * weights * coef_alpha  
#                 jac += weighted_basis_value_dx_col.t() @ basis_value_dxi_col 
# #             basis_value_dx_all_col = torch.stack([model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[d:d+1,:] for d in range(dim)]) 
#             # basis_value_dx_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[0:1,:]
#             # basis_value_dy_col = model.k * F.relu(integration_points @ w.t()+ b)**(model.k-1) * w.t()[1:2,:] 

#     print("assembling the mass matrix time taken: ", time.time()-start_time) 


#     start_time = time.time()    
#     if solver == "cg": 
#         sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
#         sol = torch.tensor(sol).view(1,-1)
#     elif solver == "direct": 
# #         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
#         sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
#     elif solver == "ls":
#         sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
#         # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
#     print("solving Ax = b time taken: ", time.time()-start_time)
#     return sol 


def OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact, u_exact_grad, N_list,num_epochs,plot_freq, M, k =1, rand_deter = 'deter', linear_solver = "direct",memory = 2**29): 
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    dim = 3 
    gw_expand, integration_points = PiecewiseGQ3D_weights_points(50, order = 3) 
    gw_expand = gw_expand.to(device)
    integration_points = integration_points.to(device)

    # define integration on the boundary 
    gw_expand_bd, integration_points_bd = PiecewiseGQ2D_weights_points(50, order = 3) 
    size_pts_bd = integration_points_bd.size(0) 
    gw_expand_bd_faces = torch.tile(gw_expand_bd,(2*dim,1))

    integration_points_bd_faces = torch.zeros(2*dim*integration_points_bd.size(0),dim).to(device)
    for ind in range(dim): 
        integration_points_bd_faces[2 *ind * size_pts_bd :(2 *ind +1) * size_pts_bd,ind:ind+1] = 0 
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,:ind] = integration_points_bd[:,:ind]
        integration_points_bd_faces[(2 *ind)*size_pts_bd :(2 * ind +1) * size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind:ind+1] = 1
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,:ind] = integration_points_bd[:,:ind]        
        integration_points_bd_faces[(2 *ind +1) * size_pts_bd:(2 *ind +2)*size_pts_bd,ind+1:] = integration_points_bd[:,ind:]

    err = torch.zeros(num_epochs+1)
    err_h10 = torch.zeros(num_epochs+1).to(device) 
    if my_model == None: 
        func_values = - target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []
    else: 
        func_values = - (target(integration_points) - my_model(integration_points).detach())
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done
    func_values_sqrd = func_values*func_values
    # print(func_values_sqrd.size())
    # print(gw_expand.size()) 
    err[0]= torch.sum(func_values_sqrd*gw_expand)**0.5
    ## h1 seminorm 
    if u_exact_grad != None:
        u_grad = u_exact_grad() 
        for grad_i in u_grad: 
            err_h10[0] += torch.sum((grad_i(integration_points))**2 * gw_expand)**0.5
    
    start_time = time.time()
    solver = linear_solver

    N0 = np.prod(N_list)
    if rand_deter == 'deter':
        relu_dict_parameters = generate_relu_dict3D(N_list).to(device)
    print("using linear solver: ",solver)
    M2 = integration_points.size(0) # add
    for i in range(num_epochs): 
        start_time = time.time()
        print("epoch: ",i+1, end = '\t')
        if rand_deter == 'rand':
            relu_dict_parameters = generate_relu_dict3D_QMC(1,N0).to(device) 
        if num_neuron == 0: 
            func_values = - target(integration_points)
        else: 
            func_values = - target(integration_points) + my_model(integration_points).detach()

        weight_func_values = func_values*gw_expand  

        ### ======================= 
        total_size = M2 * N0 
        num_batch = total_size//memory + 1 
        batch_size = N0//num_batch

        output = torch.zeros(N0,1).to(device)
        print("argmax batch num, ", num_batch) 
        
        for j in range(0,N0,batch_size):  
            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[j:end_index,0:dim].T ) - relu_dict_parameters[j:end_index,dim])**k).T # uses broadcasting    
            output[j:end_index,:]  = torch.matmul(basis_values_batch,weight_func_values)[:,:]
        ### ======================= 
        
#         basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T # uses broadcasting, # dimension 4 
#         output = torch.matmul(basis_values,weight_func_values) #

        ##========
        # grad u part
        alpha_coef = alpha(integration_points) # alpha 
        if my_model!= None:

            if k == 1:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = torch.heaviside(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim], zero) # dimension 4 
                    derivative_part *= alpha_coef # alpha 
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) *gw_expand   
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

            else:  
                for j in range(0,N0,batch_size):  
                    end_index = j + batch_size 
                    derivative_part = k * F.relu(integration_points @ (relu_dict_parameters[j:end_index,0:dim].T) - relu_dict_parameters[j:end_index,dim])**(k-1) # dimension 4 
                    derivative_part *= alpha_coef # alpha
                    for dx_i in range(dim): 

                        weight_dbasis_values_dxi =  (derivative_part * relu_dict_parameters.t()[dx_i:dx_i+1,j:end_index]) * gw_expand    
                        dmy_model_dxi = my_model.evaluate_derivative(integration_points,dx_i+1).detach()
                        output[j:end_index,:] += torch.matmul(weight_dbasis_values_dxi.t(), dmy_model_dxi) 

        ##========
    
        #Boundary condition term -<g,v>_{\Gamma_N}  
        if g_N != None:
            bcs_N = g_N(dim) 
            for ii, g_ii in bcs_N: 
                
                weighted_g_N = -g_ii(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:])* gw_expand_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[2*ii*size_pts_bd:(2*ii+1)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)
                
                weighted_g_N = g_ii(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:])* gw_expand_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:]
                basis_values_bd_faces = (F.relu( torch.matmul(integration_points_bd_faces[(2*ii+1)*size_pts_bd:(2*ii+2)*size_pts_bd,:],relu_dict_parameters[:,0:dim].T ) - relu_dict_parameters[:,dim])**k).T
                output -= torch.matmul(basis_values_bd_faces,weighted_g_N)

        
        # output = torch.abs(torch.matmul(basis_values,weight_func_values)) # 
        output = torch.abs(output)
        neuron_index = torch.argmax(output.flatten())
        print("argmax time taken, ", time.time() - start_time)
        
        
        # print(neuron_index)
        list_w.append(relu_dict_parameters[neuron_index,0:dim]) # dimension 4 
        list_b.append(-relu_dict_parameters[neuron_index,dim])
        num_neuron += 1
        my_model = model(dim,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo Done 
#         alpha = None 
        sol = minimize_linear_layer_H1_explicit_assemble_efficient(my_model,alpha, target, g_N, gw_expand, integration_points,gw_expand_bd_faces, integration_points_bd_faces,activation = 'relu',solver = solver)

        my_model.fc2.weight.data[0,:] = sol[:]

        model_values = my_model(integration_points).detach()
        # L2 error ||u - u_n||
        diff_values_sqrd = (u_exact(integration_points) - model_values)**2 
        err[i+1]= torch.sum(diff_values_sqrd*gw_expand)**0.5

        # H10 error || grad(u) - grad(u_n) ||
        if u_exact_grad != None:
            for ind, grad_i in enumerate(u_grad):  
                my_model_dxi = my_model.evaluate_derivative(integration_points,ind+1).detach() 
                err_h10[i+1] += torch.sum((grad_i(integration_points) - my_model_dxi)**2 * gw_expand)**0.5
        print("l2 error {:.6f}, h1 error {:.6f}".format(err[i+1],err_h10[i+1]))
    print("time taken: ",time.time() - start_time)
    return err, err_h10.cpu(), my_model




In [3]:

def u_exact(x):
    return torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
def alpha(x): 
    return torch.ones(x.size(0),1).to(device)

def u_exact_grad():
    d = 3 

    def grad_1(x):
        return - pi* torch.sin(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])   
    def grad_2(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.sin( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
    def grad_3(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.sin(pi*x[:,2:3])   

    
    u_grad=[grad_1, grad_2,grad_3] 

    return u_grad

def target(x):
    z = (  3 * (pi)**2 + 1)*torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) 
    return z 

g_N = None 


function_name = "cospix" 
filename_write = "data-neumann/3DOGA-{}-order.txt".format(function_name)
f_write = open(filename_write, "a")
f_write.write("\n")
f_write.close() 
save = False 
for N_list in [[2**8]]: # ,[2**6,2**6],[2**7,2**7] 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 
    # Nx = 50   
    # order = 3   
    M = 300000  
    exponent = 7
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, M, k = 3, rand_deter= 'rand', linear_solver = "direct")
    
    if save: 
        folder = 'data-neumann/'
        filename = folder + 'err_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data-neumann/'
        filename = folder + 'model_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)

    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_QMC2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    f_write.write('deterministic dictionary size: {}\n'.format(N))
    f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )   
            f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    f_write.write("\n")
    f_write.close()


using linear solver:  direct
epoch:  1	argmax batch num,  2
argmax time taken,  0.3547232151031494
total size: 1 3375000 = 3375000
num batches:  1
assembling the mass matrix time taken:  0.003964424133300781
solving Ax = b time taken:  0.07999873161315918
l2 error 0.353918, h1 error 3.331313
epoch:  2	argmax batch num,  2
argmax time taken,  0.04062366485595703
total size: 2 3375000 = 6750000
num batches:  1
assembling the mass matrix time taken:  0.010634183883666992
solving Ax = b time taken:  0.008073091506958008
l2 error 0.353592, h1 error 3.330594
epoch:  3	argmax batch num,  2
argmax time taken,  0.04001593589782715
total size: 3 3375000 = 10125000
num batches:  1
assembling the mass matrix time taken:  0.004316091537475586
solving Ax = b time taken:  0.007856369018554688
l2 error 0.353881, h1 error 3.326875
epoch:  4	argmax batch num,  2
argmax time taken,  0.03920435905456543
total size: 4 3375000 = 13500000
num batches:  1
assembling the mass matrix time taken:  0.004181861877

total size: 32 3375000 = 108000000
num batches:  1
assembling the mass matrix time taken:  0.010801315307617188
solving Ax = b time taken:  0.03394365310668945
l2 error 0.031748, h1 error 0.639582
epoch:  33	argmax batch num,  2
argmax time taken,  0.04247140884399414
total size: 33 3375000 = 111375000
num batches:  1
assembling the mass matrix time taken:  0.0041446685791015625
solving Ax = b time taken:  0.03676581382751465
l2 error 0.030042, h1 error 0.612379
epoch:  34	argmax batch num,  2
argmax time taken,  0.04286694526672363
total size: 34 3375000 = 114750000
num batches:  1
assembling the mass matrix time taken:  0.011685848236083984
solving Ax = b time taken:  0.037505388259887695
l2 error 0.028852, h1 error 0.597848
epoch:  35	argmax batch num,  2
argmax time taken,  0.04281306266784668
total size: 35 3375000 = 118125000
num batches:  1
assembling the mass matrix time taken:  0.010434150695800781
solving Ax = b time taken:  0.038402557373046875
l2 error 0.026507, h1 error 0.

total size: 63 3375000 = 212625000
num batches:  1
assembling the mass matrix time taken:  0.00475311279296875
solving Ax = b time taken:  0.06406116485595703
l2 error 0.005209, h1 error 0.154808
epoch:  64	argmax batch num,  2
argmax time taken,  0.04658055305480957
total size: 64 3375000 = 216000000
num batches:  1
assembling the mass matrix time taken:  0.0041658878326416016
solving Ax = b time taken:  0.06455135345458984
l2 error 0.005125, h1 error 0.150838
epoch:  65	argmax batch num,  2
argmax time taken,  0.04613614082336426
total size: 65 3375000 = 219375000
num batches:  1
assembling the mass matrix time taken:  0.0061511993408203125
solving Ax = b time taken:  0.07966399192810059
l2 error 0.005038, h1 error 0.149120
epoch:  66	argmax batch num,  2
argmax time taken,  0.04713320732116699
total size: 66 3375000 = 222750000
num batches:  1
assembling the mass matrix time taken:  0.10678529739379883
solving Ax = b time taken:  0.040819644927978516
l2 error 0.004830, h1 error 0.14

epoch:  94	argmax batch num,  2
argmax time taken,  0.05141329765319824
total size: 94 3375000 = 317250000
num batches:  1
assembling the mass matrix time taken:  0.4382913112640381
solving Ax = b time taken:  0.09020209312438965
l2 error 0.002077, h1 error 0.067641
epoch:  95	argmax batch num,  2
argmax time taken,  0.0652613639831543
total size: 95 3375000 = 320625000
num batches:  1
assembling the mass matrix time taken:  0.023000001907348633
solving Ax = b time taken:  0.11284112930297852
l2 error 0.001983, h1 error 0.064576
epoch:  96	argmax batch num,  2
argmax time taken,  0.05147123336791992
total size: 96 3375000 = 324000000
num batches:  1
assembling the mass matrix time taken:  0.0050890445709228516
solving Ax = b time taken:  0.10824131965637207
l2 error 0.001886, h1 error 0.061563
epoch:  97	argmax batch num,  2
argmax time taken,  0.051061153411865234
total size: 97 3375000 = 327375000
num batches:  1
assembling the mass matrix time taken:  0.005532503128051758
solving Ax

l2 error 0.001185, h1 error 0.043459
epoch:  125	argmax batch num,  2
argmax time taken,  0.0550837516784668
total size: 125 3375000 = 421875000
num batches:  1
assembling the mass matrix time taken:  0.006732940673828125
solving Ax = b time taken:  0.13318634033203125
l2 error 0.001139, h1 error 0.042485
epoch:  126	argmax batch num,  2
argmax time taken,  0.056194305419921875
total size: 126 3375000 = 425250000
num batches:  1
assembling the mass matrix time taken:  0.24336767196655273
solving Ax = b time taken:  0.07239460945129395
l2 error 0.001122, h1 error 0.041081
epoch:  127	argmax batch num,  2
argmax time taken,  0.07007980346679688
total size: 127 3375000 = 428625000
num batches:  1
assembling the mass matrix time taken:  0.022411584854125977
solving Ax = b time taken:  0.13437795639038086
l2 error 0.001095, h1 error 0.040390
epoch:  128	argmax batch num,  2
argmax time taken,  0.05597734451293945
total size: 128 3375000 = 432000000
num batches:  1
assembling the mass matrix

In [4]:
4 		 0.354911 		 * 		 3.321857 		 * 

8 		 0.349854 		 0.020703 		 3.112930 		 0.093717 

16 		 0.095583 		 1.871927 		 1.543768 		 1.011817 

32 		 0.023397 		 2.030455 		 0.493063 		 1.646611 

64 		 0.005072 		 2.205648 		 0.140732 		 1.808820 

128 		 0.000975 		 2.378595 		 0.037545 		 1.906242 

256 		 0.000242 		 2.013625 		 0.012579 		 1.577634 

512 		 0.00006747 		 1.839831 		 0.004548 		 1.467792 

[tensor(0.354911),
 tensor(0.349854),
 tensor(0.095583),
 tensor(0.023397),
 tensor(0.005072),
 tensor(0.000975),
 tensor(0.000242),
 tensor(6.747653e-05)]

## oscillatory coefficient 

In [7]:

def u_exact(x):
    return torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3]) 
def alpha(x): 
#     return torch.ones(x.size(0),1).to(device)
    return 0.5 * torch.sin(6 * pi*x[:,0:1]) + 1. 

def u_exact_grad():
    d = 3 

    def grad_1(x):
        return - pi* torch.sin(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])   
    def grad_2(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.sin( pi*x[:,1:2]) * torch.cos(pi*x[:,2:3])  
    def grad_3(x):
        return - pi* torch.cos(pi*x[:,0:1])*torch.cos( pi*x[:,1:2]) * torch.sin(pi*x[:,2:3])   

    u_grad=[grad_1, grad_2,grad_3] 

    return u_grad

def target(x):

    z_c = torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) 
    z1 = 3 * pi**2 * torch.sin(pi * x[:,0:1]) * torch.cos( 6*pi*x[:,0:1] ) * torch.cos( pi*x[:,1:2] )* torch.cos(pi*x[:,2:3]) 
    z2 = 0.5 * pi**2 * torch.sin(6*pi * x[:,0:1])* z_c 
    z = z1 + z2 + 2/2*pi**2 * torch.sin(6 * pi * x[:,0:1]) * z_c 
    z += ( 3 * (pi)**2 + 1)*z_c 
    return z 

g_N = None 


function_name = "cospix" 
filename_write = "data-neumann/3DOGA-{}-order.txt".format(function_name)
f_write = open(filename_write, "a")
f_write.write("\n")
f_write.close() 
save = False 
for N_list in [[2**8]]: # ,[2**6,2**6],[2**7,2**7] 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 
    # Nx = 50   
    # order = 3   
    M = 300000  
    exponent = 9
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, M, k = 3, rand_deter= 'rand', linear_solver = "direct")
    
    if save: 
        folder = 'data-neumann/'
        filename = folder + 'err_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data-neumann/'
        filename = folder + 'model_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)

    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_QMC2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    f_write.write('deterministic dictionary size: {}\n'.format(N))
    f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )   
            f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    f_write.write("\n")
    f_write.close()



using linear solver:  direct
epoch:  1	argmax batch num,  2
argmax time taken,  0.0036385059356689453
total size: 1 3375000 = 3375000
num batches:  1
assembling the mass matrix time taken:  0.0011341571807861328
solving Ax = b time taken:  0.0030100345611572266
l2 error 0.354379, h1 error 3.330438
epoch:  2	argmax batch num,  2
argmax time taken,  0.00457000732421875
total size: 2 3375000 = 6750000
num batches:  1
assembling the mass matrix time taken:  0.0013885498046875
solving Ax = b time taken:  0.009688377380371094
l2 error 0.353762, h1 error 3.330066
epoch:  3	argmax batch num,  2
argmax time taken,  0.0059392452239990234
total size: 3 3375000 = 10125000
num batches:  1
assembling the mass matrix time taken:  0.001375436782836914
solving Ax = b time taken:  0.00920867919921875
l2 error 0.355635, h1 error 3.327806
epoch:  4	argmax batch num,  2
argmax time taken,  0.005828380584716797
total size: 4 3375000 = 13500000
num batches:  1
assembling the mass matrix time taken:  0.001529

total size: 31 3375000 = 104625000
num batches:  1
assembling the mass matrix time taken:  0.001363515853881836
solving Ax = b time taken:  0.03434395790100098
l2 error 0.037423, h1 error 0.757754
epoch:  32	argmax batch num,  2
argmax time taken,  0.009147167205810547
total size: 32 3375000 = 108000000
num batches:  1
assembling the mass matrix time taken:  0.0013141632080078125
solving Ax = b time taken:  0.03507256507873535
l2 error 0.033077, h1 error 0.670404
epoch:  33	argmax batch num,  2
argmax time taken,  0.00873565673828125
total size: 33 3375000 = 111375000
num batches:  1
assembling the mass matrix time taken:  0.0012848377227783203
solving Ax = b time taken:  0.037856340408325195
l2 error 0.030766, h1 error 0.641923
epoch:  34	argmax batch num,  2
argmax time taken,  0.00903630256652832
total size: 34 3375000 = 114750000
num batches:  1
assembling the mass matrix time taken:  0.0013196468353271484
solving Ax = b time taken:  0.03862905502319336
l2 error 0.029273, h1 error 

total size: 62 3375000 = 209250000
num batches:  1
assembling the mass matrix time taken:  0.0012328624725341797
solving Ax = b time taken:  0.06453537940979004
l2 error 0.005207, h1 error 0.138958
epoch:  63	argmax batch num,  2
argmax time taken,  0.012609004974365234
total size: 63 3375000 = 212625000
num batches:  1
assembling the mass matrix time taken:  0.0012362003326416016
solving Ax = b time taken:  0.06568264961242676
l2 error 0.004961, h1 error 0.134527
epoch:  64	argmax batch num,  2
argmax time taken,  0.01297140121459961
total size: 64 3375000 = 216000000
num batches:  1
assembling the mass matrix time taken:  0.001226663589477539
solving Ax = b time taken:  0.06610703468322754
l2 error 0.004861, h1 error 0.131810
epoch:  65	argmax batch num,  2
argmax time taken,  0.012532711029052734
total size: 65 3375000 = 219375000
num batches:  1
assembling the mass matrix time taken:  0.0012481212615966797
solving Ax = b time taken:  0.08424854278564453
l2 error 0.004709, h1 error 

epoch:  93	argmax batch num,  2
argmax time taken,  0.01712775230407715
total size: 93 3375000 = 313875000
num batches:  1
assembling the mass matrix time taken:  0.0013458728790283203
solving Ax = b time taken:  0.11499309539794922
l2 error 0.002290, h1 error 0.072503
epoch:  94	argmax batch num,  2
argmax time taken,  0.01738452911376953
total size: 94 3375000 = 317250000
num batches:  1
assembling the mass matrix time taken:  0.0013782978057861328
solving Ax = b time taken:  0.11504173278808594
l2 error 0.002235, h1 error 0.071146
epoch:  95	argmax batch num,  2
argmax time taken,  0.01749396324157715
total size: 95 3375000 = 320625000
num batches:  1
assembling the mass matrix time taken:  0.0013933181762695312
solving Ax = b time taken:  0.11611127853393555
l2 error 0.002176, h1 error 0.069992
epoch:  96	argmax batch num,  2
argmax time taken,  0.017624855041503906
total size: 96 3375000 = 324000000
num batches:  1
assembling the mass matrix time taken:  0.0013837814331054688
solv

total size: 123 3375000 = 415125000
num batches:  1
assembling the mass matrix time taken:  0.001352071762084961
solving Ax = b time taken:  0.13482451438903809
l2 error 0.001175, h1 error 0.043967
epoch:  124	argmax batch num,  2
argmax time taken,  0.021141529083251953
total size: 124 3375000 = 418500000
num batches:  1
assembling the mass matrix time taken:  0.0013659000396728516
solving Ax = b time taken:  0.13514113426208496
l2 error 0.001155, h1 error 0.043583
epoch:  125	argmax batch num,  2
argmax time taken,  0.021178245544433594
total size: 125 3375000 = 421875000
num batches:  1
assembling the mass matrix time taken:  0.0013704299926757812
solving Ax = b time taken:  0.13677597045898438
l2 error 0.001146, h1 error 0.043335
epoch:  126	argmax batch num,  2
argmax time taken,  0.021949291229248047
total size: 126 3375000 = 425250000
num batches:  1
assembling the mass matrix time taken:  0.0013766288757324219
solving Ax = b time taken:  0.13763976097106934
l2 error 0.001136, h

l2 error 0.000767, h1 error 0.031472
epoch:  154	argmax batch num,  2
argmax time taken,  0.025828838348388672
total size: 154 3375000 = 519750000
num batches:  1
assembling the mass matrix time taken:  0.06348991394042969
solving Ax = b time taken:  0.1463327407836914
l2 error 0.000750, h1 error 0.031127
epoch:  155	argmax batch num,  2
argmax time taken,  0.025208234786987305
total size: 155 3375000 = 523125000
num batches:  1
assembling the mass matrix time taken:  0.608839750289917
solving Ax = b time taken:  0.10088086128234863
l2 error 0.000744, h1 error 0.030903
epoch:  156	argmax batch num,  2
argmax time taken,  0.02608799934387207
total size: 156 3375000 = 526500000
num batches:  1
assembling the mass matrix time taken:  0.16785812377929688
solving Ax = b time taken:  0.14628100395202637
l2 error 0.000721, h1 error 0.030088
epoch:  157	argmax batch num,  2
argmax time taken,  0.025497913360595703
total size: 157 3375000 = 529875000
num batches:  1
assembling the mass matrix t

total size: 184 3375000 = 621000000
num batches:  2
assembling the mass matrix time taken:  0.008706331253051758
solving Ax = b time taken:  0.2162618637084961
l2 error 0.000523, h1 error 0.023537
epoch:  185	argmax batch num,  2
argmax time taken,  0.04831218719482422
total size: 185 3375000 = 624375000
num batches:  2
assembling the mass matrix time taken:  0.008716344833374023
solving Ax = b time taken:  0.21632933616638184
l2 error 0.000512, h1 error 0.023253
epoch:  186	argmax batch num,  2
argmax time taken,  0.5396640300750732
total size: 186 3375000 = 627750000
num batches:  2
assembling the mass matrix time taken:  0.0048828125
solving Ax = b time taken:  0.21872639656066895
l2 error 0.000509, h1 error 0.023161
epoch:  187	argmax batch num,  2
argmax time taken,  0.029909372329711914
total size: 187 3375000 = 631125000
num batches:  2
assembling the mass matrix time taken:  0.009265661239624023
solving Ax = b time taken:  0.21821188926696777
l2 error 0.000504, h1 error 0.02304

solving Ax = b time taken:  0.3312516212463379
l2 error 0.000390, h1 error 0.018773
epoch:  215	argmax batch num,  2
argmax time taken,  0.05280351638793945
total size: 215 3375000 = 725625000
num batches:  2
assembling the mass matrix time taken:  0.00895237922668457
solving Ax = b time taken:  0.29556822776794434
l2 error 0.000388, h1 error 0.018663
epoch:  216	argmax batch num,  2
argmax time taken,  0.2749361991882324
total size: 216 3375000 = 729000000
num batches:  2
assembling the mass matrix time taken:  0.010624408721923828
solving Ax = b time taken:  0.3067142963409424
l2 error 0.000388, h1 error 0.018517
epoch:  217	argmax batch num,  2
argmax time taken,  0.05286693572998047
total size: 217 3375000 = 732375000
num batches:  2
assembling the mass matrix time taken:  0.008910655975341797
solving Ax = b time taken:  0.29888367652893066
l2 error 0.000382, h1 error 0.018266
epoch:  218	argmax batch num,  2
argmax time taken,  0.049506187438964844
total size: 218 3375000 = 735750

epoch:  245	argmax batch num,  2
argmax time taken,  0.037142038345336914
total size: 245 3375000 = 826875000
num batches:  2
assembling the mass matrix time taken:  0.008951902389526367
solving Ax = b time taken:  0.32144737243652344
l2 error 0.000290, h1 error 0.014430
epoch:  246	argmax batch num,  2
argmax time taken,  0.23492121696472168
total size: 246 3375000 = 830250000
num batches:  2
assembling the mass matrix time taken:  0.003839254379272461
solving Ax = b time taken:  0.32258057594299316
l2 error 0.000289, h1 error 0.014382
epoch:  247	argmax batch num,  2
argmax time taken,  0.06314539909362793
total size: 247 3375000 = 833625000
num batches:  2
assembling the mass matrix time taken:  0.01519775390625
solving Ax = b time taken:  0.3231208324432373
l2 error 0.000285, h1 error 0.014297
epoch:  248	argmax batch num,  2
argmax time taken,  0.4725971221923828
total size: 248 3375000 = 837000000
num batches:  2
assembling the mass matrix time taken:  0.015470743179321289
solvin

solving Ax = b time taken:  0.3928682804107666
l2 error 0.000238, h1 error 0.012211
epoch:  276	argmax batch num,  2
argmax time taken,  0.04359102249145508
total size: 276 3375000 = 931500000
num batches:  2
assembling the mass matrix time taken:  0.018390178680419922
solving Ax = b time taken:  0.39505815505981445
l2 error 0.000235, h1 error 0.012124
epoch:  277	argmax batch num,  2
argmax time taken,  0.06718254089355469
total size: 277 3375000 = 934875000
num batches:  2
assembling the mass matrix time taken:  0.008798360824584961
solving Ax = b time taken:  0.39461255073547363
l2 error 0.000234, h1 error 0.012096
epoch:  278	argmax batch num,  2
argmax time taken,  0.5715336799621582
total size: 278 3375000 = 938250000
num batches:  2
assembling the mass matrix time taken:  0.19377446174621582
solving Ax = b time taken:  0.45917201042175293
l2 error 0.000231, h1 error 0.011980
epoch:  279	argmax batch num,  2
argmax time taken,  0.04183697700500488
total size: 279 3375000 = 941625

argmax time taken,  0.720334529876709
total size: 306 3375000 = 1032750000
num batches:  2
assembling the mass matrix time taken:  0.016310453414916992
solving Ax = b time taken:  0.4221053123474121
l2 error 0.000191, h1 error 0.010288
epoch:  307	argmax batch num,  2
argmax time taken,  0.07110071182250977
total size: 307 3375000 = 1036125000
num batches:  2
assembling the mass matrix time taken:  0.015295982360839844
solving Ax = b time taken:  0.42160892486572266
l2 error 0.000189, h1 error 0.010175
epoch:  308	argmax batch num,  2
argmax time taken,  0.07350349426269531
total size: 308 3375000 = 1039500000
num batches:  2
assembling the mass matrix time taken:  0.017555713653564453
solving Ax = b time taken:  0.42092013359069824
l2 error 0.000187, h1 error 0.010113
epoch:  309	argmax batch num,  2
argmax time taken,  0.07138633728027344
total size: 309 3375000 = 1042875000
num batches:  2
assembling the mass matrix time taken:  0.015250205993652344
solving Ax = b time taken:  0.422

solving Ax = b time taken:  0.4918220043182373
l2 error 0.000156, h1 error 0.008851
epoch:  337	argmax batch num,  2
argmax time taken,  0.04926347732543945
total size: 337 3375000 = 1137375000
num batches:  3
assembling the mass matrix time taken:  0.003347635269165039
solving Ax = b time taken:  0.5490190982818604
l2 error 0.000155, h1 error 0.008815
epoch:  338	argmax batch num,  2
argmax time taken,  0.0778806209564209
total size: 338 3375000 = 1140750000
num batches:  3
assembling the mass matrix time taken:  0.01790308952331543
solving Ax = b time taken:  0.6583359241485596
l2 error 0.000151, h1 error 0.008694
epoch:  339	argmax batch num,  2
argmax time taken,  0.07517719268798828
total size: 339 3375000 = 1144125000
num batches:  3
assembling the mass matrix time taken:  0.01335906982421875
solving Ax = b time taken:  0.5511074066162109
l2 error 0.000151, h1 error 0.008682
epoch:  340	argmax batch num,  2
argmax time taken,  0.052585601806640625
total size: 340 3375000 = 114750

total size: 367 3375000 = 1238625000
num batches:  3
assembling the mass matrix time taken:  0.009946346282958984
solving Ax = b time taken:  0.576026439666748
l2 error 0.000131, h1 error 0.007710
epoch:  368	argmax batch num,  2
argmax time taken,  0.056493520736694336
total size: 368 3375000 = 1242000000
num batches:  3
assembling the mass matrix time taken:  0.003325223922729492
solving Ax = b time taken:  0.5717921257019043
l2 error 0.000131, h1 error 0.007684
epoch:  369	argmax batch num,  2
argmax time taken,  0.053543806076049805
total size: 369 3375000 = 1245375000
num batches:  3
assembling the mass matrix time taken:  0.009874343872070312
solving Ax = b time taken:  0.5768337249755859
l2 error 0.000130, h1 error 0.007631
epoch:  370	argmax batch num,  2
argmax time taken,  0.24591374397277832
total size: 370 3375000 = 1248750000
num batches:  3
assembling the mass matrix time taken:  0.009907245635986328
solving Ax = b time taken:  0.5784878730773926
l2 error 0.000130, h1 err

solving Ax = b time taken:  0.6367778778076172
l2 error 0.000109, h1 error 0.006564
epoch:  398	argmax batch num,  2
argmax time taken,  0.0780494213104248
total size: 398 3375000 = 1343250000
num batches:  3
assembling the mass matrix time taken:  0.016501188278198242
solving Ax = b time taken:  0.6674461364746094
l2 error 0.000109, h1 error 0.006531
epoch:  399	argmax batch num,  2
argmax time taken,  0.0814371109008789
total size: 399 3375000 = 1346625000
num batches:  3
assembling the mass matrix time taken:  0.016449928283691406
solving Ax = b time taken:  0.6381778717041016
l2 error 0.000108, h1 error 0.006510
epoch:  400	argmax batch num,  2
argmax time taken,  0.08383607864379883
total size: 400 3375000 = 1350000000
num batches:  3
assembling the mass matrix time taken:  0.016404151916503906
solving Ax = b time taken:  0.6656761169433594
l2 error 0.000107, h1 error 0.006480
epoch:  401	argmax batch num,  2
argmax time taken,  0.08531308174133301
total size: 401 3375000 = 135337

total size: 428 3375000 = 1444500000
num batches:  3
assembling the mass matrix time taken:  0.010048866271972656
solving Ax = b time taken:  0.6927731037139893
l2 error 0.000092, h1 error 0.005748
epoch:  429	argmax batch num,  2
argmax time taken,  0.08139491081237793
total size: 429 3375000 = 1447875000
num batches:  3
assembling the mass matrix time taken:  0.009823322296142578
solving Ax = b time taken:  0.6736657619476318
l2 error 0.000091, h1 error 0.005727
epoch:  430	argmax batch num,  2
argmax time taken,  0.07498955726623535
total size: 430 3375000 = 1451250000
num batches:  3
assembling the mass matrix time taken:  0.009858369827270508
solving Ax = b time taken:  0.7288162708282471
l2 error 0.000091, h1 error 0.005653
epoch:  431	argmax batch num,  2
argmax time taken,  0.0682532787322998
total size: 431 3375000 = 1454625000
num batches:  3
assembling the mass matrix time taken:  0.0162045955657959
solving Ax = b time taken:  0.6748805046081543
l2 error 0.000091, h1 error 0

solving Ax = b time taken:  0.8089194297790527
l2 error 0.000080, h1 error 0.005139
epoch:  459	argmax batch num,  2
argmax time taken,  0.09106016159057617
total size: 459 3375000 = 1549125000
num batches:  3
assembling the mass matrix time taken:  0.016320466995239258
solving Ax = b time taken:  0.7869887351989746
l2 error 0.000080, h1 error 0.005130
epoch:  460	argmax batch num,  2
argmax time taken,  0.09144735336303711
total size: 460 3375000 = 1552500000
num batches:  3
assembling the mass matrix time taken:  0.01651477813720703
solving Ax = b time taken:  0.7730100154876709
l2 error 0.000080, h1 error 0.005124
epoch:  461	argmax batch num,  2
argmax time taken,  0.09211277961730957
total size: 461 3375000 = 1555875000
num batches:  3
assembling the mass matrix time taken:  0.01637434959411621
solving Ax = b time taken:  0.7836394309997559
l2 error 0.000080, h1 error 0.005115
epoch:  462	argmax batch num,  2
argmax time taken,  0.0914299488067627
total size: 462 3375000 = 1559250

total size: 489 3375000 = 1650375000
num batches:  4
assembling the mass matrix time taken:  0.004324197769165039
solving Ax = b time taken:  0.8077080249786377
l2 error 0.000071, h1 error 0.004644
epoch:  490	argmax batch num,  2
argmax time taken,  0.06918120384216309
total size: 490 3375000 = 1653750000
num batches:  4
assembling the mass matrix time taken:  0.004352092742919922
solving Ax = b time taken:  0.806098222732544
l2 error 0.000071, h1 error 0.004637
epoch:  491	argmax batch num,  2
argmax time taken,  0.06935572624206543
total size: 491 3375000 = 1657125000
num batches:  4
assembling the mass matrix time taken:  0.004335880279541016
solving Ax = b time taken:  0.8098258972167969
l2 error 0.000070, h1 error 0.004627
epoch:  492	argmax batch num,  2
argmax time taken,  0.06948280334472656
total size: 492 3375000 = 1660500000
num batches:  4
assembling the mass matrix time taken:  0.01726555824279785
solving Ax = b time taken:  0.8107399940490723
l2 error 0.000070, h1 error 

In [None]:
4 		 0.355877 		 * 		 3.322601 		 * 

8 		 0.332912 		 0.096240 		 3.179593 		 0.063471 

16 		 0.095527 		 1.801162 		 1.544408 		 1.041788 

32 		 0.033077 		 1.530097 		 0.670404 		 1.203951 

64 		 0.004861 		 2.766485 		 0.131810 		 2.346571 

128 		 0.001110 		 2.131217 		 0.042079 		 1.647304 

256 		 0.000265 		 2.065511 		 0.013437 		 1.646892 

512 		 0.000065 		 2.028127 		 0.004318 		 1.637881 

## gaussian example 

In [3]:
def u_exact(x):
    d = 3 
    cn =   7.03/d 
    return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True))  

def u_exact_grad():
    d = 3 
    def make_grad_i(i):
        def grad_i(x):
            d = 3  
            cn = 7.03/d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return grad_i 
    
    u_grad=[] 
    for i in range(d):
        u_grad.append(make_grad_i(i))
    return u_grad
                                                                
                                                                    
def alpha(x): 
    return torch.ones(x.size(0),1).to(device)
    # return 0.5 * torch.sin(6 * pi*x[:,0:1]) + 1. 

# def target(x):
# #     z = (  4 * (pi)**2 + 1)*torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
# #     return z 
#     z_c = torch.cos( pi*x[:,0:1])*torch.cos( pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
#     z1 = 3 * pi**2 * torch.sin(pi * x[:,0:1]) * torch.cos( 6*pi*x[:,1:2] ) * torch.cos(pi*x[:,2:3]) * torch.cos( pi*x[:,3:4]) 
#     z2 = 0.5 * pi**2 * torch.sin(6*pi * x[:,0:1])* z_c 
#     z = z1 + z2 + 3/2*pi**2 * torch.sin(6 * pi * x[:,0:1]) * z_c 
#     z += ( 4 * (pi)**2 + 1)*z_c 
#     return z 

def target(x):
    d = 3 
    cn =   7.03/d 
    z = torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    return z* ( -torch.sum(  (2 *cn**2 * (x - 0.5))**2 - 2*cn**2 ,dim = 1, keepdim = True) +1)

def g_N(dim):
    def make_g(i):
        def g_i(x):
            d = 3 
            cn = 7.03 / d
            return torch.exp(-torch.sum(cn**2 * (x - 0.5)**2, dim=1, keepdim=True)) * (-2 * cn**2 * (x[:, i:i+1] - 0.5))
        return g_i

    bcs_N = []
    for i in range(dim):
        bcs_N.append((i, make_g(i)))
    
    return bcs_N


function_name = "cospix" 
filename_write = "data-neumann/3DOGA-{}-order.txt".format(function_name)
f_write = open(filename_write, "a")
f_write.write("\n")
f_write.close() 
save = False 
for N_list in [[2**2,2**3,2**3]]: # ,[2**6,2**6],[2**7,2**7] 
    # save = True 
    f_write = open(filename_write, "a")
    my_model = None 
    # Nx = 50   
    # order = 3   
    M = 300000  
    exponent = 8
    num_epochs = 2**exponent  
    plot_freq = num_epochs 
    N = np.prod(N_list)
    err_QMC2, err_h10, my_model = OGANeumannReLU3D(my_model,alpha, target,g_N, u_exact,u_exact_grad, N_list,num_epochs,plot_freq, M, k =2, rand_deter= 'rand', linear_solver = "direct")
    
    if save: 
        folder = 'data-neumann/'
        filename = folder + 'err_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(err_QMC2,filename) 
        folder = 'data-neumann/'
        filename = folder + 'model_OGA_4D_{}_neuron_{}_N_{}_deterministic.pt'.format(function_name,num_epochs,N)
        torch.save(my_model,filename)

    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_QMC2[i] for i in neuron_nums ]
    err_list2 = [err_h10[i] for i in neuron_nums ] 
    f_write.write('deterministic dictionary size: {}\n'.format(N))
    f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t")
            
            # print("*")
            print("{} \t\t {:.6f} \t\t * \t\t {:.6f} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ) )   
            f_write.write("{} \t\t {} \t\t * \t\t {} \t\t * \n".format(neuron_nums[i],item, err_list2[i] ))
        else: 
            # print(neuron_nums[i], end = "\t\t")
            # print(item, end = "\t\t") 
            # print(np.log(err_list[i-1]/err_list[i])/np.log(2))
            print("{} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ) )
            f_write.write("{} \t\t {} \t\t {} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2),err_list2[i] , np.log(err_list2[i-1]/err_list2[i])/np.log(2) ))
    f_write.write("\n")
    f_write.close()


using linear solver:  direct
epoch:  1	assembling the mass matrix time taken:  0.004244804382324219
solving Ax = b time taken:  0.1321420669555664
l2 error 0.274375, h1 error 2.522434
epoch:  2	assembling the mass matrix time taken:  0.005568742752075195
solving Ax = b time taken:  0.004712104797363281
l2 error 0.213882, h1 error 2.411771
epoch:  3	assembling the mass matrix time taken:  0.0041046142578125
solving Ax = b time taken:  0.005236625671386719
l2 error 0.204598, h1 error 2.114279
epoch:  4	assembling the mass matrix time taken:  0.0047397613525390625
solving Ax = b time taken:  0.005539417266845703
l2 error 0.187524, h1 error 2.010464
epoch:  5	assembling the mass matrix time taken:  0.004512786865234375
solving Ax = b time taken:  0.005780696868896484
l2 error 0.201255, h1 error 1.951563
epoch:  6	assembling the mass matrix time taken:  0.0047681331634521484
solving Ax = b time taken:  0.00668644905090332
l2 error 0.167274, h1 error 1.863414
epoch:  7	assembling the mass ma

epoch:  53	assembling the mass matrix time taken:  0.017168521881103516
solving Ax = b time taken:  0.04207324981689453
l2 error 0.009000, h1 error 0.234303
epoch:  54	assembling the mass matrix time taken:  0.004387378692626953
solving Ax = b time taken:  0.042840003967285156
l2 error 0.008760, h1 error 0.229233
epoch:  55	assembling the mass matrix time taken:  0.017067670822143555
solving Ax = b time taken:  0.04342222213745117
l2 error 0.008520, h1 error 0.223530
epoch:  56	assembling the mass matrix time taken:  0.0171816349029541
solving Ax = b time taken:  0.043412208557128906
l2 error 0.008266, h1 error 0.218048
epoch:  57	assembling the mass matrix time taken:  0.016373395919799805
solving Ax = b time taken:  0.04475760459899902
l2 error 0.007885, h1 error 0.212589
epoch:  58	assembling the mass matrix time taken:  0.01629638671875
solving Ax = b time taken:  0.04501008987426758
l2 error 0.007483, h1 error 0.207104
epoch:  59	assembling the mass matrix time taken:  0.016647100

epoch:  106	assembling the mass matrix time taken:  0.004843235015869141
solving Ax = b time taken:  0.08366155624389648
l2 error 0.002967, h1 error 0.108168
epoch:  107	assembling the mass matrix time taken:  0.0048754215240478516
solving Ax = b time taken:  0.08427834510803223
l2 error 0.002941, h1 error 0.107546
epoch:  108	assembling the mass matrix time taken:  0.004924774169921875
solving Ax = b time taken:  0.08354330062866211
l2 error 0.002922, h1 error 0.106939
epoch:  109	assembling the mass matrix time taken:  0.004589080810546875
solving Ax = b time taken:  0.08572793006896973
l2 error 0.002891, h1 error 0.105978
epoch:  110	assembling the mass matrix time taken:  0.00465846061706543
solving Ax = b time taken:  0.08625912666320801
l2 error 0.002825, h1 error 0.104590
epoch:  111	assembling the mass matrix time taken:  0.0048999786376953125
solving Ax = b time taken:  0.08682680130004883
l2 error 0.002750, h1 error 0.102649
epoch:  112	assembling the mass matrix time taken: 

epoch:  158	assembling the mass matrix time taken:  0.18477892875671387
solving Ax = b time taken:  0.09980916976928711
l2 error 0.001650, h1 error 0.069962
epoch:  159	assembling the mass matrix time taken:  0.010071039199829102
solving Ax = b time taken:  0.13744473457336426
l2 error 0.001616, h1 error 0.069251
epoch:  160	assembling the mass matrix time taken:  0.005541324615478516
solving Ax = b time taken:  0.1387178897857666
l2 error 0.001572, h1 error 0.068309
epoch:  161	assembling the mass matrix time taken:  0.00565791130065918
solving Ax = b time taken:  0.14049720764160156
l2 error 0.001556, h1 error 0.067910
epoch:  162	assembling the mass matrix time taken:  0.007859468460083008
solving Ax = b time taken:  0.1388721466064453
l2 error 0.001539, h1 error 0.067381
epoch:  163	assembling the mass matrix time taken:  0.22953104972839355
solving Ax = b time taken:  0.07926440238952637
l2 error 0.001520, h1 error 0.067023
epoch:  164	assembling the mass matrix time taken:  0.009

l2 error 0.001030, h1 error 0.050361
epoch:  211	assembling the mass matrix time taken:  0.008592605590820312
solving Ax = b time taken:  0.1985642910003662
l2 error 0.001024, h1 error 0.050189
epoch:  212	assembling the mass matrix time taken:  0.3036653995513916
solving Ax = b time taken:  0.1352989673614502
l2 error 0.001019, h1 error 0.049999
epoch:  213	assembling the mass matrix time taken:  0.00924372673034668
solving Ax = b time taken:  0.19916582107543945
l2 error 0.001013, h1 error 0.049817
epoch:  214	assembling the mass matrix time taken:  0.006047725677490234
solving Ax = b time taken:  0.20202136039733887
l2 error 0.001014, h1 error 0.049608
epoch:  215	assembling the mass matrix time taken:  0.23576831817626953
solving Ax = b time taken:  0.11326241493225098
l2 error 0.001010, h1 error 0.049514
epoch:  216	assembling the mass matrix time taken:  0.00902414321899414
solving Ax = b time taken:  0.19700837135314941
l2 error 0.001007, h1 error 0.049396
epoch:  217	assembling