In [20]:
from __future__ import division, print_function, unicode_literals
import numpy as np
import math

In [21]:
def make_data():
    x = np.random.random(10).reshape(1, -1)
    y = np.zeros((1, 2))
    y[0, np.random.randint(2)] = 1
    return {'features': x, 'labels': y}

In [22]:
data = sc.parallelize(range(10000)).map(lambda i: make_data())

In [23]:
data.persist()

PythonRDD[38] at RDD at PythonRDD.scala:48

In [24]:
batch_size = 128
data_cnt = data.count()
n_batch = data_cnt // batch_size

data_cnt, batch_size, n_batch

(10000, 128, 78)

# Random Split

In [6]:
batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))

In [39]:
def concate_data(x, y):
    dataset = {}
    for f in ['features', 'labels']:
        dataset[f] = np.concatenate((x[f], y[f]), axis=0)
    return dataset

def next_batch_rdd_maker(data_rdd, batch_size, epoch_limit=None):
    data_cnt = data_rdd.count()
    per_batch_ratios = [batch_size / data_cnt] * int(math.ceil(data_cnt / batch_size))
    def next_batch():
        current_epoch = 0
        while True:
            current_epoch += 1
            print('{} epoch starts!'.format(current_epoch))
            if epoch_limit is not None and current_epoch > epoch_limit: 
                break 
            batches = data_rdd.randomSplit(per_batch_ratios)
            for batch in batches:
                dataset = batch.reduce(concate_data)
                yield dataset['features'], dataset['labels']
    return next_batch

In [8]:
next_batch = next_batch_rdd_maker(data, batch_size, epoch_limit=2)

i = 0
for X_batch, Y_batch in next_batch():
    print(X_batch.shape, Y_batch.shape)
    print(X_batch[0:5])
    #print(Y_batch[0:5])
    i += 1
    print("run n_batch {}".format(i))
    if i > 2:
        break

1 epoch starts!
(138, 10) (138, 2)
[[ 0.27177158  0.09338174  0.92922062  0.36163314  0.73897763  0.92920518
   0.38691253  0.72721468  0.70133953  0.69426265]
 [ 0.08568551  0.73242113  0.05241102  0.52411065  0.65283585  0.63729227
   0.52427596  0.17664183  0.87864636  0.12582561]
 [ 0.07634491  0.65796259  0.95377654  0.58873708  0.9534006   0.80688632
   0.6015825   0.00614666  0.88245379  0.34886694]
 [ 0.25666931  0.72250646  0.04606842  0.40685677  0.24097207  0.49946615
   0.25957772  0.60548888  0.36035151  0.86844128]
 [ 0.2342444   0.78365248  0.17697823  0.41353043  0.2105957   0.65231435
   0.28149155  0.62406179  0.65004654  0.09867522]]
run n_batch 1
(121, 10) (121, 2)
[[ 0.66221035  0.61361795  0.40311533  0.05496737  0.43807495  0.26238466
   0.38744381  0.95097486  0.95174443  0.54242663]
 [ 0.35378197  0.92210629  0.60737942  0.38028022  0.18423603  0.36786244
   0.4846854   0.22169417  0.91627218  0.08162744]
 [ 0.62853929  0.28698232  0.93687132  0.44806423  0.962

# Group

In [9]:
group_rdd = data \
    .sortBy(lambda x: np.random.random()) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch, x[0])) \
    .groupByKey() \
    .map(lambda d: (d[0], reduce(concate_data, d[1]))) \
    .persist()

keys = group_rdd.keys().collect()

i = 0
for k in keys:
    batch = group_rdd \
        .filter(lambda d: d[0] == k) \
        .map(lambda d: d[1]) \
        .collect()[0]
    print('k: {}'.format(k))
    print(batch['features'].shape)
    print(batch['features'][:5])
    i += 1
    if i > 3:
        break
    

k: 0
(129, 10)
[[ 0.48635838  0.98921333  0.91002174  0.10785167  0.52518158  0.67597269
   0.43819769  0.29130476  0.20639989  0.43597756]
 [ 0.18104039  0.5238487   0.9512993   0.32673712  0.15810872  0.58777654
   0.27607022  0.24480085  0.11503928  0.26171353]
 [ 0.95290001  0.50672181  0.83296424  0.39207746  0.00743004  0.4126532
   0.98544884  0.26875377  0.20756379  0.50919635]
 [ 0.66578263  0.08125266  0.11051948  0.87752661  0.19079544  0.31896252
   0.87644394  0.76199772  0.85945378  0.87912064]
 [ 0.77659382  0.18806926  0.06847789  0.82411926  0.54146513  0.75857833
   0.85752319  0.53143017  0.03749443  0.56385056]]
