In [1]:
%pylab inline

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

  warn(f"Failed to load image Python extension: {e}")


In [3]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])

mnist_dataset = torchvision.datasets.MNIST(root='./mnist/', train=True,download=True, transform=transform_train)

# Writing benchmark 

In [4]:
from time import time

In [5]:
import h5py
class HDF5DatasetWriter(object):
    def __init__(self, filename, dataset):
        super().__init__()


        # This is a raster type dataset 
        self.dataset = dataset
        # Find dtype, shape and length of the array
        n = len(self.dataset)
        inputs, labels = self.dataset[0]
        inputs = inputs.numpy()
        if isinstance(inputs,list):
            c,h,w = inputs[0].shape
            inputs_dtype = inputs[0].dtype
        else:
            c,h,w = inputs.shape
            inputs_dtype = inputs.dtype

        #cm = labels.shape[0]
        labels_dtype = np.int64 # labels.dtype

        self.hdf5 = h5py.File(filename,"a")
        # Tailored for MNIST
        self.hdf5.create_dataset('inputs',(n,c,h,w),dtype=inputs_dtype)
        self.hdf5.create_dataset('labels',(n,),dtype=labels_dtype)

    def write_element(self,i):

        inputs, labels = self.dataset[i]
        self.hdf5['inputs'][i] = inputs
        self.hdf5['labels'][i] = labels

In [7]:
myHDF5MNISTWriter = HDF5DatasetWriter(filename="mnist.hdf5",dataset=mnist_dataset)

In [8]:
tic = time()
for i in range(len(mnist_dataset)):
    myHDF5MNISTWriter.write_element(i)
Dt = time() - tic
NData = len(mnist_dataset)
print("Time per datum:{}sec".format(Dt/NData))
print("Datums per Second::{} - SERIAL".format(NData/Dt))
print("time to WRITE {} data::{}sec".format(NData,Dt))


Time per datum:0.0003399996558825175sec
Datums per Second::2941.1794473860796 - SERIAL
time to WRITE 60000 data::20.39997935295105sec


In [9]:
myHDF5MNISTWriter.hdf5.close()

In [28]:
myHDF5MNISTWriter = HDF5DatasetWriterChunks(filename="mnist_chunked.hdf5",dataset=mnist_dataset)

In [29]:
tic = time()
for i in range(len(mnist_dataset)):
    myHDF5MNISTWriter.write_element(i)
Dt = time() - tic
NData = len(mnist_dataset)
print("Time per datum:{}sec".format(Dt/NData))
print("Datums per Second::{} - SERIAL".format(NData/Dt))
print("time to WRITE {} data::{}sec".format(NData,Dt))

Time per datum:0.0003361504395802816sec
Datums per Second::2974.858522418126 - SERIAL
time to WRITE 60000 data::20.169026374816895sec


# Now let's see RocksDB 

