In [1]:
import numpy
import multiprocessing
import timeit
import h5py
import random
import time
print multiprocessing.cpu_count()

4


In [2]:
def load_data(data_file):
    """ Read Data """
    D = h5py.File(data_file)
    X, Y_ = D['X'][()].transpose(), numpy.squeeze(D['Y'][()] - 1).astype(int)
    X = numpy.reshape(X, (X.shape[0], 28, 28, 1))
    Y = numpy.eye(10)[Y_]
    return {'X': X, 'Y': Y}

# dataset = load_data('MNIST.mat')
dataset = {'X': numpy.random.rand(1000, 128, 128, 3).astype(numpy.float32), 
           'Y': numpy.random.randint(0, 9, size=1000)}

In [30]:
def random_flips(x):
    H, W, C = x.shape
    out = numpy.empty(x.shape)
    curr_rand = random.random()
    # flip with probability 1/2
    if curr_rand < 0.5:
        out = x[:, ::-1, :]
    else:
        out = x
    return out


def random_contrast(x, scale=(0.8, 1.2)):
    low, high = scale
    scale = numpy.random.uniform(low, high)  
    return x * scale


def random_tint(x, scale=(-10, 10)):
    low, high = scale
    C = x.shape[2]
    out = numpy.zeros_like(x)
    bias = numpy.random.uniform(low, high, C)
    for c in xrange(C):
        out[:, :, c] = x[:, :, c] + bias[c]
    return out

In [31]:
def pow2(x):
    time.sleep(0.1)
    return x**2

def pow3(x):
    time.sleep(0.1)
    return x**2

def pow4(x):
    time.sleep(0.1)
    return x**2

def augment(x):
    augment_fns = [random_flips, random_contrast, random_tint]
    for i in xrange(x.shape[0]):
        x[i] = random.choice(augment_fns)(x[i])
    return x

def get_batch(X, Y, batch_size):
    """ Samples a minibatch of size batch_size """
    num_examples = X.shape[0]
    batch_mask = numpy.random.choice(num_examples, batch_size)
    x_batch, y_batch = X[batch_mask], Y[batch_mask]
    x_batch = augment(x_batch)
    return x_batch, y_batch

x_batch, y_batch = get_batch(dataset['X'], dataset['Y'], batch_size=256)
print x_batch.shape, y_batch.shape

(256, 128, 128, 3) (256,)


In [32]:
def augment_worker(x):
    augment_fns = [pow2, pow3, pow4]
    return random.choice(augment_fns)(x)

def augment_parallel(x, workers):
    pool = multiprocessing.Pool(processes=workers)
    results = pool.map(augment_worker, [i for i in x])
    pool.close()
    return results

def get_batch_parallel(X, Y, batch_size, workers=4):
    """ Samples a minibatch of size batch_size """
    num_examples = X.shape[0]
    batch_mask = numpy.random.choice(num_examples, batch_size)
    x_batch, y_batch = X[batch_mask], Y[batch_mask]
    x_batch = numpy.array(augment_parallel(x_batch, workers))
    return x_batch, y_batch

In [9]:
print timeit.timeit("get_batch(dataset['X'], dataset['Y'], batch_size=256)", 
                    "from __main__ import get_batch, dataset", 
                    number=10)

0.381237030029


In [10]:
print timeit.timeit("get_batch_parallel(dataset['X'], dataset['Y'], batch_size=256)", 
                    "from __main__ import get_batch_parallel, dataset", 
                    number=10)

2.93027019501


In [33]:
start_time = time.time()
get_batch(dataset['X'], dataset['Y'], batch_size=256)
print time.time() - start_time

0.0624949932098


In [34]:
start_time = time.time()
x, y = get_batch_parallel(dataset['X'], dataset['Y'], batch_size=256)
print time.time() - start_time

6.76107311249
