<a href="https://colab.research.google.com/github/mcnica89/DNNs/blob/main/Simulating_Sparse_DNNs_on_initialization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Import packages we use!  mostly jax (which is like numpy but beefed up)
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import nn
import time
import math

import numpy as np

import pandas as pd  #data frames for use in plotting
from plotnine import * #this is the ggplot package!

In [2]:
#These helper functions involve the L2 norm ||v||^2 of a vector, 
# BUT work assuming that you have many vectors in a single array, so that 
# v is of shape (dim, N_samples) where dim is the dimension and N_samples is the number of samples 
def norm2(v):
  return jnp.einsum("is,is->s",v,v)

def norm(v):
 return jnp.sqrt(norm2(v))

def unit_vector(v):
  return v/norm(v)

# Simulating a simple fully connected network

In [None]:
N_width = 2**6 #network width = number of neurons per layer
N_depth = 2**6 #network depth = number of layers
N_samples =  int(2**32/(N_width**2 * N_depth)) #number of samples to run simulatenously
#(we use a number of samples that grows/shrinks with depth and width so that we can automatically fill the computers memory with samples)

key = random.PRNGKey(int(time.time())) #random key for generating random numbers
  
#Initial input vector is a random unit vector
input = unit_vector(random.normal(key,(N_width,N_samples), dtype=jnp.float64))
keys = random.split(key, N_depth) #get a random key used to generate each layer of the network

z = input
for layer in range(N_depth): 
  #Setup the weight matrix and normalize by the fan-in
  W = random.normal(keys[layer],(N_width,N_width,N_samples),dtype=jnp.float64)*math.sqrt(2/N_width)

  phi = nn.relu(z) #vector after applying the activation function
  z = jnp.einsum("ijs,js->is",W,phi) #Apply the weight matrix W in each sample

output = z #reutrn a vector of shape (N_width, N_samples) with the outputs!

In [None]:
output_norm2 = norm2(output) #take the norm of each simulation to get a vector of shape (N_samples,) with the output
print(f'Mean norm2 of output is {jnp.mean(output_norm2)} using {N_samples} samples')
print(f'Var norm2 of output is {jnp.var(output_norm2)} using {N_samples} samples')
ln_output_norm2 = jnp.log(output_norm2)
print(f'Mean ln(norm2) of output is {jnp.mean(ln_output_norm2)} using {N_samples} samples')
print(f'Var ln(norm2) of output is {jnp.var(ln_output_norm2)} using {N_samples} samples')

print(f'Var + 2*Mean of ln(norm2) is {jnp.var(ln_output_norm2)+2*jnp.mean(ln_output_norm2)}') #This should be close to 0!

Mean norm2 of output is 0.9477857351303101 using 16384 samples
Var norm2 of output is 38.29972839355469 using 16384 samples
Mean ln(norm2) of output is -2.565073013305664 using 16384 samples
Var ln(norm2) of output is 5.311448574066162 using 16384 samples
Var + 2*Mean of ln(norm2) is 0.18130254745483398


## Plotting the output with ggplot

In [None]:
#def NormalCDF(mean,var,x):
#  return 0.5*math.erfc(-(x-mean)/math.sqrt(2*var))

#def NormalPDF(mean,var,x):
#  return math.exp(-(x-mean)**2/(2*var))/math.sqrt(2*math.pi*var)


#df = pd.DataFrame(jnp.log(output_norm2),columns=['ln_norm2']) #load data into a DataFrame object
#plot = (
#       ggplot(df, aes(x='ln_norm2')) +
#       labs(title='Histogram of log(norm2(output))') +
#       geom_histogram(aes(y='..density..'), color='black', bins=30) +
#       stat_function(fun=lambda x: NormalPDF(-2.5*N_depth/N_width,5*N_depth/N_width,x), color='red')
#)

In [None]:
#print(plot)

# Sparse Networks

In [27]:
def simulate_sparseDNNs(key,N_per_neuron,N_width,N_depth,N_samples,sparse_type='bernoulli'):
  '''Simulate sparse DNNs that have N_per_neuron connections per neuron (on average)'''
  
  #Initial input vector is a random unit vector
  input = unit_vector(random.normal(key,(N_width,N_samples), dtype=jnp.float64))
  keys = random.split(key, 2*N_depth) #get a random key used to generate each layer of the network

  z = input
  for layer in range(N_depth): 
  
    #Setup the masks for the sparsity
    if sparse_type == 'const-fan-in':
      M = jnp.concatenate( (jnp.ones((N_width, N_per_neuron,N_samples)), jnp.zeros((N_width, N_width-N_per_neuron,N_samples))), axis=1 )
      M = random.permutation( keys[2*layer], M, axis=1, independent = True )
    elif sparse_type == 'bernoulli':
      p = N_per_neuron/N_width #This is the right probability to average N_per_neuron for each neuron
      M = random.bernoulli(keys[2*layer], p=p, shape=(N_width,N_width,N_samples))
      #REMARK: if N_per_neuron is close to zero, there is chance to get fan-in's=0 to which will mess up the mean 1 property
    elif sparse_type == 'const-per-layer':
      N_per_layer = N_per_neuron*N_width #Number of neurons in each layer
      #....something with random permutation again...
      #similar to M = jnp.concatenate( (jnp.ones((N_per_neuron,N_width,N_samples)), jnp.zeros((N_width-N_per_neuron,N_width,N_samples))), axis=0 )
      #M = random.permutation( keys[2*layer], M, axis=0, independent = True )

    #Calculate the fan-in from the mask and then normalize W accordingly
    #fan_in = jnp.sum(M,axis=0) 
    #fan_in = jnp.where(fan_in == 0, 1, fan_in) #set places with fan_in=0 to 1 to avoid divide by zero error in weight normalization!

    #Setup the weight matrix and normalize by the fan-in
    W = random.normal(keys[2*layer+1],(N_width,N_width,N_samples),dtype=jnp.float64)*math.sqrt(2/N_per_neuron) #/jnp.sqrt(fan_in) #weight matrices

    phi = nn.relu(z) #vector after applying the activation function
    z = jnp.einsum("ijs,js->is",M*W,phi) #Apply the weight matrix W in each sample

  output = z #vector of shape (N_width, N_samples) with the outputs!
  return(output)

