In [None]:
import gc
import tensorflow as tf
import tensorflow_datasets
import numpy as np
import tensorflow.keras as keras
from tensorflow.keras.layers import Dense, Input

from transformers import (TFBertModel, 
                          BertTokenizer,
                          glue_convert_examples_to_features)

In [None]:
# Constants
BATCH_SIZE = 32
MAX_SEQ_LEN = 128
EPOCHS = 3

# FP16 settings
fp16 = True
if fp16:
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
    BATCH_SIZE = 48

In [None]:
# Fetch pre-trained models

bert_base_model = TFBertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [None]:
# Fetch and format dataset.
def downconvert_tf_dataset(dataset, tok, pad_token=0):
    inputs = []
    atts = []
    toks = []
    outputs = []
    for i,m in enumerate(dataset):
        input = tok.encode_plus(m['sentence'].numpy().decode("utf-8"),\
                                      add_special_tokens=True, max_length=MAX_SEQ_LEN,)
        input_ids, token_type_ids = input["input_ids"], input["token_type_ids"]
        attention_mask = [0] * len(input_ids)
        
        # Pad strings to exactly MAX_SEQ_LEN
        padding_length = MAX_SEQ_LEN - len(input_ids)
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)

        # Double-check results.
        assert len(input_ids) == MAX_SEQ_LEN, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == MAX_SEQ_LEN, "Error with input length {} vs {}".format(
            len(attention_mask), MAX_SEQ_LEN
        )
        assert len(token_type_ids) == MAX_SEQ_LEN, "Error with input length {} vs {}".format(
            len(token_type_ids), MAX_SEQ_LEN
        )
        
        # Form lists.
        inputs.append(np.asarray(input_ids))
        atts.append(np.asarray(attention_mask))
        toks.append(np.asarray(token_type_ids))
        outputs.append(m['label'].numpy())
    return [np.asarray(inputs), np.asarray(atts), np.asarray(toks)], np.asarray(outputs)

sst_data = tensorflow_datasets.load("glue/sst2")
sst_train_x, sst_train_y = downconvert_tf_dataset(sst_data["train"], tokenizer)
sst_val_x, sst_val_y = downconvert_tf_dataset(sst_data["validation"], tokenizer)

'''
# We can work in-place with a TF dataset, but working with them is a pain and there are bugs:
#  Doesn't complete some epochs with "out of data" error
#  Doesn't track training properly
#  Doesn't use validation set properly
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, MAX_SEQ_LEN, 'sst-2')\
    .shuffle(1337)\
    .batch(BATCH_SIZE)\
    .repeat(EPOCHS)
validation_dataset = glue_convert_examples_to_features(data["validation"], tokenizer, MAX_SEQ_LEN, 'sst-2')\
    .batch(BATCH_SIZE)
    '''

print("Done.")

In [None]:
# Configure and compile model.

# Later cells might set trainable=False; which we don't necessarily want here.
bert_base_model.trainable = True
sst_inputs = [Input(shape=(128,), dtype='int32', name='input_ids'),
          Input(shape=(128,), dtype='int32', name='attention_mask'), 
          Input(shape=(128,), dtype='int32', name='token_type_ids')]
# Fetch the CLS head of the BERT model; index 1.
sst_tensor = bert_base_model(sst_inputs)[1]
#sst_tensor = Dense(activation='softmax', units=256)(sst_tensor)
sst_tensor = Dense(activation='softmax', units=2)(sst_tensor)
sst_bert_model = keras.Model(inputs=sst_inputs, outputs=sst_tensor)
print(sst_bert_model.summary())

# Configure optimizer, loss function and metrics.
sst_optimizer_base_model = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)
sst_optimizer_head = tf.keras.optimizers.Adam()
if fp16:
    tf.train.experimental.enable_mixed_precision_graph_rewrite(sst_optimizer_base_model)
    tf.train.experimental.enable_mixed_precision_graph_rewrite(sst_optimizer_head)
sst_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
sst_metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

#sst_bert_model.compile(optimizer=[sst_optimizer_base_model, sst_optimizer_head], loss=sst_loss, metrics=[sst_metric])

In [None]:
## This cell defines a utility function for easily tracking progress in jupyter.

import time, sys
from IPython.display import clear_output

