In [1]:
import torch, jax; print(torch.cuda.is_available()); print(jax.devices())

True
[GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]


In [2]:
import numpy as np
import jax.numpy as jnp
import jax.nn as relu
from jax.config import config
from jax import jit, grad, random
from jax import grad, jit, vmap, hessian, jvp, lax
jnp.set_printoptions(suppress=True)
key = random.PRNGKey(1)


sigma_w = 1.5  # hidden layer variance
sigma_v = 1 # output layer variance
n_h = 100000 # Hidden-layer size
d =  10# dimension of features
N = 250  # Number of data vectors

W = sigma_w/jnp.sqrt(n_h)  * random.normal(key, shape = (d, n_h))
key, split = random.split(key)
v = sigma_v * random.normal(split, shape =(n_h, 1))

W_init = W
v_init = v


key, split = random.split(key)
X_train = random.normal(split, shape = (N, d))
# X_train = ( X_train.T / jnp.linalg.norm(X_train, axis = 1) ).T
key, split = random.split(key)
X_test = random.normal(split, shape = (N, d))
# X_test = ( X_test.T / jnp.linalg.norm(X_test, axis = 1) ).T
# print( jnp.linalg.norm(X_test, axis = 1))


key, split = random.split(key)
gr_truth = (jnp.arange(d) + 1) / d
Y_train = jnp.dot(X_train, gr_truth) + 0.125 * random.normal(split, shape = (N,))
key, split = random.split(key)
Y_test = jnp.dot(X_test, gr_truth)  + 0.125 * random.normal(split, shape = (N,))

data_train = (X_train,Y_train)
data_test = (X_test,Y_test)


In [3]:
def ReLU(x, l):
    """ Rectified Linear Unit (ReLU) activation function """
    return jnp.maximum(l*x, x)

jit_ReLU = jit(ReLU)

def forward_dyn(params, features, l):
  v, W = params
  batch_size = features.shape[0]
  h = jit_ReLU( jnp.dot(features, W), l  )
  # h = 0.5 * (h + jnp.abs(h))
  return jnp.dot(h, v).reshape(batch_size)



def lin_forward_dyn(params, features, l):
  v, W = params
  diff_v, diff_W = v - v_init, W - W_init
  _preds1 = forward_dyn((v, W), features, l)
  _preds2 = forward_dyn((diff_v, diff_W), features, l)
  return _preds1 - _preds2,  _preds2 # Return the output of linearized netwrok, and linearization error



def loss(params, data, l):
  X, Y = data
  preds = forward_dyn(params, X, l)
  # print(preds.shape)
  # print(Y.shape)
  return 0.5 * jnp.mean( jnp.square( Y - preds) )






def hvp(loss_fn, params, v):
  return  jvp(grad(loss_fn), [params], [v])[1]



In [4]:
def _lanczos_helper(w, z, z_old, vecs, beta):
    # print(w.shape)
    # print(z.shape)
    # print(vecs.shape)
        
    alpha  = jnp.dot(w,z)        # print(alpha)
     
    w     = w - alpha *z - beta * z_old
        
    coeffs = jnp.dot(vecs, w)
    w = w - jnp.dot(coeffs, vecs)
        
    beta = jnp.linalg.norm(w)
    return w/beta, alpha, beta

def lanczos_alg(loss_fn, params,  order, split):
    v, W = params
    dim = v.shape[0] + W.shape[0]* W.shape[1]
    tridiag = jnp.zeros(shape=(order+1,order+1))
    # print(tridiag)
    vecs = jnp.zeros(shape=(order+1, dim))

    init_vec = random.laplace(split, shape=(dim,))
    init_vec = init_vec / jnp.linalg.norm(init_vec)
    vecs = vecs.at[0].set(init_vec) 
    z_old = 0
    beta = 0
    # print('Dim', dim)
    for i in range(order):
        z = vecs[i,:]
        vpart = z[:v.shape[0]].reshape(v.shape[0], 1)
        Wpart = z[v.shape[0]:].reshape(W.shape[0], W.shape[1])

        vpart, Wpart = hvp(loss_fn, params, (vpart, Wpart))
        w = jnp.concatenate( (vpart.flatten(), Wpart.flatten()) )
        
        
        w, alpha, beta = jit(_lanczos_helper)(w, z, z_old, vecs, beta)
        
        if(beta < 1):
          break
        tridiag = tridiag.at[(i,i)].set(alpha)   
        z_old = z
        
    
        tridiag = tridiag.at[(i,i+1)].set(beta)
        tridiag = tridiag.at[(i+1,i)].set(beta) 
        vecs = vecs.at[i+1].set(w)

    # print(jnp.dot(vecs,vecs.T))
    return tridiag[:i, :i], vecs[:i, :]

