In [26]:
import sys
import shutil
import time
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
import matplotlib.pyplot as plt
from networks import *
from normal_train import *
import numpy as np
import copy
import math

In [30]:
# I list some settings, you can swap to other settings by commenting and uncommenting

args = dict()

# training inputs
args['xp'] = np.array([[-1,-1], [1,1], [0,0], [-1,1], [1,-1]],dtype=np.float32)
# other training inputs
# args['xp'] = np.array([[-1.3,-0.7], [0.5,0.9], [-0.8,0.3], [-0.4,1.6], [1.6,-0.4]],dtype=np.float32)

# training outputs
args['fp'] = np.array([[1.5], [1.5], [0.5], [-0.5], [-0.5]],dtype=np.float32)
# other training outputs
# args['fp'] = np.array([1.5, 1.5, 0.5, -0.5, -0.5],dtype=np.float32)
# args['fp'] = np.array([[0.05], [0.02], [0], [-0.01], [-0.03]],dtype=np.float32)

#number of inputs
args['input_dim'] = 2
#name of the saved model
args['tag'] = "2d_toy"
#whether to only train the output layer
args['train_only_output'] = True
# args['bias'] = ("uniform", 2)
args['bias'] = ("uniform", 2)
args['weight'] = "unit_vector"
args['path'] = '/Users/jinhui/Documents/GitHub/model'
args['epochs'] = 10000
# args['nums_of_neurons']=[10,20,40,80,160,320,640,1280,2560]
args['nums_of_neurons']=[10]

# Gaussian initialzation

# args = dict()
# args['xp'] = np.array([[-1,-1], [1,1], [0,0], [-1,1], [1,-1]],dtype=np.float32)
# args['fp'] = np.array([[1.5], [1.5], [0.5], [-0.5], [-0.5]],dtype=np.float32)
# args['input_dim'] = 2
# args['tag'] = "2d_toy_gaussian_try"
# args['train_only_output'] = True
# # args['bias'] = ("uniform", 2)
# sigmab = 0.1
# args['bias'] = ("normal", sigmab)
# args['weight'] = "normal"
# args['path'] = '/Users/jinhui/Documents/GitHub/model'
# args['epochs'] = 10000
# args['nums_of_neurons']=[10,20,40,80,160,320,640,1280,2560]

In [31]:
# this block compute the solution of the variational problem for uniform initialzation 
# i.e. the weight uniformly distribute on the sphere and the bias uniformly distribute on a interval

# solution of the variational problem with linear correction ((eq.10) in paper)
# valid only under the condition of Theorem 7
# if weights and biases are initialized by Gaussian distribution, use the next cell

numgrid = 21 # you can choose larger numgrid to make the grid dense
R = np.abs(args['xp']).max()
x = np.linspace(-R, R, numgrid, dtype=np.float32)
h = x[1]-x[0]
y = np.linspace(-R, R, numgrid, dtype=np.float32)
X, Y = np.meshgrid(x, y)
X = np.expand_dims(X, axis=2)
Y = np.expand_dims(Y, axis=2)

M = args['xp'].shape[0]
d = args['xp'].shape[1]
coeff = np.zeros((M+d+1,M+d+1))
for i in range(M):
    coeff[i,0:M] = np.linalg.norm(args['xp']-args['xp'][i], axis=1)**3
    coeff[i,M:M+d] = args['xp'][i]
    coeff[i,M+d] = 1
    coeff[M:M+d, i] = args['xp'][i]
    coeff[M+d, i] = 1

lambda_u_v = np.linalg.solve(coeff, np.concatenate((args['fp'].squeeze(),np.zeros(3))))

u_first = lambda_u_v[-3:-1]
v_first = lambda_u_v[-1]

fx = x.copy()
fy = y.copy()
f_variation_exact = np.zeros((fx.shape[0], fy.shape[0]))
fX, fY = np.meshgrid(fx, fy)
xi = np.zeros((2))

