In [1]:
# !pip install tf-models-official
# !pip install tfds-nightly

In [2]:
# Imports
import os
import tensorflow as tf
import tensorflow_models as tfm
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text
from typing import Any

# Only let TensorFlow allocate RAM as needed
physical_devices = tf.config.experimental.list_physical_devices("GPU")
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Configuration
tf.random.set_seed(0)
batch_size = 64
seq_len = 128
snli_ds_name = "snli"
sentence_features = ["premise", "hypothesis"]
num_classes = 3
bert_encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"
bert_preprocessor_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"

In [3]:
def build_bert_preprocessor(sentence_features: "list[str]", seq_length: int = 128) -> tf.keras.Model:
    input_segments = [
        tf.keras.layers.Input(shape=(), dtype=tf.string, name=ft)
        for ft in sentence_features]

    # Tokenize the text to word pieces.
    bert_preprocess = hub.load(bert_preprocessor_url)
    tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name="tokenizer")
    segments = [tokenizer(s) for s in input_segments]

    # Optional: Trim segments in a smart way to fit seq_length.
    # Simple cases (like this example) can skip this step and let
    # the next step apply a default truncation to approximately equal lengths.
    truncated_segments = segments

    # Pack inputs. The details (start/end token ids, dict of output tensors)
    # are model-dependent, so this gets loaded from the SavedModel.
    packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs,
                            arguments=dict(seq_length=seq_length),
                            name="packer")
    model_inputs = packer(truncated_segments)
    return tf.keras.Model(input_segments, model_inputs)

In [4]:
def get_data_from_dataset(full_dataset: "dict[Any, tf.data.Dataset]", info: tfds.core.DatasetBuilder, split: str,
                          batch_size: int, preprocessor: tf.keras.Model) -> "tuple[tf.data.Dataset, int]":
    is_training = split.startswith("train")
    dataset = tf.data.Dataset.from_tensor_slices(full_dataset[split])
    num_examples = info.splits[split].num_examples

    if is_training:
        dataset = dataset.shuffle(num_examples)
        dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda ex: (preprocessor(ex), ex["label"]))
    dataset = dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset, num_examples

In [5]:
def build_bert_classifier(num_classes: int) -> tf.keras.Model:
    class Classifier(tf.keras.Model):
        def __init__(self, num_classes):
            super(Classifier, self).__init__(name="prediction")
            self.encoder = hub.KerasLayer(bert_encoder_url, trainable=True)
            self.dropout = tf.keras.layers.Dropout(0.1)
            self.dense = tf.keras.layers.Dense(num_classes)

        def call(self, preprocessed_text):
            encoder_outputs = self.encoder(preprocessed_text)
            pooled_output = encoder_outputs["pooled_output"]
            x = self.dropout(pooled_output)
            x = self.dense(x)
            return x

    model = Classifier(num_classes)
    return model

In [6]:
bert_preprocessor = build_bert_preprocessor(sentence_features, seq_length=seq_len)

# Create train, test, and validation data
snli_ds: "dict[tfds.Split, tf.data.Dataset]" = tfds.load(snli_ds_name, batch_size=-1, shuffle_files=True)  # type: ignore
snli_ds_info = tfds.builder(snli_ds_name).info

train_data, train_data_size = get_data_from_dataset(snli_ds, snli_ds_info,
    "train", batch_size, bert_preprocessor)
steps_per_epoch = train_data_size // batch_size

validation_data, validation_data_size = get_data_from_dataset(snli_ds, snli_ds_info,
    "validation", batch_size, bert_preprocessor)
validation_steps = validation_data_size // batch_size

test_data, test_data_size = get_data_from_dataset(snli_ds, snli_ds_info,
    "test", batch_size, bert_preprocessor)

  inputs = self._flatten_to_reference_inputs(inputs)


In [7]:
optimizer = tf.keras.optimizers.Adam(2e-5)
metrics = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy", dtype=tf.float32)]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Build model
bert_classifier = build_bert_classifier(num_classes)
bert_classifier.compile(optimizer=optimizer, metrics=metrics, loss=loss)

# Evaluate model
bert_classifier.evaluate(test_data)



[nan, 0.3179999887943268]