## Imports

In [1]:
from typing import Callable, Dict
import tensorflow_hub as hub
import tensorflow_text as tft
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import random

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Contants

In [2]:
TFRECORDS_DIR = "gs://variable-length-sequences-tf/tfrecords-sentence-splitter"
BERT_MAX_SEQLEN = 512
BATCH_SIZE = 64

## TFRecord parsing utilities

In [18]:
feature_descriptions = {
    "summary": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_tokens": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_sentence_indices": tf.io.FixedLenFeature([], dtype=tf.string),
    "summary_num_sentences": tf.io.FixedLenFeature([], dtype=tf.int64),
    "summary_tokens_len": tf.io.FixedLenFeature([], dtype=tf.string),
    "label": tf.io.FixedLenFeature([1], dtype=tf.int64),
}

In [4]:
def deserialize_composite(
    serialized: bytes, type_spec: tf.RaggedTensorSpec
) -> tf.Tensor:
    """Deserializes a serialised ragged tensor."""

    serialized = tf.io.parse_tensor(serialized, tf.string)
    component_specs = tf.nest.flatten(type_spec, expand_composites=True)
    components = [
        tf.io.parse_tensor(serialized[i], spec.dtype)
        for i, spec in enumerate(component_specs)
    ]
    return tf.nest.pack_sequence_as(type_spec, components, expand_composites=True)


def read_example(example):
    """Parses a single TFRecord file."""
    features = tf.io.parse_single_example(example, feature_descriptions)
    features["summary_tokens"] = deserialize_composite(
        features.get("summary_tokens"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=2),
    )
    features["summary_sentence_indices"] = deserialize_composite(
        features.get("summary_sentence_indices"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )
    features["summary_tokens_len"] = deserialize_composite(
        features.get("summary_tokens_len"),
        tf.RaggedTensorSpec(dtype=tf.int32, ragged_rank=1),
    )

    return features

In [5]:
class ModelInputUtils:
    def __init__(
        self,
        bert_preprocessor_path: str = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3",
        encoder_max_seqlen: int = BERT_MAX_SEQLEN,
    ):
        """Initializes a BERT model input preprocessing utility class."""
        self.bert_preprocessor_path = bert_preprocessor_path
        self.preprocessor_module = hub.load(bert_preprocessor_path)
        self.encoder_max_seqlen = encoder_max_seqlen

    def init_packer_and_pack_inputs(
        self, batch_tokens: tf.Tensor, batch_token_lens: tf.Tensor
    ) -> tf.Tensor:
        """Prepares inputs for the BERT encoder."""
        max_token_len = tf.reduce_max(batch_token_lens)
        packer = hub.KerasLayer(
            self.preprocessor_module.bert_pack_inputs,
            arguments={
                "seq_length": tf.math.minimum(max_token_len, self.encoder_max_seqlen)
            },
        )
        return packer([batch_tokens])

    def unravel_ragged_batch(self, ragged_batch, ragged_idx, batch_lens, batch_size):
        """Flattens out a batch of ragged tensors by one level."""
        # create indices for each tensor in the batch
        # for entries which have multiple ragged tensors, repeat their
        # index once for each tensor in the entry
        batch_idx = tf.repeat(tf.range(batch_size), batch_lens, axis=0)

        # calculate length of the unravelled batch
        unravelled_len = tf.reduce_sum(batch_lens)

        # create a vector with alternating batch index and ragged tensor index
        gather_nd_idx = tf.dynamic_stitch(
            indices=[
                tf.range(0, (unravelled_len * 2) - 1, 2, dtype=tf.int32),
                tf.range(1, unravelled_len * 2, 2, dtype=tf.int32),
            ],
            data=[batch_idx, ragged_idx.flat_values],
        )

        # reshape the vector to obtain a unravelled_len x 2 matrix of indices
        gather_nd_idx = tf.reshape(gather_nd_idx, shape=[-1, 2])

        # obtain the flattened ragged batch using the index matrix
        unravelled_tensors = tf.gather_nd(
            ragged_batch, indices=gather_nd_idx, batch_dims=0
        )

        return unravelled_tensors

    def get_bert_inputs(self, batch, batch_size):
        """Generates padded BERT inputs for a given batch of tokenied
        text features."""
        # flatten out the RaggedTensor token batch.
        tokens = self.unravel_ragged_batch(
            batch.pop("summary_tokens"),
            batch.pop("summary_sentence_indices"),
            batch["summary_num_sentences"],
            batch_size,
        )
        # obtain the BERT inputs
        batch["summary_tokens"] = tokens
        bert_inputs = self.init_packer_and_pack_inputs(
            tokens, batch.pop("summary_tokens_len").flat_values
        )
        return bert_inputs

    def preprocess_batch(self, batch: Dict[str, tf.Tensor]):
        """Applies batch level transformations to the data."""
        batch_size = tf.shape(batch["label"])[0]

        # generate padded BERT inputs for all the text features
        batch["bert_inputs"] = self.get_bert_inputs(batch, batch_size)

        label = batch.pop("label")
        return batch, label

## Dataset preparation

In [6]:
input_utils = ModelInputUtils()

2021-12-16 08:44:09.287371: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-12-16 08:44:10.815284: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


In [7]:
def get_dataset(split, batch_size, shuffle):
    """Prepares tf.data.Dataset objects from TFRecords."""
    ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/{split}-*.tfrecord")
    ds = ds.interleave(
        tf.data.TFRecordDataset, cycle_length=3, num_parallel_calls=tf.data.AUTOTUNE
    )

    ds = ds.prefetch(tf.data.AUTOTUNE)
    ds = ds.map(
        read_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False
    ).cache()
    if shuffle:
        ds = ds.shuffle(batch_size * 10)
    ds = ds.batch(batch_size)
    ds = ds.map(input_utils.preprocess_batch, num_parallel_calls=tf.data.AUTOTUNE)
    return ds

In [8]:
train_ds = get_dataset("train", BATCH_SIZE, True)
# valid_ds = get_dataset("val", BATCH_SIZE, False)
# test_ds = get_dataset("test", BATCH_SIZE, False)

In [9]:
for features, labels in train_ds.take(1):
    print(features.keys())
    print(labels.shape)

2021-12-16 08:44:23.278668: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:44:23.502315: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:44:23.568683: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:44:23.568712: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:44:23.568761: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse 

InvalidArgumentError: Key: summary_num_sentences.  Can't parse serialized Example.
	 [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]] [Op:IteratorGetNext]

