### Window Functions

In [68]:
import findspark
findspark.init('/home/rich/spark/spark-2.4.3-bin-hadoop2.7')
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import Row
import pandas as pd

In [69]:
spark = SparkSession.builder.appName('SparkSQL').getOrCreate()

### Create a SQL table from a dataframe

A dataframe can be used to create a temporary table. A temporary table is one that will not exist after the session ends. Spark documentation also refers to this type of table as a SQL temporary view. 

In [70]:
# Load trainsched.txt
df = spark.read.csv("./data/trainsched.txt", header=True)

# Create temporary table called table1
df.createOrReplaceTempView('table1')

### Query a table

In [71]:
spark.sql('select * from table1')

DataFrame[train_id: string, station: string, time: string]

In [72]:
spark.sql("select * from table1 where station = 'San Jose'").show()

+--------+--------+-----+
|train_id| station| time|
+--------+--------+-----+
|     324|San Jose|9:05a|
|     217|San Jose|6:59a|
+--------+--------+-----+



In [73]:
spark.sql("describe table1").show()

+--------+---------+-------+
|col_name|data_type|comment|
+--------+---------+-------+
|train_id|   string|   null|
| station|   string|   null|
|    time|   string|   null|
+--------+---------+-------+



### Running sums using window function SQL

A window function is like an aggregate function, except that it gives an output for every row in the dataset instead of a single row per group.

You can do aggregation along with window functions. A running sum using a window function is simpler than what is required using joins. The query duration can also be much faster. 

In [74]:
df = spark.read.csv("./data/sched1.txt", header=True)
df.show()

+--------+-------------+-----+--------+
|train_id|      station| time|diff_min|
+--------+-------------+-----+--------+
|     217|       Gilroy|6:06a|     9.0|
|     217|   San Martin|6:15a|     6.0|
|     217|  Morgan Hill|6:21a|    15.0|
|     217| Blossom Hill|6:36a|     6.0|
|     217|      Capitol|6:42a|     8.0|
|     217|       Tamien|6:50a|     9.0|
|     217|     San Jose|6:59a|    null|
|     324|San Francisco|7:59a|     4.0|
|     324|  22nd Street|8:03a|    13.0|
|     324|     Millbrae|8:16a|     8.0|
|     324|    Hillsdale|8:24a|     7.0|
|     324| Redwood City|8:31a|     6.0|
|     324|    Palo Alto|8:37a|    28.0|
|     324|     San Jose|9:05a|    null|
+--------+-------------+-----+--------+



In [75]:
df.createOrReplaceTempView('schedule')

Here query that adds an additional column to the records in this dataset called running_total. The column running_total SUM()s the difference between station time given by the diff_min column

In [76]:
# Add col running_total that sums diff_min col in each group
query = """
SELECT train_id, station, time, diff_min,
SUM(diff_min) OVER (PARTITION BY train_id ORDER BY time) AS running_total
FROM schedule
"""

# Run the query and display the result
spark.sql(query).show(10)

+--------+-------------+-----+--------+-------------+
|train_id|      station| time|diff_min|running_total|
+--------+-------------+-----+--------+-------------+
|     217|       Gilroy|6:06a|     9.0|          9.0|
|     217|   San Martin|6:15a|     6.0|         15.0|
|     217|  Morgan Hill|6:21a|    15.0|         30.0|
|     217| Blossom Hill|6:36a|     6.0|         36.0|
|     217|      Capitol|6:42a|     8.0|         44.0|
|     217|       Tamien|6:50a|     9.0|         53.0|
|     217|     San Jose|6:59a|    null|         53.0|
|     324|San Francisco|7:59a|     4.0|          4.0|
|     324|  22nd Street|8:03a|    13.0|         17.0|
|     324|     Millbrae|8:16a|     8.0|         25.0|
+--------+-------------+-----+--------+-------------+
only showing top 10 rows



it takes 53 minutes to go from Gilroy to San Jose by train 217

### add time to next train column using window function and add row column

In [77]:
query = """
SELECT 
ROW_NUMBER() OVER (ORDER BY time) AS row,
train_id, 
station, 
time, 
LEAD(time,1) OVER (PARTITION BY train_id ORDER BY time) AS time_next 
FROM schedule
"""
spark.sql(query).show(5)

+---+--------+------------+-----+---------+
|row|train_id|     station| time|time_next|
+---+--------+------------+-----+---------+
|  1|     217|      Gilroy|6:06a|    6:15a|
|  2|     217|  San Martin|6:15a|    6:21a|
|  3|     217| Morgan Hill|6:21a|    6:36a|
|  4|     217|Blossom Hill|6:36a|    6:42a|
|  5|     217|     Capitol|6:42a|    6:50a|
+---+--------+------------+-----+---------+
only showing top 5 rows



### Comparing dot notation with SQL queries

In [78]:
# these two commands give the same result
spark.sql('SELECT train_id, MIN(time) AS start FROM schedule GROUP BY train_id').show()
df.groupBy('train_id').agg({'time':'min'}).withColumnRenamed('min(time)', 'start').show()

