In [1]:
import numpy as np
import pandas as pd
from copy import deepcopy

from identification_shallownn.SNN import SNN, generate_random_SNN
from identification_shallownn.identifying_SNN import a_cand_SNN, gradient_descent_A
from identification_shallownn.matrix_manip import normalize_col
from identification_shallownn.modules.experiment import Experiment 

In [2]:
def run_once(d,m,n_rep,eps_A, number_hessians, repetition, n_steps_apgd, delta):
    metrics = pd.Series(dtype='float')
    
    #Create a random network which remains over repetition
    net = generate_random_SNN(d=d,m=m,eps_A=eps_A, seed=repetition)

    #Creating a test set for evaluation
    np.random.seed(repetition)
    
    X = np.random.normal(size=(number_hessians, net.m))
    X = normalize_col(X.T).T
    
    #Compute random estimates of a_1,...,a_m
    V = a_cand_SNN(net, X,n_rep=n_rep, n_steps_apgd=n_steps_apgd)
    V = np.array(V)
    
    #Finding which vectors were found can be done via the inner products due to the unit length of the vectors
    found_weights_up_to_delta = np.max(np.abs(V @ net.A), axis=0) >= 1-delta**2 / 2
    number_of_found_weights = np.sum(found_weights_up_to_delta)
    
    metrics["number_of_found_weights"] = number_of_found_weights
    metrics["found_all_weights"] = (number_of_found_weights == m)
    
    return metrics 

In [3]:
handle = 'recovery_heatmap'

host_config = {
    'output_dir': './data'
}

fixed_params = {
    'd':20,
    'm':20,
    'n_rep':180,
    'delta':0.05,
    'n_steps_apgd':100,
}

varying_params = {
    'number_hessians': np.arange(start=2, stop=40,step=2, dtype="int"),
    'repetition':np.arange(start=0, stop=20, dtype="int"),
    'eps_A':[0,0.25,0.5,1,2,3]
}

experiment = Experiment(
    run_once = run_once,
    fixed_params = fixed_params,
    varying_params = varying_params,
    host_config = host_config,
    handle = handle,
    use_pickle = True)

In [4]:
results = experiment()

In [5]:
results

Unnamed: 0,d,m,n_rep,delta,n_steps_apgd,number_hessians,repetition,eps_A,number_of_found_weights,found_all_weights
0,20.0,20.0,180.0,0.05,100.0,2.0,0.0,0.00,0.0,0.0
1,20.0,20.0,180.0,0.05,100.0,2.0,0.0,0.25,0.0,0.0
2,20.0,20.0,180.0,0.05,100.0,2.0,0.0,0.50,0.0,0.0
3,20.0,20.0,180.0,0.05,100.0,2.0,0.0,1.00,0.0,0.0
4,20.0,20.0,180.0,0.05,100.0,2.0,0.0,2.00,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...
2275,20.0,20.0,180.0,0.05,100.0,38.0,19.0,0.25,20.0,1.0
2276,20.0,20.0,180.0,0.05,100.0,38.0,19.0,0.50,20.0,1.0
2277,20.0,20.0,180.0,0.05,100.0,38.0,19.0,1.00,20.0,1.0
2278,20.0,20.0,180.0,0.05,100.0,38.0,19.0,2.00,20.0,1.0