In [19]:
ds = tf.data.Dataset.list_files(f"{TFRECORDS_DIR}/train-*.tfrecord")
raw_dataset = tf.data.TFRecordDataset(ds)
raw_dataset

<TFRecordDatasetV2 shapes: (), types: tf.string>

In [20]:
# https://www.tensorflow.org/tutorials/load_data/tfrecord
for raw_record in raw_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    print(example)

features {
  feature {
    key: "label"
    value {
      int64_list {
        value: 20
      }
    }
  }
  feature {
    key: "summary"
    value {
      bytes_list {
        value: "After his defeat in the original Star Fox, the game\'s antagonist, Andross, returns to the Lylat system and launches an all-out attack against Corneria, using his new fleet of battleships and giant missiles launched from hidden bases to destroy the planet. General Pepper again calls upon the Star Fox team for help. Armed with new custom Arwings, a Mothership, and two new recruits (Miyu, a lynx, and Fay, a dog), the Star Fox team sets out to defend Corneria by destroying Andross\'s forces before they can inflict critical damage on the planet. Along the way, Star Fox must also combat giant boss enemies, bases on planets throughout the Lylat system, members of the Star Wolf team and finally Andross himself."
      }
    }
  }
  feature {
    key: "summary_num_sentences"
    value {
      int64_list {
      

In [21]:
# https://www.tensorflow.org/tutorials/load_data/tfrecord
result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
    # The values are the Feature objects which contain a `kind` which contains:
    # one of three fields: bytes_list, float_list, int64_list

    kind = feature.WhichOneof("kind")
    result[key] = np.array(getattr(feature, kind).value)

result

{'summary_tokens_len': array([b'\x08\x07\x12\x04\x12\x02\x08\x02B\x1a\x08\x03\x12\x04\x12\x02\x08\x04"\x106\x00\x00\x00\x0c\x00\x00\x00:\x00\x00\x00#\x00\x00\x00B\x1a\x08\t\x12\x04\x12\x02\x08\x02"\x10\x00\x00\x00\x00\x00\x00\x00\x00\x04'],
       dtype='|S64'),
 'summary_sentence_indices': array([b'\x08\x07\x12\x04\x12\x02\x08\x02B\x1a\x08\x03\x12\x04\x12\x02\x08\x04"\x10\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00B\x1a\x08\t\x12\x04\x12\x02\x08\x02"\x10\x00\x00\x00\x00\x00\x00\x00\x00\x04'],
       dtype='|S64'),
 'summary_tokens': array([b'\x08\x07\x12\x04\x12\x02\x08\x03B\x88\x05\x08\x03\x12\x05\x12\x03\x08\x9f\x01"\xfc\x04\xfc\x07\x00\x00\xda\x07\x00\x00:\x10\x00\x00\xcf\x07\x00\x00\xcc\x07\x00\x00\x82\t\x00\x00\xac\n\x00\x00C\x11\x00\x00\xf2\x03\x00\x00\xcc\x07\x00\x00\xa0\x08\x00\x00\xed\x03\x00\x00\x1f\x04\x00\x00\xe3C\x00\x00\xf2\x03\x00\x00\xce\x07\x00\x00}d\x00\x00\xf2\x03\x00\x00\x13\x16\x00\x00\xd0\x07\x00\x00\xcc\x07\x00\x00\x18\x04\x00\x00\x87]\x00\x0

In [24]:
raw_dataset_mapped = raw_dataset.map(lambda x: tf.io.parse_single_example(x, feature_descriptions))
# raw_dataset_mapped
for parsed_record in raw_dataset_mapped.take(10):
    print(repr(parsed_record))

2021-12-16 08:59:05.354274: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:59:05.354305: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:59:05.354326: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:59:05.354353: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse serialized Example.
2021-12-16 08:59:05.354367: W tensorflow/core/framework/op_kernel.cc:1692] OP_REQUIRES failed at example_parsing_ops.cc:94 : Invalid argument: Key: summary_num_sentences.  Can't parse 

InvalidArgumentError: Key: summary_num_sentences.  Can't parse serialized Example.
	 [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]] [Op:IteratorGetNext]