## Imports

In [None]:
from typing import Callable, Dict
import tensorflow_hub as hub
import tensorflow_text
import tensorflow as tf
import numpy as np
import wandb
import random
import time

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [None]:
TFRECORDS_DIR = "tfrecords-sentence-splitter"
BERT_MAX_SEQLEN = 512
BERT_DIM = 768
BATCH_SIZE = 64
AUTO = tf.data.AUTOTUNE

## TFRecord parsing utilities

In [None]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_sentences": tf.io.RaggedFeature(
        value_key="summary_sentences_values",
        dtype=tf.int64,
        partitions=[
            tf.io.RaggedFeature.RowSplits("summary_sentences_splits_0"),
            tf.io.RaggedFeature.RowSplits("summary_sentences_splits_1"),
        ],
    ),
    "summary_sentence_lens": tf.io.RaggedFeature(
        value_key="summary_sentence_lens_values",
        dtype=tf.int64,
        partitions=[
            tf.io.RaggedFeature.RowSplits("summary_sentence_lens_splits_0"),
        ],
    ),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [None]:
def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)

    # Re-casting as int32 RaggedTensors
    features["summary_sentences"] = tf.cast(
        features["summary_sentences"].with_row_splits_dtype(tf.int64), tf.int32
    )
    features["summary_sentence_lens"] = tf.cast(
        features["summary_sentence_lens"].with_row_splits_dtype(tf.int64), tf.int32
    )

    return features

In [None]:
class ModelInputUtils:
    def __init__(
        self,
        bert_preprocessor_path: str = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
        encoder_max_seqlen: int = BERT_MAX_SEQLEN,
        dynamic_batching: bool = True,
        data_max_seq_len: int = 0,
    ):
        """Initializes a BERT model input preprocessing utility class."""
        self.bert_preprocessor_path = bert_preprocessor_path
        self.preprocessor_module = hub.load(bert_preprocessor_path)
        self.encoder_max_seqlen = encoder_max_seqlen

        max_seq_len = tf.minimum(data_max_seq_len + 2, encoder_max_seqlen)
        self.packer = hub.KerasLayer(
            self.preprocessor_module.bert_pack_inputs,
            arguments={"seq_length": max_seq_len},
        )
        self.dynamic_batching = tf.constant(dynamic_batching)

    def pack_inputs(self, batch_tokens: tf.Tensor) -> tf.Tensor:
        """Prepares inputs for the BERT encoder"""
        return self.packer([batch_tokens])

    def init_packer_and_pack_inputs(
        self, batch_tokens: tf.Tensor, batch_token_lens: tf.Tensor
    ) -> tf.Tensor:
        """Prepares inputs for the BERT encoder."""
        max_token_len = tf.reduce_max(batch_token_lens)
        packer = hub.KerasLayer(
            self.preprocessor_module.bert_pack_inputs,
            arguments={
                "seq_length": tf.math.minimum(
                    max_token_len + 2, self.encoder_max_seqlen
                )
            },
        )
        return packer([batch_tokens])

    def get_bert_inputs(self, batch, batch_size):
        """Generates padded BERT inputs for a given batch of tokenied
        text features."""

        # Unravelling a batch of RaggedTensors
        tokens = batch.pop("summary_sentences").merge_dims(0, 1)

        # obtain the BERT inputs
        bert_inputs = tf.cond(
            self.dynamic_batching,
            lambda: self.init_packer_and_pack_inputs(
                tokens, batch.pop("summary_sentence_lens").flat_values
            ),
            lambda: self.pack_inputs(tokens),
        )

        return bert_inputs

    def preprocess_batch(self, batch: Dict[str, tf.Tensor]):
        """Applies batch level transformations to the data."""
        batch_size = tf.shape(batch["label"])[0]

        # generate padded BERT inputs for all the text features
        batch["bert_inputs"] = self.get_bert_inputs(batch, batch_size)

        label = batch.pop("label")
        return batch, label