In [6]:
import rocksdb
import pickle
class RocksDBDatasetWriter(object):

    def __init__(self, flname_db, metadata, dataset):
        self.db = None
        #self.lock = Lock()
        self.dataset=dataset


        self.flname_db = flname_db
        self.meta = metadata

    def _open_rocks(self):


        # ==========================================================================
        # TODO : they need further investigation 
        # Some good behaving defaults 
        opts = rocksdb.Options()
        opts.create_if_missing = True
        opts.max_open_files = 300000
        opts.write_buffer_size = 67108864
        opts.max_write_buffer_number = 30 # 3 default
        opts.target_file_size_base = 67108864  # default 67108864, value starting 7: 7340032 input.nbytes
        opts.paranoid_checks=False

        # @@@@@@@@@@@ NEW LINE @@@@@@@@@@@@@@@@@@@
        opts.IncreaseParallelism()
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        
        opts.table_factory = rocksdb.BlockBasedTableFactory(
                filter_policy=rocksdb.BloomFilterPolicy(10), # 10 default
                block_cache=rocksdb.LRUCache( 2 * (1024 ** 3)), # 2 * (1024 ** 3) default
                block_cache_compressed=rocksdb.LRUCache(500 * (1024 ** 2)))
        # ==================================================================================

        self.db = rocksdb.DB(self.flname_db, opts,read_only=False)

        # REMARK: it is possible to create "column_family", so as to have two distince separate categories of variables (actually as many 
        # dimensions as desired). This MUst improve search capabilities
        # See 
        
        
        meta_key = 'meta'.encode('ascii')
        meta_dumps = pickle.dumps(self.meta)
        self.db.put(meta_key,meta_dumps)

    def write_element(self,idx):
        if self.db is None:
            self._open_rocks()

        # Tailored for mnist 
        imgs, labels = self.dataset[idx]
        imgs = imgs.numpy().tobytes()
        labels = np.array([labels]).astype(np.uint8).tobytes()
        
        if isinstance(imgs,list):
            n_imgs = len(imgs)
            for i, img in enumerate(imgs):
                key_img = 'inputs{}_{}'.format(i,idx).encode('ascii')
                img = img.tobytes()
                self.db.put(key_img,  img)
        else:
            key_img  = 'inputs_{}'.format(idx).encode('ascii')
            self.db.put(key_img,imgs)

        key_mask = 'labels_{}'.format(idx).encode('ascii')
        self.db.put(key_mask, labels)

    # Optionally one can write in batches, to increase parallelism. 
    def write_batch(self, batch_idx, databatch):
        if self.db is None:
            self._open_rocks()

        inputs, labels = zip(*databatch)
        inputs = np.array(inputs)
        labels = np.array(inputs)
        batch_size = inputs.shape[0]
        bb =  rocksdb.WriteBatch()
        for idx, (tinput, tlabel) in enumerate(zip(inputs,labels)):
            key_input = 'inputs_{}'.format(idx+batch_idx*batch_size).encode('ascii')
            bb.put(key_input,tinput.tobytes())
            key_input = 'labels_{}'.format(idx+batch_idx*batch_size).encode('ascii')
            bb.put(key_input,tlabel.tobytes())
        self.db.write(bb)


In [32]:
myRocksDBWriter = RocksDBDatasetWriter(flname_db="mnist_rocks_batch.db",
                                       metadata={'inputs_shape':(1,28,28),
                                                'inputs_dtype':np.float32,
                                                'labels_shape':None,
                                                'labels_dtype':np.uint8},
                                      dataset=mnist_dataset)

In [33]:
tic = time()
for i in range(len(mnist_dataset)):
    myRocksDBWriter.write_element(i)
Dt = time() - tic
NData = len(mnist_dataset)
print("Time per datum:{}sec".format(Dt/NData))
print("Datums per Second::{} - SERIAL".format(NData/Dt))
print("time to WRITE {} data::{}sec".format(NData,Dt))


Time per datum:0.00010444822311401367sec
Datums per Second::9574.121705339297 - SERIAL
time to WRITE 60000 data::6.26689338684082sec


## Therefore RocksDB is 20.20 / 6.73 = 3 times faster when writing elementwise 


How about writing in batches? 

In [5]:
import sys
sys.path.append(r'../../')
# Here I am using the class provided with this repo
from ssg2.data.rocksdbutils import RocksDBWriter 

In [6]:
from collections import OrderedDict as odict
F=28
meta = odict({'inputs'  : odict({'inputs_shape' : (1, F,F),
                                    'inputs_dtype' : np.float32}),
                 'labels'  : odict({'labels_shape' : None,
                                    'labels_dtype' : np.uint8})})


In [7]:
# The num_workers is a hyperparameter, 4 works best for this dataset
myDBWriter = RocksDBWriter(flname_db='mnist_RDB_v2.db',metadata=meta,num_workers=4)

In [8]:
# Data Preprocessing - make them in batches, this is slow as it is serial, but we can parallelize it easily - see below
temp = [ [datum[0].numpy(),np.array(datum[1],dtype=np.uint8)] for datum in mnist_dataset]
n=100 # Write 100 datums within batch
databatches = [temp[i:i+n] for i in range(0,len(temp),n)]

In [9]:
# Timing only the write operation 
tic = time()
for datum in databatches:
    myDBWriter.write_batch(datum)
Dt = time() - tic
NData = len(mnist_dataset)
print("Time per datum:{}sec".format(Dt/NData))
print("Datums per Second::{} - WriteBatch".format(NData/Dt))
print("time to WRITE {} data::{}sec".format(NData,Dt))

Time per datum:2.99858291943868e-05sec
Datums per Second::33349.08611388993 - WriteBatch
time to WRITE 60000 data::1.799149751663208sec


In [12]:
from multiprocessing import Pool, cpu_count

