In [1]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.datastax.spark:spark-cassandra-connector_2.11:2.0.1 pyspark-shell'

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .appName("workshop-analytics") \
    .config("spark.master", "spark://sparkmaster:7077")\
    .config("spark.cassandra.connection.host", "node1")\
    .getOrCreate()

In [3]:
ct = spark.read\
.format("org.apache.spark.sql.cassandra")\
.options(table="generation", keyspace="energydata")\
.load()

In [4]:
ct.explain()

== Physical Plan ==
*Scan org.apache.spark.sql.cassandra.CassandraSourceRelation@10b23498 [region#0,type#1,ts#2,value#3] ReadSchema: struct<region:string,type:string,ts:timestamp,value:double>


In [5]:
# ctpd = ct.filter("ts < cast('2012-01-11' as timestamp)")
#ctpd.explain()

Grouped Aggregates

In [84]:
from pyspark.sql.functions import *
ct_agg = ct \
    .withColumn('year', year(ct.ts)) \
    .withColumn('month', month(ct.ts)) \
    .filter("type == 'solar' AND region == 'DE'") \
    .groupBy('type', 'region', 'year', 'month') \
    .agg( \
        max("value").alias("max_generation_MW"),
        sum(col("value")/(4*10**3)).alias("sum_generation_GWh") # divide by 4*10^3 because we have 15 min MW values
        ) \
    .withColumn('sum_generation_GWh', round('sum_generation_GWh', 0)) \
    .sort(desc('sum_generation_GWh'))

In [37]:
ct_agg.explain()

== Physical Plan ==
*Sort [sum_generation#214 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(sum_generation#214 DESC NULLS LAST, 200)
   +- *HashAggregate(keys=[type#1, region#0, year#186], functions=[avg(value#3), max(value#3), sum(value#3)])
      +- Exchange hashpartitioning(type#1, region#0, year#186, 200)
         +- *HashAggregate(keys=[type#1, region#0, year#186], functions=[partial_avg(value#3), partial_max(value#3), partial_sum(value#3)])
            +- *Project [region#0, type#1, value#3, year(cast(ts#2 as date)) AS year#186]
               +- *Filter (isnotnull(type#1) && isnotnull(region#0))
                  +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation@10b23498 [region#0,type#1,value#3,ts#2] PushedFilters: [IsNotNull(type), IsNotNull(region), *EqualTo(type,solar), *EqualTo(region,DE)], ReadSchema: struct<region:string,type:string,value:double,year:int>


In [85]:
#ct_agg.explain()
%time ct_agg.show()

+-----+------+----+-----+-----------------+------------------+
| type|region|year|month|max_generation_MW|sum_generation_GWh|
+-----+------+----+-----+-----------------+------------------+
|solar|    DE|2013|    7|          23998.0|            5129.0|
|solar|    DE|2016|    7|          25688.0|            4943.0|
|solar|    DE|2015|    7|          24731.0|            4918.0|
|solar|    DE|2014|    6|          24244.0|            4834.0|
|solar|    DE|2016|    6|          26201.0|            4767.0|
|solar|    DE|2016|    8|          25371.0|            4720.0|
|solar|    DE|2016|    5|          26252.0|            4717.0|
|solar|    DE|2015|    8|          24429.0|            4613.0|
|solar|    DE|2015|    6|          24847.0|            4553.0|
|solar|    DE|2015|    4|          25928.0|            4435.0|
|solar|    DE|2014|    7|          23624.0|            4417.0|
|solar|    DE|2015|    5|          22453.0|            4412.0|
|solar|    DE|2013|    6|          23203.0|            

In [11]:
ct_station = spark.read\
.format("org.apache.spark.sql.cassandra")\
.options(table="weather_station", keyspace="energydata")\
.load()

ct_sensor = spark.read\
.format("org.apache.spark.sql.cassandra")\
.options(table="weather_sensor", keyspace="energydata")\
.load()

In [30]:
station_subset = ct_station\
.filter("lat < 50 and lon > 10") 

sensor_subset = ct_sensor.filter("sensor=='h2'")

In [31]:
sensor_subset.count()

2248704

In [36]:
station_subset.count()

40

Joining with the Datasets API is limited, joins are not pushed down

You would need to go with the RDD API (Scala only) using rdd.joinWithCassandraTable()

In [35]:
from pyspark.sql.functions import *

%time station_subset.join(sensor_subset, station_subset.id == sensor_subset.id).count()

CPU times: user 12 ms, sys: 4 ms, total: 16 ms
Wall time: 44.7 s


351360

In [38]:
station_subset.join(sensor_subset, station_subset.id == sensor_subset.id).explain()

== Physical Plan ==
*SortMergeJoin [id#247], [id#254], Inner
:- *Sort [id#247 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(id#247, 200)
:     +- *Filter ((((isnotnull(lon#249) && isnotnull(lat#248)) && isnotnull(id#247)) && (lat#248 < 50.0)) && (lon#249 > 10.0))
:        +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation@7a27bb71 [id#247,lat#248,lon#249] PushedFilters: [IsNotNull(lon), IsNotNull(lat), IsNotNull(id), LessThan(lat,50.0), GreaterThan(lon,10.0)], ReadSchema: struct<id:string,lat:double,lon:double>
+- *Sort [id#254 ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#254, 200)
      +- *Filter ((isnotnull(sensor#255) && (sensor#255 = h2)) && isnotnull(id#254))
         +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation@d32519b [id#254,sensor#255,ts#256,value#257] PushedFilters: [IsNotNull(sensor), EqualTo(sensor,h2), IsNotNull(id)], ReadSchema: struct<id:string,sensor:string,ts:timestamp,value:double>


In [65]:
(1.11549526E8/4)

27887381.5