In [7]:
import tf.records as tfr
import tensorflow as tf
from glob import glob

In [8]:
schema = tfr.Dataset.read('records/mr').meta.schema_

In [9]:
files = glob('records/mr/*.tfrecord.gz')

In [11]:
from typing import Iterable, Mapping, Literal
from tf.records import Field, parse

def batched_read(
  schema: Mapping[str, Field], filepaths: Iterable[str], *,
  compression: Literal['GZIP', 'ZLIB'] | None = None,
  keep_order: bool = True, batch_size: int = 16
) -> tf.data.Dataset:
  """Parse a series of TFRecord files into a single dataset"""
  ignore_order = tf.data.Options()
  ignore_order.experimental_deterministic = keep_order
  return (
    tf.data.TFRecordDataset(filepaths, compression_type=compression, num_parallel_reads=tf.data.AUTOTUNE)
    .with_options(ignore_order)
    .batch(batch_size)
    .map(parse(schema).batch, num_parallel_calls=tf.data.AUTOTUNE)
  )

ds = batched_read(schema, files, compression='GZIP', keep_order=False)

In [12]:
for i, x in ds.enumerate():
  if i % 1000 == 0:
    print(f'\r{i}', end='', flush=True)

4000