## Dataset preparation

In [None]:
def get_dataset(
    split: str, batch_size: int, batch_preprocessor: Callable, shuffle: bool
):
    """Prepares tf.data.Dataset objects from TFRecords."""

    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = ds.interleave(
        tf.data.TFRecordDataset, cycle_length=AUTO, num_parallel_calls=AUTO
    )

    ds = ds.prefetch(AUTO)
    ds = ds.map(
        read_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
    ).cache()
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = ds.batch(batch_size)
    ds = ds.map(batch_preprocessor, num_parallel_calls=AUTO)
    return ds

## Model building

In [None]:
def genre_classifier(
    proj_dim: int,
    num_labels: int,
):
    """Creates a simple classification model."""
    inputs = tf.keras.Input(shape=(BERT_DIM,), dtype=tf.float32, name="cmlm_embeddings")

    projections = tf.keras.layers.Dense(proj_dim, activation="relu")(inputs)
    probs = tf.keras.layers.Dense(num_labels, activation="softmax")(projections)
    return tf.keras.Model(inputs=inputs, outputs=probs)

## Training routine

In [None]:
class GenreModelTrainer(tf.keras.Model):
    """Encapsulates the core model training logic."""

    def __init__(self, proj_dim, num_labels, encoder_path, train_encoder=False):
        super(GenreModelTrainer, self).__init__()
        self.predictor = genre_classifier(proj_dim, num_labels)
        self.sentence_encoder = hub.KerasLayer(encoder_path)
        self.sentence_encoder.trainable = train_encoder

    def contiguous_group_average_vectors(self, vectors, groups):
        """Works iff sum(groups) == len(vectors)
        Example:
            Inputs: vectors: A dense 2D tensor of shape = (13, 3)
                    groups : A dense 1D tensor with values [2, 5, 1, 4, 1]
                    indicating that there are 5 groups.
            Objective: Compute a 5x3 matrix where the first row
                        is the average of the rows 0-1 of `vectors`,
                        the second row is the average of rows 2-6 of
                        vectors, the third row is the row 7 of vectors,
                        the fourth row is the average of rows 8-11 of
                        vectors and the fifth and final row is the row
                        12 of vectors.
            Logic: A selection mask matrix is generated
                    mask = [[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
                            [0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0.]
                            [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
                    This mask is then multiplied with `vectors` to get a
                    matrix of shape (5, 3) called `summed_vectors` where
                    each row contains the group sums.
                    `summed_vectors` is then devided by `groups` to
                    obtain the averages.
        """
        groups = tf.expand_dims(tf.cast(groups, dtype=tf.int32), axis=1)
        group_cumsum = tf.cumsum(groups)

        mask = tf.repeat(
            tf.expand_dims(tf.range(tf.shape(vectors)[0]), axis=0),
            repeats=tf.shape(groups)[0],
            axis=0,
        )
        mask = tf.cast(mask < group_cumsum, dtype=tf.float32)

        def complete_mask(mask):
            neg_mask = tf.concat(
                (tf.expand_dims(tf.ones_like(mask[0]), axis=0), 1 - mask[:-1]), axis=0
            )
            return mask * neg_mask

        mask = tf.cond(
            tf.greater(tf.shape(groups)[0], 1),
            true_fn=lambda: complete_mask(mask),
            false_fn=lambda: mask,
        )

        summed_vectors = tf.matmul(mask, vectors)
        averaged_vectors = summed_vectors / tf.cast(groups, dtype=tf.float32)

        return averaged_vectors

    def compute_text_embeddings(self, features):
        embeddings = self.sentence_encoder(features["bert_inputs"])["pooled_output"]
        embeddings = self.contiguous_group_average_vectors(
            embeddings, features["summary_num_sentences"]
        )

        return embeddings

    def train_step(self, batch):
        # Unpack the features and the labels.
        features, labels = batch

        # Compute embeddings for the text features.
        embeddings = self.compute_text_embeddings(features)

        # Main loop.
        with tf.GradientTape() as tape:
            predictions = self.predictor(embeddings, training=True)
            loss = self.compiled_loss(labels, predictions)

        # Compute gradients and update the parameters.
        learnable_params = self.predictor.trainable_variables
        gradients = tape.gradient(loss, learnable_params)

        # Apply the gradients to the parameters.
        self.optimizer.apply_gradients(zip(gradients, learnable_params))

        # Report progress.
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, batch):
        # Unpack the features and the labels.
        features, labels = batch

        # Get the predictions.
        predictions = self.call(features)

        # Report progress.
        self.compiled_loss(labels, predictions)
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def call(self, features):
        embeddings = self.compute_text_embeddings(features)
        return self.predictor(embeddings, training=False)

    def predict_step(self, data):
        return self.call(data)

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        self.predictor.save_weights(
            filepath, overwrite=overwrite, save_format=save_format, options=options
        )

    def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
        self.predictor.load_weights(
            filepath, by_name=by_name, skip_mismatch=skip_mismatch, options=options
        )

