## Imports

In [1]:
from typing import Callable, List, Dict
import tensorflow_hub as hub
import tensorflow_text
import tensorflow as tf
import wandb

import numpy as np
import random
import time

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [2]:
# TFRECORDS_DIR = "gs://variable-length-sequences-tf/tfrecords"
TFRECORDS_DIR = "tfrecords"
BERT_MAX_SEQLEN = 512
BATCH_SIZE = 64

## TFRecord parsing utilities

In [7]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens": tf.io.RaggedFeature(
        value_key="summary_tokens_values",
        dtype=tf.int64,
        partitions=[
            tf.io.RaggedFeature.RowSplits(
                "summary_tokens_splits_0"
            ),
            tf.io.RaggedFeature.RowSplits(
                "summary_tokens_splits_1"
            )
        ]
    ),
    "summary_tokens_len": tf.io.FixedLenFeature([1], dtype=tf.int64),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [4]:
# def deserialize_composite(serialized, type_spec):
#     """Parses a serialized Ragged features and retains the original structure."""
#     serialized = tf.io.parse_tensor(serialized, tf.string)
#     component_specs = tf.nest.flatten(type_spec, expand_composites=True)
#     components = [
#         tf.io.parse_tensor(serialized[i], spec.dtype)
#         for i, spec in enumerate(component_specs)
#     ]
#     return tf.nest.pack_sequence_as(type_spec, components, expand_composites=True)

def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)
    return features

## Preprocessing function for fixed length batching.

In [5]:
# Find the longest sequence in the training set
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
ds = tf.data.TFRecordDataset(ds)

for d in ds:
    print(d)
    break

