Copyright 2021 The Powerpropagation Authors. All rights reserved

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
#@title Imports

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import tensorflow as tf
import tensorflow_probability as tfp

#@title Installing and Importing Dependencies

print('Installing necessary libraries...')

def install_libraries():
  !pip install dm-sonnet

import IPython

with IPython.utils.io.capture_output() as captured:
  install_libraries()

import sonnet as snt

### Implementation

In [None]:
#@title Training & Pruning functions

@tf.function
def train_fn(model, train_x, train_y, optimizer_to_use):
  with tf.GradientTape() as tape:
    loss, stats = model.loss(train_x, train_y)

  train_vars = model.trainable_variables()
  train_grads = tape.gradient(loss, train_vars)
  optimizer_to_use.apply(train_grads, train_vars)

  return stats

@tf.function
def eval_fn(model, eval_x, eval_y):
  _, stats = model.loss(eval_x, eval_y)

  return stats

def _bottom_k_mask(percent_to_keep, condition):
  how_many = int(percent_to_keep * condition.size)
  top_k = tf.nn.top_k(condition, k=how_many)

  mask = np.zeros(shape=condition.shape, dtype=np.float32)
  mask[top_k.indices.numpy()] = 1

  assert np.sum(mask) == how_many

  return mask

def prune_by_magnitude(percent_to_keep, weight):
  mask = _bottom_k_mask(percent_to_keep, np.abs(weight.flatten()))

  return mask.reshape(weight.shape)

In [None]:
#@title Initialisers

class PowerPropVarianceScaling(snt.initializers.VarianceScaling):

  def __init__(self, alpha, *args, **kwargs):
    super(PowerPropVarianceScaling, self).__init__(*args, **kwargs)
    self._alpha = alpha

  def __call__(self, shape, dtype):
    u = super(PowerPropVarianceScaling, self).__call__(shape, dtype).numpy()

    return tf.sign(u) * tf.pow(tf.abs(u), 1.0/self._alpha)

In [None]:
#@title Models

class PowerPropLinear(snt.Linear):
  """Powerpropagation Linear module."""
  def __init__(self, alpha, *args, **kwargs):
    super(PowerPropLinear, self).__init__(*args, **kwargs)
    self._alpha = alpha

  def get_weights(self):
    return tf.sign(self.w) * tf.pow(tf.abs(self.w), self._alpha)

  def __call__(self, inputs, mask=None):
    self._initialize(inputs)
    params = self.w * tf.pow(tf.abs(self.w), self._alpha-1)

    if mask is not None:
      params *= mask

    outputs = tf.matmul(inputs, params) + self.b

    return outputs

class MLP(snt.Module):
  """A multi-layer perceptron module."""

  def __init__(self, alpha, w_init, output_sizes=[300, 100, 10], name='MLP'):

    super(MLP, self).__init__(name=name)
    self._alpha = alpha
    self._w_init = w_init

    self._layers = []
    for index, output_size in enumerate(output_sizes):
      self._layers.append(
          PowerPropLinear(
              output_size=output_size,
              alpha=alpha,
              w_init=w_init,
              name="linear_{}".format(index)))

  def get_weights(self):
    return [l.get_weights().numpy() for l in self._layers]

  def __call__(self, inputs, masks=None):
    num_layers = len(self._layers)

    for i, layer in enumerate(self._layers):
      if masks is not None:
        inputs = layer(inputs, masks[i])
      else:
        inputs = layer(inputs)
      if i < (num_layers - 1):
        inputs = tf.nn.relu(inputs)

    return inputs


class DensityNetwork(snt.Module):
  """Produces categorical distribution."""

  def __init__(self, network=None, name="DensityNetwork", *args, **kwargs):
    super(DensityNetwork, self).__init__(name=name)
    self._network = network

  def __call__(self, inputs, masks=None, *args, **kwargs):
    outputs = self._network(inputs, masks, *args, **kwargs)

    return tfp.distributions.Categorical(logits=outputs), outputs

  def trainable_variables(self):
    return self._network.trainable_variables

  def get_weights(self):
    return self._network.get_weights()

  def loss(self, inputs, targets, masks=None, *args, **kwargs):
    dist, logits = self.__call__(
        inputs, masks, *args, **kwargs)
    loss = -tf.reduce_mean(dist.log_prob(targets))

    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(tf.argmax(logits, axis=1), targets), tf.float32))

    return loss, {'loss': loss, 'acc': accuracy}

