# Convert text into tfrecords

In [0]:
import collections
import re
import unicodedata
import six
import tensorflow as tf
import numpy as np
import pandas as pd
import os
import csv

download the 'train.csv' file from this link: https://drive.google.com/open?id=1iLyI3BgjO8H5tf96WI__C1MmOrujzAtE

In [0]:
toxic = pd.read_csv('train.csv')

#### convert csv file to tsv

In [0]:
train_examples = []

for i in range(97320):
  example = collections.OrderedDict()
  text_id = toxic['id'][i]
  text = toxic['comment_text'][i]
  label_id = toxic_label['id'][i]
  label = toxic_label['prediction'][i]
  
  if re.search(str(text_id), str(label_id)):
  
    st = ''
    split = text.split('\n')
    for j in range(len(split)):
      if split[j] != '':
        st += split[j]
        
    example['comment_text'] = st
    example['target'] = float(label)
    train_examples.append(example)
      


In [0]:
train_examples

In [0]:
train = pd.DataFrame(train_examples)

In [0]:
train.to_csv('train.tsv')

### **Utility functions for preprocessing the data before convert to tfrecords**


In [0]:
def convert_to_unicode(text):
  
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return text.decode('utf-8', 'ignore')
    else:
      raise ValueError('Unsupported string type: ', type(text))
  elif six.PY2: 
    if isinstance(text, str): 
      return text
    elif isinstance(text, unicode):
      return text.encode("utf-8")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError('Not running on Python2 or Python3?')
      
  
  
def printable_text(text):
  
  if six.PY3:
    if isinstance(text, str):
      return text
    elif isinstance(text, bytes):
      return text.decode('utf-8', 'ignore')
    else:
      raise ValueError('Unsupported string type: ', type(text))
  elif six.PY2:
    if isinstance(text, str):
      return text
    elif isinstance(text, unicode):
      return text.encode("utf-8")
    else:
      raise ValueError("Unsupported string type: %s" % (type(text)))
  else:
    raise ValueError("Not running on Python2 or Python 3?")
      
def load_vocab(vocab_file):
  vocab = collections.OrderedDict()
  index = 0
  with tf.gfile.GFile(vocab_file, 'r') as reader:
    while True:
      token = convert_to_unicode(reader.readline())
      if not token:
        break
        
      token = token.strip()
      vocab[token] = index
      index += 1
      
  return vocab


def convert_by_vocab(vocab, items):
  #print(items)
  
  output = []
  for item in items:
    output.append(vocab[item])
  return output


def convert_tokens_to_ids(vocab, tokens):
  return convert_by_vocab(vocab, tokens)

def convert_ids_to_tokens(inv_vocab, ids):
  return convert_by_vocab(inv_vocab, ids)

def whitespace_tokenize(text):
  text = text.strip()
  if not text:
    return []
  tokens = text.split()
  return tokens


class FullTokenizer(object):
  
  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.inv_vocab = {v: k for k, v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
    
    
  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text):
      for sub_tokens in self.wordpiece_tokenizer.tokenize(token):
        split_tokens.append(sub_tokens)
        
        
    return split_tokens
  
  def convert_tokens_to_ids(self, tokens):
    return convert_by_vocab(self.vocab, tokens)
  
  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)
  
  
