In [None]:
# Collect dataset


def hf_dump_chars_to_textfile(file, dataset, data_keys, max_char=-1):
    """Write part of a TFDS sentence dataset to lines in a text file.

  Args:
    dataset: tf.dataset containing string-data.
    data_keys: what keys in dataset to dump from.
    max_char: max character to dump to text file.

  Returns:
    name of temp file with dataset bytes, exact number of characters dumped.
  """
    line_count = 0
    with open(file, "a+") as outfp:
        char_count = 0
        for example in tqdm.tqdm(dataset):
            for k in data_keys:
                line = example[k]
                if len(line) < 10 or line == "\n":  # 50 chars
                    continue
                # line = line + b"\n"
                char_count += len(line)
                line_count += 1
                outfp.write(line)

    print("Total lines {}, chars {}".format(line_count, char_count))


import tqdm
from datasets import load_dataset

dataset = load_dataset("wikipedia", "20200501.en")
hf_dump_chars_to_textfile("/home/sidhu/Datasets/wikipedia.txt", 
                          dataset["train"].select(range(1000)),
                          
                          ("text",))

In [None]:
212763665

3187609 # After filtering length of 5

In [None]:
### Train Sentencepiece tokenizer (Albert)

import sentencepiece as spm
spm.SentencePieceTrainer.train(input='DATA/wikipedia.txt',
                               model_prefix='bert-joint',
                               vocab_size=32000,
                               pad_id=0,
                               unk_id=1,
                               bos_id=-1,
                               user_defined_symbols=['(', ')', '"', '-', '.', '–', '£', '€'],
                               control_symbols=['[CLS]','[SEP]','[MASK]'],
                               shuffle_input_sentence=True,
                               input_sentence_size=10000000,
                               character_coverage=0.99995,
                               model_type='unigram')

In [None]:
# Clone tft

git clone -b modification https://github.com/legacyai/tf-transformers.git

In [None]:
import tensorflow as tf
import tensorflow_text as tf_text
import tqdm
from tf_transformers.data import TFWriter, TFReader, TFProcessor


In [None]:
# Read tokenizer
model_file_path = '/home/sidhu/Datasets/PRETRAIN_DATA/vocab/bert-joint.model'
dtype = tf.int32
nbest_size = 0
alpha = 1.0

def _create_tokenizer(model_serialized_proto, dtype, nbest_size, alpha):
    return tf_text.SentencepieceTokenizer(
        model=model_serialized_proto,
        out_type=dtype,
        nbest_size=nbest_size,
        alpha=alpha)

model_serialized_proto = tf.io.gfile.GFile(model_file_path,
                                                       "rb").read()

tokenizer = _create_tokenizer(model_serialized_proto, 
                             dtype,
                             nbest_size,
                             alpha)

In [None]:
# Read wikipedia data

file_paths = ['DATA/wikipedia.txt']

#file_buffer = open(file_paths[0])

# file_paths = ['/tmp/tmpip9jwekj']
dataset = tf.data.TextLineDataset(file_paths)
BATCH_SIZE = 1024
dataset = dataset.batch(BATCH_SIZE, drop_remainder=False)


def parse_train():
    for batch_input in tqdm.tqdm(dataset):
        batch_tokenized = tokenizer.tokenize(batch_input).merge_dims(-1,1).to_list()
        for input_ids in batch_tokenized:
            
            yield {"input_ids": input_ids}
        
            
# Lets write using TF Writer
# Use TFProcessor for smalled data

schema = {
    "input_ids": ("var_len", "int"),
}

tfrecord_train_dir = 'TFRECORD'
tfrecord_filename = 'wikipedia'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    n_files=100,
                    overwrite=True
                    )
tfwriter.process(parse_fn=parse_train())

In [None]:
# HF way
# from transformers import AlbertTokenizer

# tokenizer_hf = AlbertTokenizer(vocab_file='vocab/bert-joint.model')

# def parse_train_hf():
#     for batch_input in tqdm.tqdm(dataset):
#         batch_input = [item.decode() for item in batch_input.numpy()]
#         batch_tokenized = tokenizer_hf(batch_input)['input_ids']
#         for input_ids in batch_tokenized:
            
#             yield {"input_ids": input_ids}
        
            
# schema = {
#     "input_ids": ("var_len", "int"),
# }

# tfrecord_train_dir = 'DUMMY'
# tfrecord_filename = 'wikipedia'
# tfwriter = TFWriter(schema=schema, 
#                     file_name=tfrecord_filename, 
#                     model_dir=tfrecord_train_dir,
#                     tag='train',
#                     n_files=100,
#                     overwrite=True
#                     )
# tfwriter.process(parse_fn=parse_train())