print("compute f")
for idx, ix in enumerate(fx):
    for idy, iy in enumerate(fy):
        xi[0] = ix
        xi[1] = iy
        f_variation_exact[idy,idx]=np.dot(lambda_u_v, 
                                          np.concatenate((np.linalg.norm(args['xp']-xi, axis=1)**3,xi,np.ones(1))))

compute f


In [32]:
# # this block compute the solution of the optimization problem for gaussian initialzation 
# # if you use Gauusian initialzation, uncomment the whole block
# # follow the instrcution in the comments to use the following code

# # solution of the optimization problem with linear correction ((eq.24) in paper)

# def continuous_version_with_linear(args):
#     linear_constraint_weight = np.zeros((args["xp"].shape[0], Angle.shape[0]+3))
#     for idx, train_sample in enumerate(args["xp"]):
#         linear_constraint_weight[idx,0:-3] = np.maximum(np.cos(Angle)*train_sample[0]+np.sin(Angle)*train_sample[1]+Bias,0)*weight.reshape(-1)
#         linear_constraint_weight[idx,-3] = train_sample[0]
#         linear_constraint_weight[idx,-2] = train_sample[1]
#         linear_constraint_weight[idx,-1] = 1

#     linear_constraint_with_linear = LinearConstraint(linear_constraint_weight, args["fp"].reshape(-1), args["fp"].reshape(-1))

#     from scipy.optimize import minimize
#     def obj_with_linear(x):
#         return np.sum(x[0:-3]**2*weight.reshape(-1))

#     def obj_der_with_linear(x):
#         return np.concatenate((2*x[0:-3],np.array([0,0,0])))


#     def obj_hess_p_with_linear(x, p):
#         return np.concatenate((2*p[0:-3],np.array([0,0,0])))

#     x0 = np.zeros(angle.shape[0]*bias.shape[0]+3)


#     return minimize(obj_with_linear, x0, method='trust-constr', jac=obj_der_with_linear, hessp=obj_hess_p_with_linear,
#                 constraints=[linear_constraint_with_linear],
#                 options={'verbose': 1, 'xtol':1e-20, 'gtol':1e-10})


# # solution of the optimization problem without linear correction ((eq.19) in paper)

# def continuous_version(args):

#     linear_constraint_weight = np.zeros((args["xp"].shape[0], Angle.shape[0]))
#     for idx, train_sample in enumerate(args["xp"]):
#         linear_constraint_weight[idx] = np.maximum(np.cos(Angle)*train_sample[0]+np.sin(Angle)*train_sample[1]+Bias,0)*weight.reshape(-1)


#     linear_constraint = LinearConstraint(linear_constraint_weight, args["fp"].reshape(-1), args["fp"].reshape(-1))

#     from scipy.optimize import minimize
#     def obj(x):
#         return np.sum(x**2*weight.reshape(-1))

#     def obj_der(x):
#         return 2*x


#     def obj_hess_p(x, p):
#         return 2*p

#     x0 = np.zeros(angle.shape[0]*bias.shape[0])


#     return minimize(obj, x0, method='trust-constr', jac=obj_der, hessp=obj_hess_p,
#                     constraints=[linear_constraint],
#                     options={'verbose': 1})

# num_angle = 601
# angle = np.linspace(0, 2*math.pi, num_angle, dtype=np.float32)
# angle = angle[0:-1]
# num_bias = 601
# bias = np.linspace(-2, 2, num_bias, dtype=np.float32) 

# sigmab = 1 # standard deviation of initialization of bias, change it if you use different standard deviation
# sigmaw = 1 # standard deviation of initialization of weight, change it if you use different standard deviation

# weight = np.ones((angle.shape[0], bias.shape[0]))
# weight[0] = 1/2
# weight[-1] = 1/2
# weight[:,0] = 1/2
# weight[:,-1] = 1/2
# weight = weight/(sigmab*sigmab+bias*bias*sigmaw*sigmaw)**2.5
# weight = weight.T

# Angle, Bias = np.meshgrid(angle, bias)
# Angle = Angle.reshape(-1)
# Bias = Bias.reshape(-1)

# # change to the code in the comment if you compute the solution of optimization problem with linear correction
# # res_with_linear = continuous_version_with_linear(args)
# res_non_linear = continuous_version(args)

