<a href="https://colab.research.google.com/github/legacyai/legacyai_notebooks/blob/master/tfrecord_utils_example_with_nlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
# coding=utf-8
# Copyright 2020 The legacyai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import six
import tensorflow as tf
import json
import sys
import collections
from absl import logging
logging.set_verbosity("INFO")

# The following functions can be used to convert a value to a type compatible
# with tf.Example.


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        # BytesList won't unpack a string from an EagerTensor.
        value = value.numpy()
    if isinstance(value, list):
        value = [six.ensure_binary(token) for token in value]
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
        # value = str([six.ensure_text(token, "utf-8") for token in value]).encode()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    if isinstance(value, list):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int_feature(values):
    if isinstance(values, int):
        values = [values]
    feature = tf.train.Feature(
        int64_list=tf.train.Int64List(value=list(values)))
    return feature


TF_SCHEMA = {"var_len": tf.io.VarLenFeature,
             "fixed_len": tf.io.FixedLenFeature}

TF_VALUE  = {"bytes": tf.string, "int": tf.int64, "float": tf.float32}

TF_FUNC   = {'bytes': _bytes_feature,
           'int': _int_feature, 'float': _float_feature}


class TFWriter(object):
    '''TFWriter class . This class is responsible to write tfrecords, based on given schema and data.
    '''

    def __init__(self,
                 schema,
                 file_name,
                 model_dir = None,
                 tag='dev',
                 n_files=10,
                 overwrite=False, 
                 verbose_counter = 1000):
        '''
        Args:
            schema: dict - (this is where schema of the tfrecords specified)
            file_name: str - file name
            model_dir: str - TFRecords will write to this model dir . If not given, use the default directory
            tag: str - 'train' or 'dev'
            n_files: int - If `tag` == 'train': file will be split into `n_fles` for randomness
            overwrite: bool - If True, we will overwrite tfrecords of the same name

        Raises:
            Error if the model_dir / the file exists . You can pass overwrite = True to disable this behaviour

        '''
        # Schema Check
        self.is_schema_valid(schema)

        if tag not in ['train', 'dev']:
            logging.info("Unknown tag {} found".format(tag))
            raise Exception("Unknwon Tag")

        def is_check(all_files):
            for file_ in all_files:
                if os.path.exists(file_):
                    logging.info(
                        "File exists, overwrite is not recommended. If you want to overwrite, pass `overwrite`=True")
                    raise FileExistsError(file_)
        
        # we need this file to write the schema to the model_dir
        schema_file_name = 'schema.json'
        if model_dir:
            if overwrite == False:
                if os.path.exists(model_dir):
                    logging.info("Model directory {} exists".format(model_dir))
                    raise FileExistsError(model_dir)
            os.makedirs(model_dir, exist_ok=True)
            file_name = os.path.join(model_dir, file_name)
            schema_file_name = os.path.join(model_dir, schema_file_name)

        if tag == 'train':

            self.all_files = ['{}_{}_{}_{}.tfrecord'.format(
                file_name.replace('tfrecord', ''), tag,i, n_files) for i in range(n_files)]
            if overwrite == False:
                is_check(self.all_files)
            self.all_writer = [tf.io.TFRecordWriter(
                file_) for file_ in self.all_files]
        else:
            n_files = 1
            self.all_files = ['{}_{}_{}_{}'.format(
                file_name, i, tag, n_files) for i in range(n_files)]
            if overwrite == False:
                is_check(self.all_files)
            self.all_writer = [tf.io.TFRecordWriter(
                file_) for file_ in self.all_files]

        self.schema = schema
        self.schema_writer_fn = self.generate_schema_from_dict(
            schema)

        self.verbose_counter = verbose_counter
        self.global_counter = 0

        with open(schema_file_name, "w") as f:
            json.dump(schema, f, indent=2)

    def is_schema_valid(self,schema):
        '''
        simple schema validation check
        '''
        for k, v in schema.items():
            if v[0] == 'var_len':
                assert(len(v) == 2)
                assert(v[1] in TF_VALUE)

            if v[0] == 'fixed_len':
                assert(len(v) == 3)
                assert(v[1] in TF_VALUE)
                assert(isinstance(v[2], list))

    def close_sess(self):
        for file_writer in self.all_writer:
            file_writer.close()
        logging.info("All writer objects closed")

    def generate_schema_from_dict(self, schema_dict):
        '''
        schema_dict: a dict
        '''
        allowed_schema_types  = ["var_len", "fixed_len"]
        allowed_schema_values = ["bytes", "int", "float"]

        def check_schema(schema_dict):
            for _, value in schema_dict.items():
                schema_key   = value[0]
                schema_value = value[1]
                if schema_key not in allowed_schema_types:
                    error_message = "{} not in {}".format(
                        schema_key, allowed_schema_types)
                    raise ValueError(error_message)
                if schema_value not in allowed_schema_values:
                    error_message = "{} not in {}".format(
                        schema_value, allowed_schema_values)
                    raise ValueError(error_message)
        check_schema(schema_dict)

        schema_writer_dict = {}
        for key, value in schema_dict.items():
            schema_writer_dict[key] = TF_FUNC[value[1]]  # _bytes_feature
        return schema_writer_dict


    def write_record(self, input):
        """Writes a input to a TFRecord example."""
        '''
        input: dict (dict of key, elem to write to tf-record)
        '''
        features = collections.OrderedDict()
        for key, value in input.items():
            if self.schema[key][0] == 'fixed_len':
                if self.schema[key][2] != []:
                    shape = self.schema[key][2][0]
                    if len(value) != shape:
                        raise ValueError("`{}` has schema shape `{}`, but provided values `{}` has shape `{}`".format(key, 
                                                                                                          shape, value, len(value)))

            if isinstance(value, six.text_type):
                value = six.ensure_binary(value, "utf-8")
            features[key] = self.schema_writer_fn[key](value)
        example_proto = tf.train.Example(
            features=tf.train.Features(feature=features))

        the_writer = random.choice(self.all_writer)
        the_writer.write(example_proto.SerializeToString())
        self.global_counter += 1

        if self.global_counter % self.verbose_counter == 0:
            logging.info("Wrote {} tfrecods".format(self.global_counter))


