In [1]:
# github.com/minrk/findspark
import findspark
findspark.init()

import pyspark
sc = pyspark.SparkContext(appName="Spark1")

### Create a large RDD and a small RDD, both indexed by some value.

In [2]:
users = sc.parallelize([(id, "user {}".format(id)) 
                        for id in xrange(1000000)], 20).cache()

In [3]:
events = sc.parallelize([(id, "did something")
                         for id in xrange(1, 1000000, 100)])

### Joining without co-partitioning causes a large shuffle
Check http://localhost:4040/stages

In [4]:
users_join_events = users.join(events).cache()
print users_join_events.count()  # Force a full computation
print users_join_events.take(5)

10000
[(1, ('user 1', 'did something')), (524301, ('user 524301', 'did something')), (393401, ('user 393401', 'did something')), (939401, ('user 939401', 'did something')), (262501, ('user 262501', 'did something'))]


### Co-partitioning removes shuffle
Note, however, that co-partitioned is not necessarily co-located.

In [5]:
users20 = users.partitionBy(20).cache()
users20.foreach(lambda x: x)  # force users20 to recompute
events20 = events.partitionBy(20)

Check http://localhost:4040/stages again.

Shuffle cost for these counts should be much lower (~70 KB vs. ~10 MB).

Note that you should see one 

In [6]:
joined20 = users20.join(events20)
joined20.count()  
joined20.count()

10000

### Confusion from class: rdd.take() doesn't cause a full recomputation.

I expected this to have 20 jobs (corresponding to the 20 partitions).

Note that events20 is not cached, so it should be recomputed, by default.  But we only see 4 jobs at http://localhost:4040/stages 

Why?

In [7]:
joined20.take(10)

[(1, ('user 1', 'did something')),
 (524301, ('user 524301', 'did something')),
 (240301, ('user 240301', 'did something')),
 (131101, ('user 131101', 'did something')),
 (262201, ('user 262201', 'did something')),
 (393301, ('user 393301', 'did something')),
 (101, ('user 101', 'did something')),
 (524401, ('user 524401', 'did something')),
 (87401, ('user 87401', 'did something')),
 (131201, ('user 131201', 'did something'))]

It turns out you have to look at the Spark source to figure this out.

https://github.com/apache/spark/blob/master/python/pyspark/rdd.py#L1298

The answer is: rdd.take() first tries to get enough data from one partition, and if there isn't enough, tries increasingly large numbers until it gets enough.

But this adds more confusion.  Why isn't there enough data in the first partition?

Let's see how data is distributed between partitions.

In [8]:
joined20.mapPartitions(lambda seq: [len(list(seq))]).collect()

[0, 10000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

### Suboptimal interaction between hash values and partition size.

Events run from 1 to 1,000,000 in steps of 100.  
Hashes of integer keys are (usually) just that integer.  
The destination partition for key K is hash(K) % num_paritions.  

--> Our problem is that this will always be 1 for our event keys.

### Solution: use a prime number for partition size

In [9]:
users23 = users.partitionBy(23).cache()
users23.foreach(lambda x: x)  # force users2 to recompute
events23 = events.partitionBy(23)

In [10]:
joined23 = users23.join(events23)
joined23.count()  
joined23.count()

10000

Let's check how balanced the partitions are, now.

In [11]:
joined23.mapPartitions(lambda seq: [len(list(seq))]).collect()

[434,
 435,
 435,
 435,
 435,
 435,
 435,
 434,
 434,
 435,
 435,
 435,
 435,
 435,
 435,
 434,
 434,
 435,
 435,
 435,
 435,
 435,
 435]