In [None]:
!pip install -q -U tensorflow-text
!pip install -q -U tf-models-official
!pip install -q -U tfds-nightly

In [None]:
import os
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import tensorflow_text as text
import tensorflow_addons as tfa

from official.nlp import optimization

tf.get_logger().setLevel('ERROR')

In [None]:
os.environ["TFHUB_MODEL_LOAD_FORMAT"]="UNCOMPRESSED"

In [None]:
if os.environ['COLAB_TPU_ADDR']:
    cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
    tf.config.experimental_connect_to_cluster(cluster_resolver)
    tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
    strategy = tf.distribute.TPUStrategy(cluster_resolver)
    print("Using TPU")
elif tf.test.is_gpu_available():
    strategy = tf.distribute.MirroredStrategy()
    print("Using GPU")
else:
    raise ValueError('Running on CPU is not recommended')

In [None]:
bert_model_name = 'bert_en_uncased_L-12_H-768_A-12' 

map_name_to_handle = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',
    'bert_en_uncased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_L-24_H-1024_A-16/3',
    'bert_en_wwm_uncased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_L-12_H-768_A-12/3',
    'bert_en_cased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_L-24_H-1024_A-16/3',
    'bert_en_wwm_cased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',
    'albert_en_base':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_base/2',
    'albert_en_large':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_large/2',
    'albert_en_xlarge':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_xlarge/2',
    'albert_en_xxlarge':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_xxlarge/2',
    'electra_small':
        'https://hub.tensorflow.google.cn/google/electra_small/2',
    'electra_base':
        'https://hub.tensorflow.google.cn/google/electra_base/2',
    'experts_pubmed':
        'https://hub.tensorflow.google.cn/google/experts/bert/pubmed/2',
    'experts_wiki_books':
        'https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/2',
    'talking-heads_base':
        'https://hub.tensorflow.google.cn/tensorflow/talkheads_ggelu_bert_en_base/1',
    'talking-heads_large':
        'https://hub.tensorflow.google.cn/tensorflow/talkheads_ggelu_bert_en_large/1',
}

map_model_to_preprocess = {
    'bert_en_uncased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'bert_en_wwm_cased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_preprocess/3',
    'bert_en_cased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_preprocess/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_preprocess/3',
    'bert_en_wwm_uncased_L-24_H-1024_A-16':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_multi_cased_preprocess/3',
    'albert_en_base':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/3',
    'albert_en_large':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/3',
    'albert_en_xlarge':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/3',
    'albert_en_xxlarge':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/3',
    'electra_small':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'electra_base':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'experts_pubmed':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'experts_wiki_books':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'talking-heads_base':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'talking-heads_large':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
}

tfhub_handle_encoder = map_name_to_handle[bert_model_name]
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]

print('BERT model selected           :', tfhub_handle_encoder)
print('Preprocessing model auto-selected:', tfhub_handle_preprocess)

In [None]:
bert_preprocess = hub.load(tfhub_handle_preprocess)

In [None]:
tok = bert_preprocess.tokenize(tf.constant(['Hello TensorFlow!']))
tok

In [None]:
text_preprocessed = bert_preprocess.bert_pack_inputs([tok, tok], tf.constant(20))
text_preprocessed

In [None]:
print(text_preprocessed.keys())

In [None]:
print('Shape of word ids: ', text_preprocessed['input_word_ids'].shape)
print('Word ids         : ', text_preprocessed['input_word_ids'][0, :12])
print('Shape of mask    : ', text_preprocessed['input_mask'].shape)
print('Mask             : ', text_preprocessed['input_mask'][0, :12])
print('Shape of type ids: ', text_preprocessed['input_type_ids'].shape)
print('Type ids         : ', text_preprocessed['input_type_ids'][0, :12])

In [None]:
def make_bert_preprocess_model(sentence_features, seq_length=128):
    input_segments = [tf.keras.layers.Input((), dtype=tf.string, name=ft) for ft in sentence_features]

    bert_preprocess = hub.load(tfhub_handle_preprocess)
    tokenizer = hub.KerasLayer(bert_preprocess.tokenize, name='tokenizer')
    segments = [tokenizer(s) for s in input_segments]

    truncated_segments = segments

    packer = hub.KerasLayer(bert_preprocess.bert_pack_inputs, arguments=dict(seq_length=seq_length), name='packer')
    model_inputs = packer(truncated_segments)

    return tf.keras.Model(input_segments, model_inputs)

In [None]:
test_preprocess_model = make_bert_preprocess_model(['input_1', 'input_2'])
test_text1 = [np.array(['some random test sentence']), np.array(['another sentence'])]
text_preprocessed = test_preprocess_model(test_text1)

In [None]:
print(text_preprocessed.keys())

In [None]:
print('Shape of word ids: ', text_preprocessed['input_word_ids'].shape)
print('Word ids         : ', text_preprocessed['input_word_ids'][0, :12])
print('Shape of mask    : ', text_preprocessed['input_mask'].shape)
print('Mask             : ', text_preprocessed['input_mask'][0, :12])
print('Shape of type ids: ', text_preprocessed['input_type_ids'].shape)
print('Type ids         : ', text_preprocessed['input_type_ids'][0, :12])