# fx = x.copy()
# fy = y.copy()

# # change to the code in the comment if you compute the solution of optimization problem with linear correction
# # f_variation_linear_adjust = np.zeros((fx.shape[0], fy.shape[0]))
# f_variation_no_linear_adjust = np.zeros((fx.shape[0], fy.shape[0]))

# fX, fY = np.meshgrid(fx, fy)
# xi = np.zeros((2))

# print("compute f")

# # change to the code in the comment if you compute the solution of optimization problem with linear correction
# # for idx, ix in enumerate(fx):
# #     for idy, iy in enumerate(fy):
# #         xi[0] = ix
# #         xi[1] = iy
# #         first_layer = np.maximum(np.cos(Angle)*ix+np.sin(Angle)*iy+Bias,0)*weight.reshape(-1)
# #         f_variation_linear_adjust[idy,idx]=np.sum(first_layer*res_with_linear.x[0:-3])+\
# #             res_with_linear.x[-3]*ix+res_with_linear.x[-2]*iy+res_with_linear.x[-1]

# for idx, ix in enumerate(fx):
#     for idy, iy in enumerate(fy):
#         xi[0] = ix
#         xi[1] = iy
#         first_layer = np.maximum(np.cos(Angle)*ix+np.sin(Angle)*iy+Bias,0)*weight.reshape(-1)
#         f_variation_no_linear_adjust[idy,idx]=np.sum(first_layer*res_non_linear.x)

In [None]:
# train the network
%matplotlib inline
import sys
import shutil
import time
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
import matplotlib.pyplot as plt
from networks import *
from normal_train import *
import numpy as np
import copy
import math

def sum_of_all_parameter(model):
    res = 0
    if isinstance(model, TwoLayerReluASI):
        for p in model.features1:
            if p.__class__.__name__=="Linear":
                res+=np.linalg.norm(p.weight.detach().numpy().squeeze())

        for p in model.features2:
            if p.__class__.__name__=="Linear":
                res+=np.linalg.norm(p.weight.detach().numpy().squeeze())
    else:
        for p in model.features:
            if p.__class__.__name__=="Linear":
                res+=np.linalg.norm(p.weight.detach().numpy().squeeze())

    return res

def two_relu_layer_train(args, num_neurons=100):
    epochs = args['epochs']
#     path = '/home/huijin/seminar project/model/'

    #     tag='sgd_epoch'+str(epochs)+'_init_'+str(initialization)+'_lr_'+str(learning_rate)+'_bs_'+str(batch_size)+'_network_'+network_choice
    tag = args['tag']+str(num_neurons)
    input_dim = args['input_dim']
    
    model = TwoLayerReluASI(input_dim=input_dim, num_neurons=num_neurons, 
                            initialization=args['weight'], bias_tune_tuple=args['bias'])

    global best_acc
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    use_cuda = True
    best_loss = 1000000
    old_loss = 10
    learning_rate = 0.1 #0.02
    #     momentum = 0.9
    #     weight_decay = 5e-4
    # Data
    #     dataloader = datasets.CIFAR10
    #     train_batch = batch_size #128
    #     test_batch = 100
    workers = 1
    decay_epoch = 100000
    use_cuda = False

    #     if accumulation_steps*batch_size/learning_rate>=4096/0.02:
    #         decay_epoch = 250
    #         epochs = 400
    #     elif accumulation_steps*batch_size/learning_rate>=4096/0.04:
    #         decay_epoch = 200
    #         epochs = 350

#     xp = np.array([-2, -1, 0, 1, 2],dtype=np.float)
#     fp = [1.5, 0.5, 1.5, 0.5, 1.5]
    xp = args['xp']
    fp = args['fp']
    # train_input = np.sort(np.concatenate((np.random.uniform(-2,2,30),xp))).astype(np.float32)
