In [3]:
# In this version of OGA with random dictionaries, we use QMC to evaluate the loss function. 
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu" 
pi = torch.tensor(np.pi,dtype=torch.float64)
torch.set_default_dtype(torch.float64)

class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1


def generate_relu_dict4plusD_QMC(dim, s,N0):
    # s*N0 is the total number of samples 
    samples = torch.rand(s*N0,dim)  

    # for i in range(s-1):
        # samples = torch.cat([samples,Sob.draw(N0).double()],0)
    # Form the transformation matrix and shift vector 
    diagonal = torch.ones(dim)*pi  
    diagonal[-1] =  2*dim**0.5
    diagonal[-2] = 2*pi 
    T = torch.diag(diagonal)

    shift = torch.zeros(dim)
    shift[-1] = -dim**0.5 
    samples = samples@T + shift 

    Wb_tensor = torch.ones(s*N0,dim+1) # each neuron parameter stored in rows  
    for i in range(dim): # 0, 1, ... dim-1 
        for j in range(i+1): # 0, 1, ... i 
            if i == 0: 
                Wb_tensor[:,i] = Wb_tensor[:,i]*torch.cos(samples[:,j])
            if i == (dim - 1):
                if j != i:
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.sin(samples[:,j]) 
            if i != 0 and i != (dim - 1): 
                if j != i: 
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.sin(samples[:,j]) 
                else: 
                    Wb_tensor[:,i] = Wb_tensor[:,i] * torch.cos(samples[:,j]) 
            
    Wb_tensor[:,dim] = samples[:,-1] 

    return Wb_tensor.to(device)


def MonteCarlo_Sobol_dDim_weights_points(M ,d = 4):
    Sob_integral = torch.quasirandom.SobolEngine(dimension =d, scramble= False, seed=None) 
    integration_points = Sob_integral.draw(M).double() 
    integration_points = integration_points.to(device)
    weights = torch.ones(M,1).to(device)/M 
    return weights.to(device), integration_points.to(device) 


def minimize_linear_layer_explicit_assemble(model,target,weights, integration_points,solver="direct"):
    """
    calls the following functions (dependency): 
    1. GQ_piecewise_2D
    input: the nn model containing parameter 
    1. define the loss function  
    2. take derivative to extract the linear system A
    3. call the cg solver in scipy to solve the linear system 
    output: sol. solution of Ax = b
    """
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    
    # new batched operation 
    n = b.size(0)
    M = integration_points.size(0)
    
    total_size = n * M # memory, number of floating numbers 
    num_batch = total_size//(2**30) + 1 # divide according to memory
    batch_size = M//num_batch
    start_ind = 0
    end_ind = 0 
    jac = torch.zeros(b.size(0),b.size(0)).to(device)
    rhs = torch.zeros(b.size(0),1).to(device)
#     print("mat assemble, number batches: ",num_batch)
    for j in range(0,M,batch_size): 
        end_ind = j + batch_size
        basis_value_col = F.relu(integration_points[j:end_ind] @ w.t()+ b)**(model.k) 
        weighted_basis_value_col = basis_value_col * weights[j:end_ind] 
        jac += weighted_basis_value_col.t() @ basis_value_col 
        rhs += weighted_basis_value_col.t() @ (target(integration_points[j:end_ind,:])) 
        
    print("jac: ", jac.device)
    print("assembling the matrix time taken: ", time.time()-start_time) 
    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 

def OGAL2FittingReLU4Dplus_QMC(my_model,target,s,N0,num_epochs, M, k =1, linear_solver = "direct", num_batches = 1): 
    
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    # samples for QMC integral
    dim = 10 
    start_time = time.time()
    # Sob_integral = torch.quasirandom.SobolEngine(dimension =4, scramble= False, seed=None) 
    # integration_points = Sob_integral.draw(M).double() 
    # integration_points = integration_points.to(device)
    weights, integration_points = MonteCarlo_Sobol_dDim_weights_points(M ,d = dim) 
    _, integration_points_test = MonteCarlo_Sobol_dDim_weights_points(M*2 ,d = dim) 
    print("generate sob sequence:", time.time() - start_time) 

    err = torch.zeros(num_epochs+1)
    if my_model == None: 
        func_values = target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []

    else: 
        func_values = target(integration_points) - my_model(integration_points).detach()
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done

    func_values_sqrd = func_values*func_values
    # print(func_values_sqrd.size())
    # print(gw_expand.size() ) 

    err[0]= torch.mean(func_values_sqrd)**0.5
    all_start_time = time.time()
    
    solver = linear_solver
    print("using linear solver: ",solver)
    for i in range(num_epochs): 
        relu_dict_parameters = generate_relu_dict4plusD_QMC(dim, s,N0).t()   
        print("epoch: ",i+1, end = '\t')
        if num_neuron == 0: 
            func_values = target(integration_points)
        else: 
            

            total_size = num_neuron * M 
            num_batch = total_size//(2**30) + 1 # divide according to memory
            batch_size = M//num_batch
            end_ind = 0 

            func_values = torch.zeros(M,1).to(device)

            for j in range(0,M,batch_size): 
                end_ind = j + batch_size
                func_values[j:end_ind,:] = target(integration_points[j:end_ind,:]) - my_model(integration_points[j:end_ind,:]).detach()

        start_time = time.time() 
        
        M = integration_points.size(0)
        N = s*N0 
        output = torch.zeros(N,1)
        num_batches = (N*M)//2**30 + 1 # decide num_batches according to memory 
        batch_size = N//num_batches 
        print("argmax batch num, ", num_batches)
        for j in range(0,N,batch_size): 

            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[0:dim, j:end_index] ) - relu_dict_parameters[dim, j:end_index])**k).T # uses broadcasting    
            output[j:end_index,0]  = (torch.abs(torch.matmul(basis_values_batch,func_values))/M)[:,0]
            
        neuron_index = torch.argmax(output.flatten())
        