In [None]:
tf.keras.utils.plot_model(test_preprocess_model)

In [None]:
test_preprocess_model2 = make_bert_preprocess_model(['input_1', 'input_2'])
test_text2 = [np.array(['some random test sentence']), np.array(['another sentence'])]
text_preprocessed2 = test_preprocess_model2(test_text2)

print('Shape of word ids: ', text_preprocessed2['input_word_ids'].shape)
print('Word ids         : ', text_preprocessed2['input_word_ids'][0, :16])
print('Shape of mask    : ', text_preprocessed2['input_mask'].shape)
print('Mask             : ', text_preprocessed2['input_mask'][0, :16])
print('Shape of type ids: ', text_preprocessed2['input_type_ids'].shape)
print('Type ids         : ', text_preprocessed2['input_type_ids'][0, :16])

In [None]:
def load_dataset_from_tfds(in_memory_ds, info, split, batch_size, bert_preprocess_model):
    is_training = split.startswith('train')

    dataset = tf.data.Dataset.from_tensor_slices(in_memory_ds[split])
    num_examples = info.splits[split].num_examples

    if is_training:
        dataset = dataset.shuffle(num_examples)
        dataset = dataset.repeat()
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda ex: (bert_preprocess_model(ex), ex['label']))
    dataset = dataset.cache().prefetch(tf.data.AUTOTUNE)

    return dataset, num_examples

In [None]:
def build_classifier_model(num_classes):
    inputs = dict(
        input_word_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32),
        input_mask = tf.keras.layers.Input((None,), dtype=tf.int32),
        input_type_ids = tf.keras.layers.Input((None,), dtype=tf.int32),
    )

    encoder = hub.KerasLayer(tfhub_handle_encoder, True, name='encoder')

    net = encoder(inputs)['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(num_classes, name='classifier')(net)

    return tf.keras.Model(inputs, net, name='prediction')

In [None]:
test_classifier_model = build_classifier_model(2)

In [None]:
text_preprocessed

In [None]:
bert_raw_result = test_classifier_model(text_preprocessed)
print(tf.sigmoid(bert_raw_result))

In [None]:
tf.keras.utils.plot_model(test_classifier_model)

In [None]:
tfds_name = 'glue/cola'

In [None]:
tfds.builder(tfds_name).info

In [None]:
tfds_info = tfds.builder(tfds_name).info

In [None]:
sentence_features = list(tfds_info.features.keys())
sentence_features

In [None]:
sentence_features.remove('idx')
sentence_features.remove('label')
sentence_features

In [None]:
avaliable_splits = list(tfds_info.splits.keys())
avaliable_splits

In [None]:
train_split = 'train'
validation_split = 'validation'
test_split = 'test'

In [None]:
if tfds_name == 'glue/mnli':
    validation_split = 'validation_matched'
    test_split = 'test_matched'

In [None]:
num_classes = tfds_info.features['label'].num_classes
num_examples = tfds_info.splits.total_num_examples

In [None]:
print(f'Using {tfds_name} from TFDS')
print(f'This dataset has {num_examples} examples')
print(f'Number classes: {num_classes}')
print(f'Features: {sentence_features}')
print(f'Splits: {avaliable_splits}')

In [None]:
with tf.device('/job:localhost'):
    in_memory_ds = tfds.load(tfds_name, batch_size=-1, shuffle_files=True)

In [None]:
print(f'Here are some sample rows from {tfds_name} dataset')

sample_dataset = tf.data.Dataset.from_tensor_slices(in_memory_ds[train_split])
label_names = tfds_info.features['label'].names
print(label_names)

In [None]:
sample_dataset.take(1)

In [None]:
sample_i = 1

for sample_row in sample_dataset.take(5):
    samples = [sample_row[feature] for feature in sentence_features]
    print(f'sample row {sample_i}:')
    for sample in samples:
        print(sample)
    
    sample_label = sample_row['label']
    print(f'label: {sample_label} ({label_names[sample_label]})')
    print()
    sample_i += 1

In [None]:
def get_configuration(glue_task):
    loss = tf.keras.losses.SparseCategoricalCrossentropy(True)

    if glue_task == 'glue/cola':
        metrics = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2)
    else:
        metrics = tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)
    
    return metrics, loss

In [None]:
epochs=3
batch_size=32
init_lr=2e-5

print(f'Fine tuning {tfhub_handle_encoder} model')
bert_preprocess_model = make_bert_preprocess_model(sentence_features)
tf.keras.utils.plot_model(bert_preprocess_model)