In [None]:
_MAX_SEQ_LEN = 128
_MAX_PREDICTIONS_PER_BATCH = 20
_VOCAB_SIZE = 32000
_MIN_SEN_LEN = 5

_START_TOKEN = tokenizer.string_to_id('[CLS]')
_END_TOKEN = tokenizer.string_to_id('[SEP]')
_MASK_TOKEN = tokenizer.string_to_id('[MASK]')
#_RANDOM_TOKEN = _VOCAB.index(b"[RANDOM]")
_UNK_TOKEN = tokenizer.string_to_id('<unk>')
_PAD_TOKEN = tokenizer.string_to_id('<pad>')


_START_TOKEN = 3
_END_TOKEN = 4
_MASK_TOKEN = 5
#_RANDOM_TOKEN = _VOCAB.index(b"[RANDOM]")
_UNK_TOKEN = 1
_PAD_TOKEN = 0

# Truncate inputs to a maximum length.
trimmer = tf_text.RoundRobinTrimmer(max_seq_length=_MAX_SEQ_LEN)

# Random Selector
random_selector = tf_text.RandomItemSelector(
    max_selections_per_batch=_MAX_PREDICTIONS_PER_BATCH,
    selection_rate=0.2,
    unselectable_ids=[_START_TOKEN, _END_TOKEN, _UNK_TOKEN, _PAD_TOKEN]
)

# Mask Value chooser (Encapsulates the BERT MLM token selection logic)
mask_values_chooser = tf_text.MaskValuesChooser(_VOCAB_SIZE, _MASK_TOKEN, 0.8)