tf.Tensor(b'\n\xb2\x08\n\xc5\x04\n\x07summary\x12\xb9\x04\n\xb6\x04\n\xb3\x04This NZBC religious programme goes where TV cameras had never gone before: behind the walls of the Carmelite monastery in Christchurch. There, it finds a community of 16 Catholic nuns, members of a 400-year-old order, who have shut themselves off from the outside world to lead lives devoted to prayer, contemplation and simple manual work. Despite their seclusion, the sisters are unphased by the intrusion and happy to discuss their lives and their beliefs; while the simplicity and ceremony of their world provides fertile ground for the monochrome camerawork.\n!\n\x17summary_tokens_splits_0\x12\x06\x1a\x04\n\x02\x00g\n\x8d\x02\n\x15summary_tokens_values\x12\xf3\x01\x1a\xf0\x01\n\xed\x01\xe7\x0f\xa8\x9c\x01\xdaL\xd4\x1a\x8a%\xb0\x1c\x99\x10\x86\x15\xb5C\xe2\x0f\x94\x11\xdc\x16\x9d\x10\x80\x08\xc1\x12\xcc\x0f\xe1\x1c\xcd\x0f\xcc\x0f\xf3\x97\x01\xfd \x882\xcf\x0f\x88|\xf4\x07\xfd\x0f\xf2\x07\xd9\x0f\xfa%\x8d\x08\x9

2022-01-20 03:32:43.617302: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
# Find the longest sequence in the training set
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
ds = tf.data.TFRecordDataset(ds).map(
    lambda example: tf.io.parse_single_example(example, feature_descriptions)
)
max_seq_len = tf.cast(
    tf.reduce_max([datum["summary_tokens_len"] for datum in ds]), tf.int32
)
print(f"Longest token sequence in the training split: {max_seq_len.numpy()}")

Longest token sequence in the training split: 2947


In [9]:
preprocessor = hub.load("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
fixed_len_bert_packer = hub.KerasLayer(
    preprocessor.bert_pack_inputs,
    arguments={"seq_length": tf.minimum(max_seq_len + 2, BERT_MAX_SEQLEN)},
)

#### Note: We add 2 to the maximum length to account for the CLS and SEP tokens that would be added later by the encoder

In [10]:
def preprocess_fixed_batch(batch):
    """Batch processing utility."""

    # Generating the inputs for the BERT model.
    bert_packed_text = fixed_len_bert_packer(
        [tf.squeeze(batch.pop("summary_tokens"), axis=1)]
    )

    labels = batch.pop("label")
    return bert_packed_text, labels

## Preprocessing function for variable length batching using the BERT packer from TF Hub.

In [11]:
def set_text_preprocessor(preprocessor_path: str) -> Callable:
    """Decorator to set the desired preprocessor for a
        function from a TensorFlow Hub URL.

    Arguments:
        preprocessor_path {str} -- URL of the TF-Hub preprocessor.

    Returns:
        Callable -- A function with the `preprocessor` attribute set.
    """

    def decoration(func: Callable):
        # Loading the preprocessor from TF-Hub
        preprocessor = hub.load(preprocessor_path)

        # Setting an attribute called `preprocessor` to
        # the passed function
        func.preprocessor = preprocessor
        return func

    return decoration

In [12]:
@set_text_preprocessor(
    preprocessor_path="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
)
def preprocess_variable_batch_tfh(batch):
    """Batch processing utility."""
    text_tokens_max_len = tf.cast(
        tf.math.reduce_max(batch["summary_tokens_len"]),
        dtype=tf.int32,
    )

    # Generating the inputs for the BERT model.
    bert_input_packer = hub.KerasLayer(
        preprocess_variable_batch_tfh.preprocessor.bert_pack_inputs,
        arguments={"seq_length": tf.minimum(text_tokens_max_len + 2, BERT_MAX_SEQLEN)},
    )
    bert_packed_text = bert_input_packer(
        [tf.squeeze(batch.pop("summary_tokens"), axis=1)]
    )

    labels = batch.pop("label")
    return bert_packed_text, labels

## Preprocessing function for variable length batching using a custom written BERT packer.

In [13]:
def prepare_bert_inputs(
    batch_tokens: tf.RaggedTensor,
    batch_lens: tf.Tensor,
    max_len: int = tf.constant(512),
) -> Dict[str, tf.Tensor]:
    """Pack the tokens w.r.t BERT inputs."""

    # Remove the last ragged dimension
    batch_tokens = tf.RaggedTensor.from_row_lengths(
        batch_tokens.flat_values, batch_lens
    )

    # Calcuate batch size.
    batch_size = tf.shape(batch_lens)[0]

    # Define special token values (very specific to BERT).
    CLS = 101
    SEP = 102
    PAD = 0

    # Prepare the special tokens for concatenation.
    batch_cls = tf.repeat(tf.constant([[CLS]]), batch_size, axis=0)
    batch_cls = tf.RaggedTensor.from_tensor(batch_cls).with_row_splits_dtype(
        batch_tokens.row_splits.dtype
    )
    batch_sep = tf.repeat(tf.constant([[SEP]]), batch_size, axis=0)
    batch_sep = tf.RaggedTensor.from_tensor(batch_sep).with_row_splits_dtype(
        batch_tokens.row_splits.dtype
    )

    # Truncate the sequences that are shorter than max_len.
    max_batch_len = tf.minimum(tf.reduce_max(batch_lens) + 2, max_len)
    truncated_tokens = batch_tokens[:, : max_batch_len - 2]

    # Sandwich the truncated tokens in between the special tokens.
    prepared_tokens = tf.concat([batch_cls, truncated_tokens, batch_sep], axis=1)

    # Convert the tokens to a regular int32 tensor and pad the
    # shorter sequences with PAD.
    padded_tokens = prepared_tokens.to_tensor(PAD)

    # Create the segment id tensor.
    segment_ids = tf.zeros_like(padded_tokens)

    # Create the input mask
    mask = tf.sequence_mask(batch_lens + 2, max_batch_len, dtype=tf.int32)

    ret = {
        "input_word_ids": padded_tokens,
        "input_type_ids": segment_ids,
        "input_mask": mask,
    }
    return ret

In [14]:
def preprocess_variable_batch_cust(batch):
    """Batch processing utility."""
    text_token_lens = tf.cast(batch["summary_tokens_len"], dtype=tf.int32)

    # Generating the inputs for the BERT model.
    bert_packed_text = prepare_bert_inputs(
        tf.squeeze(batch["summary_tokens"], axis=1), tf.reshape(text_token_lens, (-1,))
    )
    labels = batch.pop("label")
    return bert_packed_text, labels

## Dataset preparation

In [15]:
def get_dataset(
    split: str, batch_size: int, batch_preprocessor: Callable, shuffle: bool
):
    """Prepares tf.data.Dataset objects from TFRecords."""
    AUTO = tf.data.AUTOTUNE
    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = (
        ds.interleave(
            tf.data.TFRecordDataset,
            cycle_length=AUTO,
            num_parallel_calls=AUTO,
        )
        .map(read_example, num_parallel_calls=AUTO, deterministic=False)
        .cache()
    )
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = (
        ds.batch(batch_size)
        .map(batch_preprocessor, num_parallel_calls=AUTO)
        .prefetch(AUTO)
    )
    return ds

## Model Building

In [16]:
def genre_classifier(
    encoder_path: str,
    input_features: List[str],
    train_encoder: bool,
    proj_dim: int,
    num_labels: int,
):
    """Creates a simple classification model."""
    text_encoder = hub.KerasLayer(encoder_path)
    text_encoder.trainable = train_encoder

    inputs = {
        feature_name: tf.keras.Input(shape=(None,), dtype=tf.int32, name=feature_name)
        for feature_name in input_features
    }

    text_encodings = text_encoder(inputs)
    projections = tf.keras.layers.Dense(proj_dim, activation="relu")(
        text_encodings["pooled_output"]
    )
    probs = tf.keras.layers.Dense(num_labels, activation="softmax")(projections)
    return tf.keras.Model(inputs=inputs, outputs=probs)

## Training routine

In [None]:
def train(
    train_ds: tf.data.Dataset,
    valid_ds: tf.data.Dataset,
    test_ds: tf.data.Dataset,
    num_epochs: int,
    run_name: str,
    group_name: str,
):
    tfhub_model_uri = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
    bert_inputs = ["input_word_ids", "input_type_ids", "input_mask"]
    proj_dim = 128
    num_labels = 27

    wandb.init(
        project="batching-experiments",
        entity="carted",
        name=run_name,
        group=group_name,
    )

    model = genre_classifier(tfhub_model_uri, bert_inputs, False, proj_dim, num_labels)
    model.compile(
        optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy"
    )
    start = time.time()
    model.fit(
        train_ds,
        epochs=num_epochs,
        validation_data=valid_ds,
        callbacks=[wandb.keras.WandbCallback()],
    )
    end = time.time()
    wandb.log({"model_training_time (seconds)": end - start})

    loss, acc = model.evaluate(test_ds)
    wandb.log({"test_loss": loss})
    wandb.log({"test_acc": acc})

    wandb.finish()

## Experiment parameters

In [None]:
NUM_RUNS = 10
NUM_EPOCHS = 5

## Training with fixed batch length

In [None]:
train_ds = get_dataset("train", BATCH_SIZE, preprocess_fixed_batch, True)
valid_ds = get_dataset("val", BATCH_SIZE, preprocess_fixed_batch, False)
test_ds = get_dataset("test", BATCH_SIZE, preprocess_fixed_batch, False)

In [None]:
group_name = "fixed-length-batching"

for i in range(NUM_RUNS):
    run_name = f"fixed-length-run:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)

## Training with variable batch length

### Using TF Hub's BERT packer

In [None]:
train_ds = get_dataset("train", BATCH_SIZE, preprocess_variable_batch_tfh, True)
valid_ds = get_dataset("val", BATCH_SIZE, preprocess_variable_batch_tfh, False)
test_ds = get_dataset("test", BATCH_SIZE, preprocess_variable_batch_tfh, False)

In [None]:
group_name = "variable-length-batching-tfh"

for i in range(NUM_RUNS):
    run_name = f"variable-length-run-tfh:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)

### Using custom BERT packer

In [None]:
train_ds = get_dataset("train", BATCH_SIZE, preprocess_variable_batch_cust, True)
valid_ds = get_dataset("val", BATCH_SIZE, preprocess_variable_batch_cust, False)
test_ds = get_dataset("test", BATCH_SIZE, preprocess_variable_batch_cust, False)

In [None]:
group_name = "variable-length-batching-cust"

for i in range(NUM_RUNS):
    run_name = f"variable-length-run-cust:{i + 1}"
    train(train_ds, valid_ds, test_ds, NUM_EPOCHS, run_name, group_name)