In [None]:
with strategy.scope():
    metrics, loss = get_configuration(tfds_name)

    train_dataset, train_data_size = load_dataset_from_tfds(in_memory_ds, tfds_info, train_split, batch_size, bert_preprocess_model)

    steps_per_epoch = train_data_size // batch_size
    num_train_steps = steps_per_epoch * epochs
    num_warmup_steps = num_train_steps // 10

    validation_dataset, validation_data_size = load_dataset_from_tfds(in_memory_ds, tfds_info, validation_split, batch_size, bert_preprocess_model)
    validation_steps = validation_data_size // batch_size

    classifier_model = build_classifier_model(num_classes)

    optimizer = optimization.create_optimizer(init_lr, num_train_steps, num_warmup_steps)

    classifier_model.compile(optimizer, loss=loss, metrics=metrics)

    classifier_model.fit(x=train_dataset, validation_data=validation_dataset, epochs=epochs, steps_per_epoch= steps_per_epoch, validation_steps=validation_steps)

In [None]:
main_save_path = './my_models'

In [None]:
bert_type = tfhub_handle_encoder.split('/')[-2]
saved_model_name = f"{tfds_name.replace('/', '_')}_{bert_type}"

save_model_path = os.path.join(main_save_path, saved_model_name)
save_model_path

In [None]:
preprocess_inputs = bert_preprocess_model.inputs
bert_encoder_inputs = bert_preprocess_model(preprocess_inputs)
bert_outputs = classifier_model(bert_encoder_inputs)
model_for_export = tf.keras.Model(preprocess_inputs, bert_outputs)

print('Saving ', save_model_path)

save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model_for_export.save(save_model_path, include_optimizer=False, options=save_options)

In [None]:
with tf.device('/job:localhost'):
    reloaded_model = tf.saved_model.load(save_model_path)

In [None]:
def prepare(record):
    model_inputs = [[record[ft]] for ft in sentence_features]
    return model_inputs

def prepare_serving(record):
    model_inputs = {ft: record[ft] for ft in sentence_features}
    return model_inputs

def print_bert_results(test, bert_result, dataset_name):
    bert_result_class = tf.argmax(bert_result, axis=1)[0]

    if dataset_name == 'glue/cola':
        print('sentence: ', test[0].numpy())
        if bert_result_class == 1:
            print('This sentence is Acceptable')
        else:
            print('This sentence is Unacceptable')
    
    elif dataset_name == 'glue/sst2':
        print('sentence: ', test[0].numpy())
        if bert_result_class == 1:
            print('This sentence is POSITIVE')
        else:
            print('This sentence is NEGATIVE')
    
    elif dataset_name == 'glue/mrpc':
        print('sentence1: ', test[0].numpy())
        print('sentence2: ', test[1].numpy())
        if bert_result_class == 1:
            print('Are a paraphrase')
        else:
            print('Are NOT a paraphrase')
    
    elif dataset_name == 'glue/qqb':
        print('question1: ', test[0].numpy())
        print('question2: ', test[1].numpy())
        if bert_result_class == 1:
            print('Questions are similar')
        else:
            print('Questions are NOT similay')
    
    elif dataset_name == 'glue/mnli':
        print('premise: ', test[0].numpy())
        print('hypothesis: ', test[1].numpy())
        if bert_result_class == 1:
            print('This premise is NEUTRAL to the hypothesis')
        elif bert_result_class == 2:
            print('This premise is CONTRADICT to the hypothesis')
        else:
            print('This premise is ENTAILS to the hypothesis')
    
    elif dataset_name == 'glue/qnli':
        print('question: ', test[0].numpy())
        print('sentence: ', test[1].numpy())
        if bert_result_class == 1:
            print('This question is NOT answerable by the sentence')
        else:
            print('This question is answerable by the sentence')
    
    elif dataset_name == 'glue/rte':
        print('sentence1: ', test[0].numpy())
        print('sentence2: ', test[1].numpy())
        if bert_result_class == 1:
            print('sentence1 DOES NOT entail sentence2')
        else:
            print('sentence1 entail sentence2')

    elif dataset_name == 'glue/wnli':
        print('sentence1: ', test[0].numpy())
        print('sentence2: ', test[1].numpy())
        if bert_result_class == 1:
            print('sentence1 DOES NOT entail sentence2')
        else:
            print('sentence1 entail sentence2')
    
    print('Bert raw results: ', bert_result[0])
    print()

In [None]:
with tf.device('/job:localhost'):
    test_dataset = tf.data.Dataset.from_tensor_slices(in_memory_ds[test_split])
    for test_row in test_dataset.shuffle(1000).map(prepare).take(5):
        if len(sentence_features) == 1:
            result = reloaded_model(test_row[0])
        else:
            result = reloaded_model(list(test_row))
        
        print_bert_results(test_row, result, tfds_name)

In [None]:
with tf.device('/job:localhost'):
    serving_model = reloaded_model.signatures['serving_default']
    for test_row in test_dataset.shuffle(1000).map(prepare_serving).take(5):
        result = serving_model(**test_row)
        # The 'prediction' key is the classifier's defined model name.
        print_bert_results(list(test_row.values()), result['prediction'], tfds_name)