In [2]:
from __future__ import division, print_function, unicode_literals
import numpy as np
import math

In [3]:
def make_data():
    x = np.random.random(10).reshape(1, -1)
    y = np.zeros((1, 2))
    y[0, np.random.randint(2)] = 1
    return {'features': x, 'label': y}

In [4]:
data = sc.parallelize(range(10000)).map(lambda i: make_data())

In [5]:
data.persist()

PythonRDD[1] at RDD at PythonRDD.scala:48

In [6]:
batch_size = 128
data_cnt = data.count()
n_batch = data_cnt // batch_size

data_cnt, batch_size, n_batch

(10000, 128, 78)

# Random Split

In [7]:
batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))

In [8]:
def concate_data(x, y):
    dataset = {}
    for f in ['features', 'label']:
        dataset[f] = np.concatenate((x[f], y[f]), axis=0)
    return dataset


def next_batch_rdd_maker(data_rdd, batch_size, epoch_limit=None):
    data_cnt = data_rdd.count()
    per_batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))
    def next_batch():
        current_epoch = 0
        while True:
            current_epoch += 1
            print('{} epoch starts!'.format(current_epoch))
            if epoch_limit is not None and current_epoch > epoch_limit: 
                break 
            batches = data_rdd.randomSplit(per_batch_ratios)
            for batch in batches:
                dataset = batch.reduce(concate_data)
                yield dataset['features'], dataset['label']
    return next_batch

In [16]:
next_batch = next_batch_rdd_maker(data, batch_size, epoch_limit=2)

i = 0
for X_batch, Y_batch in next_batch():
    print(X_batch.shape, Y_batch.shape)
    print(X_batch[0:5])
    #print(Y_batch[0:5])
    i += 1
    print("run n_batch {}".format(i))
    if i > 2:
        break

1 epoch starts!
(122, 10) (122, 2)
[[ 0.43046535  0.99487139  0.25418304  0.54413216  0.87114291  0.57977683
   0.68006257  0.38922832  0.17447601  0.35889889]
 [ 0.10120172  0.71357518  0.42522073  0.57954529  0.56651004  0.89643308
   0.78255156  0.74294374  0.40264734  0.28018941]
 [ 0.9990186   0.21404378  0.23978797  0.55670715  0.41915767  0.86740202
   0.13936329  0.79298216  0.14045334  0.57676947]
 [ 0.20625397  0.0958593   0.93350915  0.69372181  0.4208059   0.9217431
   0.16450364  0.46395371  0.63320958  0.46280943]
 [ 0.79866558  0.09432513  0.01022666  0.71969341  0.08851809  0.42621472
   0.66410322  0.33755107  0.29852467  0.7638972 ]]
run n_batch 1
(132, 10) (132, 2)
[[ 0.34428665  0.1014612   0.16825857  0.88735713  0.04422784  0.09846397
   0.41227217  0.67991749  0.65634828  0.11652937]
 [ 0.83660833  0.62204563  0.82858429  0.3523604   0.99881155  0.26858747
   0.77184224  0.84774388  0.82233347  0.16062356]
 [ 0.18852273  0.7652111   0.3679878   0.46833854  0.0524

# Group

In [18]:
group_rdd = data \
    .sortBy(lambda x: np.random.random()) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch, x[0])) \
    .groupByKey() \
    .map(lambda d: (d[0], reduce(concate_data, d[1]))) \
    .persist()

keys = group_rdd.keys().collect()
group_rdd.take(2)

i = 0
for k in keys:
    batch = group_rdd \
        .filter(lambda d: d[0] == k) \
        .map(lambda d: d[1]) \
        .collect()[0]
    print('k: {}'.format(k))
    print(batch['features'].shape)
    print(batch['features'][:5])
    i += 1
    if i > 3:
        break
    

k: 0
(129, 10)
[[ 0.40449649  0.64712576  0.95220635  0.26743756  0.4897337   0.65779447
   0.13467376  0.26498336  0.25519918  0.60751199]
 [ 0.85376512  0.36794005  0.60129744  0.67956998  0.36438509  0.32216485
   0.72992167  0.4002749   0.69590466  0.15238375]
 [ 0.90109728  0.02802473  0.83910834  0.44807331  0.48835972  0.39959256
   0.23281237  0.91168778  0.97544587  0.37338592]
 [ 0.34907741  0.70095545  0.00257921  0.9355607   0.28127654  0.06429361
   0.29068068  0.81295699  0.21773784  0.95107649]
 [ 0.04054203  0.51538163  0.34785669  0.85792816  0.80504133  0.96387436
   0.91655568  0.56169304  0.01087596  0.14014302]]
k: 32
(128, 10)
[[ 0.58077789  0.13659288  0.76546657  0.50202154  0.32030147  0.09888194
   0.26569836  0.1522073   0.75436084  0.25273888]
 [ 0.34200587  0.8968246   0.30968019  0.83710076  0.28225554  0.07718871
   0.61926568  0.22449407  0.29440704  0.71991896]
 [ 0.73213367  0.70505257  0.24336376  0.03463896  0.04672491  0.62215894
   0.19099494  0.41

# MapPartition

In [11]:
data_cnt, n_batch

(10000, 78)

In [12]:
batch_rdd = data \
    .coalesce(n_batch, shuffle=True) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch , x[0])) \
    .groupByKey() \
    .map(lambda x: x[1])

In [13]:
batch_rdd.getNumPartitions()

78

In [14]:
batch_rdd.take(1)

[<pyspark.resultiterable.ResultIterable at 0x10767d750>]

In [15]:
def concat_data(it):
    dataset = None
    for d in it:
        if dataset is None:
            dataset = d
        else:
            for f in ['features', 'label']:
                dataset[f] = np.concatenate((dataset[f], d[f]), axis=0)
    yield dataset

for batch in batch_rdd.mapPartitions(concat_data).filter(lambda x: x is not None).collect():
    batch = list(batch) # process it!
    print(len(batch))
    break

129
