In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from copy import deepcopy
from identification_shallownn.SNN import SNN, generate_random_SNN
from identification_shallownn.identifying_SNN import identify_weights, gradient_descent_A
from identification_shallownn.matrix_manip import normalize_col
from identification_shallownn.modules.experiment import Experiment 
import pickle
import pathlib
import time
import pandas as pd

In [2]:
def run_once(d,m,n_rep,eps_A,number_evaluation_samples, number_hessians, repetition):
    metrics = pd.Series(dtype='float')
    
    #Create a random network which remains over repetition
    net = generate_random_SNN(d=d,m=m,eps_A=eps_A, seed=repetition)

    #Creating a test set for evaluation
    np.random.seed(repetition)
    X_test = np.random.normal(size=(number_evaluation_samples, net.d))
    X_test = normalize_col(X_test.T).T
    rad = np.random.uniform(size=number_evaluation_samples)**(1/d)
    X_test = (X_test.T*rad).T
    Y_test = net.eval(X_test)
    
    X = np.random.normal(size=(number_hessians, net.m))
    X = normalize_col(X.T).T
    A_FSV = identify_weights(net, X, n_rep)
    metrics["error_A_FVD"] = np.linalg.norm(net.A - A_FSV)
    #Creating new network with Anew as weights
    net_approximation_FVD = SNN(A=A_FSV, b=net.b, s1=net.s1,s2=net.s2, act=net.g, dact=net.dg)
    metrics["MSE_net_FVD"] = np.mean((net.eval(X_test) - net_approximation_FVD.eval(X_test))**2)
    metrics["Linf_net_FVD"] = np.max(np.abs(net.eval(X_test) - net_approximation_FVD.eval(X_test)))

    #Samples used for gradient descent are number_hessians*the number of samples used for finite differences
    N_samples_gd = int(number_hessians * (d*(d+1)/2 + 1))
    X_train = np.random.normal(size=(N_samples_gd, m))
    X_train = normalize_col(X.T).T
    Y_train = net.eval(X_train)
    
    net_approximation_gd = SNN(A=np.zeros(shape=(d,m)), b=net.b, s1=net.s1,s2=net.s2, act=net.g, dact=net.dg)
    
    losses = gradient_descent_A(X_train, Y_train, net_approximation_gd, 1000, 0.1)
    metrics['losses_gd'] = losses
    metrics["error_A_gd"] = np.linalg.norm(net.A-net_approximation_gd.A)

    metrics["MSE_net_gd"] = np.mean((net.eval(X_test) - net_approximation_gd.eval(X_test))**2)
    metrics["Linf_net_gd"] = np.max(np.abs(net.eval(X_test) - net_approximation_gd.eval(X_test)))
    
    return metrics 

In [3]:
handle = 'comparison_to_gd'

host_config = {
    'output_dir': '/media/data/sherlock/results'
}

fixed_params = {
    'd':20,
    'm':20,
    'n_rep':180,
    'eps_A':1,
    'number_evaluation_samples':10**5
}

varying_params = {
    'number_hessians': np.arange(start=1, stop=32,dtype="int"),
    'repetition':np.arange(start=0, stop=10, dtype="int")
}

experiment = Experiment(
    run_once = run_once,
    fixed_params = fixed_params,
    varying_params = varying_params,
    host_config = host_config,
    handle = handle,
    use_pickle = True)

In [4]:
results = experiment()

In [5]:
results

Unnamed: 0,d,m,n_rep,eps_A,number_evaluation_samples,number_hessians,repetition,error_A_FVD,MSE_net_FVD,Linf_net_FVD,losses_gd,error_A_gd,MSE_net_gd,Linf_net_gd
0,20,20,200,1,100000,1,0,4.497664,4.042486e-02,0.741095,"[0.06248046911288859, 0.04149466503688739, 0.0...",4.462948,0.037661,0.721143
1,20,20,200,1,100000,1,1,4.480960,4.018760e-02,0.803699,"[0.0035582588795612623, 0.002363187118807135, ...",4.471796,0.040450,0.801174
2,20,20,200,1,100000,1,2,4.456673,4.303922e-02,0.739573,"[0.0008208665609346762, 0.0005517116142599006,...",4.472195,0.043083,0.742945
3,20,20,200,1,100000,1,3,4.462657,3.981723e-02,0.711758,"[0.01404902534968421, 0.009296597228174799, 0....",4.470086,0.039359,0.723995
4,20,20,200,1,100000,1,4,4.426523,3.534983e-02,0.753536,"[0.020695492228828483, 0.01353154932302645, 0....",4.469635,0.035042,0.747953
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
305,20,20,200,1,100000,31,5,0.001089,2.382692e-09,0.000184,"[0.048761442988098765, 0.04742283523617502, 0....",4.355248,0.000752,0.103085
306,20,20,200,1,100000,31,6,0.001351,5.163331e-09,0.000264,"[0.04456171333973311, 0.04328511237115131, 0.0...",4.368060,0.000451,0.089380
307,20,20,200,1,100000,31,7,0.001193,1.145970e-09,0.000123,"[0.05706278143721869, 0.05506381016984862, 0.0...",4.348644,0.000516,0.086040
308,20,20,200,1,100000,31,8,0.001215,8.856076e-10,0.000111,"[0.03341564238700435, 0.03256599417589382, 0.0...",4.364455,0.000560,0.109501
