In [133]:
import os
import numpy as np
from multiprocessing import Pool
from time import time
from functools import partial
from matplotlib import pyplot as plt

MB = 1048576

In [139]:
inputfile = '/gpfs/work/greenber/5_625deg_all_zscored.npy'  # original data to be loaded with xarray
outputdir = '/gpfs/work/greenber/iotest/'
assert os.path.exists(inputfile)
assert(os.path.isdir(outputdir))

n_files = 16

target_read_size = 256 * MB  # bytes

rewrite = False  # replace existing files

tmax = 60  # test time in seconds

In [140]:
indata = np.load(inputfile, mmap_mode='r')
dtype = indata.dtype
n_points = indata.shape[0]
datapointsize = np.prod(indata.shape[1:])
datapointbytes = datapointsize * indata.dtype.itemsize

points_per_file = int(np.ceil(n_points / n_files))

testfiles = [os.path.join(outputdir, f'{i + 1}_of_{n_files}.dat') for i in range(n_files)]

points_per_read = np.maximum(1, int(np.round(target_read_size / datapointbytes)))

print(f'{n_points} data points, {datapointsize} values per data point, {indata.dtype.itemsize} bytes per value')
print(f'Dividing data into {n_files} files of up to {points_per_file} data points each')
del indata

350640 data points, 192512 values per data point, 4 bytes per value
Dividing data into 16 files of up to 21915 data points each


### Divide data into multiple files on disk, grouping together everything for the same time point

In [141]:
for i, outputfile in enumerate(testfiles):  # use multiprocessing here too?
    
    i_start, i_end = i * points_per_file, np.minimum((i + 1) * points_per_file, n_points)
    
    # skip existing files
    if os.path.exists(outputfile) and not rewrite:
        expected_bytes = (i_end - i_start) * datapointsize * dtype.itemsize
        actual_bytes = os.path.getsize(outputfile)
        if expected_bytes == actual_bytes:
            continue
    
    print(f'Creating file {outputfile}')
    indata = np.load(inputfile, mmap_mode='r')
    y = np.empty(dtype=indata.dtype, shape=(i_end - istart, *indata.shape[1:]))  # read data from disk
    y[:] = indata[i_start:i_end]
    y_out = np.memmap(outputfile, mode='w+', dtype=y.dtype, shape=y.shape)
    y_out[:] = y
    y_out.flush()
    # clear memory
    del indata, y, y_out
    print('done with file')
print('Data is ready.')

Data is ready.


### define infinite random permutation through each file's data points

In [142]:
def infinite_randperm(n):
    while True:
        ii = np.random.permutation(n)
        for idx in ii:
            yield idx

### define task of reading each file endlessly in random order

In [159]:
def readloop(file, datapointsize=None, dtype=np.float32, tmax=60, rngseed=None, read_sorted=True,
             mmap=True):
    assert file is not None and datapointsize is not None
    assert os.path.exists(file)
    
    n = os.path.getsize(file) / (datapointsize * dtype.itemsize)
    assert np.abs(n - np.round(n)) / n < 1e-6 and n > 0
    n = int(np.round(n))
    
    if rngseed is None:
        np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
    else:
        np.random.seed(rngseed)
    
    t0 = time()
    
    x = np.empty((points_per_read, datapointsize), dtype=dtype)
    if mmap:
        y = np.memmap(file, mode='r', dtype=dtype, shape=(n, datapointsize))
    else:
        y = open(file, 'rb')
    
    t_read, bytes_read = [], []
    
    read_order = infinite_randperm(n)
    
    while True:
        
        t1 = time()
        
        ii = [next(read_order) for _ in range(points_per_read)]
        if read_sorted:
            ii = np.sort(ii)
        
        if mmap:
            x[:len(ii)] = y[ii]
        else:
            for u, offset in enumerate(ii * datapointsize * dtype.itemsize):
                y.seek(offset, 0)
                x[u] = np.frombuffer(y.read(datapointsize * dtype.itemsize), dtype=dtype)
        
        t2 = time()
        
        t_read.append(t2 - t1)
        bytes_read.append(len(ii) * datapointsize * dtype.itemsize)
        
        if t2 - t0 > tmax:
            break
    
    if not mmap:
        y.close()
    return np.array(t_read), np.array(bytes_read)
        
f = partial(readloop, datapointsize=datapointsize, dtype=dtype, tmax=60)
fr = partial(readloop, datapointsize=datapointsize, dtype=dtype, tmax=60, mmap=False)
fu = partial(readloop, datapointsize=datapointsize, dtype=dtype, tmax=60, read_sorted=False)

In [161]:
tr, br = f(testfiles[0])

print(f'{(np.sum(br)/np.sum(tr)) / MB} MB / s with 1 process')

1324.7933786262479 MB / s with 1 process


In [160]:
tr_r, br_r = fr(testfiles[0])

print(f'{(np.sum(br)/np.sum(tr)) / MB} MB / s with 1 process, no memmap')

1331.9092235947592 MB / s with 1 process, no memmap


In [146]:
pool = Pool(processes=n_files)

In [None]:
results = pool.map(f, testfiles)
print(f'{np.sum([np.sum(br)/np.sum(tr) for tr, br in results]) / MB} MB / s with {len(results)} processes')

In [147]:
results = pool.map(fr, testfiles)
print(f'{np.sum([np.sum(br)/np.sum(tr) for tr, br in results]) / MB} MB / s with {len(results)} processes, no memmap')

In [None]:
"""tr_indiv, br_indiv = [], []
for testfile in testfiles:
    tr, br = f(testfiles)
    tr_indiv.append(tr)
    br_indiv.append(br)"""