# Setup

In [None]:
!pip install tensorflow_model_optimization

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
import numpy as np
from matplotlib import pyplot as plt
import tempfile

# Data Preprocessing (from previous [tutorial](https://substack.com/inbox/post/135885628))

In [None]:
# download tiny_shakespeare dataset
dataset_dict = tfds.load(name='tiny_shakespeare')

# get train/validation/test data
train_data = dataset_dict['train']
validation_data = dataset_dict['validation']
test_data = dataset_dict['test']

'''
each dataset contains 1 example of string type
split the string into a sequence of Unicode code points
'''
train_dataset = train_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))
validation_dataset = validation_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))
test_dataset = test_data.map(lambda x: tf.strings.unicode_split(x['text'], 'UTF-8'))

vocabulary = sorted(set(next(iter(train_dataset)).numpy()))

ids_to_tokens = {id:token for id, token in enumerate(vocabulary)}
tokens_to_ids = {token:id for id, token in enumerate(vocabulary)}

keys_tensor = tf.constant(list(tokens_to_ids.keys()))
vals_tensor = tf.constant(list(tokens_to_ids.values()))
tokens_to_ids_loopup = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), default_value=-1
)

def decode(line):
  return ''.join(ids_to_tokens[id].decode('UTF-8') for id in line)

def tokenize(line):
  return tokens_to_ids_loopup.lookup(line)

# configuration to store model parameters

MODEL_CONFIG = {
    "vocab_size": len(vocabulary),
    'batch_size': 32,
    'context_window': 16,
    'd_model': 128,
    'epochs': 10,
    'n_heads': 8,
    'n_layers': 4,
}

# tokenize and create sequence
def process_dataset(dataset, model_config, is_train=False):
  dataset = dataset.map(lambda x: tokenize(x))
  # shift the sequence by 1
  dataset = dataset.map(lambda x: (x[:-1], x[1:]))
  dataset = dataset.unbatch()
  dataset = dataset.batch(model_config['context_window'], drop_remainder=True)
  dataset = dataset.batch(model_config['batch_size'], drop_remainder=True)
  return dataset

processed_train_dataset = process_dataset(train_dataset, MODEL_CONFIG, True)
processed_validation_dataset = process_dataset(validation_dataset, MODEL_CONFIG)
processed_test_dataset = process_dataset(test_dataset, MODEL_CONFIG)

In [None]:
# [OPTIONAL] store training results in google drive
from google.colab import drive
drive.mount('/content/drive')
model_dir = '/content/drive/MyDrive/Colab Notebooks/llama'

# BabyLLaMA Model (from previous [tutorial](https://substack.com/inbox/post/135885628))

In [11]:
# model return logits rather than normalized probabilities
def custom_loss(y_true, y_pred):
    return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

def generate_text(model, config, sentence_count=5, max_new_tokens=50):
    idx = tf.zeros([sentence_count, 1], dtype=tf.int64)
    for _ in range(max_new_tokens):
        logits  = model(idx[:, -config['context_window']:])
        # get the distribution of the last token
        p = logits[:, -1,:]
        # use the distribution p to sample the next token
        idx_next = tf.random.categorical(p, num_samples=1, dtype=tf.int64)
        idx = tf.concat([idx, idx_next], axis=-1)
    return [decode(line) for line in idx.numpy()]

## RMSNormLayer

added `get_prunable_weights` for future pruning

In [14]:
class RMSNormLayer(tf.keras.layers.Layer):
  def __init__(self, layer_shape, eps=1e-8, bias=False):
    super().__init__()
    self.scale = self.add_weight("scale", shape=layer_shape, initializer="ones", trainable=True)
    self.eps = eps
    if bias:
      self.bias = self.add_weight("bias", shape=layer_shape, initializer="zeros", trainable=True)
    else:
      self.bias = None

  def get_prunable_weights(self):
    return self.trainable_weights

  def call(self, x):
    # Frobenius norm
    fro_norm = tf.norm(x, ord='fro', axis=[1, 2]) * tf.math.pow(tf.cast(tf.reduce_prod(tf.shape(x[0])), tf.float32), -0.5)
    normalized = x / (tf.expand_dims(tf.expand_dims(fro_norm, -1), -1) + self.eps)
    scaled = tf.expand_dims(self.scale[:x.shape[1], :], 0) * normalized

    if self.bias is not None:
      return scaled + self.bias
    else:
      return scaled

## RoPE
added `get_prunable_weights` for future pruning

In [18]:
def generate_square_subsequent_mask(size):
  mask = tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask

class RoPEAttentionLayer(tf.keras.layers.Layer):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.w_q = tf.keras.layers.Dense(config['d_model'], use_bias=False)
    self.w_k = tf.keras.layers.Dense(config['d_model'], use_bias=False)
    self.w_v = tf.keras.layers.Dense(config['d_model'], use_bias=False)

    self.multihead = tf.keras.layers.MultiHeadAttention(config['n_heads'], config['d_model'], dropout=0.1)

    self.rotary_matrix = self.get_rotary_matrix(config['context_window'], config['d_model'])

  def get_prunable_weights(self):
    return self.trainable_weights

  def get_rotary_matrix(self, context_window, embedding_dim):
    matrix = tf.zeros((context_window, embedding_dim, embedding_dim))

    for position in range(context_window):
      for i in range(embedding_dim // 2):
        theta = 10000. ** (-2. * (i - 1) / embedding_dim)
        m_theta = position * theta

        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i]], [tf.math.cos(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i, 2 * i + 1]], [-tf.math.sin(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i]], [tf.math.sin(m_theta)])
        matrix = tf.tensor_scatter_nd_update(matrix, [[position, 2 * i + 1, 2 * i + 1]], [tf.math.cos(m_theta)])

    return matrix

  def call(self, x):
    _, context_window, _ = x.shape
    q = self.w_q(x)
    k = self.w_k(x)
    v = self.w_v(x)
    q_out = tf.linalg.matmul(tf.transpose(q, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    q_out = tf.transpose(q_out, perm=[1, 0, 2])
    k_out = tf.linalg.matmul(tf.transpose(k, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    k_out = tf.transpose(k_out, perm=[1, 0, 2])
    v_out = tf.linalg.matmul(tf.transpose(v, perm=[1, 0, 2]), self.rotary_matrix[:context_window, ...])
    v_out = tf.transpose(v_out, perm=[1, 0, 2])

    activations = self.multihead(
        q_out, k_out, v_out,
        attention_mask=generate_square_subsequent_mask(context_window),
        return_attention_scores=False,
        use_causal_mask=True
    )
    return activations


## SwiGLU
added `get_prunable_weights` for future pruning

In [19]:
class SwiGLU(tf.keras.layers.Layer):
  def __init__(self, size):
    super().__init__()
    self.linear_gate = tf.keras.layers.Dense(units=size)
    self.linear = tf.keras.layers.Dense(units=size)
    self.beta = self.add_weight(name='beta', shape=(1,), initializer='random_normal', trainable=True)

  def get_prunable_weights(self):
    return self.trainable_weights

  def call(self, x):
    swish_gate = self.linear_gate(x) * tf.sigmoid(self.beta * self.linear_gate(x))
    swi_glu = swish_gate * self.linear(x)
    return swi_glu

## Putting AttentionBlock
added `get_prunable_weights` for future pruning

In [20]:
class AttentionBlock(tf.keras.layers.Layer):
  def __init__(self, config):
    super().__init__()
    self.rms = RMSNormLayer((config['context_window'], config['d_model']))
    self.rope_attention = RoPEAttentionLayer(config)
    self.dense = tf.keras.models.Sequential(
      [tf.keras.layers.Dense(units=config['d_model']),
      SwiGLU(config['d_model']),
      ]
    )

  def get_prunable_weights(self):
    return self.trainable_weights

  def call(self, x):
    x = self.rms(x)
    x = x + self.rope_attention(x)

    x = self.rms(x)
    x = x + self.dense(x)

    return x

## BabyLLaMA Model
**Difference**: converted the model from `subclassing` model to `Functional API`.  


In [None]:
def BabyLLaMAFunctionalModel(config):
    inputs = tf.keras.layers.Input(shape=(None,))

    # Embedding layer
    embedded = tf.keras.layers.Embedding(input_dim=config['vocab_size'], output_dim=config['d_model'])(inputs)

    # Attention blocks
    x = embedded
    for _ in range(config['n_layers']):
        x = AttentionBlock(config)(x)

    # Dense layers
    x = tf.keras.layers.Dense(units=config['d_model'])(x)
    x = SwiGLU(config['d_model'])(x)
    logits = tf.keras.layers.Dense(units=config['vocab_size'])(x)

    # 3. Create the model
    model = tf.keras.models.Model(inputs=inputs, outputs=logits)
    return model

# Pruning (New Content)

First we load the original BabyLLaMA model (functinoal API version). `tf.keras.models.load_model()` resotres the model's architecture, weights, and training configuration.

Note that we need to add `custom_objects` here. Our model includes custom-defined components (custom layers, custom activation functions, custom losses, custom metrics, etc) In this case, Keras won't be able to recognize them by default when we load the model, because these custom components are not part of the standard Keras library. This is where custom_objects comes into play.

`custom_objects` is a dictionary that maps the names of our custom components to their respective Python objects (classes or functions), which tells Keras how to handle and instantiate these components when restoring the model.

In our case, we have custom layers `AttentionBlock`, `SwiGLU`, `RMSNormLayer` and custom loss `custom_loss`

In [77]:
loaed_model = BabyLLaMAFunctionalModel(MODEL_CONFIG)
keras_file = os.path.join(model_dir, 'unpruned_model.keras')
loaded_model = tf.keras.models.load_model(keras_file,
            custom_objects={'AttentionBlock': AttentionBlock,
                            'SwiGLU': SwiGLU,
                            'RMSNormLayer': RMSNormLayer,
                            'custom_loss': custom_loss})

We inpsect the model's weights. Note that currently none of the weight is 0.

In [71]:
num_zeros = sum([np.sum(k.numpy() == 0) for k in model.trainable_weights])
print(f"Number of zero weights: {num_zeros}")
num_non_zeros = sum([np.sum(k.numpy() != 0) for k in model.trainable_weights])
print(f"Number of non zero weights: {num_non_zeros}")
print(f"Zero rate: {num_zeros/(num_zeros+num_non_zeros)}")

Number of zero weights: 0
Number of non zero weights: 2579142
Zero rate: 0.0


We evaluate the model performance again test data. The model accuray will be used as baseline.

In [81]:
loaded_model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
loaded_model.evaluate(processed_test_dataset)



[0.7568248510360718, 0.7673249244689941]

Now we define a pruned_model.

`tfmot.sparsity.keras.prune_low_magnitude` is a function provided by TensorFlow Model Optimization Toolkit (TF-MOT) to apply weight pruning to Keras models and layers. The primary purpose of weight pruning is to set certain weights in the model to zero, thereby reducing the number of effective parameters and making the model smaller and faster. This is useful for deploying models on resource-constrained devices.

When applied, prune_low_magnitude function modifies a model or layer to include the necessary operations for pruning and sets the initial configuration.


In this example, we set the `initial_sparsity` to 0.5 (50% of weights will be 0) and the `final_sparsity` to 0.8 (80% of weights will be 0).


`tfmot.sparsity.keras.PruningSummaries` is a Keras callback for adding pruning summaries to tensorboard.

In [82]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# each epoch contains 1960 batches
end_step = 1960 * MODEL_CONFIG['epochs']

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

pruned_model = prune_low_magnitude(loaded_model, **pruning_params)

Now we recompile and train the pruned model

`tfmot.sparsity.keras.UpdatePruningStep`
* This is a callback that updates the pruning step during training. When we're training a pruned model, we need to inform the pruning algorithm about the current step so that it can decide whether to prune weights or not based on the pruning_schedule we've defined.
* Essentially, every time an epoch finishes, this callback updates an internal counter that's used by the pruning algorithm to determine the current sparsity level based on our defined schedule (e.g., PolynomialDecay).
* It's crucial to include this callback during training to ensure that pruning happens correctly.



In [None]:
pruned_model.compile(optimizer='adam',
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir, update_freq='epoch'),
]

pruned_model.fit(processed_train_dataset,
                  batch_size=MODEL_CONFIG['batch_size'], epochs=MODEL_CONFIG['epochs'], validation_data=processed_validation_dataset,
                  callbacks=callbacks)

In [None]:
%load_ext tensorboard
%tensorboard --logdir "$log_dir"

Now we check again to see the percentage of 0 weights. It's now 80%

In [85]:
num_zeros = sum([np.sum(k.numpy() == 0) for k in pruned_model.trainable_weights])
print(f"Number of zero weights: {num_zeros}")
num_non_zeros = sum([np.sum(k.numpy() != 0) for k in pruned_model.trainable_weights])
print(f"Number of non zero weights: {num_non_zeros}")
print(f"Prune rate: {num_zeros/(num_zeros+num_non_zeros)}")

Number of zero weights: 2063145
Number of non zero weights: 515997
Prune rate: 0.7999346294232733


Let's check the performance of the pruned model. We can see the performance impact is minimal.

In [87]:
pruned_model.evaluate(processed_test_dataset)



[0.8442904949188232, 0.7346643805503845]

To see the compression advantages of pruning, we must utilize both `tfmot.sparsity.keras.strip_pruning` and a conventional compression method like `gzip`.

`strip_pruning` is crucial because it eliminates the `tf.Variable` elements that are only relevant during training, preventing them from inflating the model's size during inference.

On the other hand, even after pruning, the serialized weight matrices retain their original size. The distinction now is that a significant portion of these weights are zeros. This increased redundancy introduced by pruning is what standard compression algorithms can capitalize on to achieve further model compression. That's why we also need `gzip` here.

Now let's prepare a model for tensorflow

In [None]:
pruned_tf_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

_, pruned_keras_file = tempfile.mkstemp('.keras')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

And let's also prepare a model for `TFLite`

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(pruned_tf_model)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)


Here is a helper function for gzipping the model

In [90]:
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

Let's compare the mode sizes for
* unpruned TensorFlow model
* pruned TensorFlow model
* prnued TFLite model

In [91]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

Size of gzipped baseline Keras model: 9621758.00 bytes
Size of gzipped pruned Keras model: 3113416.00 bytes
Size of gzipped pruned TFlite model: 3141117.00 bytes