#         basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[:,0:4].T ) - relu_dict_parameters[:,4])**k).T # uses broadcasting
#         output = torch.abs(torch.matmul(basis_values,func_values))/M # 
#         neuron_index = torch.argmax(output.flatten())
        print("argmax time taken, ", time.time() - start_time)
        
        list_w.append(relu_dict_parameters[0:dim, neuron_index]) # 
        list_b.append(-relu_dict_parameters[dim,neuron_index])
        num_neuron += 1
        my_model = model(dim,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        start_time = time.time() 
        sol = minimize_linear_layer_explicit_assemble(my_model,target,weights,integration_points, solver)
        print("\t\t time taken minimize linear layer: ",time.time() - start_time) 
        my_model.fc2.weight.data[0,:] = sol[:]

        # calculate the test error
#         func_values = target(integration_points_test) - my_model(integration_points_test).detach()
        
        M2 = integration_points_test.size(0)
        total_size = num_neuron * M2 
        num_batch = total_size//(2**30) + 1 # divide according to memory
        batch_size = M//num_batch
        end_ind = 0 

        func_values_sqrd = torch.zeros(M2,1)  
        
        for j in range(0,M2,batch_size): 
            end_ind = j + batch_size
            func_values_sqrd[j:end_ind,:] = target(integration_points_test[j:end_ind,:]) - my_model(integration_points_test[j:end_ind,:]).detach()
        
        func_values_sqrd = func_values_sqrd**2 

        #Todo Done 
        err[i+1]= torch.mean(func_values_sqrd)**0.5
        print("current error: ",err[i+1]) 
    print("total duration: ",time.time() - all_start_time)
    return err, my_model

def print_convergence_order(err, neuron_num_exponent): 

    neuron_nums = [2**j for j in range(2,neuron_num_exponent)]
    err_list = [err[i] for i in neuron_nums ] 

    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            print(neuron_nums[i], end = "\t\t")
            print(item, end = "\t\t")
            print("*")
        else: 
            print(neuron_nums[i], end = "\t\t")
            print(item, end = "\t\t") 
            print(np.log(err_list[i-1]/err_list[i])/np.log(2))



In [10]:
if __name__ == "__main__": 

    def target(x): ## Gaussian function in dimension 10  
        d = 10 
        cn =   7.03/d 
        return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    
    save = False 
    experiment_label = "ex1"
    for k in [1]: 
        s = 1 
        for N0 in [2**16]: 
            print()
            print() 
            exponent = 9    
            num_epochs=  2**exponent 
            M = 2**20 # around2**19 50w 
            print(M)
            my_model = None 
            
            err, my_model = OGAL2FittingReLU4Dplus_QMC(my_model,target, \
                        s,N0,num_epochs, M, k = k, linear_solver = "direct", num_batches = 4)
            
            if save: 
                filename = experiment_label + "_err_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(err,filename)
                filename = experiment_label + "_model_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(my_model.state_dict(),filename) 

            print_convergence_order(err,exponent+1) 




1048576
generate sob sequence: 0.07876062393188477
using linear solver:  direct
epoch:  1	argmax batch num,  65
argmax time taken,  2.69319748878479
jac:  cuda:0
assembling the matrix time taken:  0.00037550926208496094
solving Ax = b time taken:  0.00036978721618652344
		 time taken minimize linear layer:  0.0007686614990234375
current error:  tensor(0.0894)
epoch:  2	argmax batch num,  65
argmax time taken,  2.6932663917541504
jac:  cuda:0
assembling the matrix time taken:  0.00042700767517089844
solving Ax = b time taken:  0.0007123947143554688
		 time taken minimize linear layer:  0.001165151596069336
current error:  tensor(0.0786)
epoch:  3	argmax batch num,  65
argmax time taken,  2.6931920051574707
jac:  cuda:0
assembling the matrix time taken:  0.0004069805145263672
solving Ax = b time taken:  0.0008862018585205078
		 time taken minimize linear layer:  0.0013189315795898438
current error:  tensor(0.0775)
epoch:  4	argmax batch num,  65
argmax time taken,  2.69321608543396
jac

argmax time taken,  2.6937544345855713
jac:  cuda:0
assembling the matrix time taken:  0.00043320655822753906
solving Ax = b time taken:  0.0023927688598632812
		 time taken minimize linear layer:  0.0028524398803710938
current error:  tensor(0.0324)
epoch:  31	argmax batch num,  65
argmax time taken,  2.6943254470825195
jac:  cuda:0
assembling the matrix time taken:  0.00041294097900390625
solving Ax = b time taken:  0.0024297237396240234
		 time taken minimize linear layer:  0.002869129180908203
current error:  tensor(0.0315)
epoch:  32	argmax batch num,  65
argmax time taken,  2.6940066814422607
jac:  cuda:0
assembling the matrix time taken:  0.00042176246643066406
solving Ax = b time taken:  0.002459287643432617
		 time taken minimize linear layer:  0.0029075145721435547
current error:  tensor(0.0307)
epoch:  33	argmax batch num,  65
argmax time taken,  2.6941652297973633
jac:  cuda:0
assembling the matrix time taken:  0.00041103363037109375
solving Ax = b time taken:  0.0027780532

argmax time taken,  2.6984663009643555
jac:  cuda:0
assembling the matrix time taken:  0.0004355907440185547
solving Ax = b time taken:  0.004200458526611328
		 time taken minimize linear layer:  0.0046634674072265625
current error:  tensor(0.0186)
epoch:  61	argmax batch num,  65
argmax time taken,  2.6981489658355713
jac:  cuda:0
assembling the matrix time taken:  0.0004048347473144531
solving Ax = b time taken:  0.0043108463287353516
		 time taken minimize linear layer:  0.004742145538330078
current error:  tensor(0.0184)
epoch:  62	argmax batch num,  65
argmax time taken,  2.698760509490967
jac:  cuda:0
assembling the matrix time taken:  0.0004379749298095703
solving Ax = b time taken:  0.004316091537475586
		 time taken minimize linear layer:  0.004781961441040039
current error:  tensor(0.0181)
epoch:  63	argmax batch num,  65
argmax time taken,  2.6987991333007812
jac:  cuda:0
assembling the matrix time taken:  0.0004093647003173828
solving Ax = b time taken:  0.00443005561828613

argmax time taken,  2.702324628829956
jac:  cuda:0
assembling the matrix time taken:  0.00044274330139160156
solving Ax = b time taken:  0.0077898502349853516
		 time taken minimize linear layer:  0.008260011672973633
current error:  tensor(0.0133)
epoch:  91	argmax batch num,  65
argmax time taken,  2.7023987770080566
jac:  cuda:0
assembling the matrix time taken:  0.00042700767517089844
solving Ax = b time taken:  0.0075376033782958984
		 time taken minimize linear layer:  0.007992744445800781
current error:  tensor(0.0133)
epoch:  92	argmax batch num,  65
argmax time taken,  2.7025132179260254
jac:  cuda:0
assembling the matrix time taken:  0.0004382133483886719
solving Ax = b time taken:  0.00743556022644043
		 time taken minimize linear layer:  0.007900714874267578
current error:  tensor(0.0131)
epoch:  93	argmax batch num,  65
argmax time taken,  2.699371576309204
jac:  cuda:0
assembling the matrix time taken:  0.00041985511779785156
solving Ax = b time taken:  0.0076279640197753

argmax time taken,  2.7018260955810547
jac:  cuda:0
assembling the matrix time taken:  0.0004932880401611328
solving Ax = b time taken:  0.008857488632202148
		 time taken minimize linear layer:  0.009394407272338867
current error:  tensor(0.0105)
epoch:  121	argmax batch num,  65
argmax time taken,  2.7034261226654053
jac:  cuda:0
assembling the matrix time taken:  0.00043272972106933594
solving Ax = b time taken:  0.009099721908569336
		 time taken minimize linear layer:  0.009558677673339844
current error:  tensor(0.0105)
epoch:  122	argmax batch num,  65
argmax time taken,  2.7039101123809814
jac:  cuda:0
assembling the matrix time taken:  0.00046443939208984375
solving Ax = b time taken:  0.00958251953125
		 time taken minimize linear layer:  0.010074615478515625
current error:  tensor(0.0104)
epoch:  123	argmax batch num,  65
argmax time taken,  2.703563690185547
jac:  cuda:0
assembling the matrix time taken:  0.00043463706970214844
solving Ax = b time taken:  0.00920915603637695

argmax time taken,  2.702862024307251
jac:  cuda:0
assembling the matrix time taken:  0.00046944618225097656
solving Ax = b time taken:  0.013495922088623047
		 time taken minimize linear layer:  0.013992786407470703
current error:  tensor(0.0089)
epoch:  151	argmax batch num,  65
argmax time taken,  2.7039825916290283
jac:  cuda:0
assembling the matrix time taken:  0.0004506111145019531
solving Ax = b time taken:  0.012959957122802734
		 time taken minimize linear layer:  0.013437509536743164
current error:  tensor(0.0088)
epoch:  152	argmax batch num,  65
argmax time taken,  2.703115701675415
jac:  cuda:0
assembling the matrix time taken:  0.00046539306640625
solving Ax = b time taken:  0.013007164001464844
		 time taken minimize linear layer:  0.013501882553100586
current error:  tensor(0.0088)
epoch:  153	argmax batch num,  65
argmax time taken,  2.704200029373169
jac:  cuda:0
assembling the matrix time taken:  0.0004451274871826172
solving Ax = b time taken:  0.013104677200317383


argmax time taken,  2.7048749923706055
jac:  cuda:0
assembling the matrix time taken:  0.0004849433898925781
solving Ax = b time taken:  0.014626741409301758
		 time taken minimize linear layer:  0.015139102935791016
current error:  tensor(0.0078)
epoch:  181	argmax batch num,  65
argmax time taken,  2.705024003982544
jac:  cuda:0
assembling the matrix time taken:  0.0004668235778808594
solving Ax = b time taken:  0.01437520980834961
		 time taken minimize linear layer:  0.014869451522827148
current error:  tensor(0.0077)
epoch:  182	argmax batch num,  65
argmax time taken,  2.7048161029815674
jac:  cuda:0
assembling the matrix time taken:  0.0005042552947998047
solving Ax = b time taken:  0.015302181243896484
		 time taken minimize linear layer:  0.01583552360534668
current error:  tensor(0.0077)
epoch:  183	argmax batch num,  65
argmax time taken,  2.7047157287597656
jac:  cuda:0
assembling the matrix time taken:  0.00048279762268066406
solving Ax = b time taken:  0.01448965072631836

argmax time taken,  2.700137138366699
jac:  cuda:0
assembling the matrix time taken:  0.0005252361297607422
solving Ax = b time taken:  0.021664857864379883
		 time taken minimize linear layer:  0.022217750549316406
current error:  tensor(0.0069)
epoch:  211	argmax batch num,  65
argmax time taken,  2.6996889114379883
jac:  cuda:0
assembling the matrix time taken:  0.0005154609680175781
solving Ax = b time taken:  0.020191192626953125
		 time taken minimize linear layer:  0.020734071731567383
current error:  tensor(0.0069)
epoch:  212	argmax batch num,  65
argmax time taken,  2.699878215789795
jac:  cuda:0
assembling the matrix time taken:  0.0005323886871337891
solving Ax = b time taken:  0.018824338912963867
		 time taken minimize linear layer:  0.019385337829589844
current error:  tensor(0.0069)
epoch:  213	argmax batch num,  65
argmax time taken,  2.6998214721679688
jac:  cuda:0
assembling the matrix time taken:  0.00051116943359375
solving Ax = b time taken:  0.020293474197387695


argmax time taken,  2.700634717941284
jac:  cuda:0
assembling the matrix time taken:  0.0004596710205078125
solving Ax = b time taken:  0.02147650718688965
		 time taken minimize linear layer:  0.02196192741394043
current error:  tensor(0.0063)
epoch:  241	argmax batch num,  65
argmax time taken,  2.7005631923675537
jac:  cuda:0
assembling the matrix time taken:  0.00045943260192871094
solving Ax = b time taken:  0.021741867065429688
		 time taken minimize linear layer:  0.02222728729248047
current error:  tensor(0.0062)
epoch:  242	argmax batch num,  65
argmax time taken,  2.7007954120635986
jac:  cuda:0
assembling the matrix time taken:  0.00044989585876464844
solving Ax = b time taken:  0.021756410598754883
		 time taken minimize linear layer:  0.0222322940826416
current error:  tensor(0.0062)
epoch:  243	argmax batch num,  65
argmax time taken,  2.7006680965423584
jac:  cuda:0
assembling the matrix time taken:  0.00045680999755859375
solving Ax = b time taken:  0.021806001663208008

argmax time taken,  2.7068560123443604
jac:  cuda:0
assembling the matrix time taken:  0.0005016326904296875
solving Ax = b time taken:  0.0306551456451416
		 time taken minimize linear layer:  0.03118276596069336
current error:  tensor(0.0057)
epoch:  271	argmax batch num,  65
argmax time taken,  2.7071688175201416
jac:  cuda:0
assembling the matrix time taken:  0.0004942417144775391
solving Ax = b time taken:  0.02717304229736328
		 time taken minimize linear layer:  0.027692556381225586
current error:  tensor(0.0057)
epoch:  272	argmax batch num,  65
argmax time taken,  2.706916093826294
jac:  cuda:0
assembling the matrix time taken:  0.0005145072937011719
solving Ax = b time taken:  0.025882244110107422
		 time taken minimize linear layer:  0.02642369270324707
current error:  tensor(0.0057)
epoch:  273	argmax batch num,  65
argmax time taken,  2.7068371772766113
jac:  cuda:0
assembling the matrix time taken:  0.0004885196685791016
solving Ax = b time taken:  0.02724933624267578
		 

argmax time taken,  2.7079505920410156
jac:  cuda:0
assembling the matrix time taken:  0.0004897117614746094
solving Ax = b time taken:  0.02854442596435547
		 time taken minimize linear layer:  0.029061079025268555
current error:  tensor(0.0053)
epoch:  301	argmax batch num,  65
argmax time taken,  2.711232900619507
jac:  cuda:0
assembling the matrix time taken:  0.0005035400390625
solving Ax = b time taken:  0.028650283813476562
		 time taken minimize linear layer:  0.02918100357055664
current error:  tensor(0.0053)
epoch:  302	argmax batch num,  65
argmax time taken,  2.7081565856933594
jac:  cuda:0
assembling the matrix time taken:  0.0005092620849609375
solving Ax = b time taken:  0.02867579460144043
		 time taken minimize linear layer:  0.029211997985839844
current error:  tensor(0.0052)
epoch:  303	argmax batch num,  65
argmax time taken,  2.7079811096191406
jac:  cuda:0
assembling the matrix time taken:  0.0005114078521728516
solving Ax = b time taken:  0.02871870994567871
		 t

argmax time taken,  2.7032904624938965
jac:  cuda:0
assembling the matrix time taken:  0.0005323886871337891
solving Ax = b time taken:  0.04538154602050781
		 time taken minimize linear layer:  0.04593992233276367
current error:  tensor(0.0049)
epoch:  331	argmax batch num,  65
argmax time taken,  2.7031705379486084
jac:  cuda:0
assembling the matrix time taken:  0.0005171298980712891
solving Ax = b time taken:  0.03817129135131836
		 time taken minimize linear layer:  0.03871440887451172
current error:  tensor(0.0049)
epoch:  332	argmax batch num,  65
argmax time taken,  2.703589916229248
jac:  cuda:0
assembling the matrix time taken:  0.0005450248718261719
solving Ax = b time taken:  0.03529810905456543
		 time taken minimize linear layer:  0.03587007522583008
current error:  tensor(0.0049)
epoch:  333	argmax batch num,  65
argmax time taken,  2.70355224609375
jac:  cuda:0
assembling the matrix time taken:  0.0005154609680175781
solving Ax = b time taken:  0.038330793380737305
		 ti

argmax time taken,  2.7039527893066406
jac:  cuda:0
assembling the matrix time taken:  0.0005331039428710938
solving Ax = b time taken:  0.03948521614074707
		 time taken minimize linear layer:  0.040045976638793945
current error:  tensor(0.0046)
epoch:  361	argmax batch num,  65
argmax time taken,  2.704180955886841
jac:  cuda:0
assembling the matrix time taken:  0.00054931640625
solving Ax = b time taken:  0.039766550064086914
		 time taken minimize linear layer:  0.04034304618835449
current error:  tensor(0.0046)
epoch:  362	argmax batch num,  65
argmax time taken,  2.70415997505188
jac:  cuda:0
assembling the matrix time taken:  0.0005609989166259766
solving Ax = b time taken:  0.03978848457336426
		 time taken minimize linear layer:  0.040377140045166016
current error:  tensor(0.0046)
epoch:  363	argmax batch num,  65
argmax time taken,  2.7041547298431396
jac:  cuda:0
assembling the matrix time taken:  0.0005381107330322266
solving Ax = b time taken:  0.03987884521484375
		 time 

argmax time taken,  2.7101173400878906
jac:  cuda:0
assembling the matrix time taken:  0.0005655288696289062
solving Ax = b time taken:  0.04700446128845215
		 time taken minimize linear layer:  0.04759621620178223
current error:  tensor(0.0044)
epoch:  391	argmax batch num,  65
argmax time taken,  2.710482597351074
jac:  cuda:0
assembling the matrix time taken:  0.0005710124969482422
solving Ax = b time taken:  0.044320106506347656
		 time taken minimize linear layer:  0.04491686820983887
current error:  tensor(0.0044)
epoch:  392	argmax batch num,  65
argmax time taken,  2.7104218006134033
jac:  cuda:0
assembling the matrix time taken:  0.0005731582641601562
solving Ax = b time taken:  0.046678781509399414
		 time taken minimize linear layer:  0.04727768898010254
current error:  tensor(0.0043)
epoch:  393	argmax batch num,  65
argmax time taken,  2.7104318141937256
jac:  cuda:0
assembling the matrix time taken:  0.0005602836608886719
solving Ax = b time taken:  0.04441976547241211
		

argmax time taken,  2.7114784717559814
jac:  cuda:0
assembling the matrix time taken:  0.0005724430084228516
solving Ax = b time taken:  0.04573702812194824
		 time taken minimize linear layer:  0.046335458755493164
current error:  tensor(0.0041)
epoch:  421	argmax batch num,  65
argmax time taken,  2.7114784717559814
jac:  cuda:0
assembling the matrix time taken:  0.0005757808685302734
solving Ax = b time taken:  0.04589390754699707
		 time taken minimize linear layer:  0.04649686813354492
current error:  tensor(0.0041)
epoch:  422	argmax batch num,  65
argmax time taken,  2.7132070064544678
jac:  cuda:0
assembling the matrix time taken:  0.0005753040313720703
solving Ax = b time taken:  0.04588007926940918
		 time taken minimize linear layer:  0.04648184776306152
current error:  tensor(0.0041)
epoch:  423	argmax batch num,  65
argmax time taken,  2.711934804916382
jac:  cuda:0
assembling the matrix time taken:  0.0005717277526855469
solving Ax = b time taken:  0.04599261283874512
		 

argmax time taken,  2.7122609615325928
jac:  cuda:0
assembling the matrix time taken:  0.0006036758422851562
solving Ax = b time taken:  0.05433344841003418
		 time taken minimize linear layer:  0.05496382713317871
current error:  tensor(0.0039)
epoch:  451	argmax batch num,  65
argmax time taken,  2.71886944770813
jac:  cuda:0
assembling the matrix time taken:  0.0005955696105957031
solving Ax = b time taken:  0.05462050437927246
		 time taken minimize linear layer:  0.055242061614990234
current error:  tensor(0.0039)
epoch:  452	argmax batch num,  65
argmax time taken,  2.7121076583862305
jac:  cuda:0
assembling the matrix time taken:  0.0006110668182373047
solving Ax = b time taken:  0.05276083946228027
		 time taken minimize linear layer:  0.05339813232421875
current error:  tensor(0.0039)
epoch:  453	argmax batch num,  65
argmax time taken,  2.7158777713775635
jac:  cuda:0
assembling the matrix time taken:  0.0005972385406494141
solving Ax = b time taken:  0.05464744567871094
		 t

argmax time taken,  2.7130892276763916
jac:  cuda:0
assembling the matrix time taken:  0.0006103515625
solving Ax = b time taken:  0.055737972259521484
		 time taken minimize linear layer:  0.05637550354003906
current error:  tensor(0.0038)
epoch:  481	argmax batch num,  65
argmax time taken,  2.7127902507781982
jac:  cuda:0
assembling the matrix time taken:  0.000606536865234375
solving Ax = b time taken:  0.056188344955444336
		 time taken minimize linear layer:  0.05682182312011719
current error:  tensor(0.0038)
epoch:  482	argmax batch num,  65
argmax time taken,  2.7130002975463867
jac:  cuda:0
assembling the matrix time taken:  0.0006003379821777344
solving Ax = b time taken:  0.05615854263305664
		 time taken minimize linear layer:  0.05678606033325195
current error:  tensor(0.0037)
epoch:  483	argmax batch num,  65
argmax time taken,  2.7195184230804443
jac:  cuda:0
assembling the matrix time taken:  0.0006079673767089844
solving Ax = b time taken:  0.05636906623840332
		 time 

argmax time taken,  2.7140228748321533
jac:  cuda:0
assembling the matrix time taken:  0.0006470680236816406
solving Ax = b time taken:  0.05839419364929199
		 time taken minimize linear layer:  0.05906867980957031
current error:  tensor(0.0036)
epoch:  511	argmax batch num,  65
argmax time taken,  2.7206010818481445
jac:  cuda:0
assembling the matrix time taken:  0.0006389617919921875
solving Ax = b time taken:  0.05861830711364746
		 time taken minimize linear layer:  0.05928444862365723
current error:  tensor(0.0036)
epoch:  512	argmax batch num,  65
argmax time taken,  2.7140636444091797
jac:  cuda:0
assembling the matrix time taken:  0.0006456375122070312
solving Ax = b time taken:  0.05810046195983887
		 time taken minimize linear layer:  0.05877685546875
current error:  tensor(0.0036)
total duration:  1431.2131884098053
neuron num 		 error 		 order
4		tensor(0.0760)		*
8		tensor(0.0707)		tensor(0.1040)
16		tensor(0.0517)		tensor(0.4519)
32		tensor(0.0307)		tensor(0.7518)
64		ten

In [2]:
if __name__ == "__main__": 

    def target(x): ## Gaussian function in dimension 10  
        d = 10 
        cn =   7.03/d 
        return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    
    save = False 
    experiment_label = "ex1"
    for k in [1]: 
        s = 1 
        for N0 in [2**9]: 
            print()
            print() 
            exponent = 9    
            num_epochs=  2**exponent 
            M = 2**23 # around2**19 50w 
            print(M)
            my_model = None 
            
            err, my_model = OGAL2FittingReLU4Dplus_QMC(my_model,target, \
                        s,N0,num_epochs, M, k = k, linear_solver = "direct", num_batches = 4)
            
            if save: 
                filename = experiment_label + "_err_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(err,filename)
                filename = experiment_label + "_model_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(my_model.state_dict(),filename) 

            print_convergence_order(err,exponent+1) 




8388608
generate sob sequence: 1.4505701065063477
using linear solver:  direct
epoch:  1	argmax batch num,  5
argmax time taken,  0.4784426689147949
assembling the matrix time taken:  0.0004487037658691406
solving Ax = b time taken:  0.07554054260253906
		 time taken minimize linear layer:  0.07604265213012695
current error:  tensor(0.0916)
epoch:  2	argmax batch num,  5
argmax time taken,  0.18694663047790527
assembling the matrix time taken:  0.0005314350128173828
solving Ax = b time taken:  0.007007598876953125
		 time taken minimize linear layer:  0.007581949234008789
current error:  tensor(0.0773)
epoch:  3	argmax batch num,  5
argmax time taken,  0.18919754028320312
assembling the matrix time taken:  0.0004317760467529297
solving Ax = b time taken:  0.007836103439331055
		 time taken minimize linear layer:  0.008306264877319336
current error:  tensor(0.0768)
epoch:  4	argmax batch num,  5
argmax time taken,  0.1894538402557373
assembling the matrix time taken:  0.00049161911010

argmax time taken,  0.19782805442810059
assembling the matrix time taken:  0.0004546642303466797
solving Ax = b time taken:  0.01996469497680664
		 time taken minimize linear layer:  0.02045464515686035
current error:  tensor(0.0492)
epoch:  33	argmax batch num,  5
argmax time taken,  0.19699788093566895
assembling the matrix time taken:  0.0005345344543457031
solving Ax = b time taken:  0.022403955459594727
		 time taken minimize linear layer:  0.022974729537963867
current error:  tensor(0.0489)
epoch:  34	argmax batch num,  5
argmax time taken,  0.19895553588867188
assembling the matrix time taken:  0.0005235671997070312
solving Ax = b time taken:  0.021874427795410156
		 time taken minimize linear layer:  0.022434234619140625
current error:  tensor(0.0487)
epoch:  35	argmax batch num,  5
argmax time taken,  0.1982860565185547
assembling the matrix time taken:  0.0004458427429199219
solving Ax = b time taken:  0.023133039474487305
		 time taken minimize linear layer:  0.0236153602600

argmax time taken,  0.20822763442993164
assembling the matrix time taken:  0.0004253387451171875
solving Ax = b time taken:  0.03510141372680664
		 time taken minimize linear layer:  0.03556227684020996
current error:  tensor(0.0402)
epoch:  64	argmax batch num,  5
argmax time taken,  0.20887398719787598
assembling the matrix time taken:  0.00047206878662109375
solving Ax = b time taken:  0.034700632095336914
		 time taken minimize linear layer:  0.03521132469177246
current error:  tensor(0.0400)
epoch:  65	argmax batch num,  5
argmax time taken,  0.2068624496459961
assembling the matrix time taken:  0.0005087852478027344
solving Ax = b time taken:  0.05073356628417969
		 time taken minimize linear layer:  0.05127859115600586
current error:  tensor(0.0396)
epoch:  66	argmax batch num,  5
argmax time taken,  0.21026992797851562
assembling the matrix time taken:  0.0005474090576171875
solving Ax = b time taken:  0.044821977615356445
		 time taken minimize linear layer:  0.045404434204101

argmax time taken,  0.21852397918701172
assembling the matrix time taken:  0.0005040168762207031
solving Ax = b time taken:  0.06676793098449707
		 time taken minimize linear layer:  0.06730914115905762
current error:  tensor(0.0313)
epoch:  95	argmax batch num,  5
argmax time taken,  0.21925806999206543
assembling the matrix time taken:  0.000453948974609375
solving Ax = b time taken:  0.06391119956970215
		 time taken minimize linear layer:  0.06440258026123047
current error:  tensor(0.0312)
epoch:  96	argmax batch num,  5
argmax time taken,  0.21921038627624512
assembling the matrix time taken:  0.0005068778991699219
solving Ax = b time taken:  0.0590205192565918
		 time taken minimize linear layer:  0.05956435203552246
current error:  tensor(0.0310)
epoch:  97	argmax batch num,  5
argmax time taken,  0.2182178497314453
assembling the matrix time taken:  0.00045609474182128906
solving Ax = b time taken:  0.05980563163757324
		 time taken minimize linear layer:  0.06029915809631348
c

argmax time taken,  0.22894001007080078
assembling the matrix time taken:  0.011406421661376953
solving Ax = b time taken:  0.06434035301208496
		 time taken minimize linear layer:  0.07578563690185547
current error:  tensor(0.0274)
epoch:  126	argmax batch num,  5
argmax time taken,  0.22927188873291016
assembling the matrix time taken:  0.015300273895263672
solving Ax = b time taken:  0.0661470890045166
		 time taken minimize linear layer:  0.08148789405822754
current error:  tensor(0.0272)
epoch:  127	argmax batch num,  5
argmax time taken,  0.23112869262695312
assembling the matrix time taken:  1.306748628616333
solving Ax = b time taken:  0.06603264808654785
		 time taken minimize linear layer:  1.3728771209716797
current error:  tensor(0.0272)
epoch:  128	argmax batch num,  5
argmax time taken,  0.23656558990478516
assembling the matrix time taken:  0.008951187133789062
solving Ax = b time taken:  0.06231546401977539
		 time taken minimize linear layer:  0.07130742073059082
curre

argmax time taken,  0.23835372924804688
assembling the matrix time taken:  0.005047798156738281
solving Ax = b time taken:  0.09860873222351074
		 time taken minimize linear layer:  0.10369586944580078
current error:  tensor(0.0234)
epoch:  157	argmax batch num,  5
argmax time taken,  0.2402329444885254
assembling the matrix time taken:  0.0049860477447509766
solving Ax = b time taken:  0.09785103797912598
		 time taken minimize linear layer:  0.10287761688232422
current error:  tensor(0.0232)
epoch:  158	argmax batch num,  5
argmax time taken,  0.23892998695373535
assembling the matrix time taken:  0.0052242279052734375
solving Ax = b time taken:  0.10381555557250977
		 time taken minimize linear layer:  0.10907912254333496
current error:  tensor(0.0230)
epoch:  159	argmax batch num,  5
argmax time taken,  0.24123239517211914
assembling the matrix time taken:  0.005036115646362305
solving Ax = b time taken:  0.09854841232299805
		 time taken minimize linear layer:  0.10362553596496582

argmax time taken,  0.2502715587615967
assembling the matrix time taken:  0.005728960037231445
solving Ax = b time taken:  0.10845327377319336
		 time taken minimize linear layer:  0.11422109603881836
current error:  tensor(0.0209)
epoch:  188	argmax batch num,  5
argmax time taken,  0.24831652641296387
assembling the matrix time taken:  0.007691383361816406
solving Ax = b time taken:  0.1095578670501709
		 time taken minimize linear layer:  0.11728835105895996
current error:  tensor(0.0209)
epoch:  189	argmax batch num,  5
argmax time taken,  0.25060343742370605
assembling the matrix time taken:  0.7045152187347412
solving Ax = b time taken:  0.05406832695007324
		 time taken minimize linear layer:  0.7586758136749268
current error:  tensor(0.0209)
epoch:  190	argmax batch num,  5
argmax time taken,  0.25827598571777344
assembling the matrix time taken:  0.010567188262939453
solving Ax = b time taken:  0.11644506454467773
		 time taken minimize linear layer:  0.12704992294311523
curre

argmax time taken,  0.25875258445739746
assembling the matrix time taken:  0.8463339805603027
solving Ax = b time taken:  0.09175443649291992
		 time taken minimize linear layer:  0.9381852149963379
current error:  tensor(0.0185)
epoch:  219	argmax batch num,  5
argmax time taken,  0.2607877254486084
assembling the matrix time taken:  0.009994029998779297
solving Ax = b time taken:  0.15397906303405762
		 time taken minimize linear layer:  0.16403794288635254
current error:  tensor(0.0183)
epoch:  220	argmax batch num,  5
argmax time taken,  0.2598574161529541
assembling the matrix time taken:  0.006783962249755859
solving Ax = b time taken:  0.15509343147277832
		 time taken minimize linear layer:  0.16192007064819336
current error:  tensor(0.0183)
epoch:  221	argmax batch num,  5
argmax time taken,  0.25977444648742676
assembling the matrix time taken:  0.0066890716552734375
solving Ax = b time taken:  0.15817856788635254
		 time taken minimize linear layer:  0.16491127014160156
curr

argmax time taken,  0.2679600715637207
assembling the matrix time taken:  0.013508319854736328
solving Ax = b time taken:  0.15937042236328125
		 time taken minimize linear layer:  0.17292237281799316
current error:  tensor(0.0170)
epoch:  250	argmax batch num,  5
argmax time taken,  0.26866674423217773
assembling the matrix time taken:  0.553380012512207
solving Ax = b time taken:  0.06939411163330078
		 time taken minimize linear layer:  0.6228604316711426
current error:  tensor(0.0170)
epoch:  251	argmax batch num,  5
argmax time taken,  0.2783060073852539
assembling the matrix time taken:  0.023034095764160156
solving Ax = b time taken:  0.15481305122375488
		 time taken minimize linear layer:  0.1778874397277832
current error:  tensor(0.0169)
epoch:  252	argmax batch num,  5
argmax time taken,  0.2690880298614502
assembling the matrix time taken:  0.018996477127075195
solving Ax = b time taken:  0.15399694442749023
		 time taken minimize linear layer:  0.17304420471191406
current 

argmax time taken,  0.27947545051574707
assembling the matrix time taken:  0.0012531280517578125
solving Ax = b time taken:  0.23931288719177246
		 time taken minimize linear layer:  0.24075770378112793
current error:  tensor(0.0160)
epoch:  281	argmax batch num,  5
argmax time taken,  0.2794215679168701
assembling the matrix time taken:  0.0012135505676269531
solving Ax = b time taken:  0.21843552589416504
		 time taken minimize linear layer:  0.2198340892791748
current error:  tensor(0.0160)
epoch:  282	argmax batch num,  5
argmax time taken,  0.2799229621887207
assembling the matrix time taken:  0.0012502670288085938
solving Ax = b time taken:  0.2814626693725586
		 time taken minimize linear layer:  0.28291797637939453
current error:  tensor(0.0159)
epoch:  283	argmax batch num,  5
argmax time taken,  0.2801060676574707
assembling the matrix time taken:  0.0013310909271240234
solving Ax = b time taken:  0.21772241592407227
		 time taken minimize linear layer:  0.21924805641174316
c

argmax time taken,  0.28906679153442383
assembling the matrix time taken:  0.0012199878692626953
solving Ax = b time taken:  0.22762608528137207
		 time taken minimize linear layer:  0.2290937900543213
current error:  tensor(0.0150)
epoch:  312	argmax batch num,  5
argmax time taken,  0.2895350456237793
assembling the matrix time taken:  0.0012173652648925781
solving Ax = b time taken:  0.22662901878356934
		 time taken minimize linear layer:  0.22809290885925293
current error:  tensor(0.0150)
epoch:  313	argmax batch num,  5
argmax time taken,  0.29007673263549805
assembling the matrix time taken:  0.0012199878692626953
solving Ax = b time taken:  0.2282414436340332
		 time taken minimize linear layer:  0.22969937324523926
current error:  tensor(0.0149)
epoch:  314	argmax batch num,  5
argmax time taken,  0.29048943519592285
assembling the matrix time taken:  0.001256704330444336
solving Ax = b time taken:  0.22740745544433594
		 time taken minimize linear layer:  0.22890615463256836


argmax time taken,  0.29954075813293457
assembling the matrix time taken:  0.0012819766998291016
solving Ax = b time taken:  0.3928954601287842
		 time taken minimize linear layer:  0.3943662643432617
current error:  tensor(0.0132)
epoch:  343	argmax batch num,  5
argmax time taken,  0.2991154193878174
assembling the matrix time taken:  0.0012421607971191406
solving Ax = b time taken:  0.30583882331848145
		 time taken minimize linear layer:  0.3073246479034424
current error:  tensor(0.0131)
epoch:  344	argmax batch num,  5
argmax time taken,  0.30012035369873047
assembling the matrix time taken:  0.0012350082397460938
solving Ax = b time taken:  0.3055436611175537
		 time taken minimize linear layer:  0.30701160430908203
current error:  tensor(0.0131)
epoch:  345	argmax batch num,  5
argmax time taken,  0.29990553855895996
assembling the matrix time taken:  0.001241922378540039
solving Ax = b time taken:  0.30759501457214355
		 time taken minimize linear layer:  0.30901384353637695
cu

argmax time taken,  0.30854249000549316
assembling the matrix time taken:  0.0012786388397216797
solving Ax = b time taken:  0.3198831081390381
		 time taken minimize linear layer:  0.32135891914367676
current error:  tensor(0.0124)
epoch:  374	argmax batch num,  5
argmax time taken,  0.3094358444213867
assembling the matrix time taken:  0.0012390613555908203
solving Ax = b time taken:  0.3183119297027588
		 time taken minimize linear layer:  0.319751501083374
current error:  tensor(0.0123)
epoch:  375	argmax batch num,  5
argmax time taken,  0.3090827465057373
assembling the matrix time taken:  0.0012924671173095703
solving Ax = b time taken:  0.3203601837158203
		 time taken minimize linear layer:  0.3218498229980469
current error:  tensor(0.0123)
epoch:  376	argmax batch num,  5
argmax time taken,  0.3099980354309082
assembling the matrix time taken:  0.0012698173522949219
solving Ax = b time taken:  0.3190910816192627
		 time taken minimize linear layer:  0.3205399513244629
current

argmax time taken,  0.3199574947357178
assembling the matrix time taken:  0.0013892650604248047
solving Ax = b time taken:  0.370516300201416
		 time taken minimize linear layer:  0.37210607528686523
current error:  tensor(0.0116)
epoch:  405	argmax batch num,  5
argmax time taken,  0.32013750076293945
assembling the matrix time taken:  0.01425313949584961
solving Ax = b time taken:  0.3536512851715088
		 time taken minimize linear layer:  0.3680415153503418
current error:  tensor(0.0116)
epoch:  406	argmax batch num,  5
argmax time taken,  0.3205554485321045
assembling the matrix time taken:  0.001367807388305664
solving Ax = b time taken:  0.3728652000427246
		 time taken minimize linear layer:  0.37442946434020996
current error:  tensor(0.0116)
epoch:  407	argmax batch num,  5
argmax time taken,  0.3206794261932373
assembling the matrix time taken:  0.0013554096221923828
solving Ax = b time taken:  0.3526804447174072
		 time taken minimize linear layer:  0.35423874855041504
current 

argmax time taken,  0.3292350769042969
assembling the matrix time taken:  0.001341104507446289
solving Ax = b time taken:  0.3705477714538574
		 time taken minimize linear layer:  0.3721301555633545
current error:  tensor(0.0110)
epoch:  436	argmax batch num,  5
argmax time taken,  0.3297741413116455
assembling the matrix time taken:  0.001360177993774414
solving Ax = b time taken:  0.3713865280151367
		 time taken minimize linear layer:  0.37294721603393555
current error:  tensor(0.0110)
epoch:  437	argmax batch num,  5
argmax time taken,  0.3302123546600342
assembling the matrix time taken:  0.0013623237609863281
solving Ax = b time taken:  0.37183189392089844
		 time taken minimize linear layer:  0.37339162826538086
current error:  tensor(0.0109)
epoch:  438	argmax batch num,  5
argmax time taken,  0.3310079574584961
assembling the matrix time taken:  0.0013713836669921875
solving Ax = b time taken:  0.37201547622680664
		 time taken minimize linear layer:  0.37358808517456055
curre

argmax time taken,  0.3399083614349365
assembling the matrix time taken:  0.0014705657958984375
solving Ax = b time taken:  0.4444568157196045
		 time taken minimize linear layer:  0.4461374282836914
current error:  tensor(0.0104)
epoch:  467	argmax batch num,  5
argmax time taken,  0.3399336338043213
assembling the matrix time taken:  0.0014104843139648438
solving Ax = b time taken:  0.4452555179595947
		 time taken minimize linear layer:  0.44687700271606445
current error:  tensor(0.0104)
epoch:  468	argmax batch num,  5
argmax time taken,  0.3407301902770996
assembling the matrix time taken:  0.0014352798461914062
solving Ax = b time taken:  0.44608020782470703
		 time taken minimize linear layer:  0.4477205276489258
current error:  tensor(0.0104)
epoch:  469	argmax batch num,  5
argmax time taken,  0.34090232849121094
assembling the matrix time taken:  0.0013859272003173828
solving Ax = b time taken:  0.4471311569213867
		 time taken minimize linear layer:  0.44872331619262695
curr

argmax time taken,  0.3489680290222168
assembling the matrix time taken:  0.0013849735260009766
solving Ax = b time taken:  0.44704151153564453
		 time taken minimize linear layer:  0.44862937927246094
current error:  tensor(0.0099)
epoch:  498	argmax batch num,  5
argmax time taken,  0.34982872009277344
assembling the matrix time taken:  0.0014259815216064453
solving Ax = b time taken:  0.44997525215148926
		 time taken minimize linear layer:  0.4515993595123291
current error:  tensor(0.0099)
epoch:  499	argmax batch num,  5
argmax time taken,  0.34975457191467285
assembling the matrix time taken:  0.0013964176177978516
solving Ax = b time taken:  0.44997310638427734
		 time taken minimize linear layer:  0.4515669345855713
current error:  tensor(0.0098)
epoch:  500	argmax batch num,  5
argmax time taken,  0.3506042957305908
assembling the matrix time taken:  0.0014290809631347656
solving Ax = b time taken:  0.44968080520629883
		 time taken minimize linear layer:  0.45130252838134766


## diff target

In [5]:
if __name__ == "__main__": 

    def target(x): # product of sin(pi x_i) 
#         return torch.prod(torch.sin(pi * x), dim=1,keepdim = True)
        return torch.sum(torch.sin(pi * x), dim=1,keepdim = True)
        
    save = False 
    experiment_label = "ex2"
    for k in [1]: 
        s = 1 
        for N0 in [2**12]: 
            print()
            print() 
            exponent = 9    
            num_epochs=  2**exponent 
            M = 2**20 #  2**19 around 50w 
            print(M)
            my_model = None 
            
            err, my_model = OGAL2FittingReLU4Dplus_QMC(my_model,target, \
                        s,N0,num_epochs, M, k = k, linear_solver = "direct", num_batches = 8)
            
            if save: 
                filename = experiment_label + "_err_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(err,filename)
                filename = experiment_label + "_model_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(my_model.state_dict(),filename) 

            print_convergence_order(err,exponent+1) 




1048576
generate sob sequence: 0.08808159828186035
using linear solver:  direct
epoch:  1	argmax time taken,  0.17486262321472168
assembling the matrix time taken:  0.00017833709716796875
solving Ax = b time taken:  0.00040650367736816406
		 time taken minimize linear layer:  0.0006225109100341797
current error:  tensor(1.0755)
epoch:  2	argmax time taken,  0.1720871925354004
assembling the matrix time taken:  0.0002238750457763672
solving Ax = b time taken:  0.0007169246673583984
		 time taken minimize linear layer:  0.0009715557098388672
current error:  tensor(0.9818)
epoch:  3	argmax time taken,  0.16898059844970703
assembling the matrix time taken:  0.00019598007202148438
solving Ax = b time taken:  0.0007650852203369141
		 time taken minimize linear layer:  0.0009918212890625
current error:  tensor(0.9741)
epoch:  4	argmax time taken,  0.16912007331848145
assembling the matrix time taken:  0.0002040863037109375
solving Ax = b time taken:  0.0008275508880615234
		 time taken mini

epoch:  35	argmax time taken,  0.17318296432495117
assembling the matrix time taken:  0.00019431114196777344
solving Ax = b time taken:  0.002843618392944336
		 time taken minimize linear layer:  0.0030689239501953125
current error:  tensor(0.5285)
epoch:  36	argmax time taken,  0.17325186729431152
assembling the matrix time taken:  0.00020432472229003906
solving Ax = b time taken:  0.0027463436126708984
		 time taken minimize linear layer:  0.002988100051879883
current error:  tensor(0.5252)
epoch:  37	argmax time taken,  0.17330193519592285
assembling the matrix time taken:  0.00019359588623046875
solving Ax = b time taken:  0.002917766571044922
		 time taken minimize linear layer:  0.0031430721282958984
current error:  tensor(0.5179)
epoch:  38	argmax time taken,  0.17338967323303223
assembling the matrix time taken:  0.00020265579223632812
solving Ax = b time taken:  0.0028696060180664062
		 time taken minimize linear layer:  0.0031032562255859375
current error:  tensor(0.5095)
epo

epoch:  69	argmax time taken,  0.17430782318115234
assembling the matrix time taken:  0.00019550323486328125
solving Ax = b time taken:  0.006151437759399414
		 time taken minimize linear layer:  0.006379842758178711
current error:  tensor(0.3368)
epoch:  70	argmax time taken,  0.17505407333374023
assembling the matrix time taken:  0.0002377033233642578
solving Ax = b time taken:  0.0059833526611328125
		 time taken minimize linear layer:  0.0063877105712890625
current error:  tensor(0.3302)
epoch:  71	argmax time taken,  0.17501521110534668
assembling the matrix time taken:  0.00021195411682128906
solving Ax = b time taken:  0.006273031234741211
		 time taken minimize linear layer:  0.006519317626953125
current error:  tensor(0.3207)
epoch:  72	argmax time taken,  0.17506027221679688
assembling the matrix time taken:  0.0002522468566894531
solving Ax = b time taken:  0.005761146545410156
		 time taken minimize linear layer:  0.006047964096069336
current error:  tensor(0.3172)
epoch:  

epoch:  103	argmax time taken,  0.17525506019592285
assembling the matrix time taken:  0.00021696090698242188
solving Ax = b time taken:  0.007984399795532227
		 time taken minimize linear layer:  0.008235931396484375
current error:  tensor(0.2341)
epoch:  104	argmax time taken,  0.17878055572509766
assembling the matrix time taken:  0.0002701282501220703
solving Ax = b time taken:  0.007860660552978516
		 time taken minimize linear layer:  0.008170366287231445
current error:  tensor(0.2327)
epoch:  105	argmax time taken,  0.1786656379699707
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.007946014404296875
		 time taken minimize linear layer:  0.008237123489379883
current error:  tensor(0.2308)
epoch:  106	argmax time taken,  0.178696870803833
assembling the matrix time taken:  0.00026607513427734375
solving Ax = b time taken:  0.00835728645324707
		 time taken minimize linear layer:  0.008664131164550781
current error:  tensor(0.2236)
epoch:  1

solving Ax = b time taken:  0.011966705322265625
		 time taken minimize linear layer:  0.012353897094726562
current error:  tensor(0.1885)
epoch:  137	argmax time taken,  0.17689251899719238
assembling the matrix time taken:  0.00030994415283203125
solving Ax = b time taken:  0.01205897331237793
		 time taken minimize linear layer:  0.01254725456237793
current error:  tensor(0.1877)
epoch:  138	argmax time taken,  0.17948675155639648
assembling the matrix time taken:  0.0002491474151611328
solving Ax = b time taken:  0.012727975845336914
		 time taken minimize linear layer:  0.013162612915039062
current error:  tensor(0.1867)
epoch:  139	argmax time taken,  0.17971181869506836
assembling the matrix time taken:  0.0002067089080810547
solving Ax = b time taken:  0.012231111526489258
		 time taken minimize linear layer:  0.012604475021362305
current error:  tensor(0.1860)
epoch:  140	argmax time taken,  0.17742300033569336
assembling the matrix time taken:  0.00022411346435546875
solving 

epoch:  170	argmax time taken,  0.18062877655029297
assembling the matrix time taken:  0.00020575523376464844
solving Ax = b time taken:  0.014531612396240234
		 time taken minimize linear layer:  0.014769792556762695
current error:  tensor(0.1632)
epoch:  171	argmax time taken,  0.18053531646728516
assembling the matrix time taken:  0.00019621849060058594
solving Ax = b time taken:  0.013812541961669922
		 time taken minimize linear layer:  0.01404118537902832
current error:  tensor(0.1629)
epoch:  172	argmax time taken,  0.18053627014160156
assembling the matrix time taken:  0.0002079010009765625
solving Ax = b time taken:  0.01411581039428711
		 time taken minimize linear layer:  0.014355659484863281
current error:  tensor(0.1625)
epoch:  173	argmax time taken,  0.18051719665527344
assembling the matrix time taken:  0.00020074844360351562
solving Ax = b time taken:  0.013914823532104492
		 time taken minimize linear layer:  0.014147520065307617
current error:  tensor(0.1621)
epoch: 

epoch:  204	argmax time taken,  0.17548465728759766
assembling the matrix time taken:  0.00020384788513183594
solving Ax = b time taken:  0.019103288650512695
		 time taken minimize linear layer:  0.019340991973876953
current error:  tensor(0.1482)
epoch:  205	argmax time taken,  0.17542791366577148
assembling the matrix time taken:  0.0001938343048095703
solving Ax = b time taken:  0.019782066345214844
		 time taken minimize linear layer:  0.020008563995361328
current error:  tensor(0.1477)
epoch:  206	argmax time taken,  0.1755058765411377
assembling the matrix time taken:  0.0002110004425048828
solving Ax = b time taken:  0.021068572998046875
		 time taken minimize linear layer:  0.02131199836730957
current error:  tensor(0.1475)
epoch:  207	argmax time taken,  0.1755976676940918
assembling the matrix time taken:  0.0001914501190185547
solving Ax = b time taken:  0.01987147331237793
		 time taken minimize linear layer:  0.0200955867767334
current error:  tensor(0.1464)
epoch:  208	a

epoch:  238	argmax time taken,  0.1764507293701172
assembling the matrix time taken:  0.00019359588623046875
solving Ax = b time taken:  0.02137303352355957
		 time taken minimize linear layer:  0.02159881591796875
current error:  tensor(0.1338)
epoch:  239	argmax time taken,  0.17644357681274414
assembling the matrix time taken:  0.00019359588623046875
solving Ax = b time taken:  0.02145862579345703
		 time taken minimize linear layer:  0.021684885025024414
current error:  tensor(0.1334)
epoch:  240	argmax time taken,  0.1764082908630371
assembling the matrix time taken:  0.00020170211791992188
solving Ax = b time taken:  0.021393299102783203
		 time taken minimize linear layer:  0.021626949310302734
current error:  tensor(0.1330)
epoch:  241	argmax time taken,  0.17644882202148438
assembling the matrix time taken:  0.00019049644470214844
solving Ax = b time taken:  0.021603107452392578
		 time taken minimize linear layer:  0.021826744079589844
current error:  tensor(0.1321)
epoch:  2

epoch:  272	argmax time taken,  0.1866135597229004
assembling the matrix time taken:  0.00021696090698242188
solving Ax = b time taken:  0.025792837142944336
		 time taken minimize linear layer:  0.026043176651000977
current error:  tensor(0.1232)
epoch:  273	argmax time taken,  0.18330669403076172
assembling the matrix time taken:  0.0002105236053466797
solving Ax = b time taken:  0.02711772918701172
		 time taken minimize linear layer:  0.02736067771911621
current error:  tensor(0.1227)
epoch:  274	argmax time taken,  0.18321776390075684
assembling the matrix time taken:  0.00021791458129882812
solving Ax = b time taken:  0.03142404556274414
		 time taken minimize linear layer:  0.03167414665222168
current error:  tensor(0.1225)
epoch:  275	argmax time taken,  0.18337130546569824
assembling the matrix time taken:  0.00021028518676757812
solving Ax = b time taken:  0.02721714973449707
		 time taken minimize linear layer:  0.02746105194091797
current error:  tensor(0.1223)
epoch:  276	

epoch:  306	argmax time taken,  0.18446874618530273
assembling the matrix time taken:  0.00026869773864746094
solving Ax = b time taken:  0.028763771057128906
		 time taken minimize linear layer:  0.029073476791381836
current error:  tensor(0.1138)
epoch:  307	argmax time taken,  0.18592333793640137
assembling the matrix time taken:  0.00024890899658203125
solving Ax = b time taken:  0.028828859329223633
		 time taken minimize linear layer:  0.029119014739990234
current error:  tensor(0.1132)
epoch:  308	argmax time taken,  0.18422245979309082
assembling the matrix time taken:  0.0002429485321044922
solving Ax = b time taken:  0.0288236141204834
		 time taken minimize linear layer:  0.029102802276611328
current error:  tensor(0.1130)
epoch:  309	argmax time taken,  0.18770718574523926
assembling the matrix time taken:  0.00022912025451660156
solving Ax = b time taken:  0.02890753746032715
		 time taken minimize linear layer:  0.029171228408813477
current error:  tensor(0.1129)
epoch:  

epoch:  340	argmax time taken,  0.17972087860107422
assembling the matrix time taken:  0.0003314018249511719
solving Ax = b time taken:  0.034017086029052734
		 time taken minimize linear layer:  0.03439617156982422
current error:  tensor(0.1060)
epoch:  341	argmax time taken,  0.18002009391784668
assembling the matrix time taken:  0.00024437904357910156
solving Ax = b time taken:  0.038521528244018555
		 time taken minimize linear layer:  0.038803815841674805
current error:  tensor(0.1059)
epoch:  342	argmax time taken,  0.17979907989501953
assembling the matrix time taken:  0.0002579689025878906
solving Ax = b time taken:  0.04428362846374512
		 time taken minimize linear layer:  0.04457736015319824
current error:  tensor(0.1053)
epoch:  343	argmax time taken,  0.1797780990600586
assembling the matrix time taken:  0.00023484230041503906
solving Ax = b time taken:  0.03861665725708008
		 time taken minimize linear layer:  0.03888726234436035
current error:  tensor(0.1051)
epoch:  344	

epoch:  374	argmax time taken,  0.18078875541687012
assembling the matrix time taken:  0.00021195411682128906
solving Ax = b time taken:  0.0402069091796875
		 time taken minimize linear layer:  0.04045224189758301
current error:  tensor(0.1004)
epoch:  375	argmax time taken,  0.18071293830871582
assembling the matrix time taken:  0.0002040863037109375
solving Ax = b time taken:  0.04028725624084473
		 time taken minimize linear layer:  0.04052901268005371
current error:  tensor(0.1002)
epoch:  376	argmax time taken,  0.1808016300201416
assembling the matrix time taken:  0.00021266937255859375
solving Ax = b time taken:  0.0401759147644043
		 time taken minimize linear layer:  0.040421247482299805
current error:  tensor(0.1001)
epoch:  377	argmax time taken,  0.1806790828704834
assembling the matrix time taken:  0.00021123886108398438
solving Ax = b time taken:  0.04035139083862305
		 time taken minimize linear layer:  0.04059481620788574
current error:  tensor(0.0999)
epoch:  378	argm

epoch:  408	argmax time taken,  0.192901611328125
assembling the matrix time taken:  0.0002071857452392578
solving Ax = b time taken:  0.04731869697570801
		 time taken minimize linear layer:  0.04755854606628418
current error:  tensor(0.0948)
epoch:  409	argmax time taken,  0.18689393997192383
assembling the matrix time taken:  0.00021314620971679688
solving Ax = b time taken:  0.0450291633605957
		 time taken minimize linear layer:  0.04527616500854492
current error:  tensor(0.0946)
epoch:  410	argmax time taken,  0.19344329833984375
assembling the matrix time taken:  0.0002148151397705078
solving Ax = b time taken:  0.0450439453125
		 time taken minimize linear layer:  0.04529118537902832
current error:  tensor(0.0946)
epoch:  411	argmax time taken,  0.18707656860351562
assembling the matrix time taken:  0.00020885467529296875
solving Ax = b time taken:  0.04518437385559082
		 time taken minimize linear layer:  0.045426368713378906
current error:  tensor(0.0945)
epoch:  412	argmax t

epoch:  442	argmax time taken,  0.1880793571472168
assembling the matrix time taken:  0.0002110004425048828
solving Ax = b time taken:  0.05055093765258789
		 time taken minimize linear layer:  0.0507960319519043
current error:  tensor(0.0901)
epoch:  443	argmax time taken,  0.19541215896606445
assembling the matrix time taken:  0.0002071857452392578
solving Ax = b time taken:  0.04756498336791992
		 time taken minimize linear layer:  0.047805070877075195
current error:  tensor(0.0899)
epoch:  444	argmax time taken,  0.18813109397888184
assembling the matrix time taken:  0.00020813941955566406
solving Ax = b time taken:  0.04918527603149414
		 time taken minimize linear layer:  0.049425601959228516
current error:  tensor(0.0898)
epoch:  445	argmax time taken,  0.19485878944396973
assembling the matrix time taken:  0.00020551681518554688
solving Ax = b time taken:  0.047445058822631836
		 time taken minimize linear layer:  0.0476832389831543
current error:  tensor(0.0897)
epoch:  446	ar

epoch:  476	argmax time taken,  0.1890881061553955
assembling the matrix time taken:  0.00020456314086914062
solving Ax = b time taken:  0.05562591552734375
		 time taken minimize linear layer:  0.055864810943603516
current error:  tensor(0.0861)
epoch:  477	argmax time taken,  0.18904447555541992
assembling the matrix time taken:  0.00021266937255859375
solving Ax = b time taken:  0.05583763122558594
		 time taken minimize linear layer:  0.056082963943481445
current error:  tensor(0.0860)
epoch:  478	argmax time taken,  0.1958918571472168
assembling the matrix time taken:  0.00020623207092285156
solving Ax = b time taken:  0.05583667755126953
		 time taken minimize linear layer:  0.0560755729675293
current error:  tensor(0.0859)
epoch:  479	argmax time taken,  0.1891491413116455
assembling the matrix time taken:  0.0002079010009765625
solving Ax = b time taken:  0.055855751037597656
		 time taken minimize linear layer:  0.056096553802490234
current error:  tensor(0.0858)
epoch:  480	a

epoch:  510	argmax time taken,  0.19333767890930176
assembling the matrix time taken:  0.00021123886108398438
solving Ax = b time taken:  0.05849123001098633
		 time taken minimize linear layer:  0.058734893798828125
current error:  tensor(0.0825)
epoch:  511	argmax time taken,  0.1901099681854248
assembling the matrix time taken:  0.0002067089080810547
solving Ax = b time taken:  0.058522701263427734
		 time taken minimize linear layer:  0.0587615966796875
current error:  tensor(0.0824)
epoch:  512	argmax time taken,  0.19326353073120117
assembling the matrix time taken:  0.00020742416381835938
solving Ax = b time taken:  0.05806303024291992
		 time taken minimize linear layer:  0.058303117752075195
current error:  tensor(0.0822)
total duration:  121.47982430458069
neuron num 		 error 		 order
4		tensor(0.9622)		*
8		tensor(0.8602)		tensor(0.1617)
16		tensor(0.7299)		tensor(0.2369)
32		tensor(0.5455)		tensor(0.4202)
64		tensor(0.3641)		tensor(0.5832)
128		tensor(0.1968)		tensor(0.8876

In [None]:
k = 1, N0 = 2**12

k = 3, N0 = 2**12  

4		tensor(1.1036)		*
8		tensor(1.0284)		tensor(0.1018)
16		tensor(0.9599)		tensor(0.0995)
32		tensor(0.5711)		tensor(0.7491)
64		tensor(0.1325)		tensor(2.1082)
128		tensor(0.0591)		tensor(1.1641)
256		tensor(0.0439)		tensor(0.4284)
512		tensor(0.0335)		tensor(0.3913)

## Random dictionary 

In [2]:
if __name__ == "__main__": 

    def target(x): ## Gaussian function in dimension 10  
        d = 10 
        cn =   7.03/d 
        return torch.exp(-torch.sum( cn**2 * (x - 0.5)**2,dim = 1, keepdim = True)) 
    
    save = False 
    experiment_label = "ex1"
    for k in [1]: 
        s = 1 
        for N0 in [2**14]: 
            print()
            print() 
            exponent = 9    
            num_epochs=  2**exponent 
            M = 2**20 # around 50w 
            print(M)
            my_model = None 
            
            err, my_model = OGAL2FittingReLU4Dplus_QMC(my_model,target, \
                        s,N0,num_epochs, M, k = k, linear_solver = "direct", num_batches = 10)
            
            if save: 
                filename = experiment_label + "_err_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(err,filename)
                filename = experiment_label + "_model_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(k,s * N0,num_epochs)
                torch.save(my_model.state_dict(),filename) 

            print_convergence_order(err,exponent+1) 




1048576
generate sob sequence: 0.8055334091186523
using linear solver:  direct
epoch:  1	argmax time taken,  1.062718152999878
assembling the matrix time taken:  0.00039696693420410156
solving Ax = b time taken:  0.3515477180480957
		 time taken minimize linear layer:  0.3520951271057129
current error:  tensor(0.0892)
epoch:  2	argmax time taken,  0.6840980052947998
assembling the matrix time taken:  0.00042176246643066406
solving Ax = b time taken:  0.0007321834564208984
		 time taken minimize linear layer:  0.0011892318725585938
current error:  tensor(0.0777)
epoch:  3	argmax time taken,  0.6844673156738281
assembling the matrix time taken:  0.00029850006103515625
solving Ax = b time taken:  0.0008749961853027344
		 time taken minimize linear layer:  0.0012097358703613281
current error:  tensor(0.0773)
epoch:  4	argmax time taken,  0.6845643520355225
assembling the matrix time taken:  0.00032019615173339844
solving Ax = b time taken:  0.0008721351623535156
		 time taken minimize li

epoch:  34	argmax time taken,  0.6887755393981934
assembling the matrix time taken:  0.00034427642822265625
solving Ax = b time taken:  0.00276947021484375
		 time taken minimize linear layer:  0.0031473636627197266
current error:  tensor(0.0373)
epoch:  35	argmax time taken,  0.6884863376617432
assembling the matrix time taken:  0.0002486705780029297
solving Ax = b time taken:  0.002931833267211914
		 time taken minimize linear layer:  0.0032134056091308594
current error:  tensor(0.0366)
epoch:  36	argmax time taken,  0.6884329319000244
assembling the matrix time taken:  0.0003044605255126953
solving Ax = b time taken:  0.0028264522552490234
		 time taken minimize linear layer:  0.0031633377075195312
current error:  tensor(0.0354)
epoch:  37	argmax time taken,  0.6886386871337891
assembling the matrix time taken:  0.00023174285888671875
solving Ax = b time taken:  0.0030410289764404297
		 time taken minimize linear layer:  0.003304004669189453
current error:  tensor(0.0344)
epoch:  38

epoch:  68	argmax time taken,  0.6895732879638672
assembling the matrix time taken:  0.0002627372741699219
solving Ax = b time taken:  0.0057566165924072266
		 time taken minimize linear layer:  0.006050825119018555
current error:  tensor(0.0197)
epoch:  69	argmax time taken,  0.68971848487854
assembling the matrix time taken:  0.0002295970916748047
solving Ax = b time taken:  0.006205558776855469
		 time taken minimize linear layer:  0.006466865539550781
current error:  tensor(0.0194)
epoch:  70	argmax time taken,  0.6904699802398682
assembling the matrix time taken:  0.0002701282501220703
solving Ax = b time taken:  0.006038188934326172
		 time taken minimize linear layer:  0.00634002685546875
current error:  tensor(0.0191)
epoch:  71	argmax time taken,  0.6895003318786621
assembling the matrix time taken:  0.00023174285888671875
solving Ax = b time taken:  0.006457090377807617
		 time taken minimize linear layer:  0.0067195892333984375
current error:  tensor(0.0189)
epoch:  72	argma

epoch:  102	argmax time taken,  0.6937558650970459
assembling the matrix time taken:  0.00024890899658203125
solving Ax = b time taken:  0.008424758911132812
		 time taken minimize linear layer:  0.008704185485839844
current error:  tensor(0.0140)
epoch:  103	argmax time taken,  0.6938464641571045
assembling the matrix time taken:  0.00023126602172851562
solving Ax = b time taken:  0.008147716522216797
		 time taken minimize linear layer:  0.008410930633544922
current error:  tensor(0.0139)
epoch:  104	argmax time taken,  0.6902213096618652
assembling the matrix time taken:  0.00024628639221191406
solving Ax = b time taken:  0.008059501647949219
		 time taken minimize linear layer:  0.008336782455444336
current error:  tensor(0.0138)
epoch:  105	argmax time taken,  0.6944634914398193
assembling the matrix time taken:  0.00022840499877929688
solving Ax = b time taken:  0.008157730102539062
		 time taken minimize linear layer:  0.008417367935180664
current error:  tensor(0.0137)
epoch:  

epoch:  136	argmax time taken,  0.6923656463623047
assembling the matrix time taken:  0.00031113624572753906
solving Ax = b time taken:  0.012061595916748047
		 time taken minimize linear layer:  0.012409448623657227
current error:  tensor(0.0112)
epoch:  137	argmax time taken,  0.6921942234039307
assembling the matrix time taken:  0.00028514862060546875
solving Ax = b time taken:  0.012258291244506836
		 time taken minimize linear layer:  0.01258397102355957
current error:  tensor(0.0111)
epoch:  138	argmax time taken,  0.6926376819610596
assembling the matrix time taken:  0.00033974647521972656
solving Ax = b time taken:  0.012917757034301758
		 time taken minimize linear layer:  0.013293981552124023
current error:  tensor(0.0110)
epoch:  139	argmax time taken,  0.6951138973236084
assembling the matrix time taken:  0.00030541419982910156
solving Ax = b time taken:  0.01233673095703125
		 time taken minimize linear layer:  0.0126800537109375
current error:  tensor(0.0110)
epoch:  140	

epoch:  170	argmax time taken,  0.6960878372192383
assembling the matrix time taken:  0.00033736228942871094
solving Ax = b time taken:  0.014649391174316406
		 time taken minimize linear layer:  0.015022993087768555
current error:  tensor(0.0094)
epoch:  171	argmax time taken,  0.6955776214599609
assembling the matrix time taken:  0.0002536773681640625
solving Ax = b time taken:  0.013988494873046875
		 time taken minimize linear layer:  0.014275074005126953
current error:  tensor(0.0093)
epoch:  172	argmax time taken,  0.6959133148193359
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.014261007308959961
		 time taken minimize linear layer:  0.014556169509887695
current error:  tensor(0.0093)
epoch:  173	argmax time taken,  0.6957755088806152
assembling the matrix time taken:  0.00023055076599121094
solving Ax = b time taken:  0.014091253280639648
		 time taken minimize linear layer:  0.014353513717651367
current error:  tensor(0.0092)
epoch:  17

epoch:  204	argmax time taken,  0.6909492015838623
assembling the matrix time taken:  0.0003285408020019531
solving Ax = b time taken:  0.02018570899963379
		 time taken minimize linear layer:  0.020555973052978516
current error:  tensor(0.0081)
epoch:  205	argmax time taken,  0.6907587051391602
assembling the matrix time taken:  0.0003554821014404297
solving Ax = b time taken:  0.019855976104736328
		 time taken minimize linear layer:  0.02025437355041504
current error:  tensor(0.0080)
epoch:  206	argmax time taken,  0.6910419464111328
assembling the matrix time taken:  0.00033974647521972656
solving Ax = b time taken:  0.021225690841674805
		 time taken minimize linear layer:  0.0216062068939209
current error:  tensor(0.0080)
epoch:  207	argmax time taken,  0.690950870513916
assembling the matrix time taken:  0.00031375885009765625
solving Ax = b time taken:  0.01999807357788086
		 time taken minimize linear layer:  0.020351886749267578
current error:  tensor(0.0080)
epoch:  208	argm

epoch:  238	argmax time taken,  0.691765546798706
assembling the matrix time taken:  0.0003159046173095703
solving Ax = b time taken:  0.021488428115844727
		 time taken minimize linear layer:  0.021845102310180664
current error:  tensor(0.0071)
epoch:  239	argmax time taken,  0.6916708946228027
assembling the matrix time taken:  0.0003223419189453125
solving Ax = b time taken:  0.02156972885131836
		 time taken minimize linear layer:  0.021932601928710938
current error:  tensor(0.0071)
epoch:  240	argmax time taken,  0.6918814182281494
assembling the matrix time taken:  0.0003287792205810547
solving Ax = b time taken:  0.021457195281982422
		 time taken minimize linear layer:  0.021826982498168945
current error:  tensor(0.0071)
epoch:  241	argmax time taken,  0.691918134689331
assembling the matrix time taken:  0.0003504753112792969
solving Ax = b time taken:  0.02171635627746582
		 time taken minimize linear layer:  0.022110700607299805
current error:  tensor(0.0070)
epoch:  242	argm

epoch:  272	argmax time taken,  0.698575496673584
assembling the matrix time taken:  0.00026726722717285156
solving Ax = b time taken:  0.02593231201171875
		 time taken minimize linear layer:  0.026231765747070312
current error:  tensor(0.0064)
epoch:  273	argmax time taken,  0.7020196914672852
assembling the matrix time taken:  0.0002493858337402344
solving Ax = b time taken:  0.02730393409729004
		 time taken minimize linear layer:  0.02758502960205078
current error:  tensor(0.0064)
epoch:  274	argmax time taken,  0.6982755661010742
assembling the matrix time taken:  0.0002627372741699219
solving Ax = b time taken:  0.03230142593383789
		 time taken minimize linear layer:  0.03259539604187012
current error:  tensor(0.0064)
epoch:  275	argmax time taken,  0.702418327331543
assembling the matrix time taken:  0.0002467632293701172
solving Ax = b time taken:  0.02742457389831543
		 time taken minimize linear layer:  0.027702808380126953
current error:  tensor(0.0064)
epoch:  276	argmax 

epoch:  306	argmax time taken,  0.6995425224304199
assembling the matrix time taken:  0.00024437904357910156
solving Ax = b time taken:  0.028951406478881836
		 time taken minimize linear layer:  0.029227256774902344
current error:  tensor(0.0058)
epoch:  307	argmax time taken,  0.6993460655212402
assembling the matrix time taken:  0.00025010108947753906
solving Ax = b time taken:  0.029062747955322266
		 time taken minimize linear layer:  0.02934575080871582
current error:  tensor(0.0058)
epoch:  308	argmax time taken,  0.6995155811309814
assembling the matrix time taken:  0.00024509429931640625
solving Ax = b time taken:  0.02904343605041504
		 time taken minimize linear layer:  0.029320240020751953
current error:  tensor(0.0058)
epoch:  309	argmax time taken,  0.6993622779846191
assembling the matrix time taken:  0.00024437904357910156
solving Ax = b time taken:  0.029126644134521484
		 time taken minimize linear layer:  0.029402732849121094
current error:  tensor(0.0058)
epoch:  31

epoch:  340	argmax time taken,  0.6948950290679932
assembling the matrix time taken:  0.0002646446228027344
solving Ax = b time taken:  0.034265995025634766
		 time taken minimize linear layer:  0.03456282615661621
current error:  tensor(0.0053)
epoch:  341	argmax time taken,  0.6950094699859619
assembling the matrix time taken:  0.0003159046173095703
solving Ax = b time taken:  0.03871750831604004
		 time taken minimize linear layer:  0.03907418251037598
current error:  tensor(0.0053)
epoch:  342	argmax time taken,  0.695249080657959
assembling the matrix time taken:  0.0003490447998046875
solving Ax = b time taken:  0.04459643363952637
		 time taken minimize linear layer:  0.04498887062072754
current error:  tensor(0.0053)
epoch:  343	argmax time taken,  0.6952817440032959
assembling the matrix time taken:  0.0003685951232910156
solving Ax = b time taken:  0.03872799873352051
		 time taken minimize linear layer:  0.03914046287536621
current error:  tensor(0.0053)
epoch:  344	argmax t

epoch:  374	argmax time taken,  0.696277379989624
assembling the matrix time taken:  0.0003147125244140625
solving Ax = b time taken:  0.04036688804626465
		 time taken minimize linear layer:  0.04071807861328125
current error:  tensor(0.0050)
epoch:  375	argmax time taken,  0.6962518692016602
assembling the matrix time taken:  0.000335693359375
solving Ax = b time taken:  0.040474653244018555
		 time taken minimize linear layer:  0.04084944725036621
current error:  tensor(0.0050)
epoch:  376	argmax time taken,  0.6962780952453613
assembling the matrix time taken:  0.0003237724304199219
solving Ax = b time taken:  0.04033184051513672
		 time taken minimize linear layer:  0.04069805145263672
current error:  tensor(0.0050)
epoch:  377	argmax time taken,  0.6964418888092041
assembling the matrix time taken:  0.00030541419982910156
solving Ax = b time taken:  0.040513038635253906
		 time taken minimize linear layer:  0.04085588455200195
current error:  tensor(0.0050)
epoch:  378	argmax tim

epoch:  408	argmax time taken,  0.7024860382080078
assembling the matrix time taken:  0.0003819465637207031
solving Ax = b time taken:  0.04736781120300293
		 time taken minimize linear layer:  0.04779648780822754
current error:  tensor(0.0047)
epoch:  409	argmax time taken,  0.7024760246276855
assembling the matrix time taken:  0.0003905296325683594
solving Ax = b time taken:  0.04513287544250488
		 time taken minimize linear layer:  0.04557299613952637
current error:  tensor(0.0047)
epoch:  410	argmax time taken,  0.7026867866516113
assembling the matrix time taken:  0.000331878662109375
solving Ax = b time taken:  0.04520606994628906
		 time taken minimize linear layer:  0.045578956604003906
current error:  tensor(0.0046)
epoch:  411	argmax time taken,  0.702542781829834
assembling the matrix time taken:  0.00040078163146972656
solving Ax = b time taken:  0.045198678970336914
		 time taken minimize linear layer:  0.045644283294677734
current error:  tensor(0.0046)
epoch:  412	argmax

epoch:  442	argmax time taken,  0.7035996913909912
assembling the matrix time taken:  0.0003821849822998047
solving Ax = b time taken:  0.049219608306884766
		 time taken minimize linear layer:  0.04964900016784668
current error:  tensor(0.0044)
epoch:  443	argmax time taken,  0.7102220058441162
assembling the matrix time taken:  0.00032901763916015625
solving Ax = b time taken:  0.047466278076171875
		 time taken minimize linear layer:  0.04784202575683594
current error:  tensor(0.0044)
epoch:  444	argmax time taken,  0.7071793079376221
assembling the matrix time taken:  0.0003819465637207031
solving Ax = b time taken:  0.04926943778991699
		 time taken minimize linear layer:  0.04970073699951172
current error:  tensor(0.0044)
epoch:  445	argmax time taken,  0.7102422714233398
assembling the matrix time taken:  0.00035071372985839844
solving Ax = b time taken:  0.04750704765319824
		 time taken minimize linear layer:  0.04790472984313965
current error:  tensor(0.0044)
epoch:  446	argm

epoch:  476	argmax time taken,  0.7046592235565186
assembling the matrix time taken:  0.00024628639221191406
solving Ax = b time taken:  0.05572509765625
		 time taken minimize linear layer:  0.056003570556640625
current error:  tensor(0.0041)
epoch:  477	argmax time taken,  0.705568790435791
assembling the matrix time taken:  0.00024962425231933594
solving Ax = b time taken:  0.05590677261352539
		 time taken minimize linear layer:  0.05618715286254883
current error:  tensor(0.0041)
epoch:  478	argmax time taken,  0.7044093608856201
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.0558931827545166
		 time taken minimize linear layer:  0.056180715560913086
current error:  tensor(0.0041)
epoch:  479	argmax time taken,  0.7043704986572266
assembling the matrix time taken:  0.00025200843811035156
solving Ax = b time taken:  0.055971622467041016
		 time taken minimize linear layer:  0.05625557899475098
current error:  tensor(0.0041)
epoch:  480	argmax t

epoch:  510	argmax time taken,  0.7053430080413818
assembling the matrix time taken:  0.0002455711364746094
solving Ax = b time taken:  0.05845379829406738
		 time taken minimize linear layer:  0.058730125427246094
current error:  tensor(0.0039)
epoch:  511	argmax time taken,  0.7054958343505859
assembling the matrix time taken:  0.0002453327178955078
solving Ax = b time taken:  0.05852556228637695
		 time taken minimize linear layer:  0.05880308151245117
current error:  tensor(0.0039)
epoch:  512	argmax time taken,  0.712437629699707
assembling the matrix time taken:  0.00024390220642089844
solving Ax = b time taken:  0.05817365646362305
		 time taken minimize linear layer:  0.058449745178222656
current error:  tensor(0.0039)
total duration:  389.90699076652527
neuron num 		 error 		 order
4		tensor(0.0764)		*
8		tensor(0.0681)		tensor(0.1651)
16		tensor(0.0559)		tensor(0.2861)
32		tensor(0.0384)		tensor(0.5424)
64		tensor(0.0206)		tensor(0.8942)
128		tensor(0.0118)		tensor(0.8095)
25

In [10]:
def compute_rate(k,d):
    return 0.5 + (2*k +1)/(2*d)

for k in range(1,6): 
    print(k, compute_rate(k,10))

1 0.65
2 0.75
3 0.85
4 0.95
5 1.05


In [33]:
dim = 10 
points = torch.rand(12,dim).to(device)
approx_values = my_model(points)
print(approx_values)


print("ture values")
print(target(points))

tensor([[0.7069],
        [0.5646],
        [0.5421],
        [0.6234],
        [0.7319],
        [0.6844],
        [0.5897],
        [0.7394],
        [0.5824],
        [0.5579],
        [0.6385],
        [0.7090]], device='cuda:0', grad_fn=<MmBackward0>)
ture values
tensor([[0.7066],
        [0.5647],
        [0.5421],
        [0.6238],
        [0.7315],
        [0.6843],
        [0.5901],
        [0.7392],
        [0.5826],
        [0.5581],
        [0.6386],
        [0.7088]], device='cuda:0')


In [None]:
k = 1, N0 = 2**7 
4		tensor(0.0749)		*
8		tensor(0.0689)		tensor(0.1203)
16		tensor(0.0619)		tensor(0.1563)
32		tensor(0.0547)		tensor(0.1775)
64		tensor(0.0445)		tensor(0.2971)
128		tensor(0.0367)		tensor(0.2776)
256		tensor(0.0261)		tensor(0.4910)
512		tensor(0.0182)		tensor(0.5241) 

k = 1, N0 = 2**9, # test_points = M 
4		tensor(0.0750)		*
8		tensor(0.0701)		tensor(0.0981)
16		tensor(0.0599)		tensor(0.2269)
32		tensor(0.0500)		tensor(0.2615)
64		tensor(0.0402)		tensor(0.3141)
128		tensor(0.0279)		tensor(0.5245)
256		tensor(0.0168)		tensor(0.7364)
512		tensor(0.0100)		tensor(0.7507)

k = 1, N0 = 2**9, # test_points = 2 * M 
4		tensor(0.0749)		*
8		tensor(0.0682)		tensor(0.1355)
16		tensor(0.0567)		tensor(0.2666)
32		tensor(0.0491)		tensor(0.2071)
64		tensor(0.0397)		tensor(0.3076)
128		tensor(0.0258)		tensor(0.6189)
256		tensor(0.0153)		tensor(0.7537)
512		tensor(0.0094)		tensor(0.7103)


k = 4, N0 =2**10, test_points = 2*M 
4		tensor(0.1333)		*
8		tensor(0.0888)		tensor(0.5859)
16		tensor(0.0810)		tensor(0.1327)
32		tensor(0.0504)		tensor(0.6835)
64		tensor(0.0204)		tensor(1.3081)
128		tensor(0.0087)		tensor(1.2292)
256		tensor(0.0064)		tensor(0.4499)
512		tensor(0.0034)		tensor(0.9121)
1024		tensor(0.0004)		tensor(3.0958)