def prepare_data(start_end):
    start, end = start_end
    temp = []
    for i in range(start, end):
        datum = mnist_dataset[i]
        temp.append([datum[0].numpy(), np.array(datum[1], dtype=np.uint8)])
    return temp

# Number of processes to spawn. You can also use fewer processes if you'd like.
num_processes = cpu_count()

# Calculate indices for splitting the dataset
dataset_size = len(mnist_dataset)
chunk_size = dataset_size // num_processes
indices = [(i, i + chunk_size) for i in range(0, dataset_size, chunk_size)]

# Use multiprocessing to prepare data in parallel
with Pool(num_processes) as pool:
    parallel_temp = pool.map(prepare_data, indices)

# Flatten the list of lists
temp = [datum for sublist in parallel_temp for datum in sublist]


In [13]:
tic = time()

# Use multiprocessing to prepare data in parallel
with Pool(num_processes) as pool:
    parallel_temp = pool.map(prepare_data, indices)

# Flatten the list of lists
temp = [datum for sublist in parallel_temp for datum in sublist]
n=100 # Write 100 datums within batch
databatches = [temp[i:i+n] for i in range(0,len(temp),n)]


for datum in databatches:
    myDBWriter.write_batch(datum)
Dt = time() - tic
NData = len(mnist_dataset)
print("Time per datum:{}sec".format(Dt/NData))
print("Datums per Second::{} - WriteBatch".format(NData/Dt))
print("time to WRITE {} data::{}sec".format(NData,Dt))

Time per datum:6.083787679672241e-05sec
Datums per Second::16437.12852342464 - WriteBatch
time to WRITE 60000 data::3.6502726078033447sec


## Therefore we gained further x2 speed up, overall x6 in comparison with standard HDF5 for writing the data 

# Reading Benchmark

In [4]:
import h5py
class HDF5Dataset(torch.utils.data.Dataset):
    def __init__(self,filepath, transform=None):
        self.filepath = filepath
        self.transform = transform
        self.hdf5 = None
        self.length=60000
        
    def open_hdf5(self):

        self.hdf5 = h5py.File(self.filepath,'r')
        self.length = self.hdf5['inputs'].shape[0]

    def __getitem__(self, idx):

        if self.hdf5 is None:
            self.open_hdf5()
        img = self.hdf5['inputs'][idx][:] # Do loading here
        labels = self.hdf5['labels'][idx]
        
        return img.astype(np.float32), labels.astype(np.float32)


    def __len__(self):
        return self.length


In [11]:
num_workers=16
trainset_hdf5 = HDF5Dataset(filepath='mnist.hdf5')
batch_size=128
dataloader_hdf5 = torch.utils.data.DataLoader(trainset_hdf5, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,drop_last=True,pin_memory=False)


In [13]:
tic = time()
#for data in progressbar(dataloader_hdf5):
for data in dataloader_hdf5:
    img,label=data
Dt = time() - tic
NSteps = len(dataloader_hdf5)
print("Time per batch:{}sec".format(Dt/NSteps))
print("Batches per Second::{} - SERIAL".format(NSteps/Dt))
print("time for 1 epoch, BatchSteps{}, Dt::{}sec".format(NSteps,Dt))

Time per batch:0.0028857471596481455sec
Batches per Second::346.5307058023505 - SERIAL
time for 1 epoch, BatchSteps468, Dt::1.350529670715332sec


In [12]:
import rocksdb
import pickle



