In [2]:
import tensorflow as tf
import tensorflow_text as tf_text
import tqdm, tempfile, glob, os    
from tf_transformers.data import TFWriter, TFReader, TFProcessor
from transformers import AlbertTokenizer

In [None]:
from datasets import load_dataset
dataset = load_dataset("mc4", "en")


In [24]:
def get_tf_text_tokenizer():
    temp_dir = tempfile.mkdtemp()
    tokenizer_hf = AlbertTokenizer.from_pretrained("albert-base-v2")
    tokenizer_hf.save_pretrained(temp_dir)
    
    dtype = tf.int32
    nbest_size = 0
    alpha = 1.0

    def _create_tokenizer(model_serialized_proto, dtype, nbest_size, alpha):
        return tf_text.SentencepieceTokenizer(
            model=model_serialized_proto,
            out_type=dtype,
            nbest_size=nbest_size,
            alpha=alpha)

    model_file_path = '{}/spiece.model'.format(temp_dir)
    model_serialized_proto = tf.io.gfile.GFile(model_file_path,
                                                           "rb").read()

    tokenizer_sp = _create_tokenizer(model_serialized_proto, 
                                 dtype,
                                 nbest_size,
                                 alpha)
    
    return tokenizer_sp



def batch(iterable, n):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]
        
        
def text_normalize(line):
    """Exclude empty string"""
    line = tf.strings.strip(line)
    return tf.not_equal(tf.strings.length(line),0)

In [None]:
INPUT_DIR = "/home/sidhu/Datasets/C4_newsalike_text"
BATCH_SIZE = 50
DATA_BATCH_SIZE = 1024
tokenizer_sp = get_tf_text_tokenizer()


all_files = glob.glob("{}/*.txt".format(INPUT_DIR))
schema = {
    "input_ids": ("var_len", "int"),
}

tfrecord_train_dir = '/home/sidhu/Datasets/TFRECORD_C4'
tfrecord_filename = 'c4'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    n_files=100,
                    overwrite=True
                    )

for (batch_files) in tqdm.tqdm(batch(all_files, BATCH_SIZE)):
    
    dataset = tf.data.TextLineDataset(batch_files)
    dataset = dataset.filter(text_normalize)
    dataset = dataset.batch(DATA_BATCH_SIZE, drop_remainder=False)

    def parse_train():
        for batch_input in tqdm.tqdm(dataset):
            batch_tokenized = tokenizer_sp.tokenize(batch_input).merge_dims(-1,1).to_list()
            for input_ids in batch_tokenized:

                yield {"input_ids": input_ids}
    # Process
    tfwriter.process(parse_fn=parse_train())
    
    # After that delete files to save memory
    
    #for _file in batch_files:
    #    os.remove(_file)
        
# 12462160

In [None]:
12462160

45504648

74004228

In [None]:
tfrecord_train_dir = '/home/sidhu/Datasets/TFRECORD_WIKI/'


schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['input_ids']
# padded_shapes = {'input_ids': [MAX_LEN], 
#                  'labels': [MAX_LEN], 
#                  'labels_mask': [MAX_LEN]}
train_dataset = tf_reader.read_record(
                                  )

all_seq_len = []
for item in tqdm.tqdimport m(train_dataset):
    all_seq_len.append(item['input_ids'].shape[0])
    
from collections import Counter
counter = Counter(all_seq_len)

In [28]:
tokenizer_hf.mask_token

Using mask_token, but it is not set yet.


In [None]:
# Read Wikipedia and Write to text as before

In [None]:
dataset = tf.data.TextLineDataset(all_files)
dataset = dataset.filter(text_normalize)
dataset = dataset.batch(DATA_BATCH_SIZE, drop_remainder=False)

def parse_train():
    for batch_input in tqdm.tqdm(dataset):
        batch_tokenized = tokenizer_sp.tokenize(batch_input).merge_dims(-1,1).to_list()
        for input_ids in batch_tokenized:

            yield {"input_ids": input_ids}
# Process
tfwriter.process(parse_fn=parse_train())


In [29]:
text = '''CC-News dataset contains news articles from news sites all over the world. The data is available on AWS S3 in the Common Crawl bucket at /crawl-data/CC-NEWS/. This version of the dataset has been prepared using news-please - an integrated web crawler and information extractor for news.
It contains 708241 English language news articles published between Jan 2017 and December 2019. It represents a small portion of the English language subset of the CC-News dataset.

'''

In [3]:
text_breaker = tf_text.StateBasedSentenceBreaker()

In [None]:
text_breaker.break_sentences(text)