In [16]:
# !pip install tensorflow==2.4.1
# !pip install transformers
# !pip install pyarrow

In [4]:
import tensorflow as tf
import pandas as pd
import numpy as np
import os
from math import ceil
from transformers import AlbertTokenizerFast, TFAlbertModel

In [5]:
def check_targets(targs):
    if targs[0] == -1:
        return 1
    else:
        return 0

In [6]:
def create_tfrecords_dataset(data, iter_num, dataset_type='train'):
#     paper_title = tf.keras.preprocessing.sequence.pad_sequences(data['paper_title_tok'].to_list(), 
#                                                                 maxlen=512, dtype='int64', 
#                                                                 padding='post', truncating='post', value=0)
    data['no_target'] = data['target_tok'].apply(check_targets)
    data = data[data['no_target']==0].copy()
    
    paper_title = tf.ragged.constant(data['paper_title_tok'].to_list())
    
    paper_mask = tf.ragged.constant(data['paper_title_mask'].to_list())
    
    targets = tf.keras.preprocessing.sequence.pad_sequences(data['target_tok'].to_list(), maxlen=20, 
                                                            dtype='int64', padding='post', 
                                                            truncating='post', value=0)

    ds = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(paper_title),
                              tf.data.Dataset.from_tensor_slices(paper_mask),
                              tf.data.Dataset.from_tensor_slices(data['journal_tok'].to_list()),
                              tf.data.Dataset.from_tensor_slices(data['doc_type_tok'].to_list()),
                              tf.data.Dataset.from_tensor_slices(targets)))
    
    serialized_features_dataset = ds.map(tf_serialize_example)
    
    filename = f"./iteration_1_500_test/tfrecords/{dataset_type}/{str(iter_num).zfill(4)}.tfrecord"
    writer = tf.data.experimental.TFRecordWriter(filename)
    writer.write(serialized_features_dataset)

In [7]:
def tf_serialize_example(f0, f1, f2, f3, f4):
    tf_string = tf.py_function(serialize_example, (f0, f1, f2, f3, f4), tf.string)
    return tf.reshape(tf_string, ())

In [8]:
def serialize_example(paper_title, paper_mask, journal, doc_type, targets):
    paper_title_list = tf.train.Int64List(value=paper_title.numpy().tolist())
    paper_mask_list = tf.train.Int64List(value=paper_mask.numpy().tolist())
    journal_list = tf.train.Int64List(value=journal.numpy().tolist())
    doc_type_list = tf.train.Int64List(value=doc_type.numpy().tolist())
    targets_list = tf.train.Int64List(value=targets.numpy().tolist())
    
    paper_title_feature = tf.train.Feature(int64_list = paper_title_list)
    paper_mask_feature = tf.train.Feature(int64_list = paper_mask_list)
    journal_feature = tf.train.Feature(int64_list = journal_list)
    doc_type_feature = tf.train.Feature(int64_list = doc_type_list)
    targets_feature = tf.train.Feature(int64_list = targets_list)
    
    features_for_example = {
        'paper_title': paper_title_feature,
        'paper_mask': paper_mask_feature,
        'journal': journal_feature,
        'doc_type': doc_type_feature,
        'targets': targets_feature
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=features_for_example))
    
    return example_proto.SerializeToString()

In [9]:
def turn_part_file_into_tfrecord(base_path, dataset_type='train'):
    file_list = [x for x in os.listdir(f"{base_path}{dataset_type}") if x.endswith('parquet')]
    file_list.sort()
    print(f"There are {len(file_list)} files for {dataset_type}")
    for i, file_name in enumerate(file_list):
        data = pd.read_parquet(f"{base_path}{dataset_type}/{file_name}")
        print(f"_____File number: {i} ({data.shape[0]} samples)")
        create_tfrecords_dataset(data, i, dataset_type)

In [10]:
base_file_path = f"./iteration_1_500_test/tokenized_data/"

#### Without padding

In [11]:
%%time
turn_part_file_into_tfrecord(base_file_path, 'train')

There are 50 files for train
_____File number: 0 (333934 samples)


2021-10-28 16:44:50.599508: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-10-28 16:44:50.599987: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-28 16:45:12.701094: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)


_____File number: 1 (500782 samples)
_____File number: 2 (289574 samples)
_____File number: 3 (731312 samples)
_____File number: 4 (452938 samples)
_____File number: 5 (332182 samples)
_____File number: 6 (334258 samples)
_____File number: 7 (332690 samples)
_____File number: 8 (334133 samples)
_____File number: 9 (500947 samples)
_____File number: 10 (500685 samples)
_____File number: 11 (333661 samples)
_____File number: 12 (333736 samples)
_____File number: 13 (334030 samples)
_____File number: 14 (502205 samples)
_____File number: 15 (500262 samples)
_____File number: 16 (333853 samples)
_____File number: 17 (333452 samples)
_____File number: 18 (334802 samples)
_____File number: 19 (500348 samples)
_____File number: 20 (333153 samples)
_____File number: 21 (501138 samples)
_____File number: 22 (500516 samples)
_____File number: 23 (334218 samples)
_____File number: 24 (333879 samples)
_____File number: 25 (499942 samples)
_____File number: 26 (333754 samples)
_____File number: 27 

In [12]:
%%time
turn_part_file_into_tfrecord(base_file_path, 'val')

There are 10 files for val
_____File number: 0 (81042 samples)
_____File number: 1 (73535 samples)
_____File number: 2 (152905 samples)
_____File number: 3 (133827 samples)
_____File number: 4 (82014 samples)
_____File number: 5 (99433 samples)
_____File number: 6 (72605 samples)
_____File number: 7 (59826 samples)
_____File number: 8 (72843 samples)
_____File number: 9 (64708 samples)
CPU times: user 5min 48s, sys: 19.6 s, total: 6min 8s
Wall time: 5min 14s


In [13]:
%%time
turn_part_file_into_tfrecord(base_file_path, 'test')

There are 5 files for test
_____File number: 0 (32310 samples)
_____File number: 1 (29691 samples)
_____File number: 2 (31983 samples)
_____File number: 3 (31242 samples)
_____File number: 4 (25404 samples)
CPU times: user 1min, sys: 3.26 s, total: 1min 3s
Wall time: 54.4 s
