## Imports

In [None]:
from typing import Callable, Dict
import tensorflow_hub as hub
import tensorflow_text as tft
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import random

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [None]:
TFRECORDS_DIR = "tfrecords-sentence-splitter"
BERT_MAX_SEQLEN = 512
BATCH_SIZE = 64

## TFRecord parsing utilities

In [None]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_sentence_indices": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_num_sentences": tf.io.FixedLenFeature([], dtype=tf.int64),
    "summary_tokens_len": tf.io.FixedLenFeature([], dtype=tf.string),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [None]:
def deserialize_composite(
    serialized: bytes, type_spec: tf.RaggedTensorSpec
) -> tf.Tensor:
    """Deserializes a serialised ragged tensor."""

    serialized = tf.io.parse_tensor(serialized, tf.string)
    component_specs = tf.nest.flatten(type_spec, expand_composites=True)
    components = [
        tf.io.parse_tensor(serialized[i], spec.dtype)
        for i, spec in enumerate(component_specs)
    ]
    return tf.nest.pack_sequence_as(type_spec, components, expand_composites=True)


def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)
    features["summary_tokens"] = deserialize_composite(
        features.get("summary_tokens"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=2),
    )
    features["summary_sentence_indices"] = deserialize_composite(
        features.get("summary_sentence_indices"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )
    features["summary_tokens_len"] = deserialize_composite(
        features.get("summary_tokens_len"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )

    return features

In [None]:
class ModelInputUtils:
    def __init__(
        self,
        bert_preprocessor_path: str = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
        encoder_max_seqlen: int = BERT_MAX_SEQLEN,
    ):
        """Initializes a BERT model input preprocessing utility class."""
        self.bert_preprocessor_path = bert_preprocessor_path
        self.preprocessor_module = hub.load(bert_preprocessor_path)
        self.encoder_max_seqlen = encoder_max_seqlen

    def init_packer_and_pack_inputs(
        self, batch_tokens: tf.Tensor, batch_token_lens: tf.Tensor
    ) -> tf.Tensor:
        """Prepares inputs for the BERT encoder."""
        max_token_len = tf.reduce_max(batch_token_lens)
        packer = hub.KerasLayer(
            self.preprocessor_module.bert_pack_inputs,
            arguments={
                "seq_length": tf.math.minimum(
                    max_token_len + 2, self.encoder_max_seqlen
                )
            },
        )
        return packer([batch_tokens])

    def unravel_ragged_batch(self, ragged_batch, ragged_idx, batch_lens, batch_size):
        """Flattens out a batch of ragged tensors by one level."""
        # create indices for each tensor in the batch
        # for entries which have multiple ragged tensors, repeat their
        # index once for each tensor in the entry
        batch_idx = tf.repeat(tf.range(batch_size), batch_lens, axis=0)

        # calculate length of the unravelled batch
        unravelled_len = tf.reduce_sum(batch_lens)

        # create a vector with alternating batch index and ragged tensor index
        gather_nd_idx = tf.dynamic_stitch(
            indices=[
                tf.range(0, (unravelled_len * 2) - 1, 2, dtype=tf.int32),
                tf.range(1, unravelled_len * 2, 2, dtype=tf.int32),
            ],
            data=[batch_idx, ragged_idx.flat_values],
        )

        # reshape the vector to obtain a unravelled_len x 2 matrix of indices
        gather_nd_idx = tf.reshape(gather_nd_idx, shape=[-1, 2])

        # obtain the flattened ragged batch using the index matrix
        unravelled_tensors = tf.gather_nd(
            ragged_batch, indices=gather_nd_idx, batch_dims=0
        )

        return unravelled_tensors

    def get_bert_inputs(self, batch, batch_size):
        """Generates padded BERT inputs for a given batch of tokenied
        text features."""
        # flatten out the RaggedTensor token batch.
        tokens = self.unravel_ragged_batch(
            batch.pop("summary_tokens"),
            batch.pop("summary_sentence_indices"),
            batch["summary_num_sentences"],
            batch_size,
        )
        # obtain the BERT inputs
        batch["summary_tokens"] = tokens
        bert_inputs = self.init_packer_and_pack_inputs(
            tokens, batch.pop("summary_tokens_len").flat_values
        )
        return bert_inputs

    def preprocess_batch(self, batch: Dict[str, tf.Tensor]):
        """Applies batch level transformations to the data."""
        batch_size = tf.shape(batch["label"])[0]

        # generate padded BERT inputs for all the text features
        batch["bert_inputs"] = self.get_bert_inputs(batch, batch_size)

        label = batch.pop("label")
        return batch, label

## Dataset preparation

In [None]:
input_utils = ModelInputUtils()

In [None]:
def get_dataset(split, batch_size, shuffle):
    """Prepares tf.data.Dataset objects from TFRecords."""
    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = ds.interleave(
        tf.data.TFRecordDataset, cycle_length=3, num_parallel_calls=tf.data.AUTOTUNE
    )

    ds = ds.prefetch(tf.data.AUTOTUNE)
    ds = ds.map(
        read_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
    ).cache()
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = ds.batch(batch_size)
    ds = ds.map(input_utils.preprocess_batch, num_parallel_calls=tf.data.AUTOTUNE)
    return ds

In [None]:
train_ds = get_dataset("train", BATCH_SIZE, True)
valid_ds = get_dataset("val", BATCH_SIZE, False)
test_ds = get_dataset("test", BATCH_SIZE, False)

In [None]:
for batch_features, batch_labels in train_ds.take(1):
    print(batch_features.keys())
    print(batch_labels.shape)