In [None]:
import tensorflow as tf
#import tensorflow_text as tf_text
import tqdm
from tf_transformers.data import TFWriter, TFReader, TFProcessor
from tf_transformers.core import  TPUTrainer
from tf_transformers.optimization import create_optimizer
from tf_transformers.losses import cross_entropy_loss

In [None]:
# Create TFRecord or read from gs bucket

# Read TFrecords

import json
import glob

tfrecord_train_dir = '/home/sidhu/Datasets/tfrecord_cnn_train'
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['input_ids']
y_keys = ['labels', 'labels_mask']
MAX_LEN = 128


batch_size = 256
padded_shapes = {'input_ids': [MAX_LEN], 
                 'labels': [MAX_LEN], 
                 'labels_mask': [MAX_LEN]}
train_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=True,
                                   drop_remainder=True,
                                   padded_shapes=padded_shapes
                                  )


# When mask_mode == 'user-defined' use this map

def map_add_input_mask(x, y):
    
    x['input_mask'] = tf.ones_like(x['input_ids'])
    return x, y

train_dataset = train_dataset.map(map_add_input_mask)


def map_add_input_mask_bert(x, y):
    
    x['input_mask'] = tf.ones_like(x['input_ids'])
    x['input_type_ids'] = tf.zeros_like(x['input_ids'])
    return x, y

train_dataset = train_dataset.map(map_add_input_mask_bert)

In [None]:

def get_model():

    from tf_transformers.models import GPT2Model
    # model = GPT2Model.from_pretrained("gpt2", save_checkpoint_cache=False)
    config = {
      "attention_probs_dropout_prob": 0.1,
      "hidden_act": "gelu",
      "intermediate_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "embedding_size": 768,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "max_position_embeddings": 1024,
      "num_attention_heads": 12,
      "attention_head_size": 64,
      "num_hidden_layers": 12,
      "type_vocab_size": -1,
      "vocab_size": 50257,
      "layer_norm_epsilon": 1e-05
    }
    model = GPT2Model.from_config(config, mask_mode="user_defined")
    print("Model inputs --->", model.input)
    return model

def get_model():
    from tf_transformers.models import BertModel
    # Change vocab size to 50257 as in GPT2
    config = {
        "attention_probs_dropout_prob": 0.1,
        "hidden_act": "gelu",
        "intermediate_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "embedding_size": 768,
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "max_position_embeddings": 512,
        "num_attention_heads": 12,
        "attention_head_size": 64,
        "num_hidden_layers": 12,
        "type_vocab_size": 2,
        "vocab_size": 50257,
        "layer_norm_epsilon": 1e-12
        }
    model = BertModel.from_config(config)
    print("Model inputs --->", model.input)
    print("Model variables --->", len(model.variables))
    return model

def get_model():
    from tf_transformers.models import BertModel
    # Change vocab size to 50257 as in GPT2
    config = {
        "attention_probs_dropout_prob": 0.1,
        "hidden_act": "gelu",
        "intermediate_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "embedding_size": 768,
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "max_position_embeddings": 512,
        "num_attention_heads": 12,
        "attention_head_size": 64,
        "num_hidden_layers": 12,
        "type_vocab_size": 2,
        "vocab_size": 50257,
        "layer_norm_epsilon": 1e-12
        }
    model = BertModel.from_config(config, return_all_layer_outputs=True)
    print("Model inputs --->", model.input)
    print("Model variables --->", len(model.variables))
    return model


def get_optimizer():

    init_lr = 2e-05
    optimizer, learning_rate_fn = create_optimizer(init_lr=init_lr, 
                                                 num_train_steps=100000,
                                                 num_warmup_steps=10000)
    return optimizer

def lm_loss(y_true_dict, y_pred_dict):
    loss = cross_entropy_loss(labels=y_true_dict['labels'], 
                            logits=y_pred_dict['token_logits'], 
                            label_weights=y_true_dict['labels_mask'])
    return {"loss": loss}

