In [None]:
import torch
import tensorflow as tf
import time
import multiprocessing as mp

In [None]:
batch_size = 32
dataset_length = 32*10
cpu_count = mp.cpu_count()
print(f'CPUs available: {cpu_count}')

def do_work():
    for i in range(1_000_000):
        i = i + 1

def do_sleep():
    time.sleep(0.05)

In [None]:
class TFDataLoaderThin():
    def gen_sample(self):
        for idx in range(dataset_length):
            yield idx # dummy data

def map_fn_io(x):
    tf.py_function(do_sleep, [], ())
    return x
    
def map_fn_cpu(x):
    tf.py_function(do_work, [], ())
    return x

In [None]:
dataloader = TFDataLoaderThin()
datagen = dataloader.gen_sample
tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
tf_dataset = tf_dataset.map(map_fn_io).batch(batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(tf_dataset):
    pass
print("Time taken for TF vanilla IO dataset: {}".format(time.time() - start))

In [None]:
dataloader = TFDataLoaderThin()
datagen = dataloader.gen_sample
tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
tf_dataset = tf_dataset.map(map_fn_io, num_parallel_calls=cpu_count).batch(batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(tf_dataset):
    pass
print("Time taken for TF parallel IO dataset: {}".format(time.time() - start))

In [None]:
dataloader = TFDataLoaderThin()
datagen = dataloader.gen_sample
tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
tf_dataset = tf_dataset.map(map_fn_cpu).batch(batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(tf_dataset):
    pass
print("Time taken for TF vanilla CPU dataset: {}".format(time.time() - start))

In [None]:
dataloader = TFDataLoaderThin()
datagen = dataloader.gen_sample
tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
tf_dataset = tf_dataset.map(map_fn_cpu, num_parallel_calls=cpu_count).batch(batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(tf_dataset):
    pass
print("Time taken for TF parallel CPU dataset: {}".format(time.time() - start))

In [None]:
class TFDataLoader():
    def __init__(self, num_workers=1):
        self.num_workers = num_workers
        self.queue = mp.Queue(maxsize=10)
        self.processes = []
        self.chunk_size = dataset_length//self.num_workers # split the dataset into chunks here

    def initialize(self):
        processes = []
        for i in range(self.num_workers):
            p = mp.Process(target=self._worker, daemon=True)
            processes.append(p)
        for p in processes:
            p.start()
        self.processes = processes

    def _worker(self):
        for idx in range(self.chunk_size):
            do_work()
            self.queue.put(idx)

    def gen_sample(self):
        processed = 0
        while processed < dataset_length:
            yield self.queue.get()
            processed += 1

    def close(self):
        for p in self.processes:
            p.terminate()

In [None]:
    dataloader = TFDataLoader(num_workers=1)
    dataloader.initialize()
    datagen = dataloader.gen_sample
    tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
    tf_dataset = tf_dataset.batch(batch_size)
    # time the dataloader
    start = time.time()
    for idx, data in enumerate(tf_dataset):
        pass
    print("Time taken for TF vanilla CPU dataset: {}".format(time.time() - start))
    dataloader.close()

In [None]:
    dataloader = TFDataLoader(num_workers=cpu_count)
    dataloader.initialize()
    datagen = dataloader.gen_sample
    tf_dataset = tf.data.Dataset.from_generator(datagen, output_signature=tf.TensorSpec([], tf.int32))
    tf_dataset = tf_dataset.batch(batch_size)
    # time the dataloader
    start = time.time()
    for idx, data in enumerate(tf_dataset):
        pass
    print("Time taken for TF parallel CPU dataset: {}".format(time.time() - start))
    dataloader.close()

In [None]:
class CPUDataset(torch.utils.data.Dataset):
    def __len__(self):
        return dataset_length

    def __getitem__(self, idx):
        do_work()
        return idx
    
class IODataset(torch.utils.data.Dataset):
    def __len__(self):
        return dataset_length

    def __getitem__(self, idx):
        do_sleep()
        return idx

In [None]:
cpu_dataset = CPUDataset()
dataloader = torch.utils.data.DataLoader(cpu_dataset, batch_size=batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(dataloader):
    pass
print("Time taken for torch vanilla CPU dataset: {}".format(time.time() - start))

In [None]:
cpu_dataset = CPUDataset()
dataloader = torch.utils.data.DataLoader(cpu_dataset, batch_size=batch_size, num_workers=cpu_count)
# time the dataloader
start = time.time()
for idx, data in enumerate(dataloader):
    pass
print("Time taken for torch parallel CPU dataset: {}".format(time.time() - start))

In [None]:
io_dataset = IODataset()
dataloader = torch.utils.data.DataLoader(io_dataset, batch_size=batch_size)
# time the dataloader
start = time.time()
for idx, data in enumerate(dataloader):
    pass
print("Time taken for torch vanilla IO dataset: {}".format(time.time() - start))

In [None]:
io_dataset = IODataset()
dataloader = torch.utils.data.DataLoader(io_dataset, batch_size=batch_size, num_workers=cpu_count)
# time the dataloader
start = time.time()
for idx, data in enumerate(dataloader):
    pass
print("Time taken for torch parallel IO dataset: {}".format(time.time() - start))