k: 32
(128, 10)
[[ 0.49769865  0.35492167  0.73923198  0.54776133  0.53627974  0.7961213
   0.7117204   0.8003712   0.05167417  0.4436853 ]
 [ 0.64673222  0.89587558  0.14287286  0.30088512  0.93246049  0.83798712
   0.04624728  0.35743467  0.67240937  0.23058261]
 [ 0.47284257  0.22737204  0.6183083   0.83636114  0.8050909   0.04081084
   0.70676419  0.4559

# Aggregate by key

In [36]:
init_data = {'features': [], 'labels': []}

def seq(data_arr, data):
    for f in ['features', 'labels']:
        data_arr[f].append(data[f])
    return data_arr

def comp(data_arr1, data_arr2):
    result = {}
    for f in ['features', 'labels']:
        result[f] = data_arr1[f] + data_arr2[f]
    return result

def concate(data_arr):
    result = {}
    for f in ['features', 'labels']:
        result[f] = np.concatenate(data_arr[f], axis=0)
    return result

In [43]:
for epoch in range(2):
    print('===epoch {}==='.format(epoch+1))
    
    group_rdd = data \
        .sortBy(lambda x: np.random.random()) \
        .zipWithIndex() \
        .map(lambda x: (x[1]%n_batch, x[0])) \
        .aggregateByKey(init_data, seq, comp) \
        .mapValues(concate) \
        .persist()
    
    i = 0
    for k in group_rdd.keys().collect():
        batch = group_rdd \
            .filter(lambda d: d[0] == k) \
            .map(lambda d: d[1]) \
            .collect()[0]
        print('\nbatch_key: {}'.format(k))
        print('--features--')
        print(batch['features'].shape)
        print(batch['features'][:2])
        print('--labels--')
        print(batch['labels'].shape)
        print(batch['labels'][:2])
        i += 1
        if i >= 2:
            break
        
    group_rdd.unpersist()

===epoch 1===

batch_key: 0
--features--
(129, 10)
[[ 0.00990009  0.5423004   0.49345573  0.99874654  0.37192966  0.61709061
   0.75635913  0.2223606   0.19015739  0.42220939]
 [ 0.63431151  0.61970168  0.83143894  0.23898155  0.67763027  0.03630547
   0.09853002  0.07944521  0.99443968  0.85384792]]
--labels--
(129, 2)
[[ 1.  0.]
 [ 0.  1.]]

batch_key: 32
--features--
(128, 10)
[[ 0.69581008  0.29972252  0.71522633  0.64813835  0.22056874  0.53705215
   0.40476738  0.02412487  0.45648007  0.19110131]
 [ 0.40190421  0.99927003  0.28505162  0.59169471  0.32652689  0.79453915
   0.06946079  0.46745301  0.89838656  0.76408965]]
--labels--
(128, 2)
[[ 1.  0.]
 [ 1.  0.]]
===epoch 2===

batch_key: 0
--features--
(129, 10)
[[ 0.40414736  0.51240878  0.25732977  0.34385521  0.72576452  0.90960537
   0.72805053  0.50107734  0.23724009  0.35689232]
 [ 0.52275893  0.14185051  0.40293815  0.4839048   0.47405422  0.23736155
   0.75616067  0.01124846  0.11510762  0.61535321]]
--labels--
(129, 2)
[

# MapPartition

In [10]:
data_cnt, n_batch

(10000, 78)

In [11]:
batch_rdd = data \
    .coalesce(n_batch, shuffle=True) \
    .zipWithIndex() \
    .map(lambda x: (x[1]%n_batch , x[0])) \
    .groupByKey() \
    .map(lambda x: x[1])

In [12]:
batch_rdd.getNumPartitions()

78

In [13]:
batch_rdd.take(1)

[<pyspark.resultiterable.ResultIterable at 0x1075418d0>]

In [14]:
def concat_data(it):
    dataset = None
    for d in it:
        if dataset is None:
            dataset = d
        else:
            for f in ['features', 'label']:
                dataset[f] = np.concatenate((dataset[f], d[f]), axis=0)
    yield dataset

for batch in batch_rdd.mapPartitions(concat_data).filter(lambda x: x is not None).collect():
    batch = list(batch) # process it!
    print(len(batch))
    break

129