#     train_input = np.sort(xp).astype(np.float32)
#     train_output = np.interp(train_input, xp, fp).astype(np.float32)
    train_input = xp
    train_output = fp

    trainset = [(train_input[i:i+1], train_output[i:i+1]) for i in range(train_input.shape[0])]
    # batch_size=train_input.shape[0]
    trainloader = data.DataLoader(trainset, batch_size=train_input.shape[0], shuffle=False, num_workers=workers)

    #     testset = dataloader(root='/home/huijin/large-batch-training-torch/data', train=False, download=False, transform=transform_test)
    #     testloader = data.DataLoader(testset, batch_size=test_batch, shuffle=False, num_workers=workers)

    #     model = AlexNet(10)
    #     model_state = torch.load('model_38')
    #     model.load_state_dict(model_state)

    criterion = nn.MSELoss()
    #     optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay= weight_decay)
    if args['train_only_output']:
        model.features1[0].weight.requires_grad = False
        model.features1[0].bias.requires_grad = False
        model.features2[0].weight.requires_grad = False
        model.features2[0].bias.requires_grad = False

    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    if use_cuda == True:
        model = model.cuda()
        criterion = criterion.cuda()

    # Initialization

    #     initial_model_file = '/home/huijin/large-batch-training-torch/model/init_'+str(initialization)+'_network_'+network_choice
    #     if os.path.isfile(initial_model_file):
    #         model_state = torch.load(initial_model_file)
    #         model.load_state_dict(model_state)
    #     else:
    #         torch.save(model.state_dict(), initial_model_file)
    inital_norm_weight = sum_of_all_parameter(model)
    # Train and val
    history = dict(loss=list(),val_loss=list(),acc=list(),val_acc=list())
    old_loss = test(trainloader, model, criterion, 0, use_cuda)+0.01
    for epoch in range(start_epoch, epochs):
#         import ipdb; ipdb.set_trace() 
        old_state_dict = copy.deepcopy(model.state_dict())
#         print(old_state_dict)
#         import ipdb; ipdb.set_trace()
        train_loss = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
    #         test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)
        train_loss = test(trainloader, model, criterion, epoch, use_cuda)
        print('Epoch: [%d | %d] LR: %f; Train Loss %f'
              % (epoch + 1, epochs, optimizer.param_groups[0]['lr'],train_loss))
        sys.stdout.flush()
        history['loss'].append(train_loss)

        # save model
        is_best = train_loss < best_loss
        best_loss = min(train_loss, best_loss)
        if is_best:
            torch.save(model.state_dict(), args['path']+tag+'_best'+"normal [-1,1]")
        # torch.save(model.state_dict(), '/home/huijin/large-batch-training-torch/model/'+tag+'model_batch_'+str(epoch+1))

        if train_loss > old_loss:
            model.load_state_dict(old_state_dict)
            learning_rate = learning_rate*0.5
            optimizer = optim.SGD(model.parameters(), lr=learning_rate)
#             for param_group in optimizer.param_groups:
#                 param_group['lr'] *= 1.0/2
        else:
            if epoch % decay_epoch == (decay_epoch-1):
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.5
            if abs(old_loss-train_loss)<1e-8:
                break
            old_loss = train_loss

    torch.save(model.state_dict(), args['path']+tag)
    
    end_norm_weight = sum_of_all_parameter(model)
    
#     test_input = np.arange(-2 , 2 , 0.01).astype(np.float32)
#     testset = [(test_input[i:i+1]) for i in range(test_input.shape[0])]
#     testloader = data.DataLoader(testset, batch_size=test_input.shape[0], shuffle=False, num_workers=workers)
#     for batch_idx, inputs in enumerate(testloader):
#         outputs = model(inputs)
#     test_output = outputs.squeeze().detach().numpy()
#     true_output = np.interp(test_input, xp, fp)
#     away_from_piecewise_linear = np.mean((true_output-test_output)**2)
    return (inital_norm_weight, end_norm_weight, 0, train_loss, model)

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.square_sum = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.square_sum += val * val * n
        self.count += n
        self.avg = self.sum / self.count
        self.std = self.square_sum / self.count - self.avg*self.avg
        
        
# args['nums_of_neurons']=[10,20,40,80,160,320,640,1280,2560]

