In [1]:
import gc
import tensorflow as tf
import tensorflow_datasets
import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Dense, Input
from utils import downconvert_tf_dataset
import wandb
from wandb.keras import WandbCallback

from transformers import (TFBertModel, 
                          BertTokenizer,
                          glue_convert_examples_to_features)

In [2]:
# Constants
BATCH_SIZE = 32
MAX_SEQ_LEN = 128
EPOCHS = 3

# FP16 settings
fp16 = True
if fp16:
    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
    BATCH_SIZE = 48

In [3]:
# Fetch pre-trained models
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

In [4]:
dataset_cache = {}

def create_new_classification_head(dataset_name, base_model_cls_head, dense_config=[256,2]):
    # Fetch the data.
    if(dataset_name in dataset_cache.keys()):
        train_x = dataset_cache[dataset_name]["tx"]
        train_y = dataset_cache[dataset_name]["ty"]
        val_x = dataset_cache[dataset_name]["vx"]
        val_y = dataset_cache[dataset_name]["vy"]
        print("Restored dataset from cache.")
    else:
        data = tensorflow_datasets.load(dataset_name)
        train_x, train_y = downconvert_tf_dataset(data["train"], tokenizer, MAX_SEQ_LEN)
        val_x, val_y = downconvert_tf_dataset(data["validation"], tokenizer, MAX_SEQ_LEN)
        dataset_cache.update({dataset_name: {"tx": train_x, "ty": train_y, "vx": val_x, "vy": val_y }})
        print("Dataset %s train_sz=%i val_sz=%i" % \
              (dataset_name, train_y.shape[0], val_y.shape[0]))
    
    # Create the head.
    tensor = base_model_cls_head
    for layer_units in dense_config[0:-1]:
        tensor = Dense(units=layer_units, activation="relu", name="%s_%i" % (dataset_name, layer_units))(tensor)
    tensor = Dense(units=dense_config[-1], activation="softmax", name="final_%s" % (dataset_name))(tensor)
    
    return train_x, train_y, val_x, val_y, tensor

In [5]:
def fine_tune_task(dataset_id, optimizer, batch_sz=32, epochs=4):
    # Re-load base model weights.
    bert_base_model = TFBertModel.from_pretrained("bert-base-cased")

    inputs = [Input(shape=(128,), dtype='int32', name='input_ids'),
              Input(shape=(128,), dtype='int32', name='attention_mask'), 
              Input(shape=(128,), dtype='int32', name='token_type_ids')]

    # Fetch the CLS head of the BERT model; index 1.
    cls_head = bert_base_model(inputs)[1]

    # Fetch and format dataset and classification head.
    train_x, train_y, val_x, val_y, tensor = \
        create_new_classification_head(dataset_id, cls_head, dense_config=[2])
    model = keras.Model(inputs=inputs, outputs=tensor)
    print(model.summary())

    # Configure loss function and metrics.
    if fp16:
        tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

    model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
    
    # Train model.
    history = model.fit(train_x, train_y, batch_size=batch_sz, epochs=epochs, \
                                          validation_data=(val_x, val_y),\
                                          callbacks=[WandbCallback(log_batch_frequency=int(1024/batch_sz))])
    return model, history

skipper=True
for bsi in [16,32]:
    for lri in [1e-5, 5e-6]:
        if skipper:
            skipper=False
            continue
        dataset = "glue/sst2"
        name = "sst2-%i-%f" % (bsi, lri)
        wandb.init(project="nonint-transformers",\
                   name=name,\
                   config={"dataset": "glue/sst2", "learning_rate": lri, "epsilon": 1e-08, "batch_sz": bsi})
        optimizer = tf.keras.optimizers.Adam(learning_rate=lri, epsilon=1e-08)
        # Todo - configure optimizer in mixed precision mode.
        model, history = fine_tune_task("glue/sst2", optimizer, bsi)

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset glue (C:\Users\jbetk\tensorflow_datasets\glue\sst2\0.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from C:\Users\jbetk\tensorflow_datasets\glue\sst2\0.0.2


Dataset glue/sst2 train_sz=67349 val_sz=872
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 128)]        0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
tf_bert_model (TFBertModel)     ((None, 128, 768), ( 108310272   input_ids[0][0]                  
                                                  

wandb: ERROR Can't save model, h5py returned error: 


Epoch 2/4
Epoch 3/4
Epoch 4/4


Restored dataset from cache.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 128)]        0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
tf_bert_model_1 (TFBertModel)   ((None, 128, 768), ( 108310272   input_ids[0][0]                  
                                                               

wandb: ERROR Can't save model, h5py returned error: 


Epoch 2/4
Epoch 3/4
Epoch 4/4


Restored dataset from cache.
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_ids (InputLayer)          [(None, 128)]        0                                            
__________________________________________________________________________________________________
attention_mask (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
token_type_ids (InputLayer)     [(None, 128)]        0                                            
__________________________________________________________________________________________________
tf_bert_model_2 (TFBertModel)   ((None, 128, 768), ( 108310272   input_ids[0][0]                  
                                                               

wandb: ERROR Can't save model, h5py returned error: 


Epoch 2/4
Epoch 3/4
Epoch 4/4


In [None]:
#phrase = "I was disappointed to see the credits roll, the film really had me."
phrase = "A human there was she walked it"

def pad_zero(inputs, seq_len):
    for k in inputs: 
        output = np.zeros(seq_len+1, dtype='int32')
        output[:len(inputs[k])] = np.asarray(inputs[k])
        inputs[k] = output
    return inputs
 
phrase_encoded = pad_zero(tokenizer.encode_plus(phrase, add_special_tokens=True, max_length=128), 128)

phrase_encoded_formatted = \
    [np.resize(phrase_encoded['input_ids'], (1,-1)),
    np.resize(phrase_encoded['token_type_ids'], (1,-1)),
    np.resize(phrase_encoded['attention_mask'], (1,-1)),
    np.asarray([[1,0]], dtype='float32')]
print(sst_bert_model.predict(phrase_encoded_formatted))