class BasicTokenizer(object):
  
  def __init__(self, do_lower_case=True):
    self.do_lower_case = do_lower_case
    
  def tokenize(self, text):
    text = convert_to_unicode(text)
    text = self._clean_text(text)
    
    text = self._tokenize_chinese_chars(text)
    
    orig_tokens = whitespace_tokenize(text)
    split_tokens = []
    for token in orig_tokens:
      if self.do_lower_case:
        token = token.lower()
        token = self._run_strip_accents(token)
      split_tokens.extend(self._run_split_on_punc(token))
      
    output_tokens = whitespace_tokenize(" ".join(split_tokens))
    return output_tokens
  
  def _run_strip_accents(self, text):
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
      cat = unicodedata.category(char)
      if cat == 'Mn':
        continue
      output.append(char)
    return "".join(output)
  
  def _run_split_on_punc(self, text):
    chars = list(text)
    i = 0
    start_new_word = True
    output = []
    while i < len(chars):
      char = chars[i]
      if _is_punctuation(char):
        output.append([char])
        start_new_word = True
      else:
        if start_new_word:
          output.append([])
        start_new_word = False
        output[-1].append(char)
        
      i += 1
      
    return ["".join(x) for x in output]
  
  def _tokenize_chinese_chars(self, text):
    output = []
    for char in text:
      cp = ord(char)
      if self._is_chinese_char(cp):
        output.append(" ")
        output.append(char)
        output.append(" ")
      else:
        output.append(char)
        
    return "".join(output)
  
  def _is_chinese_char(self, cp):
    
    if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
        (cp >= 0x3400 and cp <= 0x4DBF) or  #
        (cp >= 0x20000 and cp <= 0x2A6DF) or  #
        (cp >= 0x2A700 and cp <= 0x2B73F) or  #
        (cp >= 0x2B740 and cp <= 0x2B81F) or  #
        (cp >= 0x2B820 and cp <= 0x2CEAF) or
        (cp >= 0xF900 and cp <= 0xFAFF) or  #
        (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
      return True

    return False
  
  def _clean_text(self, text):
    output = []
    for char in text:
      cp = ord(char)
      if cp == 0 or cp == 0xfffd or _is_control(char):
        continue
      if _is_whitespace(char):
        output.append(" ")
        
      else:
        output.append(char)
        
    return "".join(output)
  
  
class WordpieceTokenizer(object):
  
  def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
    
    self.vocab = vocab
    self.unk_token = unk_token
    self.max_input_chars_per_word = max_input_chars_per_word
    
    
  def tokenize(self, text):
    
    text = convert_to_unicode(text)
    #print(text)
    
    output_tokens = []
    for token in whitespace_tokenize(text):
      chars = list(token)
      if len(chars) > self.max_input_chars_per_word:
        output_tokens.append(self.unk_token)
        continue
        
      #print(output_tokens)
        
      is_bad = False
      start = 0
      sub_tokens = []
      while start < len(chars):
        end = len(chars)
        cur_substr = None
        while start < end:
          substr = "".join(chars[start:end])
          if start > 0:
            substr = "##" + substr
          if substr in self.vocab:
            cur_substr = substr
            break
          end -= 1
        if cur_substr is None:
          is_bad = True
          break
          
        sub_tokens.append(cur_substr)
        start = end
        
      if is_bad:
        output_tokens.append(self.unk_token)
        
      else:
        output_tokens.extend(sub_tokens)
        
    return output_tokens
  
  
  
def _is_whitespace(char):
  if char == " " or char == "\t" or char == "\n" or char == "\r":
    return True
  cat = unicodedata.category(char)
  if cat == "Zs":
    return True
  return False

def _is_control(char):
  
  if char == "\t" or char == "\n" or char == "\r":
    return False
  cat = unicodedata.category(char)
  if cat in ("Cc", "Cf"):
    return True
  return False

def _is_punctuation(char):
  
  cp = ord(char)
  
  if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
      (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
    return True
  cat = unicodedata.category(char)
  if cat.startswith("P"):
    return True
  return False
    

In [0]:
class InputFeatures(object):
  
  def __init__(self, input_ids, input_mask, segment_ids, label_id, is_real_example=True):
    self.input_ids = input_ids
    self.input_mask = input_mask
    self.segment_ids = segment_ids
    self.label_id = label_id
    self.is_real_example = is_real_example

In [0]:
class InputExample(object):
  
  def __init__(self, guid, text_a, text_b = None, label = None):
    
    self.guid = guid
    self.text_a = text_a
    self.text_b = text_b
    self.label = label

In [0]:
def read_tsv(input_file, quotechar=None):
  with tf.gfile.Open(input_file, 'r') as f:
    reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
    lines = []
    for (idx, line) in enumerate(reader):
      data = []
      com = ''
      target = 0.0
      if len(line) != 1:
        for (i, ln) in enumerate(line):
          ln_split = ln.split(',')
          if ',' in ln and i==0:
            for k in ln_split[1:]:
              com += k
            com += ' '
          elif ',' in ln and i==len(line)-1:
            for k in ln_split[:len(ln_split)-1]:
              com += k
            target = float(ln_split[len(ln_split)-1])
          elif ',' in ln:
            for k in ln_split:
              com += k
            com += ' '
          elif ',' not in ln:
            for k in ln_split:
              com += k
            com += ' '
        data.append(com)
        data.append(target)
            
      else:
        if idx != 0:
          line = line[0].split(',')
          #data = []
          if len(line) != 3:
            target = float(line[len(line)-1])
            comment = line[1:len(line)-1]
            text = ''
            for ln in comment:
              text += ln
            data.append(text)
            data.append(target)

          elif len(line) == 3:
            data.append(line[1])
            data.append(float(line[2]))

          else:
            raise ValueError('Value cannot be Null ', line)
          
      #print(idx, data)
      lines.append(data)
    lines = lines[1:]
    #print(lines)
      
    return lines
        
        

In [0]:
def create_examples(lines, set_type):
  examples = []
  for (i, line) in enumerate(lines):
    
    guid = '%s-%s' % (set_type, i)
    text_a = convert_to_unicode(line[0])
    if line[1] >= 0.5:
      label = 'toxic'
    else:
      label = 'non-toxic'
      
    label = convert_to_unicode(label)
    examples.append(InputExample(guid=guid, text_a = text_a, text_b = None, label = label))
    
  return examples

#### change the length of tokens which is equal to max_sequence_length=512 and pad the tokens if length is not equals to max_sequence_length=512

In [0]:
def truncate_seq_pair(tokens_a, tokens_b, max_length):
  while True:
    total_length = len(token_a) + len(token_b)
    if total_length <= max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()

#### convert the strings into tokens

In [0]:
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
  
  label_map = {}
  for (i, label) in enumerate(label_list):
    label_map[label] = i
    
  #print(example.text_a)
    
  tokens_a = tokenizer.tokenize(example.text_a)
  tokens_b = None
  
  if example.text_b:
    tokens_b = tokenizer.tokenize(example.text_b)
    
  if tokens_b:
    truncate_seq_pair(tokens_a, tokens_b, max_length - 3)
  else:
    if len(tokens_a) > max_seq_length - 2:
      tokens_a = tokens_a[0:(max_seq_length - 2)]
      
  #print(tokens_a)
  tokens = []
  segment_ids = []
  tokens.append("[CLS]")
  segment_ids.append(0)
  for token in tokens_a:
    tokens.append(token)
    segment_ids.append(0)
  tokens.append("[SEP]")
  segment_ids.append(0)
  
  if tokens_b:
    for token in tokens_b:
      tokens.append(token)
      segment_ids.append(1)
    tokens.append("[SEP]")
    segment_ids.append(1)
    
  input_ids = tokenizer.convert_tokens_to_ids(tokens)
  input_mask = [1] * len(input_ids)
  
  while len(input_ids) < max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)
    
  assert len(input_ids) == max_seq_length
  assert len(input_mask) == max_seq_length
  assert len(segment_ids) == max_seq_length
  
  label_id = label_map[example.label]
  if ex_index < 2 :
    tf.logging.info("*** Example ***")
    tf.logging.info('guid: %s' % (example.guid))
    tf.logging.info('tokens: %s' % " ".join([printable_text(x) for x in tokens]))
    tf.logging.info('input_ids: %s' % " ".join([str(x) for x in input_ids]))
    tf.logging.info('input_mask: %s' % " ".join([str(x) for x in input_mask]))
    tf.logging.info('segment_ids: %s' % " ".join([str(x) for x in segment_ids]))
    tf.logging.info('label: %s (id = %d)' % (example.label, label_id))
  
  feature = InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_id, is_real_example=True)
  return feature

