## Imports

In [1]:
from typing import Callable, List
import tensorflow_hub as hub
import tensorflow_text as tft
import tensorflow as tf
import wandb

import matplotlib.pyplot as plt
import numpy as np
import random

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [2]:
TFRECORDS_DIR = "gs://variable-length-sequences-tf/tfrecords"
BERT_MAX_SEQLEN = 512
BATCH_SIZE = 64

## TFRecord parsing utilities

In [3]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens_len": tf.io.FixedLenFeature([1], dtype=tf.int64),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [4]:
def deserialize_composite(serialized, type_spec):
    """Parses a serialized Ragged features and retains the original structure."""
    serialized = tf.io.parse_tensor(serialized, tf.string)
    component_specs = tf.nest.flatten(type_spec, expand_composites=True)
    components = [
        tf.io.parse_tensor(serialized[i], spec.dtype)
        for i, spec in enumerate(component_specs)
    ]
    return tf.nest.pack_sequence_as(type_spec, components, expand_composites=True)


def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)
    features["summary_tokens"] = deserialize_composite(
        features.get("summary_tokens"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=2),
    )

    return features

In [5]:
def set_text_preprocessor(preprocessor_path: str) -> Callable:
    """ Decorator to set the desired preprocessor for a
        function from a TensorFlow Hub URL.
        
    Arguments:
        preprocessor_path {str} -- URL of the TF-Hub preprocessor.
    
    Returns:
        Callable -- A function with the `preprocessor` attribute set.
    """
    def decoration(func: Callable):
        # Loading the preprocessor from TF-Hub
        preprocessor = hub.load(preprocessor_path)
        
        # Setting an attribute called `preprocessor` to
        # the passed function
        func.preprocessor = preprocessor
        return func
    return decoration

In [16]:
@set_text_preprocessor(
    preprocessor_path="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
)
def preprocess_variable_batch(batch):
    """Batch processing utility."""
    text_tokens_max_len = tf.cast(
        tf.math.reduce_max(batch["summary_tokens_len"]), dtype=tf.int32,
    )

    # Generating the inputs for the BERT model.
    bert_input_packer = hub.KerasLayer(
        preprocess_variable_batch.preprocessor.bert_pack_inputs,
        arguments={"seq_length": tf.minimum(text_tokens_max_len, BERT_MAX_SEQLEN)},
    )
    bert_packed_text = bert_input_packer(
        [tf.squeeze(batch.pop("summary_tokens"), axis=1)]
    )

    labels = batch.pop("label")
    return bert_packed_text, labels

In [8]:
# Find the longest sequence in the training set
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
ds = tf.data.TFRecordDataset(ds).map(
    lambda example: tf.io.parse_single_example(
        example, feature_descriptions
    )
)
max_seq_len = max([datum["summary_tokens_len"] for datum in ds])
print(f"Longest token sequence in the training split: {max_seq_len.numpy()[0]}")


# Use the bert packer for packing input batches
@set_text_preprocessor(
    preprocessor_path="https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
)
def preprocess_fixed_batch(batch):
    """Batch processing utility."""
    
    # Create a bert input packer using it
    if not hasattr("fixed_len_bert_packer"):
        preprocess_fixed_batch.fixed_len_bert_packer = hub.KerasLayer(
            preprocess_fixed_batch.preprocessor.bert_pack_inputs,
            arguments={"seq_length": min(max_seq_len, BERT_MAX_SEQLEN)}
        )
    # Generating the inputs for the BERT model.
    bert_packed_text = preprocess_fixed_batch.fixed_len_bert_packer(
        [tf.squeeze(batch.pop("summary_tokens"), axis=1)]
    )

    labels = batch.pop("label")
    return bert_packed_text, labels

## Dataset preparation

In [9]:
def get_dataset(
    split: str,
    batch_size: int,
    batch_preprocessor: Callable,
    shuffle: bool
):
    """Prepares tf.data.Dataset objects from TFRecords."""
    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = ds.interleave(
        tf.data.TFRecordDataset,
        cycle_length=tf.data.AUTOTUNE,
        num_parallel_calls=tf.data.AUTOTUNE,
    ).map(
        read_example,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False
    ).cache()
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = ds.batch(
        batch_size
    ).map(
        batch_preprocessor,
        num_parallel_calls=tf.data.AUTOTUNE
    ).prefetch(tf.data.AUTOTUNE)
    return ds

