In [1]:
from __future__ import division, print_function, unicode_literals
import numpy as np
import math

In [2]:
def make_data():
    x = np.random.random(10).reshape(1, -1)
    y = np.zeros((1, 2))
    y[0, np.random.randint(2)] = 1
    return {'features': x, 'label': y}

In [3]:
data = sc.parallelize(range(10000)).map(lambda i: make_data())

In [4]:
data.persist()

PythonRDD[1] at RDD at PythonRDD.scala:48

In [9]:
batch_size = 128
data_cnt = data.count()
n_batch = data_cnt // batch_size

data_cnt, batch_size, n_batch

(10000, 128, 78)

# Random Split

In [7]:
batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))

In [8]:
def concate_data(x, y):
    dataset = {}
    for f in ['features', 'label']:
        dataset[f] = np.concatenate((x[f], y[f]), axis=0)
    return dataset


def next_batch_rdd_maker(data_rdd, batch_size, epoch_limit=None):
    data_cnt = data_rdd.count()
    per_batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))
    def next_batch():
        current_epoch = 0
        while True:
            current_epoch += 1
            print('{} epoch starts!'.format(current_epoch))
            if epoch_limit is not None and current_epoch > epoch_limit: 
                break 
            batches = data_rdd.randomSplit(per_batch_ratios)
            for batch in batches:
                dataset = batch.reduce(concate_data)
                yield dataset['features'], dataset['label']
    return next_batch

In [20]:
next_batch = next_batch_rdd_maker(data, batch_size, epoch_limit=2)

i = 0
for X_batch, Y_batch in next_batch():
    print(X_batch.shape, Y_batch.shape)
    print(X_batch[0:5])
    #print(Y_batch[0:5])
    i += 1
    print("run n_batch {}".format(i))

1 epoch starts!
(135, 10) (135, 2)
[[  8.28128172e-01   7.21331148e-01   6.63957775e-02   2.32426736e-01
    4.56041157e-01   8.09835432e-02   7.14077437e-01   8.08210628e-01
    9.86063871e-01   1.42992556e-01]
 [  1.55681195e-01   1.80606717e-01   7.62144989e-01   7.24782388e-01
    8.61966248e-01   5.50625453e-01   3.00610000e-01   1.82758683e-01
    7.76822457e-01   4.62146044e-01]
 [  5.38513563e-01   7.91854510e-01   3.25569815e-01   2.74219205e-01
    1.12040933e-01   8.72800603e-01   7.03616396e-01   6.40697977e-01
    9.67057467e-02   8.39538563e-01]
 [  9.08446212e-02   4.05568746e-01   6.85400665e-01   1.75869853e-01
    1.93149088e-01   5.93925187e-02   2.21232991e-04   7.83807873e-01
    3.67174324e-01   9.03314647e-01]
 [  4.24758321e-01   4.64913288e-01   6.33092228e-01   5.23019306e-01
    1.02766232e-01   7.57264995e-02   9.73231216e-01   1.42144926e-02
    8.92033913e-01   1.07064273e-01]]
run n_batch 1
(101, 10) (101, 2)
[[ 0.4436542   0.32137206  0.507971    0.68325

# Group

In [21]:
group_rdd = data \
    .sortBy(lambda x: np.random.random()) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch, x[0])) \
    .groupByKey() \
    .map(lambda d: (d[0], reduce(concate_data, d[1]))) \
    .persist()

keys = group_rdd.keys().collect()
group_rdd.take(2)

for k in keys:
    batch = group_rdd \
        .filter(lambda d: d[0] == k) \
        .map(lambda d: d[1]) \
        .collect()[0]
    print('k: {}'.format(k))
    print(batch['features'].shape)
    print(batch['features'][:5])
    

k: 0
(129, 10)
[[ 0.87786529  0.04727796  0.62709721  0.83383163  0.15334956  0.19790839
   0.61024732  0.95576599  0.75069871  0.55432745]
 [ 0.62677412  0.70220903  0.27833413  0.40516508  0.12311422  0.23409317
   0.25449853  0.71733499  0.62502829  0.11113304]
 [ 0.69211195  0.54020922  0.72156288  0.59183229  0.8638497   0.34688601
   0.86409396  0.42955325  0.84519961  0.78410707]
 [ 0.97813002  0.62104713  0.51728035  0.61586434  0.82256776  0.24161209
   0.66169541  0.15684381  0.36445558  0.91647954]
 [ 0.40361483  0.62138179  0.97406546  0.28116676  0.61807476  0.70349917
   0.77001276  0.3856793   0.56875715  0.76643392]]
k: 32
(128, 10)
[[ 0.10563144  0.61677443  0.46455055  0.9867425   0.11994604  0.95391356
   0.86936882  0.98947822  0.71993566  0.30848526]
 [ 0.79063067  0.1122924   0.54537708  0.24566419  0.33251647  0.89661054
   0.88023691  0.24135247  0.7445243   0.335607  ]
 [ 0.12144652  0.59002106  0.97972167  0.24926965  0.93362859  0.26183522
   0.94944294  0.11

# MapPartition

In [22]:
data_cnt, n_batch

(10000, 78)

In [26]:
batch_rdd = data \
    .coalesce(n_batch, shuffle=True) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch , x[0])) \
    .groupByKey() \
    .map(lambda x: x[1])

In [28]:
batch_rdd.getNumPartitions()

78

In [29]:
batch_rdd.take(1)

[<pyspark.resultiterable.ResultIterable at 0x10736f950>]

In [32]:
def concat_data(it):
    dataset = None
    for d in it:
        if dataset is None:
            dataset = d
        else:
            for f in ['features', 'label']:
                dataset[f] = np.concatenate((dataset[f], d[f]), axis=0)
    yield dataset

for batch in batch_rdd.mapPartitions(concat_data).filter(lambda x: x is not None).collect():
    print(list[batch])

TypeError: 'type' object has no attribute '__getitem__'