def lm_loss(y_true_dict, y_pred_dict):
    
#     token_logits = y_pred_dict['all_layer_token_logits'][-1]
#     loss = cross_entropy_loss(labels=y_true_dict['labels'], 
#                             logits=token_logits, 
#                             label_weights=y_true_dict['labels_mask'])
#     return {"loss": loss}
    
    loss_dict = {}
    loss_holder = []
    for layer_count, per_layer_output in enumerate(y_pred_dict['all_layer_token_logits']):
        
        loss = cross_entropy_loss(labels=y_true_dict['labels'], 
                                logits=per_layer_output, 
                                label_weights=y_true_dict['labels_mask'])
        loss_dict['loss_{}'.format(layer_count+1)] = loss
        loss_holder.append(loss)
#         if layer_count == 0:
#             l = loss
#         else:
#             l += loss
    loss_dict['loss'] = tf.reduce_mean(loss_holder, axis=0)
    #loss_dict['loss'] = l/(layer_count+1)
    return loss_dict


In [None]:
tpu_address = 'local'
trainer =  TPUTrainer(
    tpu_address=tpu_address,
    dtype='bf16'
)

GLOBAL_BATCH_SIZE = batch_size

training_loss_names = ['loss_1',
 'loss_2',
 'loss_3',
 'loss_4',
 'loss_5',
 'loss_6',
 'loss_7',
 'loss_8',
 'loss_9',
 'loss_10',
 'loss_11',
 'loss_12']

#training_loss_names = None
trainer.run(
    model_fn=get_model,
    optimizer_fn=get_optimizer,
    train_dataset=train_dataset,
    train_loss_fn=lm_loss,
    epochs=2,
    steps_per_epoch=300,
    model_checkpoint_dir="temp_dir", # gs://tft_free/
    batch_size=GLOBAL_BATCH_SIZE,
    training_loss_names=training_loss_names,
    validation_loss_names=None,
    validation_dataset=None,
    validation_loss_fn=None,
    validation_interval_steps=None,
    steps_per_call=100,
    enable_xla=False,
    callbacks=None,
    callbacks_interval_steps=None,
    overwrite_checkpoint_dir=True,
    max_number_of_models=10,
    model_save_interval_steps=None,
    repeat_dataset=True
)


#!gsutil -cp -r temp_dir/logs gs:tft_free/logs_sample

In [None]:
# Read dataset and check timings
tfrecord_train_dir = '/home/sidhu/Datasets/PRETRAIN_DATA/TFRECORD'
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))


index = 1
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))

#all_files = [all_files[index]]

tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['input_ids']

MAX_LEN = 128
batch_size = 128
# padded_shapes = {'input_ids': [MAX_LEN], 
#                  'labels': [MAX_LEN], 
#                  'labels_mask': [MAX_LEN]}
train_dataset = tf_reader.read_record(auto_batch=False, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   shuffle=True,
                                   drop_remainder=True
                                  )



def filter_by_length(x):
    return tf.squeeze(tf.greater_equal(tf.shape(x['input_ids']) ,tf.constant(_MIN_SEN_LEN)), axis=0)

def filter_by_batch(x, y):
    x_batch = tf.shape(x['input_ids'])[0]
    return tf.equal(x_batch, tf.constant(batch_size))

train_dataset = train_dataset.filter(filter_by_length)
train_dataset = train_dataset.apply(
    tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size))
train_dataset = train_dataset.map(map_mlm)

train_dataset = train_dataset.filter(filter_by_batch)
all_batches = []
for (batch_inputs, batch_labels) in tqdm.tqdm(train_dataset):
    all_batches.append(batch_inputs['input_ids'].shape[0])

In [None]:
train_dataset_distributed = iter(trainer.distribution_strategy.experimental_distribute_dataset(
                train_dataset.repeat(1)
            ))

dist_inputs = next(train_dataset_distributed)
batch_inputs = trainer.distribution_strategy.experimental_local_results(dist_inputs[0])
self.input_ids_cache