class TFReader(object):
    '''
    TFReader class . This class is responsible to read tfrecords, based on given schema.

    '''

    def __init__(self, schema, tfrecord_files, keys=[]):

        if not isinstance(tfrecord_files, (list, tuple)):
            raise Exception("input must be a list or tuple of files")
        self.schema = schema
        self.tfrecord_files = tfrecord_files
        self.keys = keys
        if self.keys == []:
            self.keys = self.schema.keys()
        self.schema_reader_fn, self.schema_writer_fn = self.generate_schema_from_dict(
            schema)


    def generate_schema_from_dict(self, schema_dict):
        '''
        schema_dict: a dict
        '''
        allowed_schema_types  = ["var_len", "fixed_len"]
        allowed_schema_values = ["bytes", "int", "float"]

        def check_schema(schema_dict):
            for _, value in schema_dict.items():
                schema_key   = value[0]
                schema_value = value[1]
                if schema_key not in allowed_schema_types:
                    error_message = "{} not in {}".format(
                        schema_key, allowed_schema_types)
                    raise ValueError(error_message)
                if schema_value not in allowed_schema_values:
                    error_message = "{} not in {}".format(
                        schema_value, allowed_schema_values)
                    raise ValueError(error_message)
        check_schema(schema_dict)

        # Schema reader function is here

        schema_reader_dict = {}
        for key, value in schema_dict.items():
            if self.keys and key not in self.keys:
                continue

            if value[0] == 'var_len':
                schema_reader_dict[key] = tf.io.VarLenFeature(TF_VALUE[value[1]])
            if value[0] == 'fixed_len':
                # Fixed len should have shape mentioned in the schema
                shape = value[2]
                schema_reader_dict[key] = tf.io.FixedLenFeature(shape=shape, dtype=TF_VALUE[value[1]], default_value=None)

        schema_writer_dict = {}
        for key, value in schema_dict.items():
            schema_writer_dict[key] = TF_FUNC[value[1]]  # _bytes_feature
        return schema_reader_dict, schema_writer_dict

    def decode_record_var(self, record, keys=[]):
        """Decodes a record to a TensorFlow example."""
        feature_dict = tf.io.parse_single_example(
            record, self.schema_reader_fn)

        parse_dict = feature_dict.copy()
        for k in self.keys:
            v = feature_dict[k]
            if self.schema[k][0] == 'var_len':
                parse_dict[k] = tf.sparse.to_dense(v)             

        return parse_dict

    def read_record(self, keys=[]):
        dataset = tf.data.Dataset.list_files(self.tfrecord_files, shuffle=True)
        dataset = dataset.interleave(
            tf.data.TFRecordDataset, cycle_length=8,
            num_parallel_calls=tf.data.experimental.AUTOTUNE)

        def decode_fn(record): return self.decode_record_var(record, keys)
        dataset = dataset.map(decode_fn)
        return dataset

