In [1]:
# python script to run the pre-processed learning curves for a single step of GD
import numpy as np
from scipy.special import erf   
import os 
import time 

##### Useful functions

In this notebook we investigate the generalization performance after one step for two-neuron teachers ($k=2$). 

Let us define some useful functions that we will need:

In [2]:
# few lines to define relevant functions 
def perpendicular_vector(v):
    d = len(v)
    w = np.zeros(d)
    w[0] = np.random.randn()  
    w -= np.dot(w, v) * v / np.dot(v, v) 
    w /= np.linalg.norm(w)
    return w
def sample_data(n,ntest,d):
    # function to sample data 
    Z = np.random.randn(n,d) ; Ztest = np.random.randn(ntest,d)
    Y =  (f01(Z@theta1) + f02(Z@theta2)) / 2  + np.sqrt(noise)*np.random.randn(n)
    Ytest =  (f01(Ztest@theta1) + f02(Ztest@theta2)) / 2 + np.sqrt(noise)*np.random.randn(ntest)
    return Z,Ztest,Y,Ytest
def ridge_estimator(X, y, lamb=0.1):
    # Implements the pseudo-inverse ridge estimator.
    m, n = X.shape
    if m >= n:
        return np.linalg.inv(X.T @ X + lamb*np.identity(n)) @ X.T @ y
    elif m < n:
        return X.T @ np.linalg.inv(X @ X.T + lamb*np.identity(m)) @ y
def get_errors_ridge(Xtrain,Xtest,Ytrain,Ytest,lamb,flag = True):
    ' get errors for ridge regression with fixed data matrices'
    ' normalize the data by dividing by sqrt(p) in the ridge estimator while label are of O(1) already'
    n,p = Xtrain.shape
    eg, et = [], [] 
    # Iterate over different realisations of the problem.
    w = ridge_estimator(Xtrain / np.sqrt(p), Ytrain, lamb)
    yhat_train = Xtrain @ w / np.sqrt(p)
    yhat_test = Xtest @ w   / np.sqrt(p)
    # Train loss
    train_loss = np.mean((Ytrain - yhat_train)**2)
    # Fresh samples
    test_error = np.mean((Ytest - yhat_test)**2) 
    eg.append(test_error)   ;  et.append(train_loss)
    if flag:
        print(f' we have train loss {train_loss} and test error {test_error}')
    # Return average and standard deviation of both errors
    return (np.mean(et), np.mean(eg) , np.std(et), 
             np.std(eg),w)    

##### Parameter setup 

- We fix the relevant dimensions $(p,d)$
- We sample the orthonormal teacher vectors $\vec{w}^*_1,\vec{w}^*_2$. 
- We fix the first and second layer at initialization $W_0,\vec{a}_0$.
- We define the different activations we may consider in Teacher-Student setup.
- We build the array of sample sizes $\vec{n_s}$ from which we compute the generalization error curve.

In [3]:
# Dictionary of possible student activation functions and its derivatives 
stud_acts = { 'relu': lambda x: np.maximum(x,0),  'hermite1+2+4_norm': lambda x: (x**4 - 6*x**2 + 3) /24 + (x**2 -1) /2 + x}
stud_ders = { 'relu': lambda x: (x>0).astype(int), 'hermite1+2+4_norm': lambda x: x**3 - 3*x}
# Dictionary of possible teacher activation functions
teach_acts = {'erf': lambda x: erf(x),'hermite1+2+4_norm': lambda x: (x**4 - 6*x**2 + 3) /24 + (x**2 -1) /2 + x}
stud_act = 'relu' 
f = stud_acts[stud_act]
fprime = stud_ders[stud_act] 
fnn0 = lambda D , W , a : 1/np.sqrt(p)*f(D@W.T)@a
teach_act1 = 'hermite1+2+4_norm'   ; f01 = teach_acts[teach_act1]
teach_act2 = 'hermite1+2+4_norm'    ; f02 = teach_acts[teach_act2]
number_ns = 3 
d = 512 
ntest = int(1e5) ; lamb = 1 ; noise = 0 
exp_min_n = 1.3 ; exp_max_n = 2
p = 1024
# choose sample size array 
ns = np.logspace(exp_min_n,exp_max_n,number_ns,base=d,dtype=int) # Take the range for th eplot of log-normalized sample complexity
# Initialize the weights
W0 = 1/np.sqrt(d)*np.random.randn(p,d)
# sample the orthonormal teacher vectors -  choose k=2 for simplicity
v1 = np.random.randn(d) 
theta1 = v1/np.linalg.norm(v1)
v2 = perpendicular_vector(v1)
theta2 = v2/np.linalg.norm(v2)
# sample second layer and fix it 
a0 = 1/np.sqrt(p)*np.random.randn(p) 