#### convert the data into tfrecords

In [0]:
def _int64_feature(value):
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def file_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file):
  
  writer = tf.python_io.TFRecordWriter(output_file)
  
  for (ex_index, example) in enumerate(examples):
   
    #if ex_index > 500000:
      
    if ex_index %10000 == 0:
      tf.logging.info('Writing example %d of %d' % (ex_index, len(examples)))

    feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer)



    features = collections.OrderedDict()
    features['input_ids'] = _int64_feature(feature.input_ids)
    features['input_mask'] = _int64_feature(feature.input_mask)
    features['segment_ids'] = _int64_feature(feature.segment_ids)
    features['label_ids'] = _int64_feature(feature.label_id)
    features['is_real_example'] = _int64_feature([int(feature.is_real_example)])

    tf_example = tf.train.Example(features=tf.train.Features(feature=features))
    writer.write(tf_example.SerializeToString())
    
  writer.close()

In [0]:
if __name__ == '__main__':
  
  train_file = 'bert_train.tfrecords'
  vocab_file = 'vocab.txt'
  tsv_file = 'train.tsv'
  max_seq_length = 512
  do_lower_case = True
  label_list = ['toxic', 'non-toxic']
  
  tokenizer = FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
  
  lines = read_tsv(tsv_file)
  train_examples = create_examples(lines, 'train')
  
  file_based_convert_examples_to_features(train_examples, label_list, max_seq_length, tokenizer, train_file)
  