In [5]:
T =  2000
dim = v.shape[0] + W.shape[0]* W.shape[1]
print("Dim:", dim)
l_vec = [0, 0.1, 0.25, 0.5, 0.75, 1]
freq = 20
Hg_product_vec   = jnp.zeros( shape = (len(l_vec), T//freq) )
overlap_vec      = jnp.zeros( shape = (len(l_vec), T//freq) )
eig_vecs_diffs   = jnp.zeros( shape = (len(l_vec), T//freq, 3) )
test_loss_vec    = jnp.zeros( shape = (len(l_vec), T//freq) )
train_loss_vec   = jnp.zeros( shape = (len(l_vec), T//freq) )
eig_vals_vec     = jnp.zeros( shape = (len(l_vec), T//freq, 100) )
lin_err_vec      = jnp.zeros( shape = (len(l_vec),  T//freq) )
rel_diff_vec = jnp.zeros(shape = (len(l_vec),) )
diff_vec = jnp.zeros(shape = (len(l_vec),) )


Dim: 1100000


In [6]:
learning_rate = 0.000001


for i in range(len(l_vec)):
    leak = l_vec[i]
    print('Leak:', leak)
    W = W_init
    v = v_init
    eig_vecs_priors = jnp.zeros(shape = (3, 15, dim) )
    loss_cl = lambda x: loss(x, data_train, leak)

    for t in range(T):
        grad_v, grad_W = grad(loss_cl)((v,W))
        if(t % freq == 0):
            flat_grad  = jnp.concatenate( (grad_v.flatten(), grad_W.flatten()) )
            key, split               = random.split(key)
            tridiag, lancsoz_vecs    = lanczos_alg(loss_cl, (v,W), 250,  split)
            eigs_vals, evecs_tridiag = jnp.linalg.eigh(tridiag)
            eig_vecs                 = jnp.dot(evecs_tridiag.T, lancsoz_vecs)
            if(eig_vecs.shape[0] < 100):
                pad_dim = 100 - eig_vecs.shape[0]
                pads = jnp.zeros(shape=(pad_dim,))
                eigs_vals = jnp.concatenate( (eigs_vals, pads) )
            if(eig_vecs.shape[0] < 15):    
                pad_dim = 15 - eig_vecs.shape[0]
                pads = jnp.zeros(shape=(pad_dim,dim))
                eig_vecs = jnp.concatenate( (eig_vecs, pads) )
            for r in range(3):
                prior_eig_vecs = eig_vecs_priors[r, :, :]
                diff = jnp.linalg.norm( jnp.abs(prior_eig_vecs) - jnp.abs(eig_vecs[-15:,:]) ) 
                print("Diff:", diff)
                eig_vecs_diffs = eig_vecs_diffs.at[(i, t//freq, r)].set(diff)
            eig_vecs_priors = eig_vecs_priors.at[2].set(eig_vecs_priors[1, :, :]) 
            eig_vecs_priors = eig_vecs_priors.at[1].set(eig_vecs_priors[0, :, :]) 
            eig_vecs_priors = eig_vecs_priors.at[0].set(eig_vecs[-15:,:])
            
            proj     =   jnp.dot(eig_vecs[-15:,:],flat_grad)
            overlap  = jnp.linalg.norm(proj) / (jnp.linalg.norm(flat_grad) + 1e-6)
            print("Overlap:", overlap)  
            train_loss = loss_cl((v,W))
            test_loss  = loss((v,W), data_test, leak)
            lin_err    = lin_forward_dyn((v,W), data_test[0], leak)[1]
            print('Train loss:', train_loss, 'Test loss:', test_loss, 'Linearization Err:', jnp.max(jnp.abs(lin_err)))
            # Save
            test_loss_vec  = test_loss_vec.at[(i,t//freq)].set(test_loss)
            train_loss_vec = train_loss_vec.at[(i,t//freq)].set(train_loss)
            lin_err_vec    = lin_err_vec.at[(i,t//freq)].set(jnp.max(jnp.abs(lin_err)))
            overlap_vec    = overlap_vec.at[(i,t//freq)].set(overlap)    
            eig_vals_vec   = eig_vals_vec.at[(i,t//freq)].set(eigs_vals[-100:])


        v -= learning_rate*grad_v
        W -= learning_rate*grad_W 
        
        if(t == T -1):
            key, split               = random.split(key)
            tridiag, lancsoz_vecs   = lanczos_alg(loss_cl, (v,W), 250,  split)
            eig_vals, evecs_tridiag= jnp.linalg.eigh(tridiag)
            eig_vecs  = jnp.dot(evecs_tridiag.T, lancsoz_vecs)

            if(eig_vecs.shape[0] < 100):
                pad_dim = 100 - eig_vecs.shape[0]
                pads = jnp.zeros(shape=(pad_dim,))
                eig_vals = jnp.concatenate( (eig_vals, pads) )
            if(eig_vecs.shape[0] < 15):    
                pad_dim = 15 - eig_vecs.shape[0]
                pads = jnp.zeros(shape=(pad_dim,dim))
                eig_vecs = jnp.concatenate( (eig_vecs, pads) )
                
            y_hat_init = forward_dyn((v_init, W_init), X_train, leak)
            data_hat_init = (X_train, y_hat_init)
            loss_cl2 = lambda x: loss(x, data_hat_init, leak)
            tridiag2, lancsoz_vecs2 = lanczos_alg(loss_cl2, (v_init,W_init), 250,  split)
            eig_vals2, evecs_tridiag2 = jnp.linalg.eigh(tridiag2)
            eig_vecs2  = jnp.dot(evecs_tridiag2.T, lancsoz_vecs2)

            if(eig_vecs2.shape[0] < 100):
                pad_dim = 100 - eig_vecs2.shape[0]
                pads = jnp.zeros(shape=(pad_dim,))
                eig_vals2 = jnp.concatenate( (eig_vals2, pads) )
            if(eig_vecs2.shape[0] < 15):    
                pad_dim = 15 - eig_vecs2.shape[0]
                pads = jnp.zeros(shape=(pad_dim,dim))
                eig_vecs2 = jnp.concatenate( (eig_vecs2, pads) ) 
                
            rel_diffs = jnp.divide( jnp.abs(eig_vals[-100:] - eig_vals2[-100:]), eig_vals2[-100:] + 1e-6)
            max_rel_diff = jnp.max(rel_diffs)    
            rel_diff_vec = rel_diff_vec.at[i].set(max_rel_diff)
            diff = jnp.linalg.norm( jnp.abs(eig_vecs[-15:,:]) - jnp.abs(eig_vecs2[-15:,:]) ) 
            diff_vec = diff_vec.at[i].set(diff)
            print('RelDiff',  max_rel_diff)
            print('Diff',  diff)



Leak: 0
Train loss: 4.928611 Test loss: 5.0729313 Linearization Err: 0.0
Train loss: 2.2293527 Test loss: 2.4990025 Linearization Err: 1.01607505e-08
Train loss: 1.22905 Test loss: 1.528849 Linearization Err: 2.1226572e-08
Train loss: 0.81985444 Test loss: 1.1263728 Linearization Err: 2.8292204e-08
Train loss: 0.6269091 Test loss: 0.93383527 Linearization Err: 3.2829313e-08
Train loss: 0.5188389 Test loss: 0.823931 Linearization Err: 3.584963e-08
Train loss: 0.4476239 Test loss: 0.74971205 Linearization Err: 3.7964885e-08
Train loss: 0.3947232 Test loss: 0.69304055 Linearization Err: 3.9540332e-08
Train loss: 0.35242164 Test loss: 0.64643383 Linearization Err: 4.0766967e-08
Train loss: 0.31716833 Test loss: 0.60651135 Linearization Err: 4.1758998e-08
Train loss: 0.28711632 Test loss: 0.57155126 Linearization Err: 4.2586713e-08
Train loss: 0.26115823 Test loss: 0.54055095 Linearization Err: 4.329604e-08
Train loss: 0.23855029 Test loss: 0.5128439 Linearization Err: 4.3916035e-08
Train l

Train loss: 0.41485298 Test loss: 0.6559087 Linearization Err: 3.0159896e-08
Train loss: 0.37370437 Test loss: 0.6138608 Linearization Err: 3.1330195e-08
Train loss: 0.33940762 Test loss: 0.57804286 Linearization Err: 3.2272396e-08
Train loss: 0.30989808 Test loss: 0.54655135 Linearization Err: 3.3062157e-08
Train loss: 0.28409377 Test loss: 0.5184186 Linearization Err: 3.374407e-08
Train loss: 0.2613188 Test loss: 0.493052 Linearization Err: 3.434769e-08
Train loss: 0.24109751 Test loss: 0.470042 Linearization Err: 3.4890363e-08
Train loss: 0.22306828 Test loss: 0.44907975 Linearization Err: 3.538265e-08
Train loss: 0.20693804 Test loss: 0.4299186 Linearization Err: 3.5833313e-08
Train loss: 0.19246274 Test loss: 0.4123521 Linearization Err: 3.6249062e-08
Train loss: 0.17943811 Test loss: 0.39620936 Linearization Err: 3.6631796e-08
Train loss: 0.16768935 Test loss: 0.38134217 Linearization Err: 3.6986457e-08
Train loss: 0.15706645 Test loss: 0.36762503 Linearization Err: 3.731685e-08


Train loss: 0.23251978 Test loss: 0.3960712 Linearization Err: 2.3332595e-08
Train loss: 0.2185319 Test loss: 0.38136798 Linearization Err: 2.3777528e-08
Train loss: 0.20572895 Test loss: 0.36770192 Linearization Err: 2.4192826e-08
Train loss: 0.19397978 Test loss: 0.35496286 Linearization Err: 2.4581398e-08
Train loss: 0.18317105 Test loss: 0.34305897 Linearization Err: 2.4947138e-08
Train loss: 0.17320807 Test loss: 0.33191434 Linearization Err: 2.5291145e-08
Train loss: 0.16400975 Test loss: 0.32146275 Linearization Err: 2.5615215e-08
Train loss: 0.15550354 Test loss: 0.31164753 Linearization Err: 2.5921302e-08
Train loss: 0.14762588 Test loss: 0.3024178 Linearization Err: 2.6208788e-08
Train loss: 0.14032039 Test loss: 0.29372907 Linearization Err: 2.6477643e-08
Train loss: 0.13353682 Test loss: 0.2855414 Linearization Err: 2.6731392e-08
Train loss: 0.1272303 Test loss: 0.2778187 Linearization Err: 2.6971588e-08
Train loss: 0.121361 Test loss: 0.2705278 Linearization Err: 2.719808e

Train loss: 0.1458347 Test loss: 0.21798748 Linearization Err: 9.957581e-09
Train loss: 0.14127256 Test loss: 0.21360528 Linearization Err: 1.021985e-08
Train loss: 0.13693045 Test loss: 0.20941016 Linearization Err: 1.0472954e-08
Train loss: 0.13279203 Test loss: 0.20538752 Linearization Err: 1.07170255e-08
Train loss: 0.12884292 Test loss: 0.20152472 Linearization Err: 1.0952444e-08
Train loss: 0.12506975 Test loss: 0.19781083 Linearization Err: 1.117981e-08
Train loss: 0.121461004 Test loss: 0.19423579 Linearization Err: 1.1400029e-08
Train loss: 0.118006155 Test loss: 0.19079116 Linearization Err: 1.1612906e-08
Train loss: 0.11469568 Test loss: 0.187469 Linearization Err: 1.1818689e-08
Train loss: 0.11152086 Test loss: 0.1842618 Linearization Err: 1.20179715e-08
Train loss: 0.10847393 Test loss: 0.18116343 Linearization Err: 1.2211119e-08
Train loss: 0.10554771 Test loss: 0.17816809 Linearization Err: 1.239855e-08
Train loss: 0.10273554 Test loss: 0.1752705 Linearization Err: 1.258

Train loss: 0.06264225 Test loss: 0.07905552 Linearization Err: 1.9910598e-09
Train loss: 0.06201596 Test loss: 0.07849797 Linearization Err: 2.0439659e-09
Train loss: 0.061401647 Test loss: 0.07795129 Linearization Err: 2.0960644e-09
Train loss: 0.06079894 Test loss: 0.07741488 Linearization Err: 2.1475641e-09
Train loss: 0.060207468 Test loss: 0.07688842 Linearization Err: 2.198316e-09
Train loss: 0.05962687 Test loss: 0.07637181 Linearization Err: 2.2484112e-09
Train loss: 0.059056927 Test loss: 0.07586475 Linearization Err: 2.2979294e-09
Train loss: 0.058497194 Test loss: 0.075366855 Linearization Err: 2.3468618e-09
Train loss: 0.057947468 Test loss: 0.07487769 Linearization Err: 2.3952573e-09
Train loss: 0.057407405 Test loss: 0.07439722 Linearization Err: 2.4432902e-09
Train loss: 0.056876812 Test loss: 0.07392501 Linearization Err: 2.4907705e-09
Train loss: 0.056355316 Test loss: 0.07346103 Linearization Err: 2.5381692e-09
Train loss: 0.055842765 Test loss: 0.07300481 Linearizat

Train loss: 0.007262076 Test loss: 0.009859583 Linearization Err: 1.0524748e-09
Train loss: 0.007262073 Test loss: 0.009859563 Linearization Err: 1.0524739e-09
Train loss: 0.0072620735 Test loss: 0.0098595405 Linearization Err: 1.0524743e-09
Train loss: 0.0072620735 Test loss: 0.009859526 Linearization Err: 1.0524738e-09
Train loss: 0.007262072 Test loss: 0.009859513 Linearization Err: 1.0524736e-09
Train loss: 0.007262072 Test loss: 0.009859496 Linearization Err: 1.0524734e-09
Train loss: 0.007262075 Test loss: 0.009859482 Linearization Err: 1.052473e-09
Train loss: 0.0072620735 Test loss: 0.00985947 Linearization Err: 1.0524733e-09
Train loss: 0.007262075 Test loss: 0.009859458 Linearization Err: 1.0524732e-09
Train loss: 0.007262075 Test loss: 0.009859439 Linearization Err: 1.0524733e-09
Train loss: 0.0072620786 Test loss: 0.009859433 Linearization Err: 1.0524729e-09
Train loss: 0.0072620795 Test loss: 0.009859427 Linearization Err: 1.0524729e-09
Train loss: 0.007262076 Test loss: 0

In [7]:
import os
# np.savez('results12.npz', lin_err0 = lin_err_vec, overlap0 = overlap_vec, test_loss0 = test_loss_vec, train_loss0 = train_loss_vec, eig_vals0 = eig_vals_vec, eig_vecs_diff0 = eig_vecs_diffs, Hg0 = Hg_product_vec)
np.savez('approx_res.npz', rel0 = rel_diff_vec, dif0 = diff_vec, test_loss0 = test_loss_vec, train_loss0 = train_loss_vec)
print('Done')



Done