# See https://github.com/pytorch/vision/issues/689
class RocksDBDatasetDemo(torch.utils.data.Dataset):
    # This Dataset class Reads inputs, labels that are in numpy format 
    # The LMDB Dataset class MUST include meta information (provided under the key 'meta') that include information 
    # of inputs, labels shape and data type. 

    # TODO: I don't need map_size when reading the database 
    def __init__(self, filepath,  transform=None, num_workers=6):
        super().__init__()
        self.flname_db = filepath

        self.db = None
        self.transform = transform
        self.num_workers = num_workers
        
        self._open_rocks()
        
    def _open_rocks(self):


        # ==========================================================================
        # TODO : they need further investigation 
        # Some good behaving defaults 
        opts = rocksdb.Options()
        opts.create_if_missing = True
        opts.max_open_files = 300000
        opts.write_buffer_size = 67108864
        opts.max_write_buffer_number = 30 # 3 default
        opts.target_file_size_base = 67108864  # default 67108864, value starting 7: 7340032 input.nbytes
        opts.paranoid_checks=False

        
        opts.table_factory = rocksdb.BlockBasedTableFactory(
                format_version=5,
                filter_policy=rocksdb.BloomFilterPolicy(10), # 10 default
               #block_size=4*1024, # 16KB
                block_cache=rocksdb.LRUCache( 32 * (1024 ** 3)), # 2 * (1024 ** 3) default
                block_cache_compressed=rocksdb.LRUCache(16 * (1024 ** 2)))
        # ==================================================================================

        opts.IncreaseParallelism(self.num_workers)
        
        self.db = rocksdb.DB(self.flname_db, opts,read_only=True)
        
        it = self.db.iterkeys()
        it.seek_to_first()
        self.keys = list(it)
                
        meta_key  = list(filter(lambda x : 'meta'.encode('ascii') in x, self.keys ))[0]
        meta = self.db.get(meta_key)
        meta = pickle.loads(meta)

        self.inputs_shape = meta['inputs_shape']
        self.inputs_dtype = meta['inputs_dtype']

        self.labels_shape = meta['labels_shape']
        self.labels_dtype = meta['labels_dtype']

        # 
        self.inputs_keys = list(filter(lambda x : 'inputs'.encode('ascii') in x, self.keys ))
        self.labels_keys = list(filter(lambda x : 'labels'.encode('ascii') in x, self.keys ))

        self.length = len(self.inputs_keys) # Alternative definition 

    def get_inputs_labels(self,idx):
        
        key_input = self.inputs_keys[idx]
        key_label = self.labels_keys[idx]
        
        
        inputs = self.db.get(key_input)
        inputs = np.frombuffer(inputs, dtype= self.inputs_dtype).reshape(*self.inputs_shape)

        labels = self.db.get(key_label)
        labels = np.frombuffer(labels, dtype= self.labels_dtype) # .reshape(*self.labels_shape)

        return inputs, labels

    def __getitem__(self, idx):
        #if self.db is None:
        #    self._open_rocks()

        inputs, labels = self.get_inputs_labels(idx)

        #if self.transform is not None:
        #    inputs,labels = self.transform(inputs,labels)

        return inputs.astype(np.float32), labels.astype(np.float32)
    
    def __len__(self):
        return self.length

In [13]:
num_workers=16
trainset_rocks = RocksDBDatasetDemo('mnist_rocks_batch.db/',num_workers=num_workers)
batch_size= 128
dataloader_rocks = torch.utils.data.DataLoader(trainset_rocks, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,drop_last=True,pin_memory=False)


In [15]:
tic = time()
for data in dataloader_rocks:
    img,label=data
Dt = time() - tic
NSteps = len(dataloader_rocks)
print("Time per batch:{}sec".format(Dt/NSteps))
print("Batches per Second::{} - SERIAL".format(NSteps/Dt))
print("time for 1 epoch, BatchSteps{}, Dt::{}sec".format(NSteps,Dt))

Time per batch:0.0012967566139677651sec
Batches per Second::771.1547326835984 - SERIAL
time for 1 epoch, BatchSteps468, Dt::0.6068820953369141sec


In [6]:
import sys
sys.path.append(r'../../')
from ssg2.data.rocksdbutils import RocksDBDataset

In [18]:
import torch
num_workers=16
trainset_rocks = RocksDBDataset('mnist_RDB_v2.db',num_workers=num_workers)
batch_size= 128
dataloader_rocks = torch.utils.data.DataLoader(trainset_rocks, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers,drop_last=True,pin_memory=False)


In [82]:
## Note, the first 2 epochs are slower due to automatic caching of the data, that happens progressively.

In [10]:
tic = time()
for data in dataloader_rocks:
    img,label=data
Dt = time() - tic
NSteps = len(dataloader_rocks)
print("Time per batch:{}sec".format(Dt/NSteps))
print("Batches per Second::{} - SERIAL".format(NSteps/Dt))
print("time for 1 epoch, BatchSteps{}, Dt::{}sec".format(NSteps,Dt))
print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")

Time per batch:0.0012125912894550553sec
Batches per Second::824.6801776461756 - SERIAL
time for 1 epoch, BatchSteps468, Dt::0.5674927234649658sec
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX


## Speedup over HDF5 ~ 1.35/0.56 = 2.4 times 

In [21]:
1.35/0.56

2.4107142857142856