In [None]:
!pip install -q tensorflow_text

In [None]:
!pip install -q tf-models-official

In [None]:
import os
import shutil

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

from official.nlp import optimization

import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

In [None]:
url = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'

dataset = tf.keras.utils.get_file('aclImdb_v1.tar.gz', url, True, cache_dir='.', cache_subdir='')

In [None]:
os.listdir(os.path.dirname(dataset))

In [None]:
dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
dataset_dir

In [None]:
os.listdir(dataset_dir)

In [None]:
train_dir = os.path.join(dataset_dir, 'train')
train_dir

In [None]:
os.listdir(train_dir)

In [None]:
shutil.rmtree(os.path.join(train_dir, 'unsup'))

In [None]:
os.listdir(train_dir)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE=32
SEED=42

In [None]:
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train',
    batch_size=BATCH_SIZE,
    validation_split=0.2,
    subset='training',
    seed=SEED
)

class_names = raw_train_ds.class_names

# 注意在加载时已经batch过了
train_ds = raw_train_ds.cache().prefetch(AUTOTUNE)

In [None]:
val_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/train',
    batch_size=BATCH_SIZE,
    validation_split=0.2,
    subset='validation',
    seed=SEED
)

val_ds = val_ds.cache().prefetch(AUTOTUNE)

In [None]:
test_ds = tf.keras.preprocessing.text_dataset_from_directory(
    'aclImdb/test',
    batch_size=BATCH_SIZE
)

test_ds = test_ds.cache().prefetch(AUTOTUNE)

In [None]:
for text_batch, label_batch in train_ds.take(1):
    print("Text: ", text_batch.numpy()[:3])
    print("Label: ", label_batch.numpy()[:3])

In [None]:
map_name_to_handle = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_L-12_H-768_A-12/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',
    'albert_en_base':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_base/2',
    'electra_small':
        'https://hub.tensorflow.google.cn/google/electra_small/2',
    'electra_base':
        'https://hub.tensorflow.google.cn/google/electra_base/2',
    'experts_pubmed':
        'https://hub.tensorflow.google.cn/google/experts/bert/pubmed/2',
    'experts_wiki_books':
        'https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/2',
    'talking-heads_base':
        'https://hub.tensorflow.google.cn/tensorflow/talkheads_ggelu_bert_en_base/1',
}

map_model_to_preprocess = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_cased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://hub.tensorflow.google.cn/tensorflow/bert_multi_cased_preprocess/3',
    'albert_en_base':
        'https://hub.tensorflow.google.cn/tensorflow/albert_en_preprocess/3',
    'electra_small':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'electra_base':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'experts_pubmed':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'experts_wiki_books':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
    'talking-heads_base':
        'https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/3',
}

bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8' 

tfhub_handle_encoder = map_name_to_handle[bert_model_name]
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]

print(f'BERT model selected           : {tfhub_handle_encoder}')
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')

In [None]:
bert_preprocess_layer = hub.KerasLayer(tfhub_handle_preprocess)

In [None]:
text_test = ["this is such an amazing movie!"]
text_preprocessed = bert_preprocess_layer(text_test)
text_preprocessed

In [None]:
print(f"Keys      : {list(text_preprocessed.keys())}")
print(f"Shape     : {text_preprocessed['input_word_ids'].shape}")
print(f"Input word: {text_preprocessed['input_word_ids'][0, :12]}")
print(f"Input mask: {text_preprocessed['input_mask'][0, :12]}")
print(f"Input type: {text_preprocessed['input_type_ids'][0, :12]}")

In [None]:
bert_model = hub.KerasLayer(tfhub_handle_encoder)

In [None]:
bert_result = bert_model(text_preprocessed)
# bert_result

In [None]:
bert_result.keys()

In [None]:
print(f"Loaded Bert: {tfhub_handle_encoder}")
print(f"Pooled output shape: {bert_result['pooled_output'].shape}")
print(f"Pooled output: {bert_result['pooled_output'][0, :12]}")
print(f"Sequence output shape:{bert_result['sequence_output'].shape}")
# print(f'Sequence Outputs Values:{bert_result["sequence_output"][0, :12]}')
print(f"Encoder output shape:{len(bert_result['encoder_outputs'])}")

In [None]:
def build_classifier_model():
    text_input = tf.keras.layers.Input((), dtype=tf.string, name='text')

    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    preprocessed_output = preprocessing_layer(text_input)

    bert_encoder_layer = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    bert_encoded_output = bert_encoder_layer(preprocessed_output)

    x = bert_encoded_output['pooled_output']
    x = tf.keras.layers.Dropout(0.1)(x)
    x = tf.keras.layers.Dense(1, name='classifier')(x)

    return tf.keras.Model(text_input, x)

In [None]:
classifier_model = build_classifier_model()

In [None]:
bert_raw_result = classifier_model(tf.constant(text_test))
tf.sigmoid(bert_raw_result)

In [None]:
tf.keras.utils.plot_model(classifier_model)

In [None]:
loss = tf.keras.losses.BinaryCrossentropy(True)
metrics = tf.metrics.BinaryAccuracy()

In [None]:
epochs = 5

steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)

init_lr = 3e-5
optimizer = optimization.create_optimizer(init_lr, num_train_steps, num_warmup_steps, optimizer_type='adamw')

In [None]:
classifier_model.compile(optimizer, loss, metrics)

In [None]:
print("Training model with ", tfhub_handle_encoder)
history = classifier_model.fit(x=train_ds, validation_data=val_ds, epochs=epochs)

In [None]:
loss, acc = classifier_model.evaluate(test_ds)

print("Loss: ", loss)
print("Acc: ", acc)

In [None]:
history_dict = history.history
print(history_dict.keys())

In [None]:
acc = history_dict['binary_accuracy']
val_acc = history_dict['val_binary_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']

In [None]:
epochs = range(1, len(acc) + 1)

fig = plt.figure(figsize=(10,6))
fig.tight_layout()

plt.subplot(2,1,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.ylabel('Loss')
plt.legend()

plt.subplot(2,1,2)
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.legend(loc='lower right')

In [None]:
dataset_name = 'imdb'
save_model_path = './{}_bert'.format(dataset_name.replace('/', '_'))
save_model_path

In [None]:
classifier_model.save(save_model_path, include_optimizer=False)

In [None]:
reloaded_model = tf.saved_model.load(save_model_path)

In [None]:
def print_my_examples(inputs, results):
    result_for_printing = [f'input: {inputs[i]:<30} : score: {results[i][0]:.6f}' for i in range(len(inputs))]
    print(*result_for_printing, sep='\n')
    print()

In [None]:
examples = [
    'this is such an amazing movie!',
    'The movie was great!',
    'The movie was meh.',
    'The movie was okish.',
    'The movie was terrible...'
]

In [None]:
reloaded_results = tf.sigmoid(reloaded_model(tf.constant(examples)))
origin_results = tf.sigmoid(classifier_model(tf.constant(examples)))

In [None]:
print("Reloaded model result:")
print_my_examples(examples, reloaded_results)
print("Origin model result:")
print_my_examples(examples, origin_results)

In [None]:
serving_results = reloaded_model.signatures['serving_default'](tf.constant(examples))
serving_results = tf.sigmoid(serving_results['classifier'])
print_my_examples(examples, serving_results)