INFO:tensorflow:Writing example 0 of 97320
INFO:tensorflow:*** Example ***
INFO:tensorflow:guid: pred-0
INFO:tensorflow:tokens: [CLS] jeff sessions is another one of trump ' s or ##well ##ian choices . he believes and has believed his entire career the exact opposite of what the position requires . [SEP]
INFO:tensorflow:input_ids: 101 5076 6521 2003 2178 2028 1997 8398 1005 1055 2030 4381 2937 9804 1012 2002 7164 1998 2038 3373 2010 2972 2476 1996 6635 4500 1997 2054 1996 2597 5942 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

### validate the tfrecords file so that no corrupted record exists in tfrecords file

In [0]:
import struct
from crcmod.predefined import mkPredefinedCrcFun

_crc_fn = mkPredefinedCrcFun('crc-32c')


def calc_masked_crc(data):
    crc = _crc_fn(data)
    return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xFFFFFFFF


def validate_dataset(filename):
    total_records = 0
    total_bad_len_crc = 0
    total_bad_data_crc = 0
    #for f_name in filenames:
    i = 0
    print('validating ', filename)

    with open(filename, 'rb') as f:

        len_bytes = f.read(8)
        while len(len_bytes) > 0:
            # tfrecord format is a wrapper around protobuf data
            length, = struct.unpack('<Q', len_bytes) # u64: length of the protobuf data (excluding the header)
            len_crc, = struct.unpack('<I', f.read(4)) # u32: masked crc32c of the length bytes
            data = f.read(length) # protobuf data
            data_crc, = struct.unpack('<I', f.read(4)) # u32: masked crc32c of the protobuf data

            if len_crc != calc_masked_crc(len_bytes):
                print('bad crc on len at record', i)
                total_bad_len_crc += 1

            if data_crc != calc_masked_crc(data):
                print('bad crc on data at record', i)
                total_bad_data_crc += 1

            i += 1
            len_bytes = f.read(8)

    print('checked', i, 'records')
    total_records += i
    print('checked', total_records, 'total records')
    print('total with bad length crc: ', total_bad_len_crc)
    print('total with bad data crc: ', total_bad_data_crc)


In [0]:
if __name__ == '__main__':
  validate_dataset('bert_train.tfrecords')