In [None]:
def map_mlm(item):
    # Tokenizer (always return Ragged tensor I think)
    #segments = tokenizer.tokenize(item).merge_dims(1, -1)
    # Trim based on maximum Sequence Length (list is important)
    
    segments = item['input_ids']
    trimmed_segments = trimmer.trim([segments])
    
    # We replace trimmer with slice [:_MAX_SEQ_LEN-2] operation # 2 to add CLS and SEP
    # input_ids = item['input_ids'][:_MAX_SEQ_LEN-2]
    
    # Combine segments, get segment ids and add special tokens.
    segments_combined, segment_ids = tf_text.combine_segments(
          trimmed_segments,
          start_of_sequence_id=_START_TOKEN,
          end_of_segment_id=_END_TOKEN)
    
    # We replace segment with concat
    # input_ids = tf.concat([[_START_TOKEN], input_ids, [_END_TOKEN]], axis=0)

    # Apply dynamic masking
    masked_token_ids, masked_pos, masked_lm_ids = tf_text.mask_language_model(
      segments_combined,
      item_selector=random_selector,
      mask_values_chooser=mask_values_chooser)

    # Prepare and pad combined segment inputs
    input_word_ids, input_mask = tf_text.pad_model_inputs(
        masked_token_ids, max_seq_length=_MAX_SEQ_LEN)
    input_type_ids, _ = tf_text.pad_model_inputs(
        segment_ids, max_seq_length=_MAX_SEQ_LEN)

    # Prepare and pad masking task inputs
    # Masked lm weights will mask the weights
    masked_lm_positions, masked_lm_weights = tf_text.pad_model_inputs(
      masked_pos, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
    masked_lm_ids, _ = tf_text.pad_model_inputs(
      masked_lm_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
    

#     model_inputs = {
#       "input_ids": input_word_ids,
#       "input_mask": input_mask,
#       "input_type_ids": input_type_ids,
#       "masked_lm_ids": masked_lm_ids,
#       "masked_lm_positions": masked_lm_positions,
#       "masked_lm_weights": masked_lm_weights,
#     }
    
    inputs = {}
    inputs['input_ids'] = input_word_ids
    inputs['input_type_ids'] = input_type_ids
    inputs['input_mask'] = input_mask
    inputs['masked_lm_positions'] = masked_lm_positions
    
    labels = {}
    labels['masked_lm_labels'] = masked_lm_ids
    labels['masked_lm_weights']   = masked_lm_weights # Mask
    
    return (inputs, labels)

In [None]:
# Read dataset and check timings

import json
import glob

tfrecord_train_dir = '/home/sidhu/Datasets/PRETRAIN_DATA/TFRECORD'
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['input_ids']

MAX_LEN = 128
batch_size = 1024
# padded_shapes = {'input_ids': [MAX_LEN], 
#                  'labels': [MAX_LEN], 
#                  'labels_mask': [MAX_LEN]}
train_dataset = tf_reader.read_record(auto_batch=False, 
                                   keys=x_keys,
                                   batch_size=batch_size, 
                                   x_keys = x_keys, 
                                   shuffle=True,
                                   drop_remainder=True
                                  )



def filter_by_length(x):
    return tf.squeeze(tf.greater_equal(tf.shape(x['input_ids']) ,tf.constant(_MIN_SEN_LEN)), axis=0)
train_dataset = train_dataset.filter(filter_by_length)
train_dataset = train_dataset.apply(
    tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size))
train_dataset = train_dataset.map(map_mlm)

for (batch_inputs, batch_labels) in tqdm.tqdm(train_dataset):
    print(batch_inputs, batch_labels)
    break

In [None]:
# Now Check with HF tokenizer

from transformers import AlbertTokenizer
tokenizer_hf = AlbertTokenizer(vocab_file='/home/sidhu/Datasets/PRETRAIN_DATA/vocab/bert-joint.model')

file_paths = ['/home/sidhu/Datasets/wikipedia.txt']
dataset = tf.data.TextLineDataset(file_paths)


def filter_empty_string(line):
    return tf.not_equal(tf.strings.length(line),0)

def normalize_string(line):
    return tf.strings.strip(line)

def hf_tokenize(item):
    item = item.numpy().decode()
    input_ids = tokenizer_hf(item, add_special_tokens=False)['input_ids']
    return [input_ids]

def map_mlm_hf(item):
    input_ids = tf.py_function(hf_tokenize, [item],
                   tf.int32)
    # MAX_SEQ_LEN - 2 # 2 for CLS and SEP
    input_ids = input_ids[:_MAX_SEQ_LEN-2]
    input_ids = tf.concat([[_START_TOKEN], input_ids, [_END_TOKEN]], axis=0)
    segments_combined = tf.RaggedTensor.from_tensor([input_ids])
    segment_ids = tf.zeros_like(segments_combined)
    # Apply dynamic masking
    masked_token_ids, masked_pos, masked_lm_ids = tf_text.mask_language_model(
      segments_combined,
      item_selector=random_selector,
      mask_values_chooser=mask_values_chooser)
    
    masked_token_ids = tf.squeeze(masked_token_ids.to_tensor(), axis=0)
    masked_pos = tf.squeeze(masked_pos.to_tensor(), axis=0)
    masked_lm_ids = tf.squeeze(masked_lm_ids.to_tensor(), axis=0)
    # Prepare and pad combined segment inputs
    # input_word_ids, input_mask = tf_text.pad_model_inputs(
        # masked_token_ids, max_seq_length=_MAX_SEQ_LEN)
    # input_type_ids, _ = tf_text.pad_model_inputs(
        # segment_ids, max_seq_length=_MAX_SEQ_LEN)

    # Prepare and pad masking task inputs
    # Masked lm weights will mask the weights
#     masked_lm_positions, masked_lm_weights = tf_text.pad_model_inputs(
#       masked_pos, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
#     masked_lm_ids, _ = tf_text.pad_model_inputs(
#       masked_lm_ids, max_seq_length=_MAX_PREDICTIONS_PER_BATCH)
    

    model_inputs = {
      "input_ids": masked_token_ids,
      "input_mask": tf.ones_like(masked_token_ids),
      "input_type_ids": tf.zeros_like(masked_token_ids),
      "masked_lm_labels": masked_lm_ids,
      "masked_lm_positions": masked_pos,
      "masked_lm_weights": tf.ones_like(masked_pos),
    }
    
    return model_inputs

dataset = dataset.map(normalize_string)
dataset = dataset.filter(filter_empty_string)

train_dataset = dataset.map(map_mlm_hf)    
batch_size = 1024
padded_shapes = {
      "input_ids": [_MAX_SEQ_LEN],
      "input_mask": [_MAX_SEQ_LEN],
      "input_type_ids": [_MAX_SEQ_LEN],
      "masked_lm_labels": [_MAX_PREDICTIONS_PER_BATCH],
      "masked_lm_positions": [_MAX_PREDICTIONS_PER_BATCH],
      "masked_lm_weights": [_MAX_PREDICTIONS_PER_BATCH],
    }
train_dataset = auto_batch(
    train_dataset,
    batch_size,
    padded_values=None,
    padded_shapes=padded_shapes,
    x_keys=['input_ids', 'input_type_ids', 'input_mask', 'masked_lm_positions'],
    y_keys=['masked_lm_labels', 'masked_lm_weights'],
    shuffle=False,
    drop_remainder=False,
    shuffle_buffer_size=100,
    prefetch_buffer_size=100,
)

for (batch_inputs, batch_labels) in train_dataset:
    print(batch_inputs, batch_labels)
    break

In [None]:
#### Build a dummy dataset and see whether it works

input_ids = tf.random.uniform(minval=0, maxval=32000, shape=(4000, 128), dtype=tf.int32)
masked_lm_positions = tf.random.uniform(minval=0, maxval=128, shape=(4000, 20), dtype=tf.int32)
masked_lm_labels = tf.random.uniform(minval=0, maxval=32000, shape=(4000, 20), dtype=tf.int32)
masked_lm_weights = tf.random.uniform(minval=0, maxval=2, shape=(4000, 20), dtype=tf.int32)
dummy_batch_inputs = {"input_ids": input_ids, 
                      "input_mask": tf.ones_like(input_ids), 
                      "input_type_ids": tf.zeros_like(input_ids), 
                      "masked_lm_positions": masked_lm_positions}

dummy_batch_labels = {"masked_lm_labels": masked_lm_labels, 
                     "masked_lm_weights": masked_lm_weights}


batch_size = 1024
train_dataset = tf.data.Dataset.from_tensor_slices((dummy_batch_inputs, dummy_batch_labels)).batch(batch_size)

In [None]:
#### Modeling

from tf_transformers.losses import cross_entropy_loss
from tf_transformers.optimization import create_optimizer
from tf_transformers.core import Trainer, TPUTrainer

def get_model():
    config = {
        "attention_probs_dropout_prob": 0.1,
        "hidden_act": "gelu",
        "intermediate_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "embedding_size": 768,
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "max_position_embeddings": 512,
        "num_attention_heads": 12,
        "attention_head_size": 64,
        "num_hidden_layers": 12,
        "type_vocab_size": 2,
        "vocab_size": 32000,
        "layer_norm_epsilon": 1e-12,
        "mask_mode": "user_defined",
    }
    
    from tf_transformers.models import BertModel
    model = BertModel.from_config(config,
                                  batch_size=batch_size,
                                 use_masked_lm_positions=True, # Add batch_size to avoid dynamic shapes
                                return_all_layer_outputs=True) 
    return model
def get_optimizer():
    init_lr = 2e-05
    optimizer, learning_rate_fn = create_optimizer(init_lr=init_lr, 
                                                 num_train_steps=100000,
                                                 num_warmup_steps=10000)
    return optimizer

def lm_loss(y_true_dict, y_pred_dict):
    
    loss_dict = {}
    loss_holder = []
    for layer_count, per_layer_output in enumerate(y_pred_dict['all_layer_token_logits']):
        
        loss = cross_entropy_loss(labels=y_true_dict['masked_lm_labels'], 
                                logits=per_layer_output, 
                                label_weights=y_true_dict['masked_lm_weights'])
        loss_dict['loss_{}'.format(layer_count+1)] = loss
        loss_holder.append(loss)
    loss_dict['loss'] = tf.reduce_mean(loss_holder, axis=0)
    return loss_dict




tpu_address = 'local'
trainer =  TPUTrainer(
    tpu_address=tpu_address,
    dtype='bf16'
)

GLOBAL_BATCH_SIZE = batch_size

training_loss_names = ['loss_1',
 'loss_2',
 'loss_3',
 'loss_4',
 'loss_5',
 'loss_6',
 'loss_7',
 'loss_8',
 'loss_9',
 'loss_10',
 'loss_11',
 'loss_12']

trainer.run(
    model_fn=get_model,
    optimizer_fn=get_optimizer,
    train_dataset=train_dataset,
    train_loss_fn=lm_loss,
    epochs=2,
    steps_per_epoch=1000,
    model_checkpoint_dir="temp_dir", # gs://tft_free/
    batch_size=GLOBAL_BATCH_SIZE,
    training_loss_names=training_loss_names,
    validation_loss_names=None,
    validation_dataset=None,
    validation_loss_fn=None,
    validation_interval_steps=None,
    steps_per_call=100,
    enable_xla=False,
    callbacks=None,
    callbacks_interval_steps=None,
    overwrite_checkpoint_dir=True,
    max_number_of_models=10,
    model_save_interval_steps=None,
    repeat_dataset=True
)