In [1]:
# In this version of OGA with random dictionaries, we use QMC to evaluate the loss function. 
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import sys
import os 
from scipy.sparse import linalg
from pathlib import Path
import itertools
if torch.cuda.is_available():  
    device = "cuda" 
else:  
    device = "cpu" 
pi = torch.tensor(np.pi,dtype=torch.float64)
torch.set_default_dtype(torch.float64)


class model(nn.Module):
    """ ReLU k shallow neural network
    Parameters: 
    input size: input dimension
    hidden_size1 : number of hidden layers 
    num_classes: output classes 
    k: degree of relu functions
    """
    def __init__(self, input_size, hidden_size1, num_classes,k = 1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, num_classes,bias = False)
        self.k = k 
    def forward(self, x):
        u1 = self.fc2(F.relu(self.fc1(x))**self.k)
        return u1
    

def show_convergence_order(err_l2,exponent,dict_size, filename,write2file = False):
    
    if write2file:
        file_mode = "a" if os.path.exists(filename) else "w"
        f_write = open(filename, file_mode)
    
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    if write2file:
        f_write.write('dictionary size: {}\n'.format(dict_size))
        f_write.write("neuron num \t\t error \t\t order \t\t h10 error \\ order \n")
    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            print("{} \t\t {:.6f} \t\t *  \n".format(neuron_nums[i],item ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t * \t\t \n".format(neuron_nums[i],item ))
        else: 
            print("{} \t\t {:.6f} \t\t {:.6f} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2) ) )
            if write2file: 
                f_write.write("{} \t\t {} \t\t {} \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2) ))
    if write2file:     
        f_write.write("\n")
        f_write.close()

def show_convergence_order_latex(err_l2,exponent,k=1,d=1): 
    neuron_nums = [2**j for j in range(2,exponent+1)]
    err_list = [err_l2[i] for i in neuron_nums ]
    l2_order = -1/2-(2*k + 1)/(2*d)
    print("neuron num  & \t $\\|u-u_n \\|_{{L^2}}$ & \t order $O(n^{{{:.2f}}})$  \\\\ \\hline \\hline ".format(l2_order))
    for i, item in enumerate(err_list):
        if i == 0: 
            print("{} \t\t & {:.6f} &\t\t *  \\\ \hline  \n".format(neuron_nums[i],item) )   
        else: 
            print("{} \t\t &  {:.3e} &  \t\t {:.2f} \\\ \hline  \n".format(neuron_nums[i],item,np.log(err_list[i-1]/err_list[i])/np.log(2) ) )


def print_convergence_order(err, neuron_num_exponent): 

    neuron_nums = [2**j for j in range(2,neuron_num_exponent)]
    err_list = [err[i] for i in neuron_nums ] 

    print("neuron num \t\t error \t\t order")
    for i, item in enumerate(err_list):
        if i == 0: 
            print(neuron_nums[i], end = "\t\t")
            print(item, end = "\t\t")
            print("*")
        else: 
            print(neuron_nums[i], end = "\t\t")
            print(item, end = "\t\t") 
            print(np.log(err_list[i-1]/err_list[i])/np.log(2))



def generate_relu_dict4D(N_list):
    
    N = np.prod(N_list) 

    grid_indices = [np.linspace(0,1,N_item,endpoint=False) for N_item in N_list]
    grid = np.meshgrid(*grid_indices,indexing='ij')
    grid_coordinates = np.column_stack([axis.ravel() for axis in grid]) 
    samples = torch.tensor(grid_coordinates) 

    T =torch.tensor([[pi,0,0,0],[0,pi,0,0],[0,0,2*pi,0],[0,0,0,2*2]]) # 2 * sqrt(d)
    shift = torch.tensor([0,0,0,-2])
    samples = samples@T + shift 

    f1 = torch.zeros(N,1) 
    f2 = torch.zeros(N,1)
    f3 = torch.zeros(N,1)
    f4 = torch.zeros(N,1)
    f5 = torch.zeros(N,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.cos(samples[:,2])
    f4[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.sin(samples[:,2])  
    f5[:,0] = samples[:,3]

    Wb_tensor = torch.cat([f1,f2,f3,f4,f5],1) # N x 4 
    return Wb_tensor.to(device)


def generate_relu_dict4D_QMC(s,N0):
    # Sob = torch.quasirandom.SobolEngine(dimension =4, scramble= True, seed=None) 
    # samples = Sob.draw(N0).double() 

    # for i in range(s-1):
    #     samples = torch.cat([samples,Sob.draw(N0).double()],0)

    # Monte Carlo 
    samples = torch.rand(s*N0,4) 

    T =torch.tensor([[pi,0,0,0],[0,pi,0,0],[0,0,2*pi,0],[0,0,0,2*2]])
    shift = torch.tensor([0,0,0,-2])
    samples = samples@T + shift 

    f1 = torch.zeros(s*N0,1) 
    f2 = torch.zeros(s*N0,1)
    f3 = torch.zeros(s*N0,1)
    f4 = torch.zeros(s*N0,1)
    f5 = torch.zeros(s*N0,1)

    f1[:,0] = torch.cos(samples[:,0]) 
    f2[:,0] = torch.sin(samples[:,0]) * torch.cos(samples[:,1])
    f3[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.cos(samples[:,2])
    f4[:,0] = torch.sin(samples[:,0]) * torch.sin(samples[:,1]) * torch.sin(samples[:,2])  
    f5[:,0] = samples[:,3]

    Wb_tensor = torch.cat([f1,f2,f3,f4,f5],1) # N x 4 
    return Wb_tensor.to(device)

def adjust_neuron_position(my_model, dims = 3):

    def create_mesh_grid(dims, pts):
        mesh = torch.tensor(list(itertools.product(pts,repeat=dims)))
        vertices = mesh.reshape(len(pts) ** dims, -1) 
        return vertices
    counter = 0 
    # positions = torch.tensor([[0.,0.],[0.,1.],[1.,1.],[1.,0.]])
    pts = torch.tensor([0.,1.])
    positions = create_mesh_grid(dims,pts) 
    neuron_num = my_model.fc1.bias.size(0)
    for i in range(neuron_num): 
        w = my_model.fc1.weight.data[i:i+1,:]
        b = my_model.fc1.bias.data[i]
    #     print(w,b)
        values = torch.matmul(positions,w.T) # + b
        left_end = - torch.max(values)
        right_end = - torch.min(values)
        offset = (right_end - left_end)/50
        if b <= left_end + offset/2 : 
            b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
            my_model.fc1.bias.data[i] = b 
        if b >= right_end - offset/2 :
            if counter < (dims+1):
#                 print("here")
                counter += 1
            else: # (d + 1) or more 
                b = torch.rand(1)*(right_end - left_end - offset) + left_end + offset/2 
                my_model.fc1.bias.data[i] = b 
    return my_model

def MonteCarlo_Sobol_dDim_weights_points(M ,d = 4):
    Sob_integral = torch.quasirandom.SobolEngine(dimension =d, scramble= False, seed=None) 
    integration_points = Sob_integral.draw(M).double() 
    integration_points = integration_points.to(device)
    weights = torch.ones(M,1).to(device)/M 
    return weights.to(device), integration_points.to(device) 


def minimize_linear_layer_explicit_assemble(model,target,weights, integration_points,solver="direct"):
    """assemble the linear system and solve it 
    """
    start_time = time.time() 
    w = model.fc1.weight.data 
    b = model.fc1.bias.data 
    basis_value_col = F.relu(integration_points @ w.t()+ b)**(model.k) 
    weighted_basis_value_col = basis_value_col * weights 
    jac = weighted_basis_value_col.t() @ basis_value_col 
     
    rhs = weighted_basis_value_col.t() @ (target(integration_points)) 
    print("assembling the matrix time taken: ", time.time()-start_time) 
    start_time = time.time()    
    if solver == "cg": 
        sol, exit_code = linalg.cg(np.array(jac.detach().cpu()),np.array(rhs.detach().cpu()),tol=1e-12)
        sol = torch.tensor(sol).view(1,-1)
    elif solver == "direct": 
#         sol = np.linalg.inv( np.array(jac.detach().cpu()) )@np.array(rhs.detach().cpu())
        sol = (torch.linalg.solve( jac.detach(), rhs.detach())).view(1,-1)
    elif solver == "ls":
        sol = (torch.linalg.lstsq(jac.detach().cpu(),rhs.detach().cpu(),driver='gelsd').solution).view(1,-1)
        # sol = (torch.linalg.lstsq(jac.detach(),rhs.detach()).solution).view(1,-1) # gpu/cpu, driver = 'gels', cannot solve singular
    print("solving Ax = b time taken: ", time.time()-start_time)
    return sol 

def OGAL2FittingReLU4D(my_model,target,N_list,num_epochs, M, k =1, linear_solver = "direct",num_batches = 1): 
    
    """ Orthogonal greedy algorithm (0,1)^4 
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 
    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """

    weights, integration_points = MonteCarlo_Sobol_dDim_weights_points(M ,d = 4) 

    err = torch.zeros(num_epochs+1)
    if my_model == None: 
        func_values = target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []

    else: 
        func_values = target(integration_points) - my_model(integration_points).detach()
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done

    func_values_sqrd = func_values*func_values

    err[0]= torch.mean(func_values_sqrd)**0.5
    all_start_time = time.time()
    
    solver = linear_solver
    print("using linear solver: ",solver)
    N = np.prod(N_list) 
    relu_dict_parameters = generate_relu_dict4D(N_list).t() 
#     num_batches = 2**3 #8, divide dictionary elements 
    batch_size = N//num_batches # batch_size * i : batch_size * (i+1), i = 0,..., num_batches - 1 
    
    for i in range(num_epochs): 
  
        print("epoch: ",i+1, end = '\t')
        if num_neuron == 0: 
            func_values = target(integration_points)
        else: 
            with torch.no_grad(): 
                func_values = target(integration_points) - my_model(integration_points)

        start_time = time.time() 

        output = torch.zeros(N,1)
        
        for j in range(0,N,batch_size): 

            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[0:4, j:end_index] ) - relu_dict_parameters[4, j:end_index])**k).T # uses broadcasting    
            output[j:end_index,0]  = (torch.abs(torch.matmul(basis_values_batch,func_values))/M)[:,0]
            
        neuron_index = torch.argmax(output.flatten())
        
        print("argmax time taken, ", time.time() - start_time)
        
        # print(neuron_index)
        list_w.append(relu_dict_parameters[0:4, neuron_index]) # 
        list_b.append(-relu_dict_parameters[4,neuron_index])
        num_neuron += 1
        my_model = model(4,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo 
        start_time = time.time() 
        sol = minimize_linear_layer_explicit_assemble(my_model,target,weights,integration_points, solver)
        print("\t\t time taken minimize linear layer: ",time.time() - start_time) 
        my_model.fc2.weight.data[0,:] = sol[:]

        func_values = target(integration_points) - my_model(integration_points).detach()
        func_values_sqrd = func_values*func_values

        #Todo Done 
        err[i+1]= torch.mean(func_values_sqrd)**0.5
        print("current error: ",err[i+1]) 
    print("time taken: ",time.time() - all_start_time)
    return err, my_model


def OGAL2FittingReLU4D_QMC(my_model,target,s,N0,num_epochs, M, k =1, linear_solver = "direct", num_batches = 1): 
    
    """ Orthogonal greedy algorithm using 1D ReLU dictionary over [-pi,pi]
    Parameters
    ----------
    my_model: 
        nn model 
    target: 
        target function
    num_epochs: int 
        number of training epochs 
    integration_intervals: int 
        number of subintervals for piecewise numerical quadrature 

    Returns
    -------
    err: tensor 
        rank 1 torch tensor to record the L2 error history  
    model: 
        trained nn model 
    """
    #Todo Done
    # samples for QMC integral
    start_time = time.time()
    # Sob_integral = torch.quasirandom.SobolEngine(dimension =4, scramble= False, seed=None) 
    # integration_points = Sob_integral.draw(M).double() 
    # integration_points = integration_points.to(device)
    weights, integration_points = MonteCarlo_Sobol_dDim_weights_points(M ,d = 4) 
    print("generate sob sequence:", time.time() - start_time) 

    err = torch.zeros(num_epochs+1)
    if my_model == None: 
        func_values = target(integration_points)
        num_neuron = 0

        list_b = []
        list_w = []

    else: 
        func_values = target(integration_points) - my_model(integration_points).detach()
        bias = my_model.fc1.bias.detach().data
        weights = my_model.fc1.weight.detach().data
        num_neuron = int(bias.size(0))

        list_b = list(bias)
        list_w = list(weights)
    
    # initial error Todo Done

    func_values_sqrd = func_values*func_values
    # print(func_values_sqrd.size())
    # print(gw_expand.size() ) 

    err[0]= torch.mean(func_values_sqrd)**0.5
    all_start_time = time.time()
    
    solver = linear_solver
    print("using linear solver: ",solver)
    for i in range(num_epochs): 
        relu_dict_parameters = generate_relu_dict4D_QMC(s,N0).t()   
        print("epoch: ",i+1, end = '\t')
        if num_neuron == 0: 
            func_values = target(integration_points)
        else: 
            func_values = target(integration_points) - my_model(integration_points).detach()

        start_time = time.time() 
        
        N = s*N0 
        output = torch.zeros(N,1)
        batch_size = N//num_batches
        
        for j in range(0,N,batch_size): 

            end_index = j + batch_size  
            basis_values_batch = (F.relu( torch.matmul(integration_points,relu_dict_parameters[0:4, j:end_index] ) - relu_dict_parameters[4, j:end_index])**k).T # uses broadcasting    
            output[j:end_index,0]  = (torch.abs(torch.matmul(basis_values_batch,func_values))/M)[:,0]
            
        neuron_index = torch.argmax(output.flatten())
        
#         basis_values = (F.relu( torch.matmul(integration_points,relu_dict_parameters[:,0:4].T ) - relu_dict_parameters[:,4])**k).T # uses broadcasting
#         output = torch.abs(torch.matmul(basis_values,func_values))/M # 
#         neuron_index = torch.argmax(output.flatten())
        print("argmax time taken, ", time.time() - start_time)
        
        # print(neuron_index)
#         list_w.append(relu_dict_parameters[neuron_index,0:4]) # 
#         list_b.append(-relu_dict_parameters[neuron_index,4])
        list_w.append(relu_dict_parameters[0:4, neuron_index]) # 
        list_b.append(-relu_dict_parameters[4,neuron_index])
        num_neuron += 1
        my_model = model(4,num_neuron,1,k).to(device)
        w_tensor = torch.stack(list_w, 0 ) 
        b_tensor = torch.tensor(list_b)
        my_model.fc1.weight.data[:,:] = w_tensor[:,:]
        my_model.fc1.bias.data[:] = b_tensor[:]

        #Todo 
        start_time = time.time() 
        sol = minimize_linear_layer_explicit_assemble(my_model,target,weights,integration_points, solver)
        print("\t\t time taken minimize linear layer: ",time.time() - start_time) 
        my_model.fc2.weight.data[0,:] = sol[:]

        func_values = target(integration_points) - my_model(integration_points).detach()
        func_values_sqrd = func_values*func_values

        #Todo Done 
        err[i+1]= torch.mean(func_values_sqrd)**0.5
        print("current error: ",err[i+1]) 
    print("total duration: ",time.time() - all_start_time)
    return err, my_model


## Deterministic dictionary 

In [4]:

def target(x):
    return torch.sin(pi*x[:,0:1])*torch.sin(pi*x[:,1:2])*torch.sin(pi*x[:,2:3]) *torch.sin(pi*x[:,3:4])

dim = 4 
function_name = "sin-product-4d" 
filename_write = "data/2DOGA-{}-order.txt".format(function_name)
M = 2**19 # MC points around 50w 
f_write = open(filename_write, "a")
f_write.write("Integration points: Quasi Monte Carlo:  {}\n".format(M))
f_write.close() 
save = True 
write2file = True

for relu_k in [4]: 
    for N_list in [[2**4, 2**4, 2**3, 2**3]]: 
        N = np.prod(N_list) 
        print()
        print() 
        exponent = 9  
        num_epochs=  2**exponent 
    
        my_model = None 
        err, my_model = OGAL2FittingReLU4D(my_model,target,N_list,num_epochs, M, k = relu_k, linear_solver = "ls", num_batches= 2**3)
        
        if save: 
            folder = 'data/'
            filename = folder + function_name + "_err_deterministic_Dict_relu_{}_size_{}_num_neurons_{}.pt".format(relu_k,N,num_epochs)
            torch.save(err,filename)
            filename = folder + function_name + "_model_deterministic_Dict_relu_{}_size_{}_num_neurons_{}.pt".format(relu_k,N,num_epochs)
            torch.save(my_model.state_dict(),filename) 
            
        show_convergence_order(err,exponent,N,filename_write,write2file = write2file)
        show_convergence_order_latex(err,exponent,k=relu_k,d=dim)






using linear solver:  ls
epoch:  1	argmax time taken,  0.43138813972473145
assembling the matrix time taken:  0.0002498626708984375
solving Ax = b time taken:  0.00015544891357421875
		 time taken minimize linear layer:  0.0004353523254394531
current error:  tensor(0.2000)
epoch:  2	argmax time taken,  0.41163158416748047
assembling the matrix time taken:  0.0002880096435546875
solving Ax = b time taken:  0.00026154518127441406
		 time taken minimize linear layer:  0.0005805492401123047
current error:  tensor(0.1988)
epoch:  3	argmax time taken,  0.4120492935180664
assembling the matrix time taken:  0.0002505779266357422
solving Ax = b time taken:  0.00029468536376953125
		 time taken minimize linear layer:  0.0005764961242675781
current error:  tensor(0.1954)
epoch:  4	argmax time taken,  0.41201114654541016
assembling the matrix time taken:  0.0002665519714355469
solving Ax = b time taken:  0.0003178119659423828
		 time taken minimize linear layer:  0.0006160736083984375
current er

epoch:  34	argmax time taken,  0.41401195526123047
assembling the matrix time taken:  0.00027680397033691406
solving Ax = b time taken:  0.0015137195587158203
		 time taken minimize linear layer:  0.0018229484558105469
current error:  tensor(0.0972)
epoch:  35	argmax time taken,  0.4142005443572998
assembling the matrix time taken:  0.0002543926239013672
solving Ax = b time taken:  0.001592397689819336
		 time taken minimize linear layer:  0.0018787384033203125
current error:  tensor(0.0925)
epoch:  36	argmax time taken,  0.4140441417694092
assembling the matrix time taken:  0.00026226043701171875
solving Ax = b time taken:  0.0015718936920166016
		 time taken minimize linear layer:  0.0018663406372070312
current error:  tensor(0.0885)
epoch:  37	argmax time taken,  0.4140913486480713
assembling the matrix time taken:  0.0002529621124267578
solving Ax = b time taken:  0.001638650894165039
		 time taken minimize linear layer:  0.0019235610961914062
current error:  tensor(0.0849)
epoch: 

epoch:  68	argmax time taken,  0.4148569107055664
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.003436565399169922
		 time taken minimize linear layer:  0.0037326812744140625
current error:  tensor(0.0377)
epoch:  69	argmax time taken,  0.41465282440185547
assembling the matrix time taken:  0.0002589225769042969
solving Ax = b time taken:  0.0035605430603027344
		 time taken minimize linear layer:  0.003851652145385742
current error:  tensor(0.0374)
epoch:  70	argmax time taken,  0.4147167205810547
assembling the matrix time taken:  0.00026226043701171875
solving Ax = b time taken:  0.003690958023071289
		 time taken minimize linear layer:  0.003985404968261719
current error:  tensor(0.0370)
epoch:  71	argmax time taken,  0.41517210006713867
assembling the matrix time taken:  0.0002498626708984375
solving Ax = b time taken:  0.0036847591400146484
		 time taken minimize linear layer:  0.003967761993408203
current error:  tensor(0.0369)
epoch:  72

epoch:  102	argmax time taken,  0.41730237007141113
assembling the matrix time taken:  0.00026488304138183594
solving Ax = b time taken:  0.005480766296386719
		 time taken minimize linear layer:  0.0057790279388427734
current error:  tensor(0.0167)
epoch:  103	argmax time taken,  0.4156060218811035
assembling the matrix time taken:  0.0002613067626953125
solving Ax = b time taken:  0.005346775054931641
		 time taken minimize linear layer:  0.005640745162963867
current error:  tensor(0.0165)
epoch:  104	argmax time taken,  0.4154651165008545
assembling the matrix time taken:  0.00026345252990722656
solving Ax = b time taken:  0.005247592926025391
		 time taken minimize linear layer:  0.005543708801269531
current error:  tensor(0.0163)
epoch:  105	argmax time taken,  0.41558170318603516
assembling the matrix time taken:  0.0002510547637939453
solving Ax = b time taken:  0.005369901657104492
		 time taken minimize linear layer:  0.0056536197662353516
current error:  tensor(0.0161)
epoch:

epoch:  135	argmax time taken,  0.41641688346862793
assembling the matrix time taken:  0.00025582313537597656
solving Ax = b time taken:  0.00821685791015625
		 time taken minimize linear layer:  0.008506536483764648
current error:  tensor(0.0105)
epoch:  136	argmax time taken,  0.4164714813232422
assembling the matrix time taken:  0.00026297569274902344
solving Ax = b time taken:  0.008295774459838867
		 time taken minimize linear layer:  0.008591175079345703
current error:  tensor(0.0105)
epoch:  137	argmax time taken,  0.416400671005249
assembling the matrix time taken:  0.0002484321594238281
solving Ax = b time taken:  0.008372068405151367
		 time taken minimize linear layer:  0.008653879165649414
current error:  tensor(0.0103)
epoch:  138	argmax time taken,  0.4167788028717041
assembling the matrix time taken:  0.0002589225769042969
solving Ax = b time taken:  0.008634805679321289
		 time taken minimize linear layer:  0.008926153182983398
current error:  tensor(0.0103)
epoch:  139

epoch:  169	argmax time taken,  0.41843414306640625
assembling the matrix time taken:  0.00025463104248046875
solving Ax = b time taken:  0.010219335556030273
		 time taken minimize linear layer:  0.01050710678100586
current error:  tensor(0.0078)
epoch:  170	argmax time taken,  0.418506383895874
assembling the matrix time taken:  0.00024700164794921875
solving Ax = b time taken:  0.010154247283935547
		 time taken minimize linear layer:  0.010434627532958984
current error:  tensor(0.0078)
epoch:  171	argmax time taken,  0.41843271255493164
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.010329723358154297
		 time taken minimize linear layer:  0.01061558723449707
current error:  tensor(0.0076)
epoch:  172	argmax time taken,  0.4186134338378906
assembling the matrix time taken:  0.00024700164794921875
solving Ax = b time taken:  0.010155677795410156
		 time taken minimize linear layer:  0.010436058044433594
current error:  tensor(0.0075)
epoch:  1

epoch:  203	argmax time taken,  0.4189717769622803
assembling the matrix time taken:  0.0002472400665283203
solving Ax = b time taken:  0.016204357147216797
		 time taken minimize linear layer:  0.016484498977661133
current error:  tensor(0.0064)
epoch:  204	argmax time taken,  0.41902995109558105
assembling the matrix time taken:  0.0002658367156982422
solving Ax = b time taken:  0.013472318649291992
		 time taken minimize linear layer:  0.013770818710327148
current error:  tensor(0.0063)
epoch:  205	argmax time taken,  0.41901111602783203
assembling the matrix time taken:  0.0002467632293701172
solving Ax = b time taken:  0.01633143424987793
		 time taken minimize linear layer:  0.01661062240600586
current error:  tensor(0.0063)
epoch:  206	argmax time taken,  0.4190068244934082
assembling the matrix time taken:  0.00025534629821777344
solving Ax = b time taken:  0.013707399368286133
		 time taken minimize linear layer:  0.013996124267578125
current error:  tensor(0.0063)
epoch:  207

epoch:  237	argmax time taken,  0.4196019172668457
assembling the matrix time taken:  0.00024771690368652344
solving Ax = b time taken:  0.01693439483642578
		 time taken minimize linear layer:  0.01721501350402832
current error:  tensor(0.0056)
epoch:  238	argmax time taken,  0.41968822479248047
assembling the matrix time taken:  0.000244140625
solving Ax = b time taken:  0.0168459415435791
		 time taken minimize linear layer:  0.01712322235107422
current error:  tensor(0.0056)
epoch:  239	argmax time taken,  0.41960930824279785
assembling the matrix time taken:  0.00024819374084472656
solving Ax = b time taken:  0.017090797424316406
		 time taken minimize linear layer:  0.01737213134765625
current error:  tensor(0.0056)
epoch:  240	argmax time taken,  0.4197089672088623
assembling the matrix time taken:  0.00024437904357910156
solving Ax = b time taken:  0.016715526580810547
		 time taken minimize linear layer:  0.016992807388305664
current error:  tensor(0.0056)
epoch:  241	argmax t

epoch:  271	argmax time taken,  0.42024707794189453
assembling the matrix time taken:  0.00024271011352539062
solving Ax = b time taken:  0.0252230167388916
		 time taken minimize linear layer:  0.02549886703491211
current error:  tensor(0.0053)
epoch:  272	argmax time taken,  0.4202232360839844
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.0202634334564209
		 time taken minimize linear layer:  0.020560026168823242
current error:  tensor(0.0053)
epoch:  273	argmax time taken,  0.4204127788543701
assembling the matrix time taken:  0.000247955322265625
solving Ax = b time taken:  0.025467872619628906
		 time taken minimize linear layer:  0.02574920654296875
current error:  tensor(0.0059)
epoch:  274	argmax time taken,  0.4202735424041748
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.02083587646484375
		 time taken minimize linear layer:  0.02112412452697754
current error:  tensor(0.0057)
epoch:  275	argmax t

epoch:  305	argmax time taken,  0.42093420028686523
assembling the matrix time taken:  0.00025153160095214844
solving Ax = b time taken:  0.023869991302490234
		 time taken minimize linear layer:  0.024155139923095703
current error:  tensor(0.0054)
epoch:  306	argmax time taken,  0.4209597110748291
assembling the matrix time taken:  0.00024318695068359375
solving Ax = b time taken:  0.023818492889404297
		 time taken minimize linear layer:  0.02409505844116211
current error:  tensor(0.0053)
epoch:  307	argmax time taken,  0.4210953712463379
assembling the matrix time taken:  0.00024580955505371094
solving Ax = b time taken:  0.024198532104492188
		 time taken minimize linear layer:  0.024477005004882812
current error:  tensor(0.0055)
epoch:  308	argmax time taken,  0.4210076332092285
assembling the matrix time taken:  0.0002472400665283203
solving Ax = b time taken:  0.023749589920043945
		 time taken minimize linear layer:  0.024030208587646484
current error:  tensor(0.0053)
epoch:  3

epoch:  339	argmax time taken,  0.4215412139892578
assembling the matrix time taken:  0.0002498626708984375
solving Ax = b time taken:  0.03057384490966797
		 time taken minimize linear layer:  0.030857324600219727
current error:  tensor(0.0053)
epoch:  340	argmax time taken,  0.42157983779907227
assembling the matrix time taken:  0.0002579689025878906
solving Ax = b time taken:  0.02802562713623047
		 time taken minimize linear layer:  0.028316736221313477
current error:  tensor(0.0054)
epoch:  341	argmax time taken,  0.4216039180755615
assembling the matrix time taken:  0.0002574920654296875
solving Ax = b time taken:  0.030946731567382812
		 time taken minimize linear layer:  0.03123760223388672
current error:  tensor(0.0054)
epoch:  342	argmax time taken,  0.4217503070831299
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.028989791870117188
		 time taken minimize linear layer:  0.0292818546295166
current error:  tensor(0.0054)
epoch:  343	argm

epoch:  373	argmax time taken,  0.42214107513427734
assembling the matrix time taken:  0.000247955322265625
solving Ax = b time taken:  0.03377079963684082
		 time taken minimize linear layer:  0.034050941467285156
current error:  tensor(0.0050)
epoch:  374	argmax time taken,  0.42214345932006836
assembling the matrix time taken:  0.0002536773681640625
solving Ax = b time taken:  0.03343605995178223
		 time taken minimize linear layer:  0.0337224006652832
current error:  tensor(0.0050)
epoch:  375	argmax time taken,  0.42226696014404297
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.033879995346069336
		 time taken minimize linear layer:  0.03416752815246582
current error:  tensor(0.0049)
epoch:  376	argmax time taken,  0.42233967781066895
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.03339505195617676
		 time taken minimize linear layer:  0.033685922622680664
current error:  tensor(0.0050)
epoch:  377	ar

epoch:  407	argmax time taken,  0.4227898120880127
assembling the matrix time taken:  0.0002491474151611328
solving Ax = b time taken:  0.03841757774353027
		 time taken minimize linear layer:  0.038701772689819336
current error:  tensor(0.0053)
epoch:  408	argmax time taken,  0.42273950576782227
assembling the matrix time taken:  0.00025582313537597656
solving Ax = b time taken:  0.03903770446777344
		 time taken minimize linear layer:  0.03932666778564453
current error:  tensor(0.0051)
epoch:  409	argmax time taken,  0.4228515625
assembling the matrix time taken:  0.0002465248107910156
solving Ax = b time taken:  0.03870964050292969
		 time taken minimize linear layer:  0.03898882865905762
current error:  tensor(0.0053)
epoch:  410	argmax time taken,  0.4228842258453369
assembling the matrix time taken:  0.00025081634521484375
solving Ax = b time taken:  0.03845334053039551
		 time taken minimize linear layer:  0.03873729705810547
current error:  tensor(0.0050)
epoch:  411	argmax tim

epoch:  441	argmax time taken,  0.42343664169311523
assembling the matrix time taken:  0.0002465248107910156
solving Ax = b time taken:  0.042064666748046875
		 time taken minimize linear layer:  0.04234504699707031
current error:  tensor(0.0054)
epoch:  442	argmax time taken,  0.4235506057739258
assembling the matrix time taken:  0.00025010108947753906
solving Ax = b time taken:  0.04268074035644531
		 time taken minimize linear layer:  0.04296422004699707
current error:  tensor(0.0047)
epoch:  443	argmax time taken,  0.4234731197357178
assembling the matrix time taken:  0.00025177001953125
solving Ax = b time taken:  0.04217791557312012
		 time taken minimize linear layer:  0.042463064193725586
current error:  tensor(0.0047)
epoch:  444	argmax time taken,  0.4235708713531494
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.04253435134887695
		 time taken minimize linear layer:  0.04282331466674805
current error:  tensor(0.0048)
epoch:  445	argmax

epoch:  475	argmax time taken,  0.4276094436645508
assembling the matrix time taken:  0.0002498626708984375
solving Ax = b time taken:  0.049265384674072266
		 time taken minimize linear layer:  0.04954886436462402
current error:  tensor(0.0048)
epoch:  476	argmax time taken,  0.42748188972473145
assembling the matrix time taken:  0.00026035308837890625
solving Ax = b time taken:  0.04866170883178711
		 time taken minimize linear layer:  0.04895496368408203
current error:  tensor(0.0049)
epoch:  477	argmax time taken,  0.4258441925048828
assembling the matrix time taken:  0.00024890899658203125
solving Ax = b time taken:  0.04941272735595703
		 time taken minimize linear layer:  0.04969477653503418
current error:  tensor(0.0053)
epoch:  478	argmax time taken,  0.4274611473083496
assembling the matrix time taken:  0.0002598762512207031
solving Ax = b time taken:  0.04929208755493164
		 time taken minimize linear layer:  0.04958486557006836
current error:  tensor(0.0048)
epoch:  479	argm

epoch:  509	argmax time taken,  0.4246494770050049
assembling the matrix time taken:  0.00024819374084472656
solving Ax = b time taken:  0.05387544631958008
		 time taken minimize linear layer:  0.05415606498718262
current error:  tensor(0.0047)
epoch:  510	argmax time taken,  0.4280703067779541
assembling the matrix time taken:  0.000247955322265625
solving Ax = b time taken:  0.05375933647155762
		 time taken minimize linear layer:  0.054039955139160156
current error:  tensor(0.0055)
epoch:  511	argmax time taken,  0.4248237609863281
assembling the matrix time taken:  0.00024628639221191406
solving Ax = b time taken:  0.05480837821960449
		 time taken minimize linear layer:  0.05508732795715332
current error:  tensor(0.0047)
epoch:  512	argmax time taken,  0.42475390434265137
assembling the matrix time taken:  0.00025010108947753906
solving Ax = b time taken:  0.06059622764587402
		 time taken minimize linear layer:  0.06087994575500488
current error:  tensor(0.0048)
time taken:  232

## Random dictionary 

In [6]:


def target(x):
    return torch.sin(pi*x[:,0:1])*torch.sin(pi*x[:,1:2])*torch.sin(pi*x[:,2:3]) *torch.sin(pi*x[:,3:4])

dim = 4 
function_name = "sin-product-4d" 
filename_write = "data/2DQMCOGA-{}-order.txt".format(function_name)
M = 2**19 # MC points around 50w 
f_write = open(filename_write, "a")
f_write.write("Integration points: Quasi Monte Carlo:  {}\n".format(M))
f_write.close() 
save = True 
write2file = True

for relu_k in [1]: 
    s = 1 
    for N0 in [2**6, 2**7, 2**8]: 
        N = np.prod(N_list) 
        print()
        print() 
        exponent = 9  
        num_epochs=  2**exponent 
    
        my_model = None 
        err, my_model = OGAL2FittingReLU4D_QMC(my_model,target, \
                    s,N0,num_epochs, M, k = relu_k, linear_solver = "direct", num_batches = 1)

        if save: 
            folder = 'data/'
            filename = folder + function_name + "_err_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(relu_k,s * N0,num_epochs)
            torch.save(err,filename)
            filename = folder + function_name + "_model_randDict_relu_{}_size_{}_num_neurons_{}.pt".format(relu_k,s * N0,num_epochs)
            torch.save(my_model.state_dict(),filename) 
            
        show_convergence_order(err,exponent,N,filename_write,write2file = write2file)
        show_convergence_order_latex(err,exponent,k=relu_k,d=dim)




generate sob sequence: 0.005876302719116211
using linear solver:  direct
epoch:  1	argmax time taken,  0.0014488697052001953
assembling the matrix time taken:  0.00025653839111328125
solving Ax = b time taken:  3.698288679122925
		 time taken minimize linear layer:  3.698791742324829
current error:  tensor(0.1894)
epoch:  2	argmax time taken,  0.001430511474609375
assembling the matrix time taken:  0.0002942085266113281
solving Ax = b time taken:  0.0002727508544921875
		 time taken minimize linear layer:  0.0005979537963867188
current error:  tensor(0.1856)
epoch:  3	argmax time taken,  0.0015816688537597656
assembling the matrix time taken:  0.00024771690368652344
solving Ax = b time taken:  0.0003185272216796875
		 time taken minimize linear layer:  0.0005962848663330078
current error:  tensor(0.1795)
epoch:  4	argmax time taken,  0.0015726089477539062
assembling the matrix time taken:  0.0002644062042236328
solving Ax = b time taken:  0.0003197193145751953
		 time taken minimize 

assembling the matrix time taken:  0.0006504058837890625
solving Ax = b time taken:  0.001077413558959961
		 time taken minimize linear layer:  0.0017828941345214844
current error:  tensor(0.0630)
epoch:  39	argmax time taken,  0.004014730453491211
assembling the matrix time taken:  0.0002467632293701172
solving Ax = b time taken:  0.00154876708984375
		 time taken minimize linear layer:  0.0018262863159179688
current error:  tensor(0.0621)
epoch:  40	argmax time taken,  0.0040130615234375
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.0015077590942382812
		 time taken minimize linear layer:  0.001796722412109375
current error:  tensor(0.0615)
epoch:  41	argmax time taken,  0.004094600677490234
assembling the matrix time taken:  0.00025725364685058594
solving Ax = b time taken:  0.0014429092407226562
		 time taken minimize linear layer:  0.0017309188842773438
current error:  tensor(0.0608)
epoch:  42	argmax time taken,  0.0036172866821289062
asse

epoch:  76	argmax time taken,  0.004743814468383789
assembling the matrix time taken:  0.00026726722717285156
solving Ax = b time taken:  0.003073453903198242
		 time taken minimize linear layer:  0.0033712387084960938
current error:  tensor(0.0394)
epoch:  77	argmax time taken,  0.004072666168212891
assembling the matrix time taken:  0.00024700164794921875
solving Ax = b time taken:  0.003312349319458008
		 time taken minimize linear layer:  0.0035898685455322266
current error:  tensor(0.0391)
epoch:  78	argmax time taken,  0.00501251220703125
assembling the matrix time taken:  0.0002701282501220703
solving Ax = b time taken:  0.0034525394439697266
		 time taken minimize linear layer:  0.0037529468536376953
current error:  tensor(0.0388)
epoch:  79	argmax time taken,  0.004117250442504883
assembling the matrix time taken:  0.00024700164794921875
solving Ax = b time taken:  0.0033762454986572266
		 time taken minimize linear layer:  0.0036542415618896484
current error:  tensor(0.0386)


epoch:  114	argmax time taken,  0.004586696624755859
assembling the matrix time taken:  0.00025272369384765625
solving Ax = b time taken:  0.0042078495025634766
		 time taken minimize linear layer:  0.004491567611694336
current error:  tensor(0.0299)
epoch:  115	argmax time taken,  0.0063059329986572266
assembling the matrix time taken:  0.00025177001953125
solving Ax = b time taken:  0.0042514801025390625
		 time taken minimize linear layer:  0.004533529281616211
current error:  tensor(0.0296)
epoch:  116	argmax time taken,  0.0046045780181884766
assembling the matrix time taken:  0.00025200843811035156
solving Ax = b time taken:  0.0042324066162109375
		 time taken minimize linear layer:  0.004515171051025391
current error:  tensor(0.0295)
epoch:  117	argmax time taken,  0.004606008529663086
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.004292964935302734
		 time taken minimize linear layer:  0.0045778751373291016
current error:  tensor(0.029

current error:  tensor(0.0241)
epoch:  147	argmax time taken,  0.006034374237060547
assembling the matrix time taken:  0.0002532005310058594
solving Ax = b time taken:  0.007651567459106445
		 time taken minimize linear layer:  0.007935523986816406
current error:  tensor(0.0239)
epoch:  148	argmax time taken,  0.006742238998413086
assembling the matrix time taken:  0.00025725364685058594
solving Ax = b time taken:  0.0073888301849365234
		 time taken minimize linear layer:  0.007677316665649414
current error:  tensor(0.0237)
epoch:  149	argmax time taken,  0.006804227828979492
assembling the matrix time taken:  0.00026535987854003906
solving Ax = b time taken:  0.0075190067291259766
		 time taken minimize linear layer:  0.007815361022949219
current error:  tensor(0.0234)
epoch:  150	argmax time taken,  0.006744384765625
assembling the matrix time taken:  0.0002651214599609375
solving Ax = b time taken:  0.0074465274810791016
		 time taken minimize linear layer:  0.007741689682006836
cu

current error:  tensor(0.0196)
epoch:  182	argmax time taken,  0.0073089599609375
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.007234334945678711
		 time taken minimize linear layer:  0.0075190067291259766
current error:  tensor(0.0195)
epoch:  183	argmax time taken,  0.007236957550048828
assembling the matrix time taken:  0.0002512931823730469
solving Ax = b time taken:  0.007270336151123047
		 time taken minimize linear layer:  0.007552623748779297
current error:  tensor(0.0194)
epoch:  184	argmax time taken,  0.007219076156616211
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.007215023040771484
		 time taken minimize linear layer:  0.00749969482421875
current error:  tensor(0.0194)
epoch:  185	argmax time taken,  0.007218599319458008
assembling the matrix time taken:  0.00025582313537597656
solving Ax = b time taken:  0.0072863101959228516
		 time taken minimize linear layer:  0.007572650909423828
cu

solving Ax = b time taken:  0.012078285217285156
		 time taken minimize linear layer:  0.012444257736206055
current error:  tensor(0.0164)
epoch:  220	argmax time taken,  0.007666826248168945
assembling the matrix time taken:  0.00026679039001464844
solving Ax = b time taken:  0.009650945663452148
		 time taken minimize linear layer:  0.00994873046875
current error:  tensor(0.0163)
epoch:  221	argmax time taken,  0.007681846618652344
assembling the matrix time taken:  0.0002532005310058594
solving Ax = b time taken:  0.012095928192138672
		 time taken minimize linear layer:  0.012379169464111328
current error:  tensor(0.0163)
epoch:  222	argmax time taken,  0.007694244384765625
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.009969711303710938
		 time taken minimize linear layer:  0.010264396667480469
current error:  tensor(0.0162)
epoch:  223	argmax time taken,  0.007780313491821289
assembling the matrix time taken:  0.0002532005310058594
solving

assembling the matrix time taken:  0.0002598762512207031
solving Ax = b time taken:  0.011172771453857422
		 time taken minimize linear layer:  0.011480569839477539
current error:  tensor(0.0145)
epoch:  255	argmax time taken,  0.008295297622680664
assembling the matrix time taken:  0.0002551078796386719
solving Ax = b time taken:  0.011217832565307617
		 time taken minimize linear layer:  0.011504173278808594
current error:  tensor(0.0144)
epoch:  256	argmax time taken,  0.008321523666381836
assembling the matrix time taken:  0.0002551078796386719
solving Ax = b time taken:  0.01121664047241211
		 time taken minimize linear layer:  0.011502504348754883
current error:  tensor(0.0144)
epoch:  257	argmax time taken,  0.00818014144897461
assembling the matrix time taken:  0.00025844573974609375
solving Ax = b time taken:  0.013185501098632812
		 time taken minimize linear layer:  0.013481855392456055
current error:  tensor(0.0143)
epoch:  258	argmax time taken,  0.008241415023803711
assem

solving Ax = b time taken:  0.014405250549316406
		 time taken minimize linear layer:  0.014812707901000977
current error:  tensor(0.0127)
epoch:  290	argmax time taken,  0.008573055267333984
assembling the matrix time taken:  0.0002655982971191406
solving Ax = b time taken:  0.014330387115478516
		 time taken minimize linear layer:  0.014627218246459961
current error:  tensor(0.0127)
epoch:  291	argmax time taken,  0.008580923080444336
assembling the matrix time taken:  0.00025582313537597656
solving Ax = b time taken:  0.014273405075073242
		 time taken minimize linear layer:  0.014560222625732422
current error:  tensor(0.0126)
epoch:  292	argmax time taken,  0.008588075637817383
assembling the matrix time taken:  0.0002779960632324219
solving Ax = b time taken:  0.014192581176757812
		 time taken minimize linear layer:  0.014501810073852539
current error:  tensor(0.0126)
epoch:  293	argmax time taken,  0.008618593215942383
assembling the matrix time taken:  0.00025534629821777344
so

epoch:  323	argmax time taken,  0.009029626846313477
assembling the matrix time taken:  0.0002586841583251953
solving Ax = b time taken:  0.019189834594726562
		 time taken minimize linear layer:  0.01947927474975586
current error:  tensor(0.0118)
epoch:  324	argmax time taken,  0.009022712707519531
assembling the matrix time taken:  0.00026488304138183594
solving Ax = b time taken:  0.01712775230407715
		 time taken minimize linear layer:  0.017423629760742188
current error:  tensor(0.0118)
epoch:  325	argmax time taken,  0.00909280776977539
assembling the matrix time taken:  0.0002574920654296875
solving Ax = b time taken:  0.019243240356445312
		 time taken minimize linear layer:  0.01953291893005371
current error:  tensor(0.0117)
epoch:  326	argmax time taken,  0.009060859680175781
assembling the matrix time taken:  0.0002980232238769531
solving Ax = b time taken:  0.017249584197998047
		 time taken minimize linear layer:  0.017579078674316406
current error:  tensor(0.0117)
epoch: 

epoch:  357	argmax time taken,  0.012867450714111328
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.020040273666381836
		 time taken minimize linear layer:  0.02032613754272461
current error:  tensor(0.0108)
epoch:  358	argmax time taken,  0.009535074234008789
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.020053386688232422
		 time taken minimize linear layer:  0.020342111587524414
current error:  tensor(0.0108)
epoch:  359	argmax time taken,  0.01282644271850586
assembling the matrix time taken:  0.0002567768096923828
solving Ax = b time taken:  0.020076274871826172
		 time taken minimize linear layer:  0.020363807678222656
current error:  tensor(0.0108)
epoch:  360	argmax time taken,  0.00951385498046875
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.02004528045654297
		 time taken minimize linear layer:  0.020331382751464844
current error:  tensor(0.0108)
epoch:  

epoch:  394	argmax time taken,  0.009947538375854492
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.02368617057800293
		 time taken minimize linear layer:  0.023973703384399414
current error:  tensor(0.0099)
epoch:  395	argmax time taken,  0.009980916976928711
assembling the matrix time taken:  0.0002567768096923828
solving Ax = b time taken:  0.02247929573059082
		 time taken minimize linear layer:  0.022766828536987305
current error:  tensor(0.0099)
epoch:  396	argmax time taken,  0.009979486465454102
assembling the matrix time taken:  0.0002551078796386719
solving Ax = b time taken:  0.02361750602722168
		 time taken minimize linear layer:  0.023903369903564453
current error:  tensor(0.0099)
epoch:  397	argmax time taken,  0.010519981384277344
assembling the matrix time taken:  0.0002644062042236328
solving Ax = b time taken:  0.02245306968688965
		 time taken minimize linear layer:  0.022748231887817383
current error:  tensor(0.0099)
epoch:  

assembling the matrix time taken:  0.00025653839111328125
solving Ax = b time taken:  0.024611711502075195
		 time taken minimize linear layer:  0.024917125701904297
current error:  tensor(0.0093)
epoch:  429	argmax time taken,  0.010413169860839844
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.023747920989990234
		 time taken minimize linear layer:  0.024034738540649414
current error:  tensor(0.0092)
epoch:  430	argmax time taken,  0.010394573211669922
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.025122404098510742
		 time taken minimize linear layer:  0.02540898323059082
current error:  tensor(0.0092)
epoch:  431	argmax time taken,  0.012203216552734375
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.023847103118896484
		 time taken minimize linear layer:  0.02413463592529297
current error:  tensor(0.0092)
epoch:  432	argmax time taken,  0.010439872741699219
assembl

current error:  tensor(0.0087)
epoch:  463	argmax time taken,  0.010896444320678711
assembling the matrix time taken:  0.00025773048400878906
solving Ax = b time taken:  0.028021812438964844
		 time taken minimize linear layer:  0.028310775756835938
current error:  tensor(0.0087)
epoch:  464	argmax time taken,  0.010909318923950195
assembling the matrix time taken:  0.00026726722717285156
solving Ax = b time taken:  0.02705669403076172
		 time taken minimize linear layer:  0.0273587703704834
current error:  tensor(0.0087)
epoch:  465	argmax time taken,  0.01087498664855957
assembling the matrix time taken:  0.00025725364685058594
solving Ax = b time taken:  0.028183698654174805
		 time taken minimize linear layer:  0.0284726619720459
current error:  tensor(0.0086)
epoch:  466	argmax time taken,  0.01096796989440918
assembling the matrix time taken:  0.00026488304138183594
solving Ax = b time taken:  0.027744054794311523
		 time taken minimize linear layer:  0.028125524520874023
current

solving Ax = b time taken:  0.029001951217651367
		 time taken minimize linear layer:  0.029373645782470703
current error:  tensor(0.0081)
epoch:  499	argmax time taken,  0.012996196746826172
assembling the matrix time taken:  0.0002613067626953125
solving Ax = b time taken:  0.029024124145507812
		 time taken minimize linear layer:  0.029316425323486328
current error:  tensor(0.0081)
epoch:  500	argmax time taken,  0.011411190032958984
assembling the matrix time taken:  0.0002567768096923828
solving Ax = b time taken:  0.029015064239501953
		 time taken minimize linear layer:  0.029303550720214844
current error:  tensor(0.0081)
epoch:  501	argmax time taken,  0.011430740356445312
assembling the matrix time taken:  0.00025582313537597656
solving Ax = b time taken:  0.029023170471191406
		 time taken minimize linear layer:  0.029309988021850586
current error:  tensor(0.0080)
epoch:  502	argmax time taken,  0.011475086212158203
assembling the matrix time taken:  0.00027179718017578125
so

epoch:  31	argmax time taken,  0.003099203109741211
assembling the matrix time taken:  0.0002524852752685547
solving Ax = b time taken:  0.0010879039764404297
		 time taken minimize linear layer:  0.0013709068298339844
current error:  tensor(0.0711)
epoch:  32	argmax time taken,  0.00315093994140625
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.0010900497436523438
		 time taken minimize linear layer:  0.0013742446899414062
current error:  tensor(0.0702)
epoch:  33	argmax time taken,  0.0030798912048339844
assembling the matrix time taken:  0.0002505779266357422
solving Ax = b time taken:  0.0012316703796386719
		 time taken minimize linear layer:  0.0015172958374023438
current error:  tensor(0.0697)
epoch:  34	argmax time taken,  0.004734992980957031
assembling the matrix time taken:  0.0002703666687011719
solving Ax = b time taken:  0.0012409687042236328
		 time taken minimize linear layer:  0.0015416145324707031
current error:  tensor(0.0682)

epoch:  68	argmax time taken,  0.005231142044067383
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.0027005672454833984
		 time taken minimize linear layer:  0.0029952526092529297
current error:  tensor(0.0418)
epoch:  69	argmax time taken,  0.005244016647338867
assembling the matrix time taken:  0.0002465248107910156
solving Ax = b time taken:  0.0028295516967773438
		 time taken minimize linear layer:  0.003106355667114258
current error:  tensor(0.0416)
epoch:  70	argmax time taken,  0.005249977111816406
assembling the matrix time taken:  0.0002589225769042969
solving Ax = b time taken:  0.0029006004333496094
		 time taken minimize linear layer:  0.0031900405883789062
current error:  tensor(0.0413)
epoch:  71	argmax time taken,  0.00564122200012207
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.0028541088104248047
		 time taken minimize linear layer:  0.0031385421752929688
current error:  tensor(0.0404)
ep

solving Ax = b time taken:  0.004198551177978516
		 time taken minimize linear layer:  0.004605770111083984
current error:  tensor(0.0281)
epoch:  107	argmax time taken,  0.005788564682006836
assembling the matrix time taken:  0.0002593994140625
solving Ax = b time taken:  0.0039958953857421875
		 time taken minimize linear layer:  0.0042858123779296875
current error:  tensor(0.0277)
epoch:  108	argmax time taken,  0.007490396499633789
assembling the matrix time taken:  0.00026488304138183594
solving Ax = b time taken:  0.004027605056762695
		 time taken minimize linear layer:  0.004323005676269531
current error:  tensor(0.0275)
epoch:  109	argmax time taken,  0.005796670913696289
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.004047393798828125
		 time taken minimize linear layer:  0.004335641860961914
current error:  tensor(0.0273)
epoch:  110	argmax time taken,  0.007513999938964844
assembling the matrix time taken:  0.0002644062042236328
solv

epoch:  148	argmax time taken,  0.007308483123779297
assembling the matrix time taken:  0.00025773048400878906
solving Ax = b time taken:  0.007380485534667969
		 time taken minimize linear layer:  0.007670164108276367
current error:  tensor(0.0209)
epoch:  149	argmax time taken,  0.00904083251953125
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.007604122161865234
		 time taken minimize linear layer:  0.007891178131103516
current error:  tensor(0.0208)
epoch:  150	argmax time taken,  0.0075740814208984375
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.007603883743286133
		 time taken minimize linear layer:  0.007889270782470703
current error:  tensor(0.0205)
epoch:  151	argmax time taken,  0.009326696395874023
assembling the matrix time taken:  0.0002503395080566406
solving Ax = b time taken:  0.007692813873291016
		 time taken minimize linear layer:  0.007974386215209961
current error:  tensor(0.0204)
ep

current error:  tensor(0.0173)
epoch:  182	argmax time taken,  0.00846719741821289
assembling the matrix time taken:  0.0002524852752685547
solving Ax = b time taken:  0.007234811782836914
		 time taken minimize linear layer:  0.007523775100708008
current error:  tensor(0.0173)
epoch:  183	argmax time taken,  0.008478403091430664
assembling the matrix time taken:  0.00025963783264160156
solving Ax = b time taken:  0.007261037826538086
		 time taken minimize linear layer:  0.0075511932373046875
current error:  tensor(0.0172)
epoch:  184	argmax time taken,  0.00845646858215332
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.007228851318359375
		 time taken minimize linear layer:  0.007513999938964844
current error:  tensor(0.0171)
epoch:  185	argmax time taken,  0.008476495742797852
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.0073010921478271484
		 time taken minimize linear layer:  0.007586240768432617
c

epoch:  218	argmax time taken,  0.008987188339233398
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.01002049446105957
		 time taken minimize linear layer:  0.010315418243408203
current error:  tensor(0.0149)
epoch:  219	argmax time taken,  0.009029626846313477
assembling the matrix time taken:  0.00025272369384765625
solving Ax = b time taken:  0.011832475662231445
		 time taken minimize linear layer:  0.01211690902709961
current error:  tensor(0.0149)
epoch:  220	argmax time taken,  0.00901937484741211
assembling the matrix time taken:  0.00027298927307128906
solving Ax = b time taken:  0.009598016738891602
		 time taken minimize linear layer:  0.009902238845825195
current error:  tensor(0.0149)
epoch:  221	argmax time taken,  0.008989095687866211
assembling the matrix time taken:  0.0002601146697998047
solving Ax = b time taken:  0.012079000473022461
		 time taken minimize linear layer:  0.012370109558105469
current error:  tensor(0.0148)
epoch

epoch:  257	argmax time taken,  0.00942230224609375
assembling the matrix time taken:  0.0002570152282714844
solving Ax = b time taken:  0.013172626495361328
		 time taken minimize linear layer:  0.013461112976074219
current error:  tensor(0.0131)
epoch:  258	argmax time taken,  0.009504318237304688
assembling the matrix time taken:  0.0002675056457519531
solving Ax = b time taken:  0.012743949890136719
		 time taken minimize linear layer:  0.013042926788330078
current error:  tensor(0.0131)
epoch:  259	argmax time taken,  0.009440898895263672
assembling the matrix time taken:  0.0002548694610595703
solving Ax = b time taken:  0.01324915885925293
		 time taken minimize linear layer:  0.013535499572753906
current error:  tensor(0.0131)
epoch:  260	argmax time taken,  0.009514808654785156
assembling the matrix time taken:  0.00026226043701171875
solving Ax = b time taken:  0.012707948684692383
		 time taken minimize linear layer:  0.01300191879272461
current error:  tensor(0.0130)
epoch:

epoch:  291	argmax time taken,  0.009894847869873047
assembling the matrix time taken:  0.00025534629821777344
solving Ax = b time taken:  0.01426076889038086
		 time taken minimize linear layer:  0.01454782485961914
current error:  tensor(0.0119)
epoch:  292	argmax time taken,  0.009898662567138672
assembling the matrix time taken:  0.0002655982971191406
solving Ax = b time taken:  0.014215707778930664
		 time taken minimize linear layer:  0.014511585235595703
current error:  tensor(0.0119)
epoch:  293	argmax time taken,  0.00989389419555664
assembling the matrix time taken:  0.00026154518127441406
solving Ax = b time taken:  0.01433253288269043
		 time taken minimize linear layer:  0.014624357223510742
current error:  tensor(0.0118)
epoch:  294	argmax time taken,  0.009957075119018555
assembling the matrix time taken:  0.0002636909484863281
solving Ax = b time taken:  0.014458179473876953
		 time taken minimize linear layer:  0.014752864837646484
current error:  tensor(0.0118)
epoch:

epoch:  324	argmax time taken,  0.01038670539855957
assembling the matrix time taken:  0.0002620220184326172
solving Ax = b time taken:  0.017115354537963867
		 time taken minimize linear layer:  0.017409086227416992
current error:  tensor(0.0108)
epoch:  325	argmax time taken,  0.01034092903137207
assembling the matrix time taken:  0.0002493858337402344
solving Ax = b time taken:  0.019259214401245117
		 time taken minimize linear layer:  0.019541025161743164
current error:  tensor(0.0108)
epoch:  326	argmax time taken,  0.010398149490356445
assembling the matrix time taken:  0.0002675056457519531
solving Ax = b time taken:  0.017299175262451172
		 time taken minimize linear layer:  0.01759791374206543
current error:  tensor(0.0108)
epoch:  327	argmax time taken,  0.010406970977783203
assembling the matrix time taken:  0.0002446174621582031
solving Ax = b time taken:  0.01929783821105957
		 time taken minimize linear layer:  0.019574880599975586
current error:  tensor(0.0108)
epoch:  

assembling the matrix time taken:  0.0002562999725341797
solving Ax = b time taken:  0.020043611526489258
		 time taken minimize linear layer:  0.020346641540527344
current error:  tensor(0.0099)
epoch:  359	argmax time taken,  0.010806083679199219
assembling the matrix time taken:  0.00025844573974609375
solving Ax = b time taken:  0.020060062408447266
		 time taken minimize linear layer:  0.020349502563476562
current error:  tensor(0.0099)
epoch:  360	argmax time taken,  0.010839223861694336
assembling the matrix time taken:  0.00025200843811035156
solving Ax = b time taken:  0.02002692222595215
		 time taken minimize linear layer:  0.020310401916503906
current error:  tensor(0.0098)
epoch:  361	argmax time taken,  0.010857582092285156
assembling the matrix time taken:  0.0002532005310058594
solving Ax = b time taken:  0.020148754119873047
		 time taken minimize linear layer:  0.020433425903320312
current error:  tensor(0.0098)
epoch:  362	argmax time taken,  0.01090383529663086
asse

epoch:  394	argmax time taken,  0.01177835464477539
assembling the matrix time taken:  0.0002620220184326172
solving Ax = b time taken:  0.023691892623901367
		 time taken minimize linear layer:  0.02398538589477539
current error:  tensor(0.0091)
epoch:  395	argmax time taken,  0.011263132095336914
assembling the matrix time taken:  0.00025391578674316406
solving Ax = b time taken:  0.022469758987426758
		 time taken minimize linear layer:  0.022756099700927734
current error:  tensor(0.0090)
epoch:  396	argmax time taken,  0.011836051940917969
assembling the matrix time taken:  0.0002620220184326172
solving Ax = b time taken:  0.023687124252319336
		 time taken minimize linear layer:  0.023980379104614258
current error:  tensor(0.0090)
epoch:  397	argmax time taken,  0.011799812316894531
assembling the matrix time taken:  0.0002562999725341797
solving Ax = b time taken:  0.0225064754486084
		 time taken minimize linear layer:  0.022794246673583984
current error:  tensor(0.0090)
epoch: 

assembling the matrix time taken:  0.00030803680419921875
solving Ax = b time taken:  0.025464534759521484
		 time taken minimize linear layer:  0.02582240104675293
current error:  tensor(0.0084)
epoch:  431	argmax time taken,  0.015230178833007812
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.023838043212890625
		 time taken minimize linear layer:  0.02412700653076172
current error:  tensor(0.0084)
epoch:  432	argmax time taken,  0.013379096984863281
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.024616241455078125
		 time taken minimize linear layer:  0.02490401268005371
current error:  tensor(0.0084)
epoch:  433	argmax time taken,  0.01516270637512207
assembling the matrix time taken:  0.0002582073211669922
solving Ax = b time taken:  0.023961305618286133
		 time taken minimize linear layer:  0.024251461029052734
current error:  tensor(0.0084)
epoch:  434	argmax time taken,  0.013424158096313477
assembl

solving Ax = b time taken:  0.028259754180908203
		 time taken minimize linear layer:  0.02865910530090332
current error:  tensor(0.0078)
epoch:  464	argmax time taken,  0.012157678604125977
assembling the matrix time taken:  0.00027060508728027344
solving Ax = b time taken:  0.02709507942199707
		 time taken minimize linear layer:  0.02739715576171875
current error:  tensor(0.0078)
epoch:  465	argmax time taken,  0.01214289665222168
assembling the matrix time taken:  0.0002560615539550781
solving Ax = b time taken:  0.02815079689025879
		 time taken minimize linear layer:  0.028436899185180664
current error:  tensor(0.0078)
epoch:  466	argmax time taken,  0.012140989303588867
assembling the matrix time taken:  0.0002696514129638672
solving Ax = b time taken:  0.0277559757232666
		 time taken minimize linear layer:  0.028056859970092773
current error:  tensor(0.0078)
epoch:  467	argmax time taken,  0.012271404266357422
assembling the matrix time taken:  0.0002529621124267578
solving Ax

assembling the matrix time taken:  0.0002570152282714844
solving Ax = b time taken:  0.02900242805480957
		 time taken minimize linear layer:  0.02930760383605957
current error:  tensor(0.0073)
epoch:  501	argmax time taken,  0.012640237808227539
assembling the matrix time taken:  0.0002605915069580078
solving Ax = b time taken:  0.02911376953125
		 time taken minimize linear layer:  0.02940535545349121
current error:  tensor(0.0073)
epoch:  502	argmax time taken,  0.012629508972167969
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.029065847396850586
		 time taken minimize linear layer:  0.02935647964477539
current error:  tensor(0.0073)
epoch:  503	argmax time taken,  0.012675762176513672
assembling the matrix time taken:  0.0002536773681640625
solving Ax = b time taken:  0.02909088134765625
		 time taken minimize linear layer:  0.0294950008392334
current error:  tensor(0.0073)
epoch:  504	argmax time taken,  0.012611865997314453
assembling the

epoch:  32	argmax time taken,  0.005687236785888672
assembling the matrix time taken:  0.0002574920654296875
solving Ax = b time taken:  0.0010979175567626953
		 time taken minimize linear layer:  0.0013854503631591797
current error:  tensor(0.0681)
epoch:  33	argmax time taken,  0.0056149959564208984
assembling the matrix time taken:  0.00025463104248046875
solving Ax = b time taken:  0.0012331008911132812
		 time taken minimize linear layer:  0.0015175342559814453
current error:  tensor(0.0668)
epoch:  34	argmax time taken,  0.007259368896484375
assembling the matrix time taken:  0.0002722740173339844
solving Ax = b time taken:  0.001233816146850586
		 time taken minimize linear layer:  0.0015358924865722656
current error:  tensor(0.0657)
epoch:  35	argmax time taken,  0.007279157638549805
assembling the matrix time taken:  0.0002486705780029297
solving Ax = b time taken:  0.001308441162109375
		 time taken minimize linear layer:  0.0015871524810791016
current error:  tensor(0.0649)


epoch:  75	argmax time taken,  0.007957935333251953
assembling the matrix time taken:  0.0002474784851074219
solving Ax = b time taken:  0.0031843185424804688
		 time taken minimize linear layer:  0.0034623146057128906
current error:  tensor(0.0384)
epoch:  76	argmax time taken,  0.008455991744995117
assembling the matrix time taken:  0.000263214111328125
solving Ax = b time taken:  0.0031342506408691406
		 time taken minimize linear layer:  0.003427267074584961
current error:  tensor(0.0378)
epoch:  77	argmax time taken,  0.008478164672851562
assembling the matrix time taken:  0.0002512931823730469
solving Ax = b time taken:  0.003307819366455078
		 time taken minimize linear layer:  0.0035877227783203125
current error:  tensor(0.0376)
epoch:  78	argmax time taken,  0.00873112678527832
assembling the matrix time taken:  0.0002608299255371094
solving Ax = b time taken:  0.0034515857696533203
		 time taken minimize linear layer:  0.0037424564361572266
current error:  tensor(0.0372)
epoc

current error:  tensor(0.0267)
epoch:  117	argmax time taken,  0.010132074356079102
assembling the matrix time taken:  0.00024962425231933594
solving Ax = b time taken:  0.004302263259887695
		 time taken minimize linear layer:  0.004582405090332031
current error:  tensor(0.0266)
epoch:  118	argmax time taken,  0.010234594345092773
assembling the matrix time taken:  0.0002484321594238281
solving Ax = b time taken:  0.004339933395385742
		 time taken minimize linear layer:  0.004618167877197266
current error:  tensor(0.0263)
epoch:  119	argmax time taken,  0.010154485702514648
assembling the matrix time taken:  0.0002510547637939453
solving Ax = b time taken:  0.0043582916259765625
		 time taken minimize linear layer:  0.004639863967895508
current error:  tensor(0.0260)
epoch:  120	argmax time taken,  0.010191202163696289
assembling the matrix time taken:  0.00025010108947753906
solving Ax = b time taken:  0.004294872283935547
		 time taken minimize linear layer:  0.004575014114379883
c

assembling the matrix time taken:  0.0005533695220947266
solving Ax = b time taken:  0.0071680545806884766
		 time taken minimize linear layer:  0.007768392562866211
current error:  tensor(0.0206)
epoch:  151	argmax time taken,  0.010592937469482422
assembling the matrix time taken:  0.0002503395080566406
solving Ax = b time taken:  0.0076045989990234375
		 time taken minimize linear layer:  0.007891178131103516
current error:  tensor(0.0205)
epoch:  152	argmax time taken,  0.01012563705444336
assembling the matrix time taken:  0.0002627372741699219
solving Ax = b time taken:  0.00721287727355957
		 time taken minimize linear layer:  0.007506370544433594
current error:  tensor(0.0204)
epoch:  153	argmax time taken,  0.010099172592163086
assembling the matrix time taken:  0.0002505779266357422
solving Ax = b time taken:  0.006535053253173828
		 time taken minimize linear layer:  0.006817340850830078
current error:  tensor(0.0203)
epoch:  154	argmax time taken,  0.010350465774536133
asse

epoch:  188	argmax time taken,  0.011026382446289062
assembling the matrix time taken:  0.00025272369384765625
solving Ax = b time taken:  0.007338762283325195
		 time taken minimize linear layer:  0.0076220035552978516
current error:  tensor(0.0172)
epoch:  189	argmax time taken,  0.011049509048461914
assembling the matrix time taken:  0.0002474784851074219
solving Ax = b time taken:  0.007389545440673828
		 time taken minimize linear layer:  0.007668495178222656
current error:  tensor(0.0172)
epoch:  190	argmax time taken,  0.011102676391601562
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.007413148880004883
		 time taken minimize linear layer:  0.007697105407714844
current error:  tensor(0.0170)
epoch:  191	argmax time taken,  0.01112818717956543
assembling the matrix time taken:  0.00025773048400878906
solving Ax = b time taken:  0.007452964782714844
		 time taken minimize linear layer:  0.0077419281005859375
current error:  tensor(0.0170)


solving Ax = b time taken:  0.010291099548339844
		 time taken minimize linear layer:  0.010709524154663086
current error:  tensor(0.0145)
epoch:  227	argmax time taken,  0.011586666107177734
assembling the matrix time taken:  0.0002524852752685547
solving Ax = b time taken:  0.010544776916503906
		 time taken minimize linear layer:  0.010828495025634766
current error:  tensor(0.0145)
epoch:  228	argmax time taken,  0.011646032333374023
assembling the matrix time taken:  0.0002732276916503906
solving Ax = b time taken:  0.010180473327636719
		 time taken minimize linear layer:  0.010483980178833008
current error:  tensor(0.0143)
epoch:  229	argmax time taken,  0.011603832244873047
assembling the matrix time taken:  0.0002510547637939453
solving Ax = b time taken:  0.010600566864013672
		 time taken minimize linear layer:  0.010881900787353516
current error:  tensor(0.0143)
epoch:  230	argmax time taken,  0.011632204055786133
assembling the matrix time taken:  0.00026154518127441406
sol

current error:  tensor(0.0127)
epoch:  263	argmax time taken,  0.012056350708007812
assembling the matrix time taken:  0.0002548694610595703
solving Ax = b time taken:  0.013396739959716797
		 time taken minimize linear layer:  0.013683080673217773
current error:  tensor(0.0127)
epoch:  264	argmax time taken,  0.011996030807495117
assembling the matrix time taken:  0.0002624988555908203
solving Ax = b time taken:  0.012809276580810547
		 time taken minimize linear layer:  0.01310276985168457
current error:  tensor(0.0127)
epoch:  265	argmax time taken,  0.012061595916748047
assembling the matrix time taken:  0.0002484321594238281
solving Ax = b time taken:  0.01349496841430664
		 time taken minimize linear layer:  0.013774633407592773
current error:  tensor(0.0126)
epoch:  266	argmax time taken,  0.012020587921142578
assembling the matrix time taken:  0.00026416778564453125
solving Ax = b time taken:  0.012932300567626953
		 time taken minimize linear layer:  0.013226985931396484
curre

current error:  tensor(0.0113)
epoch:  300	argmax time taken,  0.012512683868408203
assembling the matrix time taken:  0.0002510547637939453
solving Ax = b time taken:  0.014449357986450195
		 time taken minimize linear layer:  0.014731645584106445
current error:  tensor(0.0113)
epoch:  301	argmax time taken,  0.012606620788574219
assembling the matrix time taken:  0.00025916099548339844
solving Ax = b time taken:  0.01454305648803711
		 time taken minimize linear layer:  0.01483297348022461
current error:  tensor(0.0112)
epoch:  302	argmax time taken,  0.01259160041809082
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.014545202255249023
		 time taken minimize linear layer:  0.014830350875854492
current error:  tensor(0.0112)
epoch:  303	argmax time taken,  0.012598037719726562
assembling the matrix time taken:  0.00025916099548339844
solving Ax = b time taken:  0.014582633972167969
		 time taken minimize linear layer:  0.014873027801513672
curre

current error:  tensor(0.0102)
epoch:  334	argmax time taken,  0.013006925582885742
assembling the matrix time taken:  0.0002651214599609375
solving Ax = b time taken:  0.017666339874267578
		 time taken minimize linear layer:  0.017961978912353516
current error:  tensor(0.0102)
epoch:  335	argmax time taken,  0.012985944747924805
assembling the matrix time taken:  0.0002543926239013672
solving Ax = b time taken:  0.019516468048095703
		 time taken minimize linear layer:  0.019802093505859375
current error:  tensor(0.0102)
epoch:  336	argmax time taken,  0.013024330139160156
assembling the matrix time taken:  0.0002703666687011719
solving Ax = b time taken:  0.017322063446044922
		 time taken minimize linear layer:  0.017624378204345703
current error:  tensor(0.0101)
epoch:  337	argmax time taken,  0.01295614242553711
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.01948857307434082
		 time taken minimize linear layer:  0.019774436950683594
current

solving Ax = b time taken:  0.020258426666259766
		 time taken minimize linear layer:  0.020667076110839844
current error:  tensor(0.0092)
epoch:  369	argmax time taken,  0.013406753540039062
assembling the matrix time taken:  0.0002624988555908203
solving Ax = b time taken:  0.02039194107055664
		 time taken minimize linear layer:  0.020684480667114258
current error:  tensor(0.0092)
epoch:  370	argmax time taken,  0.016858577728271484
assembling the matrix time taken:  0.0002529621124267578
solving Ax = b time taken:  0.020409107208251953
		 time taken minimize linear layer:  0.02069234848022461
current error:  tensor(0.0092)
epoch:  371	argmax time taken,  0.013418912887573242
assembling the matrix time taken:  0.000255584716796875
solving Ax = b time taken:  0.02042412757873535
		 time taken minimize linear layer:  0.020710468292236328
current error:  tensor(0.0091)
epoch:  372	argmax time taken,  0.016835689544677734
assembling the matrix time taken:  0.0002532005310058594
solving 

solving Ax = b time taken:  0.02262091636657715
		 time taken minimize linear layer:  0.023025989532470703
current error:  tensor(0.0085)
epoch:  402	argmax time taken,  0.015608549118041992
assembling the matrix time taken:  0.0002529621124267578
solving Ax = b time taken:  0.02390909194946289
		 time taken minimize linear layer:  0.02419257164001465
current error:  tensor(0.0085)
epoch:  403	argmax time taken,  0.013895273208618164
assembling the matrix time taken:  0.00025534629821777344
solving Ax = b time taken:  0.022704601287841797
		 time taken minimize linear layer:  0.022991418838500977
current error:  tensor(0.0084)
epoch:  404	argmax time taken,  0.015668392181396484
assembling the matrix time taken:  0.0002532005310058594
solving Ax = b time taken:  0.02387261390686035
		 time taken minimize linear layer:  0.024156570434570312
current error:  tensor(0.0084)
epoch:  405	argmax time taken,  0.013913393020629883
assembling the matrix time taken:  0.00025773048400878906
solvin

solving Ax = b time taken:  0.024007320404052734
		 time taken minimize linear layer:  0.024418354034423828
current error:  tensor(0.0078)
epoch:  438	argmax time taken,  0.01611614227294922
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.024944782257080078
		 time taken minimize linear layer:  0.02523016929626465
current error:  tensor(0.0078)
epoch:  439	argmax time taken,  0.014362573623657227
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.02411031723022461
		 time taken minimize linear layer:  0.024395227432250977
current error:  tensor(0.0078)
epoch:  440	argmax time taken,  0.016148805618286133
assembling the matrix time taken:  0.0002541542053222656
solving Ax = b time taken:  0.024921178817749023
		 time taken minimize linear layer:  0.025206565856933594
current error:  tensor(0.0078)
epoch:  441	argmax time taken,  0.01435708999633789
assembling the matrix time taken:  0.0003209114074707031
solving 

epoch:  473	argmax time taken,  0.016895055770874023
assembling the matrix time taken:  0.00025963783264160156
solving Ax = b time taken:  0.028345823287963867
		 time taken minimize linear layer:  0.028636455535888672
current error:  tensor(0.0072)
epoch:  474	argmax time taken,  0.014875650405883789
assembling the matrix time taken:  0.0002741813659667969
solving Ax = b time taken:  0.028536319732666016
		 time taken minimize linear layer:  0.02884197235107422
current error:  tensor(0.0072)
epoch:  475	argmax time taken,  0.016951799392700195
assembling the matrix time taken:  0.0002532005310058594
solving Ax = b time taken:  0.028467655181884766
		 time taken minimize linear layer:  0.02875232696533203
current error:  tensor(0.0072)
epoch:  476	argmax time taken,  0.014928579330444336
assembling the matrix time taken:  0.0002620220184326172
solving Ax = b time taken:  0.028558015823364258
		 time taken minimize linear layer:  0.028974533081054688
current error:  tensor(0.0072)
epoch

epoch:  507	argmax time taken,  0.01734137535095215
assembling the matrix time taken:  0.00025653839111328125
solving Ax = b time taken:  0.029596805572509766
		 time taken minimize linear layer:  0.02988457679748535
current error:  tensor(0.0068)
epoch:  508	argmax time taken,  0.01533198356628418
assembling the matrix time taken:  0.00025343894958496094
solving Ax = b time taken:  0.029537439346313477
		 time taken minimize linear layer:  0.029822349548339844
current error:  tensor(0.0068)
epoch:  509	argmax time taken,  0.017416715621948242
assembling the matrix time taken:  0.00025200843811035156
solving Ax = b time taken:  0.029595136642456055
		 time taken minimize linear layer:  0.029877424240112305
current error:  tensor(0.0067)
epoch:  510	argmax time taken,  0.015444278717041016
assembling the matrix time taken:  0.0002524852752685547
solving Ax = b time taken:  0.029601097106933594
		 time taken minimize linear layer:  0.030005693435668945
current error:  tensor(0.0067)
epoc