In [47]:
N_width = 2**6 #network width = number of neurons per layer
N_depth = 1 #network depth = number of layers
N_samples = int(2**28/(N_width**2 * N_depth)) #number of samples to run simulatenously
N_per_neuron = 12 #2**2
sparse_type_list = ['bernoulli','const-fan-in']


ber_theory_var = N_depth*(5*N_width - 8 + 18*N_width/N_per_neuron)/N_width/(N_width+2)
#print(f'Ber Theory predicts:\n {ber_theory_var}')
cfi_theory_var = ber_theory_var - 3*(N_width - N_per_neuron)/N_width/N_per_neuron/(N_width+2)
#print(f'Ber Theory predicts:\n {cfi_theory_var}')
theory = [ber_theory_var, cfi_theory_var]

N_trials = 2**3
keys = random.split( random.PRNGKey(int(time.time())), N_trials) #random key for generating random numbers

output_norm2 = np.zeros((2,N_trials, N_samples))
for trial in range(N_trials):
  for i,sparse_type in enumerate(sparse_type_list):
    output = simulate_sparseDNNs(keys[i], N_per_neuron,N_width,N_depth,N_samples,sparse_type)
    output_norm2[i,trial,:] = norm2(output)
  
print(f'Num simulations={N_trials*N_samples}')
print(f'N_width={N_width}')
print(f'N_depth={N_depth}')
print(f'N_per_neuron={N_per_neuron}')
output_ln_norm2 = np.log(output_norm2) #take the norm of each simulation to get a vector of shape (N_samples,) with the output
for i,sparse_type in enumerate(sparse_type_list):
  print(f'----Results for {sparse_type}-----')
  print(f'  Mean norm2:\n {np.nanmean(output_norm2[i,:,:])}')
  var = np.nanvar(output_norm2[i,:,:])
  print(f'  Var norm2:\n {var}')
  print(f'  Theory prediction:\n {theory[i]}')
  print(f'  % error:\n  {100*(theory[i] - var)/var:.2f}%')
  #print(f'Var ln(norm2): {np.nanvar(output_ln_norm2[i,:,:])}')
  #print(f'Var + 2*Mean of ln(norm2): {np.nanvar(output_ln_norm2[i,:,:])+2*np.mean(output_ln_norm2[i,:,:])}') #This should be close to 0!


#ber_theory_var = N_depth*(5*N_width - 8 + 18*N_width/N_per_neuron)/N_width/(N_width+2)
#print(f'Ber Theory predicts:\n {ber_theory_var}')
#cfi_theory_var = ber_theory_var - 3*(N_width - N_per_neuron)/N_width/N_per_neuron/(N_width+2)
#print(f'Ber Theory predicts:\n {cfi_theory_var}')


Num simulations=524288
N_width=64
N_depth=1
N_per_neuron=12
----Results for bernoulli-----
  Mean norm2:
 1.0009366690351271
  Var norm2:
 0.09670697405598401
  Theory prediction:
 0.09659090909090909
  % error:
  -0.12%
----Results for const-fan-in-----
  Mean norm2:
 0.9989254558533958
  Var norm2:
 0.09373992462839252
  Theory prediction:
 0.09351325757575757
  % error:
  -0.24%


In [None]:
print("Empircal Variance Ratio:")
print(np.nanvar(output_ln_norm2[1,:,:])/np.nanvar(output_ln_norm2[0,:,:]))

Empircal Variance Ratio:
0.8596238174938919
Theoretical Variance Ratio:
0.7795918367346939


In [None]:
N_width = 50 #2**5 #network width = number of neurons per layer
N_depth = 20 #2**5 #network depth = number of layers
N_per_neuron = 5 #2**3

theory_var = N_depth*(5*N_width - 8 + 18*N_width/N_per_neuron)/N_width/(N_width+2)
print(theory_var)
print(5*N_depth/ N_width)

3.246153846153846
2.0


In [None]:
#df = pd.DataFrame(jnp.log(output_ln_norm2[0,0,:]),columns=['ln_norm2']) #load data into a DataFrame object
#plot = (
#       ggplot(df, aes(x='ln_norm2')) +
#       labs(title='Histogram of log(norm2(output)) BERNOULLI') +
#       geom_histogram(aes(y='..density..'), color='black', bins=30)
#       #stat_function(fun=lambda x: NormalPDF(-2.5*N_depth/N_width,5*N_depth/N_width,x), color='red')
#)

In [None]:
#df = pd.DataFrame(jnp.log(output_ln_norm2[1,0,:]),columns=['ln_norm2']) #load data into a DataFrame object
#plot = (
#       ggplot(df, aes(x='ln_norm2')) +
#       labs(title='Histogram of log(norm2(output))') +
#       geom_histogram(aes(y='..density..'), color='black', bins=30)
       #stat_function(fun=lambda x: NormalPDF(-2.5*N_depth/N_width,5*N_depth/N_width,x), color='red')
#)