In [None]:
def train(
    train_ds: tf.data.Dataset,
    valid_ds: tf.data.Dataset,
    test_ds: tf.data.Dataset,
    num_epochs: int,
    run_name: str,
    group_name: str,
):
    tfhub_model_uri = (
        "https://tfhub.dev/google/universal-sentence-encoder-cmlm/en-base/1"
    )
    proj_dim = 128
    num_labels = 27

    wandb.init(
        project="batching-experiments",
        entity="carted",
        name=run_name,
        group=group_name,
    )

    trainer = GenreModelTrainer(proj_dim, num_labels, tfhub_model_uri, False)
    trainer.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics="accuracy",
    )
    start = time.time()
    trainer.fit(
        train_ds,
        epochs=num_epochs,
        validation_data=valid_ds,
        callbacks=[wandb.keras.WandbCallback()],
    )
    end = time.time()
    wandb.log({"model_training_time (seconds)": end - start})

    loss, acc = trainer.evaluate(test_ds)
    wandb.log({"test_loss": loss})
    wandb.log({"test_acc": acc})

    wandb.finish()

## Training with fixed batch length

In [None]:
# Find the longest sequence in the training set
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
ds = tf.data.TFRecordDataset(ds).map(read_example)
max_seq_len = tf.cast(
    tf.reduce_max(
        [tf.reduce_max(datum["summary_sentence_lens"].flat_values) for datum in ds]
    ),
    tf.int32,
)
print(f"Longest token sequence in the training split: {max_seq_len.numpy()}")

In [None]:
input_utility = ModelInputUtils(dynamic_batching=False, data_max_seq_len=max_seq_len)

In [None]:
train_ds = get_dataset(
    split="train",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=True,
)
valid_ds = get_dataset(
    split="val",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=False,
)
test_ds = get_dataset(
    split="test",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=False,
)

In [None]:
group_name = "split-sentences-fixed-length-batching"

for i in range(NUM_RUNS):
    run_name = f"cmlm-fixed-length-run:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)

## Experiment parameters

In [None]:
NUM_EPOCHS = 5
NUM_RUNS = 10

## Training with variable batch length

In [None]:
input_utility = ModelInputUtils(dynamic_batching=True)

In [None]:
train_ds = get_dataset(
    split="train",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=True,
)
valid_ds = get_dataset(
    split="val",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=False,
)
test_ds = get_dataset(
    split="test",
    batch_size=BATCH_SIZE,
    batch_preprocessor=input_utility.preprocess_batch,
    shuffle=False,
)

In [None]:
group_name = "split-sentences-variable-length-batching"

for i in range(NUM_RUNS):
    run_name = f"cmlm-variable-length-run:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)