##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

## Measuring Signal Properties of Various Initializations
For a random signal x ~ normal(0, 1), and a neural network denoted with f(x)=y; ensuring std(y)=1 at initialization is a common goal for popular NN initialization schemes. Here we measure signal propagation for different sparse initializations.

In [None]:
#@title Imports and Definitions
import numpy as np
import os
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import gin
from rigl import sparse_utils
from rigl.rigl_tf2 import init_utils
from rigl.rigl_tf2 import utils
from rigl.rigl_tf2 import train
from rigl.rigl_tf2 import networks
from rigl.rigl_tf2 import mask_updaters

import functools

pruning_params = utils.get_pruning_params(mode='constant', final_sparsity = 0., begin_step=int(1e10))
INPUT_SHAPE = (28, 28, 3)
class Lenet5(tf.keras.Model):

  def __init__(self,
               input_shape,
               num_classes,
               activation: str,
               hidden_sizes = (6, 16, 120, 84)):
    super(Lenet5, self).__init__()
    l = tf.keras.layers
    kwargs = {'activation': activation}
    filter_fn = lambda _: True
    wrap_fn = functools.partial(utils.maybe_prune_layer, params=pruning_params, filter_fn=filter_fn)
    self.conv1 =  wrap_fn(l.Conv2D(hidden_sizes[0], 5, input_shape=input_shape, **kwargs))
    self.pool1 = l.MaxPool2D(pool_size=(2, 2))
    self.conv2 =  wrap_fn(l.Conv2D(hidden_sizes[1], 5, input_shape=input_shape, **kwargs))
    self.pool2 = l.MaxPool2D(pool_size=(2, 2))
    self.flatten = l.Flatten()
    self.dense1 = wrap_fn(l.Dense(hidden_sizes[2], **kwargs))
    self.dense2 = wrap_fn(l.Dense(hidden_sizes[3], **kwargs))
    self.dense3 = wrap_fn(l.Dense(num_classes, **kwargs))
    self.build((1,)+input_shape)

  def call(self, inputs):
    x = inputs
    results = {}
    for l_name in ['conv1', 'pool1', 'conv2', 'pool2', 'flatten', 'dense1', 'dense2', 'dense3']:
      x = getattr(self, l_name)(x)
      results[l_name] = x 
    return results

def get_mask_random_numpy(mask_shape, sparsity):
  """Creates a random sparse mask with deterministic sparsity.

  Args:
    mask_shape: list, used to obtain shape of the random mask.
    sparsity: float, between 0 and 1.

  Returns:
    numpy.ndarray
  """
  all_ones = np.abs(np.ones(mask_shape))
  n_zeros = int(np.floor(sparsity * all_ones.size))
  rand_vals = np.random.uniform(size=mask_shape, high=range(1,mask_shape[-1]+1))
  randflat=rand_vals.flatten()
  randflat.sort()
  t = randflat[n_zeros]
  all_ones[rand_vals<=t] = 0
  return all_ones

def create_convnet(sparsity=0, weight_init_method = None, scale=2, method='fanin_normal'):
  model = Lenet5(INPUT_SHAPE, num_classes, 'relu')
  if sparsity > 0:
    all_masks = [layer.pruning_vars[0][1] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]
    for mask in all_masks:
      new_mask = tf.cast(get_mask_random_numpy(mask.shape, sparsity), dtype=mask.dtype)
      mask.assign(new_mask)
    if weight_init_method:
      all_weights = [layer.pruning_vars[0][0] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]
      for mask, param in zip(all_masks, all_weights):
        if weight_init_method == 'unit':
          new_init = init_utils.unit_scaled_init(mask, method=method, scale=scale)
        elif weight_init_method == 'layer':
          new_init = init_utils.layer_scaled_init(mask, method=method, scale=scale)
        else:
          raise ValueError
        param.assign(new_init)
  return model

Here we demonstrate how we can calculate the standard deviation of random noise at initialization for `layer-wise` scaled initialization of Liu et. al.

In [None]:
# Let's create a 95% sparse Lenet-5.
model = create_convnet(sparsity=0.95, weight_init_method='layer', scale=2, method='fanin_normal')
# Random input signal
random_input = tf.random.normal((1000,) + INPUT_SHAPE)
output_dict = model(random_input)
all_stds = []
for k in ['dense1', 'dense2', 'dense3']:
  out_dim = output_dict[k].shape[-1]
  stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)
  all_stds.append(stds)
print('Mean deviation per neuron', np.mean(np.concatenate(all_stds, axis=0)))
print('Mean deviation per output neuron', np.mean(all_stds[-1]))
print('Deviation at output', np.std(random_input))

Now we define the code above as a function and use it on a grid to plot signal propagation at different sparsities.

In [1]:
def propagate_signal(sparsity, init_method, batch_size=500):
  model = create_convnet(sparsity=sparsity, weight_init_method=init_method)
  random_input = tf.random.normal((batch_size,) + INPUT_SHAPE)
  # print(np.mean(random_input), np.std(random_input))
  output_dict = model(random_input)
  out_std = np.std(output_dict['dense3'])
  all_stds = []
  for k in ['dense1', 'dense2', 'dense3']:
    out_dim = output_dict[k].shape[-1]
    stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)
    all_stds.append(stds)
  meanstd = np.mean(np.concatenate(all_stds, axis=0))
  return meanstd, out_std

In [None]:
import itertools, collections
import numpy as np
all_results = collections.defaultdict(dict)

N_EXP = 3
for s in np.linspace(0.8,0.98,5):
  print(s)
  for  method, name in zip((None, 'unit', 'layer'), ('Masked Dense', 'Ours', 'Scaled-Init')):
    all_results[name][s] = [propagate_signal(s, method) for _ in range(N_EXP)]

In [None]:
import matplotlib.pyplot as plt

for k, v in all_results.items():
  # if k == 'Masked Dense':
  #   continue
  x = sorted(v.keys())
  y = [np.mean([vv[1] for vv in v[kk]])+1e-5 for kk in x]
  plt.plot(x, y, label=k)
plt.hlines(y=1, color='r', xmin=0, xmax=1)
plt.yscale('log')
plt.title('std(output)')
plt.legend()
plt.show()

for k, v in all_results.items():
  # if k == 'Masked Dense':
  #   continue
  x = sorted(v.keys())
  y = [np.mean([vv[0] for vv in v[kk]])+1e-5 for kk in x]
  plt.plot(x, y, label=k)
plt.yscale('log')
plt.hlines(y=1, color='r', xmin=0, xmax=1)
plt.title('mean(std_per_neuron)')
plt.legend()
plt.show()