# Print the second column of the result
spark.sql('SELECT train_id, MIN(time), MAX(time) FROM schedule GROUP BY train_id').show()
result = df.groupBy('train_id').agg({'time':'min', 'time':'max'})
result.show()
print(result.columns[1])

+--------+-----+
|train_id|start|
+--------+-----+
|     217|6:06a|
|     324|7:59a|
+--------+-----+

+--------+-----+
|train_id|start|
+--------+-----+
|     217|6:06a|
|     324|7:59a|
+--------+-----+

+--------+---------+---------+
|train_id|min(time)|max(time)|
+--------+---------+---------+
|     217|    6:06a|    6:59a|
|     324|    7:59a|    9:05a|
+--------+---------+---------+

+--------+---------+
|train_id|max(time)|
+--------+---------+
|     217|    6:59a|
|     324|    9:05a|
+--------+---------+

max(time)


think aggregation is simpler to use in sql

In [79]:
#calculate first and last times for each train line

from pyspark.sql.functions import min, max, col
expr = [min(col("time")).alias('start'), max(col("time")).alias('end')]
dot_df = df.groupBy("train_id").agg(*expr)
dot_df.show()

+--------+-----+-----+
|train_id|start|  end|
+--------+-----+-----+
|     217|6:06a|6:59a|
|     324|7:59a|9:05a|
+--------+-----+-----+



In [80]:
#SQL query giving a result identical to dot_df
query = "SELECT train_id, MIN(time) AS start, MAX(time) AS end FROM schedule GROUP BY train_id"
sql_df = spark.sql(query)
sql_df.show()

+--------+-----+-----+
|train_id|start|  end|
+--------+-----+-----+
|     217|6:06a|6:59a|
|     324|7:59a|9:05a|
+--------+-----+-----+



In [81]:
#window function sql

df = spark.sql("""
SELECT *, 
LEAD(time,1) OVER(PARTITION BY train_id ORDER BY time) AS time_next 
FROM schedule
""")
df.show(10)

+--------+-------------+-----+--------+---------+
|train_id|      station| time|diff_min|time_next|
+--------+-------------+-----+--------+---------+
|     217|       Gilroy|6:06a|     9.0|    6:15a|
|     217|   San Martin|6:15a|     6.0|    6:21a|
|     217|  Morgan Hill|6:21a|    15.0|    6:36a|
|     217| Blossom Hill|6:36a|     6.0|    6:42a|
|     217|      Capitol|6:42a|     8.0|    6:50a|
|     217|       Tamien|6:50a|     9.0|    6:59a|
|     217|     San Jose|6:59a|    null|     null|
|     324|San Francisco|7:59a|     4.0|    8:03a|
|     324|  22nd Street|8:03a|    13.0|    8:16a|
|     324|     Millbrae|8:16a|     8.0|    8:24a|
+--------+-------------+-----+--------+---------+
only showing top 10 rows



In [82]:
from pyspark.sql import Window
from pyspark.sql.functions import row_number,lead,unix_timestamp

#same as above window function using dot notation
dot_df = df.withColumn('time_next', lead('time', 1)
        .over(Window.partitionBy('train_id')
        .orderBy('time'))).show(10)

+--------+-------------+-----+--------+---------+
|train_id|      station| time|diff_min|time_next|
+--------+-------------+-----+--------+---------+
|     217|       Gilroy|6:06a|     9.0|    6:15a|
|     217|   San Martin|6:15a|     6.0|    6:21a|
|     217|  Morgan Hill|6:21a|    15.0|    6:36a|
|     217| Blossom Hill|6:36a|     6.0|    6:42a|
|     217|      Capitol|6:42a|     8.0|    6:50a|
|     217|       Tamien|6:50a|     9.0|    6:59a|
|     217|     San Jose|6:59a|    null|     null|
|     324|San Francisco|7:59a|     4.0|    8:03a|
|     324|  22nd Street|8:03a|    13.0|    8:16a|
|     324|     Millbrae|8:16a|     8.0|    8:24a|
+--------+-------------+-----+--------+---------+
only showing top 10 rows



In [None]:
#add a column to a train schedule so that each row contains the number of minutes 
#for the train to reach its next stop. 

#using dot notation
window = Window.partitionBy('train_id').orderBy('time')
dot_df = df.withColumn('diff_min', 
                    (unix_timestamp(lead('time', 1).over(window),'H:m') 
                     - unix_timestamp('time', 'H:m'))/60).show()

In [None]:
# Create a SQL query to obtain an identical result to dot_df 
query = """
SELECT *, 
(UNIX_TIMESTAMP(LEAD(time, 1) OVER (PARTITION BY train_id ORDER BY time),'H:m') 
 - UNIX_TIMESTAMP(time, 'H:m'))/60 AS diff_min 
FROM schedule 
"""
sql_df = spark.sql(query)
sql_df.show()

has made two diff min columns?