In [1]:
from __future__ import division, print_function, unicode_literals
import numpy as np
import math

In [2]:
def make_data():
    x = np.random.random(10).reshape(1, -1)
    y = np.zeros((1, 2))
    y[0, np.random.randint(2)] = 1
    return {'features': x, 'label': y}

In [3]:
data = sc.parallelize(range(10000)).map(lambda i: make_data())

In [4]:
data.persist()

PythonRDD[1] at RDD at PythonRDD.scala:48

In [5]:
batch_size = 128
data_cnt = data.count()
n_batch = data_cnt // batch_size

data_cnt, batch_size, n_batch

(10000, 128, 78)

# Random Split

In [6]:
batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))

In [7]:
def concate_data(x, y):
    dataset = {}
    for f in ['features', 'label']:
        dataset[f] = np.concatenate((x[f], y[f]), axis=0)
    return dataset


def next_batch_rdd_maker(data_rdd, batch_size, epoch_limit=None):
    data_cnt = data_rdd.count()
    per_batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))
    def next_batch():
        current_epoch = 0
        while True:
            current_epoch += 1
            print('{} epoch starts!'.format(current_epoch))
            if epoch_limit is not None and current_epoch > epoch_limit: 
                break 
            batches = data_rdd.randomSplit(per_batch_ratios)
            for batch in batches:
                dataset = batch.reduce(concate_data)
                yield dataset['features'], dataset['label']
    return next_batch

In [8]:
next_batch = next_batch_rdd_maker(data, batch_size, epoch_limit=2)

i = 0
for X_batch, Y_batch in next_batch():
    print(X_batch.shape, Y_batch.shape)
    print(X_batch[0:5])
    #print(Y_batch[0:5])
    i += 1
    print("run n_batch {}".format(i))

1 epoch starts!
(127, 10) (127, 2)
[[ 0.25252696  0.41724904  0.44630857  0.2069601   0.03229878  0.06939463
   0.59074478  0.68617637  0.03219128  0.54117129]
 [ 0.84372339  0.52440099  0.95344453  0.71779087  0.44073932  0.88811669
   0.37496778  0.2806755   0.85170141  0.37348766]
 [ 0.76818315  0.26598804  0.43424022  0.98093593  0.32674888  0.44976776
   0.70247556  0.87963998  0.50071456  0.45026658]
 [ 0.30037785  0.76651934  0.46979486  0.60799379  0.77998812  0.42435623
   0.41930535  0.93376665  0.21532599  0.50425121]
 [ 0.36594622  0.03870143  0.72891181  0.94295222  0.38666163  0.30870084
   0.70895249  0.4504453   0.12489906  0.71528951]]
run n_batch 1
(132, 10) (132, 2)
[[ 0.85368349  0.57921794  0.76261841  0.69031531  0.36236678  0.70331238
   0.76396086  0.32481559  0.10455959  0.1023778 ]
 [ 0.72435051  0.12272952  0.49339522  0.6782928   0.44689924  0.30538643
   0.80680405  0.28872152  0.549231    0.07197832]
 [ 0.10449342  0.54509095  0.99371129  0.09765837  0.137

# Group

In [9]:
group_rdd = data \
    .sortBy(lambda x: np.random.random()) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch, x[0])) \
    .groupByKey() \
    .map(lambda d: (d[0], reduce(concate_data, d[1]))) \
    .persist()

keys = group_rdd.keys().collect()
group_rdd.take(2)

for k in keys:
    batch = group_rdd \
        .filter(lambda d: d[0] == k) \
        .map(lambda d: d[1]) \
        .collect()[0]
    print('k: {}'.format(k))
    print(batch['features'].shape)
    print(batch['features'][:5])
    

k: 0
(129, 10)
[[ 0.78501512  0.56017705  0.30462473  0.68040336  0.50387372  0.45689395
   0.80915636  0.20056722  0.39638871  0.55388146]
 [ 0.22252252  0.54097303  0.50630497  0.48272007  0.75244439  0.634032
   0.38018154  0.8786065   0.12309362  0.18657233]
 [ 0.72272993  0.86646173  0.19509717  0.23202074  0.80086955  0.16238841
   0.94578341  0.80434744  0.2021004   0.72598015]
 [ 0.99941824  0.63340058  0.88798617  0.36485692  0.23521533  0.17096113
   0.72482803  0.02034187  0.15953868  0.21014049]
 [ 0.3846662   0.96219071  0.64202337  0.76254764  0.11341212  0.27397736
   0.60575208  0.608381    0.81714754  0.66929304]]
k: 32
(128, 10)
[[ 0.64791992  0.09447862  0.90404576  0.56873209  0.50264014  0.07581645
   0.75699235  0.77422555  0.77174946  0.45099605]
 [ 0.82558558  0.02109823  0.26718176  0.82808172  0.75947788  0.05114998
   0.48831289  0.26689163  0.91471005  0.6690597 ]
 [ 0.56841459  0.30662869  0.07563535  0.62511602  0.56473858  0.73135859
   0.63044813  0.2998

# MapPartition

In [10]:
data_cnt, n_batch

(10000, 78)

In [11]:
batch_rdd = data \
    .coalesce(n_batch, shuffle=True) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch , x[0])) \
    .groupByKey() \
    .map(lambda x: x[1])

In [12]:
batch_rdd.getNumPartitions()

78

In [13]:
batch_rdd.take(1)

[<pyspark.resultiterable.ResultIterable at 0x107424a50>]

In [18]:
def concat_data(it):
    dataset = None
    for d in it:
        if dataset is None:
            dataset = d
        else:
            for f in ['features', 'label']:
                dataset[f] = np.concatenate((dataset[f], d[f]), axis=0)
    yield dataset

for batch in batch_rdd.mapPartitions(concat_data).filter(lambda x: x is not None).collect():
    batch = list(batch) # process it!
    print(len(batch))
    break

129