inital_norm_weight_vs_num = []
inital_norm_weight_std_vs_num = []
end_norm_weight_vs_num = []
end_norm_weight_std_vs_num = []
away_from_variational_vs_num = []
away_from_variational_std_vs_num = []
train_loss_vs_num = []
train_loss_std_vs_num = []

for num_of_neurons in args['nums_of_neurons']:
    number_of_try = 3
    inital_norm_weights = AverageMeter()
    end_norm_weights = AverageMeter()
    away_from_variationals = AverageMeter()
    train_losses = AverageMeter()
    for init in range(number_of_try):
        inital_norm_weight, end_norm_weight, away_from_variational, train_loss, model = two_relu_layer_train(args, num_of_neurons)
        test_input = np.concatenate((X, Y), axis=2)
        test_input = test_input.reshape(-1,2)
        testset = [(test_input[i:i+1]) for i in range(test_input.shape[0])]
        testloader = data.DataLoader(testset, batch_size=numgrid*numgrid, shuffle=False, num_workers=1)
        for batch_idx, inputs in enumerate(testloader):
            outputs = model(inputs)
        test_output = outputs.squeeze().detach().numpy()
        test_output = test_output.reshape(numgrid, numgrid)
        away_from_variational = np.sum((test_output - f_variation_exact)**2)*h*h
        inital_norm_weights.update(inital_norm_weight)
        end_norm_weights.update(end_norm_weight)
        away_from_variationals.update(away_from_variational)
        train_losses.update(train_loss)

    inital_norm_weight_vs_num.append(inital_norm_weights.avg)
    end_norm_weight_vs_num.append(end_norm_weights.avg)
    away_from_variational_vs_num.append(away_from_variationals.avg)

    inital_norm_weight_std_vs_num.append(inital_norm_weights.std)
    end_norm_weight_std_vs_num.append(end_norm_weights.std)
    away_from_variational_std_vs_num.append(away_from_variationals.std)

    train_loss_vs_num.append(train_losses.avg)
    train_loss_std_vs_num.append(train_losses.std)

print(inital_norm_weight_vs_num)
print(end_norm_weight_vs_num)
print(away_from_variational_vs_num)
print(train_loss_vs_num)

print(inital_norm_weight_std_vs_num)
print(end_norm_weight_std_vs_num)
print(away_from_variational_std_vs_num)
print(train_loss_std_vs_num)


In [None]:
#plotting 3D surfaces of output of the network output and the solution of the variational problem
fx = x.copy()
fy = y.copy()
fX, fY = np.meshgrid(fx, fy)
%matplotlib notebook
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# X, Y, Z = axes3d.get_test_data(0.05)
# ax.plot_wireframe(fX[120:-120,120:-120], fY[120:-120,120:-120], f[120:-120,120:-120])
ax.plot_wireframe(fX, fY, f_variation_exact)
ax.plot_wireframe(fX, fY, test_output, color="red")
ax.scatter3D(args['xp'][:,0].squeeze(), args['xp'][:,1].squeeze(), args['fp'].squeeze(), c="red");
plt.show()

In [None]:
#code of plotting contour of the solution of the variational problem
fig, ax = plt.subplots()
CS2 = ax.contour(fX, fY, f_variation_exact, levels=np.linspace(-0.4, 1.4, 10, dtype=np.float32))
ax.clabel(CS2, inline=True, fontsize=15)
ax.scatter(args['xp'][:,0].squeeze(), args['xp'][:,1].squeeze(), c="red");
plt.title("Exact Solution", fontsize=22)

In [None]:
#fit the difference between the network output and the solution of the variational problems by a power function
import numpy as np
x=args['nums_of_neurons']
y=away_from_variational_vs_num
error=away_from_variational_std_vs_num
from sklearn.linear_model import LinearRegression
start = 0
n = np.log(x[start:]).shape[0]
X = np.zeros((n,1))
X[:,0] = np.log(x[start:])
# X[:,1] = np.ones(n)
y_s = np.log(np.sqrt(y[start:]))
reg = LinearRegression().fit(X, y_s)
reg.score(X, y_s)