## Analyzing the maximum sequence lengths over training batches

In [10]:
# analysis_ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
# analysis_ds = analysis_ds.interleave(
#     tf.data.TFRecordDataset, cycle_length=3, num_parallel_calls=tf.data.AUTOTUNE
# )
# analysis_ds = analysis_ds.map(
#     read_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
# )
# analysis_ds = analysis_ds.batch(BATCH_SIZE)

In [11]:
# max_seqlens = []
# batches = 0


# for batch in analysis_ds:
#     max_seqlens.append(
#         int(tf.cast(tf.math.reduce_max(batch["summary_tokens_len"]), dtype=tf.int32,))
#     )
#     batches += 1

# plt.plot(np.arange(batches), max_seqlens)
# plt.xlabel("Batch #", fontsize=14)
# plt.ylabel("Maximum sequence lengths", fontsize=14)
# plt.show()

## Model Building

In [12]:
def genre_classifier(
    encoder_path: str,
    input_features: List[str],
    train_encoder: bool,
    proj_dim: int,
    num_labels: int
):
    """Creates a simple classification model."""
    sentence_encoder = hub.KerasLayer(encoder_path)
    sentence_encoder.trainable = train_encoder

    inputs = {
        feature_name: tf.keras.Input(
            shape=(None,), dtype=tf.int32, name=feature_name
        )
        for feature_name in input_features
    }

    text_encodings = sentence_encoder(inputs)
    projections = tf.keras.layers.Dense(proj_dim, activation="relu")(text_encodings["pooled_output"])
    probs = tf.keras.layers.Dense(num_labels, activation="softmax")(projections)
    return tf.keras.Model(inputs=inputs, outputs=probs)

In [13]:
cmlm_uri = "https://tfhub.dev/google/universal-sentence-encoder-cmlm/en-base/1"
cmlm_inputs = ["input_word_ids", "input_type_ids", "input_mask"]
proj_dim = 128
num_labels = 27

num_epochs = 5
num_runs = 1

## Training with variable batch width

In [17]:
train_ds = get_dataset("train", BATCH_SIZE, preprocess_variable_batch, True)
valid_ds = get_dataset("val", BATCH_SIZE, preprocess_variable_batch, False)
test_ds = get_dataset("test", BATCH_SIZE, preprocess_variable_batch, False)

In [22]:
!wandb lo--relogin

Usage: wandb [OPTIONS] COMMAND [ARGS]...
Try 'wandb --help' for help.

Error: No such option: --relogin


In [21]:
for i in range(num_runs):
    wandb.init(
        project="batching-experiments",
        entity="carted",
        name=f"variable-legth-run:{i + 1}",
        group="variable-length-batching",
    )
    model = genre_classifier(cmlm_uri, cmlm_inputs, False, proj_dim, num_labels)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy")
    model.fit(train_ds, epochs=5, validation_data=valid_ds)
    model.evaluate(test_ds)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msayak-carted[0m (use `wandb login --relogin` to force relogin)
  warn("The `IPython.html` package has been deprecated since IPython 4.0. "


KeyboardInterrupt: 

## Training with fixed length batch width

In [None]:
train_ds = get_dataset("train", BATCH_SIZE, preprocess_fixed_batch, True)
valid_ds = get_dataset("val", BATCH_SIZE, preprocess_fixed_batch, False)
test_ds = get_dataset("test", BATCH_SIZE, preprocess_fixed_batch, False)

In [None]:
for i in range(num_runs):
    wandb.init(
        project="batching-experiments",
        entity="carted",
        name=f"fixed-legth-run:{i + 1},
        group="fixed-length-batching,
        config=training_config.to_dict(),
    )
    model = genre_classifier(cmlm_uri, cmlm_inputs, False, proj_dim, num_labels)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics="accuracy")
    model.fit(train_ds, epochs=5, validation_data=valid_ds)
    model.evaluate(test_ds)