##### Learning curve construction 

Here we implement the GD training protocol with a preprocessing step. The preprocessing step is crucial in order to learn in one giant step of GD. Indeed, as Theorem 1 provably states, it is not possible to get fully specialized hidden student units with one giant step of GD in the $n = \mathcal{O}(d^l)$ regime, if directions associated to teacher Hermite coefficients lower than $l$ are not suppressed, or equivalently, if the leap index of the problem is lower than $l$. In the following we consider sample sizes up to $n = \mathcal{O}(d^2)$, and a teacher function with leap index $l=1$. Therefore, we need to remove an estimate of the first teacher hermite coefficient, in order to obtain a problem with effective leap index $l=2$.  

- We iterate over the value in $\vec{n_s}$ 
- For each sample we set the learning rate adaptively to have $\eta = \mathcal{O}(p\sqrt{\frac{n}{d}})$
- We average over 10 random draws 
- We use standard deviation to give confidence interval

In [4]:
# iterate over array of sample sizes
for j,n in enumerate(ns):
    start = time.time()
    # set the (giant) learning rate adaptively with the current sample size
    eta = 10*np.sqrt(n)*np.sqrt(p)
    print(f'START  --- for regime exponent {np.log(n)/np.log(d)} d ={d} and p={p}  ')
    # GD on the RF weights 
    nseeds = 10 ;  errgs = []
    for seed in range(nseeds):
        Z,Ztest,Y,Ytest = sample_data(n, ntest, d)
        # preprocess data removing the estimation of the first hermite coefficient
        A = np.mean(Y) ; A_test = np.mean(Ytest)
        B = np.mean(Y.reshape(-1,1)*Z,axis=0) ; B_test = np.mean(Ytest.reshape(-1,1)*Ztest,axis=0)
        Y_touse = Y - A - Z@B ; Ytest_touse = Ytest - A_test - Ztest@B_test
        # compute gradient 
        G = 1/n * Z.T @ (1/np.sqrt(p)*np.outer( ( Y_touse - fnn0(Z,W0,a0) ) , a0) * fprime(Z@W0.T))
        Wgd = W0 + eta*G.T
        # generate features 
        X = f(Z@Wgd.T) ; Xtest = f(Ztest@Wgd.T)
        # compute ridge estimator 
        e1,e2,s1,s2,w = get_errors_ridge(X,Xtest,Y_touse,Ytest_touse,lamb,flag=False)
        # compute predictions injecting back the preprocessed part
        Yhat = A + Z@B + 1/np.sqrt(p)*X@w
        Yhat_test = A_test + Ztest@B_test + 1/np.sqrt(p)*Xtest@w
        errgs.append(np.mean((Yhat_test - Ytest)**2))
    test_error = np.mean(errgs) ; test_error_std = np.std(errgs)
    end = time.time()
    print(f'FINISH --- for regime exponent & n,d = {np.log(n)/np.log(d)} & {n,d} the gen error is {test_error}, the std is {test_error_std} and it took t={end - start}') 

START  --- for regime exponent 1.2999524948096965 d =512 and p=1024  
FINISH --- for regime exponent & n,d = 1.2999524948096965 & (3326, 512) the gen error is 0.3360002036384977, the std is 0.0063229646780440935 and it took t=41.65862417221069
START  --- for regime exponent 1.6499990492421464 d =512 and p=1024  
FINISH --- for regime exponent & n,d = 1.6499990492421464 & (29532, 512) the gen error is 0.2067522390105553, the std is 0.005893831748022804 and it took t=77.16522908210754
START  --- for regime exponent 2.0 d =512 and p=1024  
FINISH --- for regime exponent & n,d = 2.0 & (262144, 512) the gen error is 0.07741778601141551, the std is 0.002084362460098963 and it took t=445.02054595947266