In [None]:
#plot the difference between the network output and the solution of the variational problem against the number of neurons
%matplotlib inline
import matplotlib.pyplot as plt
fig, ax0= plt.subplots()
# ax0.errorbar(x[1:], np.sqrt(y[1:]), yerr=np.sqrt(error[1:])/np.sqrt(y[1:])/2, fmt='-o')
# ax0.loglog(x[start:], np.sqrt(y[start:]), '-o')
plt.xlabel("Number of neurons")
plt.ylabel("Error")
n_line = np.linspace(10,5120,1000).shape[0]
X_line = np.zeros((n_line,1))
X_line[:,0] = np.linspace(10,5120,1000)
ax0.set_xscale("log", nonposx='clip')
ax0.set_yscale("log", nonposy='clip')
# ax4.set(title='Errorbars go negative')
ax0.errorbar(x[start:], np.sqrt(y[start:]), yerr=(np.sqrt(error)/np.sqrt(y)/2)[start:], fmt='-o')
ax0.loglog(X_line[:,0], np.exp(reg.predict(np.log(X_line))))
ax0.legend([f"$y={np.exp(reg.intercept_):.4f}x^{{{reg.coef_[0]:.4f}}}$",'Error'],fontsize='large')

## The following code reads the trained model from the file, so they can be run separately once you run the above code and have the trained model

In [None]:
#plotting 3D surfaces of output of the network output and the solution of the variational problem
from scipy.interpolate import CubicSpline

args['xp'] = np.array([[-1,-1], [1,1], [0,0], [-1,1], [1,-1]],dtype=np.float32)
args['fp'] = np.array([[1.5], [1.5], [0.5], [-0.5], [-0.5]],dtype=np.float32)
args['input_dim'] = 2
args['tag'] = "2d_toy"

# number of neurons in hidden layers, you can change this value
num_neurons = 20
# num_neurons = 160

model = TwoLayerReluASI(input_dim=2, num_neurons=num_neurons)
args['path'] = '/Users/jinhui/Documents/GitHub/model'
tag = args['tag']+str(num_neurons)

model_state = torch.load(args['path']+tag+'_best'+"normal [-1,1]")
model.load_state_dict(model_state)

numgrid = 21
R = np.abs(args['xp']).max()
x = np.linspace(-R, R, numgrid, dtype=np.float32)
h = x[1]-x[0]
y = np.linspace(-R, R, numgrid, dtype=np.float32)
X, Y = np.meshgrid(x, y)
X = np.expand_dims(X, axis=2)
Y = np.expand_dims(Y, axis=2)
test_input = np.concatenate((X, Y), axis=2)
test_input = test_input.reshape(-1,2)

testset = [(test_input[i:i+1]) for i in range(test_input.shape[0])]
testloader = data.DataLoader(testset, batch_size=numgrid*numgrid, shuffle=False, num_workers=1)
for batch_idx, inputs in enumerate(testloader):
    outputs = model(inputs)
test_output = outputs.squeeze().detach().numpy()
test_output = test_output.reshape(numgrid, numgrid)


%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# X, Y, Z = axes3d.get_test_data(0.05)
# ax.plot_wireframe(fX[120:-120,120:-120], fY[120:-120,120:-120], f[120:-120,120:-120])
ax.plot_wireframe(fX, fY, f_variation_exact)
ax.plot_wireframe(fX, fY, test_output, color="red")
ax.scatter3D(args['xp'][:,0].squeeze(), args['xp'][:,1].squeeze(), args['fp'].squeeze(), c="red");
plt.show()

In [None]:
#code of plotting contour of the network output
fig, ax = plt.subplots()
# CS = ax.contour(fX, fY, f_variation_exact, 40)
# ax.clabel(CS, inline=True, fontsize=10)
CS2 = ax.contour(fX, fY, test_output, np.linspace(-0.4, 1.4, 10, dtype=np.float32))
ax.clabel(CS2, inline=True, fontsize=15)
ax.scatter(args['xp'][:,0].squeeze(), args['xp'][:,1].squeeze(), c="red");
plt.title(f"n={num_neurons}", fontsize=22)