In [1]:
%matplotlib inline
from __future__ import print_function
import sys
import numpy as np
import matplotlib.pyplot as plt

from operator import itemgetter

In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import Row

sc = SparkContext('local[*]')
sc.setLogLevel("WARN")

spark = SparkSession(sc)

In [3]:
spark.range(1000).filter("id > 100").selectExpr("sum(id)").explain()

== Physical Plan ==
*HashAggregate(keys=[], functions=[sum(id#0L)])
+- Exchange SinglePartition
   +- *HashAggregate(keys=[], functions=[partial_sum(id#0L)])
      +- *Filter (id#0L > 100)
         +- *Range (0, 1000, step=1, splits=Some(2))


In [35]:
%%time
def make_measurements(input_id):
    n_measurements = np.random.poisson(15)
    meas_values = input_id + np.random.randn(n_measurements)
    return zip(n_measurements*[input_id], meas_values.tolist())
    
rdd = sc.range(100000).flatMap(make_measurements).collect()
meas_table = spark.createDataFrame(rdd, schema=("obj_id", "meas_value"))
meas_table.registerTempTable("meas_table")

CPU times: user 37.7 s, sys: 110 ms, total: 37.8 s
Wall time: 38.8 s


In [36]:
meas_table.count()

1503414

In [37]:
%%time
summary_table = spark.sql("SELECT obj_id, count(*) as n_epochs, "
                             "min(meas_value) as min_val, max(meas_value) as max_val "
                             "FROM meas_table GROUP BY obj_id")
summary_table.registerTempTable("summary_table")

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 36.1 ms


In [21]:
summary_table.toPandas()

Unnamed: 0,obj_id,n_epochs,min_val,max_val
0,0,17,-0.961687,1.668768
1,7,23,5.064148,9.762524
2,6,15,5.193034,7.944696
3,9,12,6.324873,10.110992
4,5,17,4.038313,6.668768
5,1,15,0.193034,2.944696
6,3,9,1.954731,4.158932
7,8,9,6.954731,9.158932
8,2,23,0.064148,4.762524
9,4,12,1.324873,5.110992


In [38]:
%%time
targetObjects = spark.sql("SELECT summary_table.obj_id, meas_value FROM  meas_table "
                          "JOIN summary_table ON (summary_table.obj_id = meas_table.obj_id) "
                          "WHERE summary_table.min_val > 10000")
print(targetObjects.count())

1353020
CPU times: user 20 ms, sys: 10 ms, total: 30 ms
Wall time: 4.38 s


In [44]:
spark.sql("SELECT summary_table.obj_id, meas_value FROM  meas_table "
                          "JOIN summary_table ON (summary_table.obj_id = meas_table.obj_id) "
                          "WHERE summary_table.min_val > 10000").explain()

== Physical Plan ==
*Project [obj_id#443L, meas_value#272]
+- *SortMergeJoin [obj_id#271L], [obj_id#443L], Inner
   :- *Sort [obj_id#271L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(obj_id#271L, 200)
   :     +- *Filter isnotnull(obj_id#271L)
   :        +- Scan ExistingRDD[obj_id#271L,meas_value#272]
   +- *Sort [obj_id#443L ASC NULLS FIRST], false, 0
      +- *Project [obj_id#443L]
         +- *Filter (isnotnull(min_val#287) && (min_val#287 > 10000.0))
            +- *HashAggregate(keys=[obj_id#443L], functions=[min(meas_value#444)])
               +- Exchange hashpartitioning(obj_id#443L, 200)
                  +- *HashAggregate(keys=[obj_id#443L], functions=[partial_min(meas_value#444)])
                     +- *Filter isnotnull(obj_id#443L)
                        +- Scan ExistingRDD[obj_id#443L,meas_value#444]


In [41]:
%%time
y = targetObjects.groupBy("obj_id").mean("meas_value").collect()

CPU times: user 370 ms, sys: 20 ms, total: 390 ms
Wall time: 4.76 s


In [42]:
len(y)

89998