### Training

In [None]:
#@title Training configuration

model_seed = 0  #@param

alphas = [1.0, 2.0, 3.0, 4.0, 5.0]  #@param
init_distribution = 'truncated_normal'
init_mode = 'fan_in'
init_scale = 1.0

# Fixed values taken from the Lottery Ticket Hypothesis paper
train_batch_size = 60
num_train_steps = 50000
learning_rate = 0.1

report_interval = 2500

tf.random.set_seed(model_seed)

In [None]:
#@title Get data

(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

train_x = train_x.reshape([60000, 784]).astype(np.float32) / 255.0
test_x = test_x.reshape([10000, 784]).astype(np.float32) / 255.0

train_y = train_y.astype(np.int64)
test_y = test_y.astype(np.int64)

# Reserve some data for a validation set
valid_x = train_x[-5000:]
valid_y = train_y[-5000:]
train_x = train_x[:-5000]
train_y = train_y[:-5000]

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))

# Sample random batches of data from the entire training set
train_iterator = iter(
    train_dataset.repeat().shuffle(10000).batch(train_batch_size))

In [None]:
#@title Training set-up

model_types = []
models = []
n_models = len(alphas)

for alpha in alphas:
  w_init = PowerPropVarianceScaling(alpha)
  models.append(DensityNetwork(MLP(alpha=alpha, w_init=w_init)))
  if alpha > 1.0:
    model_types.append('Powerprop. ($\\alpha={}$)'.format(alpha))
  else:
    model_types.append('Baseline')

# Initalise variables
for m in models:
  m(valid_x)

initial_weights = [m.get_weights() for m in models]


optimizers = [snt.optimizers.SGD(learning_rate=learning_rate)
              for _ in range(n_models)]

In [None]:
#@title Training loop

all_train_stats = [[] for _ in range(n_models)]
all_valid_stats = [[] for _ in range(n_models)]

for step in range(num_train_steps+1):
  train_x_batch, train_y_batch = next(train_iterator)

  for m_id, model in enumerate(models):
    all_train_stats[m_id].append(
        train_fn(model, train_x_batch, train_y_batch, optimizers[m_id]))

  if step % report_interval == 0:
    for m_id, model in enumerate(models):
      print('[Train Step {}, Alpha {}] Loss: {:1.3f}. Acc: {:1.3f}'.format(
            step, 
            alphas[m_id], 
            all_train_stats[m_id][-1]['loss'], 
            all_train_stats[m_id][-1]['acc']))
      all_valid_stats[m_id].append(eval_fn(model, valid_x, valid_y))

      print('[Eval Step {}, Alpha {}] Loss: {:1.3f}. Acc: {:1.3f}'.format(
            step, 
            alphas[m_id], 
            all_valid_stats[m_id][-1]['loss'], 
            all_valid_stats[m_id][-1]['acc']))
    print('---')

### Results

In [None]:
#@title Pruning

final_weights = [m.get_weights() for m in models]

eval_at_sparsity_level = np.geomspace(0.01, 1.0, 20).tolist()
acc_at_sparsity = [[] for _ in range(n_models)]

for p_to_use in eval_at_sparsity_level:

  # Half the sparsity at output layer
  percent = 2*[p_to_use] + [min(1.0, p_to_use*2)]

  for m_id, model_to_use in enumerate(models):
    masks = []
    for i, w in enumerate(final_weights[m_id]):
      masks.append(prune_by_magnitude(percent[i], w))

    _, stats = model_to_use.loss(test_x, test_y, masks=masks)

    acc_at_sparsity[m_id].append(stats['acc'].numpy())
    print(' Performance @ {:1.0f}% of weights [Alpha {}]: Acc {:1.3f} NLL {:1.3f} '.format(
        100*p_to_use, alphas[m_id], stats['acc'], stats['loss']))
  print('---')

In [None]:
sns.set_style("whitegrid")
sns.set_context("paper")

#@title Plot
f, ax = plt.subplots(1, 1, figsize=(7,5))

for acc, label in zip(acc_at_sparsity, model_types):
  ax.plot(eval_at_sparsity_level, acc, label=label, marker='o', lw=2)

ax.set_xscale('log')
ax.set_xlim([1.0, 0.01])
ax.set_ylim([0.0, 1.0])
ax.legend(frameon=False)
ax.set_xlabel('Weights Remaining (%)')
ax.set_ylabel('Test Accuracy (%)')

sns.despine()