def update_progress(epoch, progress, msg):
    bar_length = 20
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
    if progress < 0:
        progress = 0
    if progress >= 1:
        progress = 1
    block = int(round(bar_length * progress))
    clear_output(wait = True)
    text = "Epoch {0}: [{1}] {2:.1f}% {3}".format(epoch, "#" * block + "-" * (bar_length - block), \
                                                  progress * 100, msg)
    print(text)

In [None]:
from statistics import mean 

def train_step(model, loss, optimizers, optimizer_vars, x_vals, y_vals):
    with tf.GradientTape(persistent=True) as tape:
        logits = model(x_vals, training=True)
        loss_value = loss(y_vals, logits)
    loss_scalar = loss_value.numpy().mean()
    for (optimizer, vrz) in zip(optimizers, optimizer_vars):
        grads = tape.gradient(loss_value, vrz)
        optimizer.apply_gradients(zip(grads, vrz))
    del tape
    return loss_scalar

def train(model, loss, optimizers, optimizer_vars, x_vals, y_vals, epochs, batch_sz):
    # Garbage collect before starting training to attempt to free up GPU memory.
    gc.collect()
    loss_history = []
    for epoch in range(epochs):
        training_data_count = y_vals.shape[0]
        for batch_num in range(int(training_data_count / batch_sz)):
            ii = batch_num * batch_sz
            li = ii + batch_sz
            li = training_data_count-1 if (li >= training_data_count) else li
            # ii=initial index, li=last index. now create batches. remember that x_vals is a list of inputs.
            batch_x = [x_vals_ele[ii:li] for x_vals_ele in x_vals]
            batch_y = y_vals[ii:li]
            loss_history.append(train_step(model, loss, optimizers, optimizer_vars, batch_x, batch_y))
            loss_mean = loss_history[-1] if (len(loss_history) < 10) else mean(loss_history[-10:-1])
            update_progress(epoch, ii/training_data_count, \
                            "Loss=%f" % (loss_history[-1]))
    return loss_history
            
def find_tf_variables_not_in_list(full_variable_list, diff_list):
    # Since tensors can't be compared directly, extract their names and use those as keys instead.
    fvl_names = [v.name for v in full_variable_list]
    dl_names = [v.name for v in diff_list]
    diff_names = [n for n in fvl_names if (n not in dl_names)]
    return [v for v in full_variable_list if (v.name in diff_names)]

head_variables = find_tf_variables_not_in_list(sst_bert_model.trainable_variables, bert_base_model.trainable_variables)
sst_bert_history = train(sst_bert_model, sst_loss, [sst_optimizer_base_model, sst_optimizer_head],\
                        [bert_base_model.trainable_variables, head_variables], sst_train_x, sst_train_y,\
                        EPOCHS, 16)
# Train model.
#sst_bert_history = sst_bert_model.fit(sst_train_x, sst_train_y, epochs=EPOCHS, \
#                                      validation_data=(sst_val_x, sst_val_y))


In [None]:
print(len(sst_bert_history))

In [None]:
# Transition dataset to another task; freeze BERT model; re-train new head.

data_cola = tensorflow_datasets.load("glue/cola")
#print(list(data_cola["validation"].__iter__())[0:5])
cola_train_x, cola_train_y = downconvert_tf_dataset(data_cola["train"], tokenizer)
cola_val_x, cola_val_y = downconvert_tf_dataset(data_cola["validation"], tokenizer)

cola_optimizer = VariantRateAdam(name="Adam")
if fp16:
    tf.train.experimental.enable_mixed_precision_graph_rewrite(cola_optimizer)
cola_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
cola_metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

# Lock down the bert model. The intent is that the previous model trained this one.
bert_base_model.trainable = False
cola_inputs = [Input(shape=(128,), dtype='int32', name='input_ids'),
          Input(shape=(128,), dtype='int32', name='attention_mask'), 
          Input(shape=(128,), dtype='int32', name='token_type_ids')]
# Fetch the CLS head of the BERT model; index 1.
cola_tensor = bert_base_model(cola_inputs)[1]
cola_tensor = Dense(activation='softmax', units=256)(cola_tensor)
cola_tensor = Dense(activation='softmax', units=2)(cola_tensor)
cola_bert_model = keras.Model(inputs=cola_inputs, outputs=cola_tensor)
print(cola_bert_model.summary())

cola_bert_model.compile(optimizer=cola_optimizer, loss=cola_loss, metrics=[cola_metric])

In [None]:
cola_bert_history = cola_bert_model.fit(cola_train_x, cola_train_y, epochs=EPOCHS, \
                                      validation_data=(cola_val_x, cola_val_y))