In [2]:
!pip install transformers
!pip install nlp

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/48/35/ad2c5b1b8f99feaaf9d7cdadaeef261f098c6e1a6a2935d4d07662a6b780/transformers-2.11.0-py3-none-any.whl (674kB)
[K     |████████████████████████████████| 675kB 9.0MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 30.7MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 48.5MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |███

In [1]:
# Load tokenizer for GPT2
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




In [6]:
# Load CNN dailymail data and prepare it for LM task

[15496, 703, 389, 345, 1220]

In [5]:
from nlp import load_dataset
cnn_dailymail = load_dataset("cnn_dailymail", "3.0.0", split="train")

Downloading and preparing dataset cnn_dailymail/3.0.0 (download: 558.32 MiB, generated: 1.26 GiB, total: 1.81 GiB) to /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=572061.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=12259516.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=660943.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0. Subsequent calls will reuse this data.


In [6]:
for item in cnn_dailymail:
  print(item)
  break

{'article': 'It\'s official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It\'s a step that is set to turn an international crisis into a fierce domestic political battle. There are key questions looming over the debate: What did U.N. weapons inspectors find in Syria? What happens if Congress votes no? And how will the Syrian government react? In a televised address from the White House Rose Garden earlier Saturday, the president said he would take his case to Congress, not because he has to -- but bec

In [27]:
# TFRecords are of 2 types 
# tf.io.VarLenFeature https://www.tensorflow.org/api_docs/python/tf/io/VarLenFeature
# tf.io.FixedLenFeature https://www.tensorflow.org/api_docs/python/tf/io/FixedLenFeature

# Reserved Keywords

# `var_len` - tf.io.VarLenFeature 
# `fixed_len` -  tf.io.FixedLenFeature
# `int` - tf.int64
# `bytes` - tf.string
# `float` - tf.float32      

# A schema consists of a key (key of json data returned by the model) and value a tuple

tf_schema = {

    "input_word_ids": ("var_len", "int"), # int here represents tf.int64
    "labels": ("var_len", "int"), 
    "dummy_value": ("fixed_len", "float", [102]) # 102 represents fixed size for that feature, because tf.io.FixedLenFeature requires fixed length features
}


model_dir = 'cnn_dailymail_tf' # model_dir
file_name = 'cnn_dailymail_tf_record' # tfrecord filenames which will be saved inside model_dir
tag = 'train' # `train` or `dev` . If `train` tfrecords will write into 10 different files with shuffling . This 10 can be changed with `n_files`
              # if `dev`, we will not shuffle the data and data will be write on to a single file
n_files = 20 # As wiki is a huge dataset (default is 10)

tf_writer = TFWriter(tf_schema,
                     model_dir=model_dir,
                     file_name=file_name,
                     tag="train", 
                     n_files= n_files,
                     overwrite=False)

logging.info("TFWriter initiated")

INFO:absl:TFWriter initiated


In [28]:
def process(text_list, max_len = 1000):
  for text in text_list:
    tokens = tokenizer.tokenize(text)[:max_len]
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    # Truncate to 1000 records
    input_ids = input_ids[:max_len]
    labels = input_ids[1:]
    input_ids = input_ids[:-1]
    dummy_value = tf.random.uniform(shape=(102,)).numpy().tolist()
    yield {"input_word_ids": input_ids, "labels": labels, "dummy_value": dummy_value}

import time
batch_size = 1000

# Stop when number of data is 10000

for i in range(0, len(cnn_dailymail), batch_size):
  batch = cnn_dailymail[i: i+batch_size]
  batch = batch['article']
  batch_generator = process(batch)
  start_time = time.time()
  for record in batch_generator:
    tf_writer.write_record(record)
  end_time = time.time()
  logging.info("CNN daily email, 1000 data written in {} seconds".format(end_time-start_time))
  if i == 10000:
    logging.info("Break")
    break

INFO:absl:Wrote 1000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.3962457180023193 seconds
INFO:absl:Wrote 2000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.355628252029419 seconds
INFO:absl:Wrote 3000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.458329439163208 seconds
INFO:absl:Wrote 4000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.4665002822875977 seconds
INFO:absl:Wrote 5000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.4179909229278564 seconds
INFO:absl:Wrote 6000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.3905141353607178 seconds
INFO:absl:Wrote 7000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.4685847759246826 seconds
INFO:absl:Wrote 8000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.38470721244812 seconds
INFO:absl:Wrote 9000 tfrecods
INFO:absl:CNN daily email, 1000 data written in 3.358107328414917 seconds
INFO:absl:Wrote 10000 tfrecods
INFO:absl:CNN daily email, 10

In [29]:
# We finshed writing 10000 records in 30 seconds

In [31]:
# How to read it back . 

# In the folder we write the tf records, we will save the schemas as json file
# We can wither load from the there or manually define it. Make sure it is same as when we write

import json
import glob

tf_schema = json.load(open("cnn_dailymail_tf/schema.json"))
all_files = glob.glob("cnn_dailymail_tf/*.tfrecord")

# Load data

# In the above case we do not want `dummy_value` while loading tf records. So we can define the keys we needed .

keys_required = ['input_word_ids', 'labels']
# If keys required is empty , we will load all the keys
tf_reader = TFReader(tf_schema, all_files, keys=keys_required)
dataset = tf_reader.read_record()
dataset = dataset.apply(tf.data.experimental.ignore_errors())

In [32]:
# Look at sample data 

for item in dataset:
  print(item)
  break

{'input_word_ids': <tf.Tensor: shape=(707,), dtype=int64, numpy=
array([    7, 18474,     8,  1377, 30405,   319,  3909, 11468,   262,
        1989,   287, 46065, 37499,  8545,   810,  3434,  1364,  3598,
         471,    13,    45,    13,  4167, 24952,   290,  3624, 10380,
        2636,    11,  1864,   284,   257,   471,    13,    45,    13,
        1743,    13,  1881,  1368,  5091,  2739,  3635,   290,   656,
        3217,  1474,  2547,    64, 14812,    11,   407,  1290,   422,
         262,  7421,    12, 31463,  5510,  3277,   338,  4865,   351,
       33208,    11,  1864,   284,   262,  1578,  7973,    13,  5524,
        8353,  5745,  2098,  3909,   484,   547, 12451,   546,   604,
          11,   830,   661,   287, 11144,    11,   531,  3982,    72,
         360,   454, 26487,    11,   257,  6523,   329,   262,   471,
          13,    45,    13,  4452,   329,   262, 22819,  1883,   286,
        5524,  8353, 10665,    13, 12168,  3470,   550,  5284,   416,
       47168,  3909,   28

In [33]:
# So we have loaded only necessary keys
# Lets pad and batch it
PAD_TOKEN = tf.constant(0, tf.int64)


# Separate inputs and labels to dict
def map_to_dict(item):
    inputs = {}
    for k , v in item.items():

        if k in ['input_word_ids']:
            inputs[k] = v

    labels = {}
    for k , v in item.items():
        if k in ['labels']:
            labels[k] = v
    return inputs, labels

batch_size = 5
dataset = dataset.padded_batch(batch_size=batch_size,
                               padding_values={'input_word_ids': PAD_TOKEN, 
                                              'labels': PAD_TOKEN
                               })
dataset = dataset.map(map_to_dict, num_parallel_calls=tf.data.experimental.AUTOTUNE)


In [34]:
for item in dataset:
  print(item)
  break

({'input_word_ids': <tf.Tensor: shape=(5, 999), dtype=int64, numpy=
array([[   48,  8107,    11, ...,    11,   257,  5313],
       [ 1026,   338,  1743, ...,  3275,   284,   262],
       [48790,  2254,   357, ...,     0,     0,     0],
       [33363,  4146, 11357, ...,  4068,    11,  3908],
       [17402,   357, 18474, ...,     0,     0,     0]])>}, {'labels': <tf.Tensor: shape=(5, 999), dtype=int64, numpy=
array([[ 8107,    11,  6365, ...,   257,  5313,    12],
       [  338,  1743,    25, ...,   284,   262, 13032],
       [ 2254,   357, 18474, ...,     0,     0,     0],
       [ 4146, 11357,    11, ...,    11,  3908,    11],
       [  357, 18474,     8, ...,     0,     0,     0]])>})


In [None]:
# Thats it , we have padded and batched it
# We are planning to support bucketing also in coming days
# Stay Tuned :-)