## Imports

In [21]:
from typing import Callable, Dict
import tensorflow_hub as hub
import tensorflow_text
import tensorflow as tf
import numpy as np
import wandb
import random
import time

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [27]:
TFRECORDS_DIR = "gs://variable-length-sequences-tf/tfrecords-sentence-splitter"
BERT_MAX_SEQLEN = 512
BERT_DIM = 768
BATCH_SIZE = 64
NUM_EPOCHS = 5
NUM_RUNS = 10

## TFRecord parsing utilities

In [3]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_sentence_indices": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_num_sentences": tf.io.FixedLenFeature([], dtype=tf.int64),
    "summary_tokens_len": tf.io.FixedLenFeature([], dtype=tf.string),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [4]:
def deserialize_composite(
    serialized: bytes, type_spec: tf.RaggedTensorSpec
) -> tf.Tensor:
    """Deserializes a serialised ragged tensor."""

    serialized = tf.io.parse_tensor(serialized, tf.string)
    component_specs = tf.nest.flatten(type_spec, expand_composites=True)
    components = [
        tf.io.parse_tensor(serialized[i], spec.dtype)
        for i, spec in enumerate(component_specs)
    ]
    return tf.nest.pack_sequence_as(type_spec, components, expand_composites=True)


def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)
    features["summary_tokens"] = deserialize_composite(
        features.get("summary_tokens"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=2),
    )
    features["summary_sentence_indices"] = deserialize_composite(
        features.get("summary_sentence_indices"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )
    features["summary_tokens_len"] = deserialize_composite(
        features.get("summary_tokens_len"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )

    return features

In [24]:
class ModelInputUtils:
    def __init__(
        self,
        bert_preprocessor_path: str = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
        encoder_max_seqlen: int = BERT_MAX_SEQLEN,
        dynamic_batching: bool = True,
        data_max_seq_len: int = None
    ):
        """Initializes a BERT model input preprocessing utility class."""
        self.bert_preprocessor_path = bert_preprocessor_path
        self.preprocessor_module = hub.load(bert_preprocessor_path)
        self.encoder_max_seqlen = encoder_max_seqlen
        self.dynamic_batching = dynamic_batching
        if not self.dynamic_batching:
            max_seq_len = tf.minimum(data_max_seq_len + 2, encoder_max_seqlen)
            self.packer = hub.KerasLayer(
                self.preprocessor_module.bert_pack_inputs,
                arguments={"max_seq_length": max_seq_len}
            )
        
    
    def pack_inputs(self, batch_tokens: tf.Tensor) -> tf.Tensor:
        """Prepares inputs for the BERT encoder"""
        return self.packer([batch_tokens])

    def init_packer_and_pack_inputs(
        self, batch_tokens: tf.Tensor, batch_token_lens: tf.Tensor
    ) -> tf.Tensor:
        """Prepares inputs for the BERT encoder."""
        max_token_len = tf.reduce_max(batch_token_lens)
        packer = hub.KerasLayer(
            self.preprocessor_module.bert_pack_inputs,
            arguments={
                "seq_length": tf.math.minimum(
                    max_token_len + 2, self.encoder_max_seqlen
                )
            },
        )
        return packer([batch_tokens])

    def unravel_ragged_batch(self, ragged_batch, ragged_idx, batch_lens, batch_size):
        """Flattens out a batch of ragged tensors by one level."""
        # create indices for each tensor in the batch
        # for entries which have multiple ragged tensors, repeat their
        # index once for each tensor in the entry
        batch_idx = tf.repeat(tf.range(batch_size), batch_lens, axis=0)

        # calculate length of the unravelled batch
        unravelled_len = tf.reduce_sum(batch_lens)

        # create a vector with alternating batch index and ragged tensor index
        gather_nd_idx = tf.dynamic_stitch(
            indices=[
                tf.range(0, (unravelled_len * 2) - 1, 2, dtype=tf.int32),
                tf.range(1, unravelled_len * 2, 2, dtype=tf.int32),
            ],
            data=[batch_idx, ragged_idx.flat_values],
        )

        # reshape the vector to obtain a unravelled_len x 2 matrix of indices
        gather_nd_idx = tf.reshape(gather_nd_idx, shape=[-1, 2])

        # obtain the flattened ragged batch using the index matrix
        unravelled_tensors = tf.gather_nd(
            ragged_batch, indices=gather_nd_idx, batch_dims=0
        )

        return unravelled_tensors

    def get_bert_inputs(self, batch, batch_size):
        """Generates padded BERT inputs for a given batch of tokenied
        text features."""
        # flatten out the RaggedTensor token batch.
        tokens = self.unravel_ragged_batch(
            batch.pop("summary_tokens"),
            batch.pop("summary_sentence_indices"),
            batch["summary_num_sentences"],
            batch_size,
        )
        # obtain the BERT inputs
        batch["summary_tokens"] = tokens
        bert_inputs = tf.cond(
            self.dynamic_batching, 
            self.init_packer_and_pack_inputs(
                tokens, batch.pop("summary_tokens_len").flat_values
            ),
            self.pack_inputs(tokens)
        )
        return bert_inputs

    def preprocess_batch(self, batch: Dict[str, tf.Tensor]):
        """Applies batch level transformations to the data."""
        batch_size = tf.shape(batch["label"])[0]

        # generate padded BERT inputs for all the text features
        batch["bert_inputs"] = self.get_bert_inputs(batch, batch_size)

        label = batch.pop("label")
        return batch, label

## Dataset preparation

In [22]:
def get_dataset(split: str, batch_size: int, batch_preprocessor: Callable, shuffle: bool):
    """Prepares tf.data.Dataset objects from TFRecords."""
    AUTO = tf.data.AUTOTUNE

    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = ds.interleave(
        tf.data.TFRecordDataset, cycle_length=3, num_parallel_calls=AUTO
    )

    ds = ds.prefetch(AUTO)
    ds = ds.map(
        read_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
    ).cache()
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = ds.batch(batch_size)
    ds = ds.map(batch_preprocessor, num_parallel_calls=AUTO)
    return ds

## Model building

In [18]:
def genre_classifier(
    proj_dim: int,
    num_labels: int,
):
    """Creates a simple classification model."""
    input = tf.keras.Input(shape=(BERT_DIM), dtype=tf.int32, name="cmlm_embeddings")

    projections = tf.keras.layers.Dense(proj_dim, activation="relu")(input)
    probs = tf.keras.layers.Dense(num_labels, activation="softmax")(projections)
    return tf.keras.Model(inputs=input, outputs=probs)

## Training routine

In [None]:
class GenreModelTrainer(tf.keras.Model):
    """Encapsulates the core model training logic.
    """

    def __init__(self,
        proj_dim, num_labels, encoder_path, train_encoder=False
    ):
        super(GenreModelTrainer, self).__init__()
        self.predictor = genre_classifier(proj_dim, num_labels)
        self.sentence_encoder = hub.KerasLayer(encoder_path)
        self.sentence_encoder.trainable = train_encoder

    def contiguous_group_average_vectors(self, vectors, groups):
        """Works iff sum(groups) == len(vectors)
        Example:
            Inputs: vectors: A dense 2D tensor of shape = (13, 3)
                    groups : A dense 1D tensor with values [2, 5, 1, 4, 1]
                    indicating that there are 5 groups.
            Objective: Compute a 5x3 matrix where the first row
                        is the average of the rows 0-1 of `vectors`,
                        the second row is the average of rows 2-6 of
                        vectors, the third row is the row 7 of vectors,
                        the fourth row is the average of rows 8-11 of
                        vectors and the fifth and final row is the row
                        12 of vectors.
            Logic: A selection mask matrix is generated
                    mask = [[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
                            [0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
                    This mask is then multiplied with `vectors` to get a
                    matrix of shape (5, 3) called `summed_vectors` where
                    each row contains the group sums.
                    `summed_vectors` is then devided by `groups` to
                    obtain the averages.
        """
        groups = tf.expand_dims(tf.cast(groups, dtype=tf.int32), axis=1)
        group_cumsum = tf.cumsum(groups)

        mask = tf.repeat(
            tf.expand_dims(tf.range(tf.shape(vectors)[0]), axis=0),
            repeats=tf.shape(groups)[0],
            axis=0,
        )
        mask = tf.cast(mask < group_cumsum, dtype=tf.float32)

        def complete_mask(mask):
            neg_mask = tf.concat(
                (tf.expand_dims(tf.ones_like(mask[0]), axis=0), 1 - mask[:-1]), axis=0
            )
            return mask * neg_mask

        mask = tf.cond(
            tf.greater(tf.shape(groups)[0], 1),
            true_fn=lambda: complete_mask(mask),
            false_fn=lambda: mask,
        )

        summed_vectors = tf.matmul(mask, vectors)
        averaged_vectors = summed_vectors / tf.cast(groups, dtype=tf.float32)

        return averaged_vectors

    def compute_text_embeddings(self, features):
        embeddings = self.bert_encoder(features[f"bert_inputs"])
        embeddings = self.contiguous_group_average_vectors(
            embeddings, features["summary_num_sentences"]
        )

        return embeddings


    def train_step(self, batch):
        # Unpack the features and the labels.
        features, labels = batch

        # Compute embeddings for the text features.
        embeddings = self.compute_text_embeddings(features)

        # Main loop.
        with tf.GradientTape() as tape:
            predictions = self.predictor(embeddings, training=True)
            loss = self.compiled_loss(labels, predictions)

        # Compute gradients and update the parameters.
        learnable_params = self.predictor.trainable_variables
        gradients = tape.gradient(loss, learnable_params)

        # Non-SAM update.
        self.optimizer.apply_gradients(zip(gradients, learnable_params))

        # Report progress.
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, batch):
        # Unpack the features and the labels.
        features, labels = batch

        # Get the predictions.
        predictions = self.call(features)

        # Report progress.
        self.compiled_loss(labels, predictions)
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def call(self, features):
        embeddings = self.compute_text_embeddings(features)
        return self.predictor(embeddings, training=False)

    def predict_step(self, data):
        return self.call(data)

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        self.predictor.save_weights(
            filepath, overwrite=overwrite, save_format=save_format, options=options
        )

    def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
        self.predictor.load_weights(
            filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
        )

In [26]:
def train(
    train_ds: tf.data.Dataset,
    valid_ds: tf.data.Dataset,
    test_ds: tf.data.Dataset,
    num_epochs: int,
    run_name: str,
    group_name: str,
):
    tfhub_model_uri = "https://tfhub.dev/google/universal-sentence-encoder-cmlm/en-base/1"
    # bert_inputs = ["input_word_ids", "input_type_ids", "input_mask"]
    proj_dim = 128
    num_labels = 27
    


    wandb.init(
        project="batching-experiments",
        entity="carted",
        name=run_name,
        group=group_name,
    )

    trainer = GenreModelTrainer(proj_dim, num_labels, tfhub_model_uri, False)
    trainer.compile(
        optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy"
    )
    start = time.time()
    train.fit(
        train_ds,
        epochs=num_epochs,
        validation_data=valid_ds,
        callbacks=[wandb.keras.WandbCallback()],
    )
    end = time.time()
    wandb.log({"model_training_time (seconds)": end - start})

    loss, acc = trainer.evaluate(test_ds)
    wandb.log({"test_loss": loss})
    wandb.log({"test_acc": acc})

    wandb.finish()

## Training with fixed batch length

In [None]:
# Find the longest sequence in the training set
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
ds = tf.data.TFRecordDataset(ds).map(read_example)
max_seq_len = tf.cast(
    tf.reduce_max([tf.reduce_max(datum["summary_tokens_len"].flat_values) for datum in ds]), tf.int32
)
print(f"Longest token sequence in the training split: {max_seq_len.numpy()}")

Longest token sequence in the training split: 549


In [25]:
input_utility = ModelInputUtils(dynamic_batching=False, data_max_seq_len=max_seq_len)

In [None]:
train_ds = get_dataset(split="train", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=True)
valid_ds = get_dataset(split="val", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=False)
test_ds = get_dataset(split="test", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=False)

In [None]:
group_name = "split-sentences-fixed-length-batching"

for i in range(NUM_RUNS):
    run_name = f"cmlm-fixed-length-run:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)

## Training with variable batch length

In [28]:
input_utility = ModelInputUtils(dynamic_batching=True)

In [None]:
train_ds = get_dataset(split="train", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=True)
valid_ds = get_dataset(split="val", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=False)
test_ds = get_dataset(split="test", batch_size=BATCH_SIZE, batch_preprocessor=input_utility.preprocess_batch, shuffle=False)

In [None]:
group_name = "split-sentences-variable-length-batching"

for i in range(NUM_RUNS):
    run_name = f"cmlm-variable-length-run:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)